【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法

本文主要是介绍【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【ShuQiHere】

在深度学习的世界中,训练一个模型不仅需要时间,还需要大量的计算资源。比如,你已经花了几天时间训练一个模型,但突然间,电脑崩溃了,你的所有进度都丢失了。这种情况就像是在一场马拉松比赛的最后一公里摔倒,让人沮丧至极。那么,有没有什么方法可以避免这种悲剧呢?今天,我们就来聊聊如何通过保存和加载模型的权重来应对这些挑战,确保你在深度学习的旅程中不会白费功夫。

模型保存和加载的背景

训练一个深度学习模型就像建造一座摩天大楼。你需要从基础开始,一层层地搭建,最终完成一个复杂的系统。然而,建造过程中难免会遇到意外,比如断电、系统崩溃,甚至是代码错误。这些意外可能让你的努力前功尽弃。如果你不想每次意外发生后都从头开始,那么模型保存和加载就显得尤为重要。

在 TensorFlow 中,我们有两种主要的保存和加载方法:保存整个模型保存模型的权重。理解它们的区别和用法,就像学会了在建造摩天大楼时如何保存施工进度,确保即使遭遇突发事件,你的建筑工程也能顺利继续。

微调模型:从预训练到自定义任务

我们都知道,训练一个深度学习模型需要大量的数据和计算资源。幸运的是,深度学习社区里有很多预训练的模型,这些模型已经在大规模数据集上进行了训练。通过微调(Fine-Tuning),你可以利用这些预训练模型,在它们的基础上进行训练,快速适应新的任务。

场景:训练猫狗识别模型

比如,你想训练一个模型来区分猫和狗。如果从零开始训练,不仅费时费力,而且可能效果不佳。但如果你有一个在 ImageNet 上预训练的模型,就可以大大减少训练时间。你只需加载预训练的权重,并在猫狗数据集上微调模型,这样不仅节省了时间,还能获得更好的效果。

微调代码示例
import tensorflow as tf
from tensorflow.keras import layers, models# 创建基础模型结构
base_model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64, 64, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(128, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 假设有10个类别
])# 加载预训练权重
base_model.load_weights('pretrained_weights.h5')# 冻结部分层
for layer in base_model.layers[:-1]:layer.trainable = False# 编译并微调模型
base_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 微调训练
base_model.fit(new_data, new_labels, epochs=10)

解释

  • 加载预训练权重:通过 load_weights,你可以将已有的知识应用到新的任务中,避免从零开始训练模型。
  • 冻结层:冻结部分层的目的是保留预训练模型中已经学到的通用特征,仅微调特定的几层以适应新任务。这就像是在一个已经建好的摩天大楼里装修几层,改造成你需要的样子,而不是重新建造整栋大楼。

通过这种方式,你可以在保持预训练模型中有用特征的同时,快速适应新的任务或数据集。对于刚入门的小白来说,这是一种高效且实用的策略。

应对训练中断:如何保存和恢复模型

在训练模型的过程中,意外总是难以避免。系统崩溃、断电、内存不足等问题可能随时出现,这就像是在建造摩天大楼时,突然遇到大风暴,导致施工中断。那么,如何在中断后快速恢复呢?定期保存模型的权重是一个明智的选择。

保存权重

通过 TensorFlow 的 ModelCheckpoint 回调函数,你可以定期保存模型的权重,确保即使训练中断,也能从上次保存的进度继续。

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 设置检查点保存路径
checkpoint_path = "training_checkpoints/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)# 创建一个 ModelCheckpoint 回调函数
cp_callback = ModelCheckpoint(filepath=checkpoint_path,save_weights_only=True,verbose=1,save_freq=600)  # 每处理 600 个输入样本保存一次# 训练模型,同时保存检查点
model.fit(train_data, train_labels, epochs=10, callbacks=[cp_callback])

解释

  • 定期保存ModelCheckpoint 可以帮助你定期保存模型的权重,就像是在建造摩天大楼时,每隔一段时间都拍张照片保存进度。这样,即使突然下雨或者停电,你也可以在天气恢复后继续施工。
恢复训练

当训练中断时,你可以从最新的检查点恢复训练,而不必从头开始。这不仅节省了时间,也让你不会因为突发事件而感到沮丧。

# 查找最新的检查点
latest = tf.train.latest_checkpoint(checkpoint_dir)# 如果有检查点,加载权重
if latest:model.load_weights(latest)print(f"Loaded weights from checkpoint: {latest}")
else:print("No checkpoint found. Starting from scratch.")# 继续训练
model.fit(train_data, train_labels, epochs=10, initial_epoch=int(latest.split('-')[-1].split('.')[0]) if latest else 0,callbacks=[cp_callback])

解释

  • 从检查点恢复:通过 tf.train.latest_checkpoint 函数,你可以找到最近的检查点,并通过 load_weights 恢复模型的状态,从中断的地方继续训练。这就像是你在大风暴后回到工地,拿出之前保存的施工进度照片,继续建造摩天大楼。
