【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法

本文主要是介绍【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【ShuQiHere】

在深度学习的世界中,训练一个模型不仅需要时间,还需要大量的计算资源。比如,你已经花了几天时间训练一个模型,但突然间,电脑崩溃了,你的所有进度都丢失了。这种情况就像是在一场马拉松比赛的最后一公里摔倒,让人沮丧至极。那么,有没有什么方法可以避免这种悲剧呢?今天,我们就来聊聊如何通过保存和加载模型的权重来应对这些挑战,确保你在深度学习的旅程中不会白费功夫。

模型保存和加载的背景

训练一个深度学习模型就像建造一座摩天大楼。你需要从基础开始,一层层地搭建,最终完成一个复杂的系统。然而,建造过程中难免会遇到意外,比如断电、系统崩溃,甚至是代码错误。这些意外可能让你的努力前功尽弃。如果你不想每次意外发生后都从头开始,那么模型保存和加载就显得尤为重要。

在 TensorFlow 中,我们有两种主要的保存和加载方法:保存整个模型保存模型的权重。理解它们的区别和用法,就像学会了在建造摩天大楼时如何保存施工进度,确保即使遭遇突发事件,你的建筑工程也能顺利继续。

微调模型:从预训练到自定义任务

我们都知道,训练一个深度学习模型需要大量的数据和计算资源。幸运的是,深度学习社区里有很多预训练的模型,这些模型已经在大规模数据集上进行了训练。通过微调(Fine-Tuning),你可以利用这些预训练模型,在它们的基础上进行训练,快速适应新的任务。

场景:训练猫狗识别模型

比如,你想训练一个模型来区分猫和狗。如果从零开始训练,不仅费时费力,而且可能效果不佳。但如果你有一个在 ImageNet 上预训练的模型,就可以大大减少训练时间。你只需加载预训练的权重,并在猫狗数据集上微调模型,这样不仅节省了时间,还能获得更好的效果。

微调代码示例
import tensorflow as tf
from tensorflow.keras import layers, models# 创建基础模型结构
base_model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 假设有10个类别
])# 加载预训练权重
base_model.load_weights('pretrained_weights.h5')# 冻结部分层
for layer in base_model.layers[:-1]:layer.trainable = False# 编译并微调模型
base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 微调训练
base_model.fit(new_data, new_labels, epochs=10)

解释

  • 加载预训练权重:通过 load_weights,你可以将已有的知识应用到新的任务中,避免从零开始训练模型。
  • 冻结层:冻结部分层的目的是保留预训练模型中已经学到的通用特征,仅微调特定的几层以适应新任务。这就像是在一个已经建好的摩天大楼里装修几层,改造成你需要的样子,而不是重新建造整栋大楼。

通过这种方式,你可以在保持预训练模型中有用特征的同时,快速适应新的任务或数据集。对于刚入门的小白来说,这是一种高效且实用的策略。

应对训练中断:如何保存和恢复模型

在训练模型的过程中,意外总是难以避免。系统崩溃、断电、内存不足等问题可能随时出现,这就像是在建造摩天大楼时,突然遇到大风暴,导致施工中断。那么,如何在中断后快速恢复呢?定期保存模型的权重是一个明智的选择。

保存权重

通过 TensorFlow 的 ModelCheckpoint 回调函数,你可以定期保存模型的权重,确保即使训练中断,也能从上次保存的进度继续。

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 设置检查点保存路径
checkpoint_path = "training_checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个 ModelCheckpoint 回调函数
cp_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1,save_freq=600)  # 每处理 600 个输入样本保存一次# 训练模型,同时保存检查点
model.fit(train_data, train_labels, epochs=10, callbacks=[cp_callback])

解释

  • 定期保存ModelCheckpoint 可以帮助你定期保存模型的权重,就像是在建造摩天大楼时,每隔一段时间都拍张照片保存进度。这样,即使突然下雨或者停电,你也可以在天气恢复后继续施工。
恢复训练

当训练中断时,你可以从最新的检查点恢复训练,而不必从头开始。这不仅节省了时间,也让你不会因为突发事件而感到沮丧。

# 查找最新的检查点
latest = tf.train.latest_checkpoint(checkpoint_dir)# 如果有检查点,加载权重
if latest:model.load_weights(latest)print(f"Loaded weights from checkpoint: {latest}")
else:print("No checkpoint found. Starting from scratch.")# 继续训练
model.fit(train_data, train_labels, epochs=10, initial_epoch=int(latest.split('-')[-1].split('.')[0]) if latest else 0,callbacks=[cp_callback])

解释

  • 从检查点恢复:通过 tf.train.latest_checkpoint 函数,你可以找到最近的检查点,并通过 load_weights 恢复模型的状态,从中断的地方继续训练。这就像是你在大风暴后回到工地,拿出之前保存的施工进度照片,继续建造摩天大楼。
保存与加载完整模型:从开发到生产的无缝衔接

在训练完成后,我们不仅需要保存权重,还可能需要保存整个模型。这样做的目的是为了方便模型的部署和迁移。通过保存整个模型,你可以在不同的环境中无缝地加载和使用它。

保存完整模型
# 保存整个模型
model.save('my_full_model.h5')
加载完整模型
from tensorflow.keras.models import load_model# 加载完整模型
model = load_model('my_full_model.h5')# 继续训练或进行推理
predictions = model.predict(test_data)

