【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法

本文主要是介绍【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【ShuQiHere】

在深度学习的世界中,训练一个模型不仅需要时间,还需要大量的计算资源。比如,你已经花了几天时间训练一个模型,但突然间,电脑崩溃了,你的所有进度都丢失了。这种情况就像是在一场马拉松比赛的最后一公里摔倒,让人沮丧至极。那么,有没有什么方法可以避免这种悲剧呢?今天,我们就来聊聊如何通过保存和加载模型的权重来应对这些挑战,确保你在深度学习的旅程中不会白费功夫。

模型保存和加载的背景

训练一个深度学习模型就像建造一座摩天大楼。你需要从基础开始,一层层地搭建,最终完成一个复杂的系统。然而,建造过程中难免会遇到意外,比如断电、系统崩溃,甚至是代码错误。这些意外可能让你的努力前功尽弃。如果你不想每次意外发生后都从头开始,那么模型保存和加载就显得尤为重要。

在 TensorFlow 中,我们有两种主要的保存和加载方法:保存整个模型保存模型的权重。理解它们的区别和用法,就像学会了在建造摩天大楼时如何保存施工进度,确保即使遭遇突发事件,你的建筑工程也能顺利继续。

微调模型:从预训练到自定义任务

我们都知道,训练一个深度学习模型需要大量的数据和计算资源。幸运的是,深度学习社区里有很多预训练的模型,这些模型已经在大规模数据集上进行了训练。通过微调(Fine-Tuning),你可以利用这些预训练模型,在它们的基础上进行训练,快速适应新的任务。

场景:训练猫狗识别模型

比如,你想训练一个模型来区分猫和狗。如果从零开始训练,不仅费时费力,而且可能效果不佳。但如果你有一个在 ImageNet 上预训练的模型,就可以大大减少训练时间。你只需加载预训练的权重,并在猫狗数据集上微调模型,这样不仅节省了时间,还能获得更好的效果。

微调代码示例
import tensorflow as tf
from tensorflow.keras import layers, models# 创建基础模型结构
base_model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 假设有10个类别
])# 加载预训练权重
base_model.load_weights('pretrained_weights.h5')# 冻结部分层
for layer in base_model.layers[:-1]:layer.trainable = False# 编译并微调模型
base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 微调训练
base_model.fit(new_data, new_labels, epochs=10)

解释

  • 加载预训练权重:通过 load_weights,你可以将已有的知识应用到新的任务中,避免从零开始训练模型。
  • 冻结层:冻结部分层的目的是保留预训练模型中已经学到的通用特征,仅微调特定的几层以适应新任务。这就像是在一个已经建好的摩天大楼里装修几层,改造成你需要的样子,而不是重新建造整栋大楼。

通过这种方式,你可以在保持预训练模型中有用特征的同时,快速适应新的任务或数据集。对于刚入门的小白来说,这是一种高效且实用的策略。

应对训练中断:如何保存和恢复模型

在训练模型的过程中,意外总是难以避免。系统崩溃、断电、内存不足等问题可能随时出现,这就像是在建造摩天大楼时,突然遇到大风暴,导致施工中断。那么,如何在中断后快速恢复呢?定期保存模型的权重是一个明智的选择。

保存权重

通过 TensorFlow 的 ModelCheckpoint 回调函数,你可以定期保存模型的权重,确保即使训练中断,也能从上次保存的进度继续。

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 设置检查点保存路径
checkpoint_path = "training_checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个 ModelCheckpoint 回调函数
cp_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1,save_freq=600)  # 每处理 600 个输入样本保存一次# 训练模型,同时保存检查点
model.fit(train_data, train_labels, epochs=10, callbacks=[cp_callback])

解释

  • 定期保存ModelCheckpoint 可以帮助你定期保存模型的权重,就像是在建造摩天大楼时,每隔一段时间都拍张照片保存进度。这样,即使突然下雨或者停电,你也可以在天气恢复后继续施工。
恢复训练

当训练中断时,你可以从最新的检查点恢复训练,而不必从头开始。这不仅节省了时间,也让你不会因为突发事件而感到沮丧。

# 查找最新的检查点
latest = tf.train.latest_checkpoint(checkpoint_dir)# 如果有检查点,加载权重
if latest:model.load_weights(latest)print(f"Loaded weights from checkpoint: {latest}")
else:print("No checkpoint found. Starting from scratch.")# 继续训练
model.fit(train_data, train_labels, epochs=10, initial_epoch=int(latest.split('-')[-1].split('.')[0]) if latest else 0,callbacks=[cp_callback])

解释

  • 从检查点恢复:通过 tf.train.latest_checkpoint 函数,你可以找到最近的检查点,并通过 load_weights 恢复模型的状态,从中断的地方继续训练。这就像是你在大风暴后回到工地,拿出之前保存的施工进度照片,继续建造摩天大楼。
保存与加载完整模型:从开发到生产的无缝衔接

在训练完成后,我们不仅需要保存权重,还可能需要保存整个模型。这样做的目的是为了方便模型的部署和迁移。通过保存整个模型,你可以在不同的环境中无缝地加载和使用它。