保存与加载完整模型:从开发到生产的无缝衔接

在训练完成后,我们不仅需要保存权重,还可能需要保存整个模型。这样做的目的是为了方便模型的部署和迁移。通过保存整个模型,你可以在不同的环境中无缝地加载和使用它。

保存完整模型
# 保存整个模型
model.save('my_full_model.h5')
加载完整模型
from tensorflow.keras.models import load_model# 加载完整模型
model = load_model('my_full_model.h5')# 继续训练或进行推理
predictions = model.predict(test_data)

解释

  • 保存整个模型:使用 save_model 可以将模型的结构、权重以及优化器状态一并保存,确保模型在不同环境中的一致性。
  • 加载整个模型load_model 允许你在任何支持 TensorFlow/Keras 的环境中重新加载并使用这个模型,无论是继续训练还是部署到生产环境。
超大型语言模型的微调

当我们面对像 LLaMA3 这样超大型的语言模型时,微调过程就更加复杂。由于这些模型的参数量巨大,通常我们不会直接微调所有参数,而是使用**参数高效微调(PEFT)**技术,如 LoRA 或 Adapter。这些技术允许我们通过调整少量参数来适应新的任务,既降低了计算资源的需求,又保证了模型的性能。

LoRA 微调示例
from peft import LoraConfig, get_peft_model# 配置LoRA
lora_config = LoraConfig(r=4,  # 低秩矩阵的秩lora_alpha=16,  # LoRA的缩放因子target_modules=["q_proj", "v_proj"],  # 在注意力层中应用LoRAlora_dropout=0.1,  # LoRA的dropout率
)# 将LoRA应用到模型
model = get_peft_model(model, lora_config)# 开始训练
trainer.train()

解释

  • LoRA 技术:LoRA 通过在模型的特定矩阵上应用低秩分解,实现参数高效微调。这就像是在摩天大楼的某些关键部位加固,从而确保建筑在更高负

载下依然稳固。

最佳实践与总结
  1. 定期保存:在训练过程中,使用 ModelCheckpoint 定期保存权重,防止因中断而丢失进度。
  2. 选择合适的保存方法:在开发过程中,可以使用 save_weights 进行频繁保存;在部署前,使用 save_model 保存整个模型。
  3. 恢复训练:通过 load_weights,你可以轻松恢复训练进度,并在中断后继续模型的优化。
  4. 参数高效微调:对于超大型模型,使用 LoRA 或 Adapter 等技术进行参数高效微调,可以大幅降低资源需求。

通过掌握这些技术,你不仅可以确保模型训练的稳健性,还能有效应对实际开发中的各种挑战。无论是微调预训练模型,还是处理不可预知的中断,load_weightssave_model 都将成为你深度学习开发中的利器。最佳实践与总结


这篇关于【ShuQiHere】微调与训练恢复:理解 `load_weights` 和 `save_model` 的实用方法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1121003

相关文章

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

SpringMVC 通过ajax 前后端数据交互的实现方法

《SpringMVC通过ajax前后端数据交互的实现方法》:本文主要介绍SpringMVC通过ajax前后端数据交互的实现方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价... 在前端的开发过程中,经常在html页面通过AJAX进行前后端数据的交互,SpringMVC的controll

Java中的工具类命名方法

《Java中的工具类命名方法》:本文主要介绍Java中的工具类究竟如何命名,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录Java中的工具类究竟如何命名?先来几个例子几种命名方式的比较到底如何命名 ?总结Java中的工具类究竟如何命名?先来几个例子JD

Spring Security自定义身份认证的实现方法

《SpringSecurity自定义身份认证的实现方法》:本文主要介绍SpringSecurity自定义身份认证的实现方法,下面对SpringSecurity的这三种自定义身份认证进行详细讲解,... 目录1.内存身份认证(1)创建配置类(2)验证内存身份认证2.JDBC身份认证(1)数据准备 (2)配置依

python获取网页表格的多种方法汇总

《python获取网页表格的多种方法汇总》我们在网页上看到很多的表格,如果要获取里面的数据或者转化成其他格式,就需要将表格获取下来并进行整理,在Python中,获取网页表格的方法有多种,下面就跟随小编... 目录1. 使用Pandas的read_html2. 使用BeautifulSoup和pandas3.

Spring 中的循环引用问题解决方法

《Spring中的循环引用问题解决方法》:本文主要介绍Spring中的循环引用问题解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录什么是循环引用?循环依赖三级缓存解决循环依赖二级缓存三级缓存本章来聊聊Spring 中的循环引用问题该如何解决。这里聊

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

Pandas统计每行数据中的空值的方法示例

《Pandas统计每行数据中的空值的方法示例》处理缺失数据(NaN值)是一个非常常见的问题,本文主要介绍了Pandas统计每行数据中的空值的方法示例,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是空值?为什么要统计空值?准备工作创建示例数据统计每行空值数量进一步分析www.chinasem.cn处