解释

  • 保存整个模型:使用 save_model 可以将模型的结构、权重以及优化器状态一并保存,确保模型在不同环境中的一致性。
  • 加载整个模型load_model 允许你在任何支持 TensorFlow/Keras 的环境中重新加载并使用这个模型,无论是继续训练还是部署到生产环境。
超大型语言模型的微调

当我们面对像 LLaMA3 这样超大型的语言模型时,微调过程就更加复杂。由于这些模型的参数量巨大,通常我们不会直接微调所有参数,而是使用**参数高效微调(PEFT)**技术,如 LoRA 或 Adapter。这些技术允许我们通过调整少量参数来适应新的任务,既降低了计算资源的需求,又保证了模型的性能。

LoRA 微调示例
from peft import LoraConfig, get_peft_model# 配置LoRA
lora_config = LoraConfig(r=4,  # 低秩矩阵的秩lora_alpha=16,  # LoRA的缩放因子target_modules=["q_proj", "v_proj"],  # 在注意力层中应用LoRAlora_dropout=0.1,  # LoRA的dropout率
)# 将LoRA应用到模型
model = get_peft_model(model, lora_config)# 开始训练
trainer.train()

解释

  • LoRA 技术:LoRA 通过在模型的特定矩阵上应用低秩分解,实现参数高效微调。这就像是在摩天大楼的某些关键部位加固,从而确保建筑在更高负

载下依然稳固。

最佳实践与总结
  1. 定期保存:在训练过程中,使用 ModelCheckpoint 定期保存权重,防止因中断而丢失进度。
  2. 选择合适的保存方法:在开发过程中,可以使用 save_weights 进行频繁保存;在部署前,使用 save_model 保存整个模型。
  3. 恢复训练:通过 load_weights,你可以轻松恢复训练进度,并在中断后继续模型的优化。
  4. 参数高效微调:对于超大型模型,使用 LoRA 或 Adapter 等技术进行参数高效微调,可以大幅降低资源需求。

通过掌握这些技术,你不仅可以确保模型训练的稳健性,还能有效应对实际开发中的各种挑战。无论是微调预训练模型,还是处理不可预知的中断,load_weightssave_model 都将成为你深度学习开发中的利器。最佳实践与总结


这篇关于【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1121003

相关文章

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法

《JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法》:本文主要介绍JavaScript中比较两个数组是否有相同元素(交集)的三种常用方法,每种方法结合实例代码给大家介绍的非常... 目录引言:为什么"相等"判断如此重要?方法1:使用some()+includes()(适合小数组)方法2

504 Gateway Timeout网关超时的根源及完美解决方法

《504GatewayTimeout网关超时的根源及完美解决方法》在日常开发和运维过程中,504GatewayTimeout错误是常见的网络问题之一,尤其是在使用反向代理(如Nginx)或... 目录引言为什么会出现 504 错误?1. 探索 504 Gateway Timeout 错误的根源 1.1 后端

MySQL 表空却 ibd 文件过大的问题及解决方法

《MySQL表空却ibd文件过大的问题及解决方法》本文给大家介绍MySQL表空却ibd文件过大的问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考... 目录一、问题背景:表空却 “吃满” 磁盘的怪事二、问题复现:一步步编程还原异常场景1. 准备测试源表与数据

python 线程池顺序执行的方法实现

《python线程池顺序执行的方法实现》在Python中,线程池默认是并发执行任务的,但若需要实现任务的顺序执行,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋... 目录方案一:强制单线程(伪顺序执行)方案二:按提交顺序获取结果方案三:任务间依赖控制方案四:队列顺序消

SpringBoot通过main方法启动web项目实践

《SpringBoot通过main方法启动web项目实践》SpringBoot通过SpringApplication.run()启动Web项目,自动推断应用类型,加载初始化器与监听器,配置Spring... 目录1. 启动入口:SpringApplication.run()2. SpringApplicat

从基础到进阶详解Python条件判断的实用指南

《从基础到进阶详解Python条件判断的实用指南》本文将通过15个实战案例,带你大家掌握条件判断的核心技巧,并从基础语法到高级应用一网打尽,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一... 目录​引言:条件判断为何如此重要一、基础语法:三行代码构建决策系统二、多条件分支:elif的魔法三、

使用Java读取本地文件并转换为MultipartFile对象的方法

《使用Java读取本地文件并转换为MultipartFile对象的方法》在许多JavaWeb应用中,我们经常会遇到将本地文件上传至服务器或其他系统的需求,在这种场景下,MultipartFile对象非... 目录1. 基本需求2. 自定义 MultipartFile 类3. 实现代码4. 代码解析5. 自定

Python文本相似度计算的方法大全

《Python文本相似度计算的方法大全》文本相似度是指两个文本在内容、结构或语义上的相近程度,通常用0到1之间的数值表示,0表示完全不同,1表示完全相同,本文将深入解析多种文本相似度计算方法,帮助您选... 目录前言什么是文本相似度?1. Levenshtein 距离(编辑距离)核心公式实现示例2. Jac

C#高效实现Word文档内容查找与替换的6种方法

《C#高效实现Word文档内容查找与替换的6种方法》在日常文档处理工作中,尤其是面对大型Word文档时,手动查找、替换文本往往既耗时又容易出错,本文整理了C#查找与替换Word内容的6种方法,大家可以... 目录环境准备方法一:查找文本并替换为新文本方法二:使用正则表达式查找并替换文本方法三:将文本替换为图