保存完整模型
# 保存整个模型
model.save('my_full_model.h5')
加载完整模型
from tensorflow.keras.models import load_model# 加载完整模型
model = load_model('my_full_model.h5')# 继续训练或进行推理
predictions = model.predict(test_data)

解释

  • 保存整个模型:使用 save_model 可以将模型的结构、权重以及优化器状态一并保存,确保模型在不同环境中的一致性。
  • 加载整个模型load_model 允许你在任何支持 TensorFlow/Keras 的环境中重新加载并使用这个模型,无论是继续训练还是部署到生产环境。
超大型语言模型的微调

当我们面对像 LLaMA3 这样超大型的语言模型时,微调过程就更加复杂。由于这些模型的参数量巨大,通常我们不会直接微调所有参数,而是使用**参数高效微调(PEFT)**技术,如 LoRA 或 Adapter。这些技术允许我们通过调整少量参数来适应新的任务,既降低了计算资源的需求,又保证了模型的性能。

LoRA 微调示例
from peft import LoraConfig, get_peft_model# 配置LoRA
lora_config = LoraConfig(r=4,  # 低秩矩阵的秩lora_alpha=16,  # LoRA的缩放因子target_modules=["q_proj", "v_proj"],  # 在注意力层中应用LoRAlora_dropout=0.1,  # LoRA的dropout率
)# 将LoRA应用到模型
model = get_peft_model(model, lora_config)# 开始训练
trainer.train()

解释

  • LoRA 技术:LoRA 通过在模型的特定矩阵上应用低秩分解,实现参数高效微调。这就像是在摩天大楼的某些关键部位加固,从而确保建筑在更高负

载下依然稳固。

最佳实践与总结
  1. 定期保存:在训练过程中,使用 ModelCheckpoint 定期保存权重,防止因中断而丢失进度。
  2. 选择合适的保存方法:在开发过程中,可以使用 save_weights 进行频繁保存;在部署前,使用 save_model 保存整个模型。
  3. 恢复训练:通过 load_weights,你可以轻松恢复训练进度,并在中断后继续模型的优化。
  4. 参数高效微调:对于超大型模型,使用 LoRA 或 Adapter 等技术进行参数高效微调,可以大幅降低资源需求。

通过掌握这些技术,你不仅可以确保模型训练的稳健性,还能有效应对实际开发中的各种挑战。无论是微调预训练模型,还是处理不可预知的中断,load_weightssave_model 都将成为你深度学习开发中的利器。最佳实践与总结


这篇关于【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1121003

相关文章

Python中反转字符串的常见方法小结

《Python中反转字符串的常见方法小结》在Python中,字符串对象没有内置的反转方法,然而,在实际开发中,我们经常会遇到需要反转字符串的场景,比如处理回文字符串、文本加密等,因此,掌握如何在Pyt... 目录python中反转字符串的方法技术背景实现步骤1. 使用切片2. 使用 reversed() 函

Python中将嵌套列表扁平化的多种实现方法

《Python中将嵌套列表扁平化的多种实现方法》在Python编程中,我们常常会遇到需要将嵌套列表(即列表中包含列表)转换为一个一维的扁平列表的需求,本文将给大家介绍了多种实现这一目标的方法,需要的朋... 目录python中将嵌套列表扁平化的方法技术背景实现步骤1. 使用嵌套列表推导式2. 使用itert

Python使用pip工具实现包自动更新的多种方法

《Python使用pip工具实现包自动更新的多种方法》本文深入探讨了使用Python的pip工具实现包自动更新的各种方法和技术,我们将从基础概念开始,逐步介绍手动更新方法、自动化脚本编写、结合CI/C... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核

在Linux中改变echo输出颜色的实现方法

《在Linux中改变echo输出颜色的实现方法》在Linux系统的命令行环境下,为了使输出信息更加清晰、突出,便于用户快速识别和区分不同类型的信息,常常需要改变echo命令的输出颜色,所以本文给大家介... 目python录在linux中改变echo输出颜色的方法技术背景实现步骤使用ANSI转义码使用tpu

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

SQL Server配置管理器无法打开的四种解决方法

《SQLServer配置管理器无法打开的四种解决方法》本文总结了SQLServer配置管理器无法打开的四种解决方法,文中通过图文示例介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录方法一:桌面图标进入方法二:运行窗口进入检查版本号对照表php方法三:查找文件路径方法四:检查 S

MyBatis-Plus 中 nested() 与 and() 方法详解(最佳实践场景)

《MyBatis-Plus中nested()与and()方法详解(最佳实践场景)》在MyBatis-Plus的条件构造器中,nested()和and()都是用于构建复杂查询条件的关键方法,但... 目录MyBATis-Plus 中nested()与and()方法详解一、核心区别对比二、方法详解1.and()

golang中reflect包的常用方法

《golang中reflect包的常用方法》Go反射reflect包提供类型和值方法,用于获取类型信息、访问字段、调用方法等,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值... 目录reflect包方法总结类型 (Type) 方法值 (Value) 方法reflect包方法总结

C# 比较两个list 之间元素差异的常用方法

《C#比较两个list之间元素差异的常用方法》:本文主要介绍C#比较两个list之间元素差异,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. 使用Except方法2. 使用Except的逆操作3. 使用LINQ的Join,GroupJoin