【深度学习实战(30)】训练框架之使用tensorboard记录loss

2024-05-02 06:36

本文主要是介绍【深度学习实战(30)】训练框架之使用tensorboard记录loss,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、 安装Tensorboard库

pip install tensorflow 
pip install tensorboardx

二、LossHistory类实现过程

1. init构造函数
传入参数log保存路径,模型,模型输入尺寸

def __init__(self, log_dir, model, input_shape):

实例化SummaryWriter对象

self.writer     = SummaryWriter(self.log_dir)
  1. tensorboard.SummaryWriter.add_graph记录model
 try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:pass

训练结束后查看保存的模型
在这里插入图片描述
在这里插入图片描述

  1. 记录loss
self.losses.append(loss)
self.val_loss.append(val_loss)
  1. txt文档记录loss
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")

训练完后,查看epoch_loss.txtepoch_val_loss.txt
在这里插入图片描述
在这里插入图片描述

  1. tensorboard.SummaryWriter.add_scalar记录loss
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

训练结束后查看保存的loss
在这里插入图片描述
在这里插入图片描述

  1. pyplot绘制loss曲线图
def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")

查看loss曲线图
在这里插入图片描述

三、LossHistory类完整代码

import os
import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signalclass LossHistory():def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:# --------- 1. tensorboard.SummaryWriter.add_graph记录model -------------#dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1],use_strict_trace=False)self.writer.add_graph(model, dummy_input)except:passdef append_loss(self, epoch, loss, val_loss):if not os.path.exists(self.log_dir):os.makedirs(self.log_dir)# --------- 2. 保存loss -------------#self.losses.append(loss)self.val_loss.append(val_loss)# --------- 3. txt记录loss -------------#with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")# --------- 4. tensorboard.SummaryWriter.add_scalar记录loss -------------#self.writer.add_scalar('loss', loss, epoch)self.writer.add_scalar('val_loss', val_loss, epoch)self.loss_plot()# --------- 5. pyplot绘制loss曲线图 -------------#def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")

四、LossHistory类使用框架

import LossHistory# 构造loss_history类
loss_history = LossHistory(log_dir, model, (input_W, input_H))# 训练一轮,将训练,验证损失传进loss_history类
loss_history.append_loss(epoch + 1, train_loss / epoch_step, val_loss / epoch_step_val)# 根据loss_history类中保存的loss来保存最佳模型
if len(loss_history_.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history_.val_loss):best_ckpt = {'epoch': epoch, 'model': save_state_dict, 'optimizer': optimizer.state_dict(), 'loss':val_loss}torch.save(best_ckpt, os.path.join(save_dir, name_best_weights))# 训练一轮结束后,关闭Tensorboard.SummryWriter
loss_history.writer.close()

这篇关于【深度学习实战(30)】训练框架之使用tensorboard记录loss的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/953596

相关文章

使用animation.css库快速实现CSS3旋转动画效果

《使用animation.css库快速实现CSS3旋转动画效果》随着Web技术的不断发展,动画效果已经成为了网页设计中不可或缺的一部分,本文将深入探讨animation.css的工作原理,如何使用以及... 目录1. css3动画技术简介2. animation.css库介绍2.1 animation.cs

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

Python文件操作与IO流的使用方式

《Python文件操作与IO流的使用方式》:本文主要介绍Python文件操作与IO流的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、python文件操作基础1. 打开文件2. 关闭文件二、文件读写操作1.www.chinasem.cn 读取文件2. 写

SpringBoot实现接口数据加解密的三种实战方案

《SpringBoot实现接口数据加解密的三种实战方案》在金融支付、用户隐私信息传输等场景中,接口数据若以明文传输,极易被中间人攻击窃取,SpringBoot提供了多种优雅的加解密实现方案,本文将从原... 目录一、为什么需要接口数据加解密?二、核心加解密算法选择1. 对称加密(AES)2. 非对称加密(R

PyQt6中QMainWindow组件的使用详解

《PyQt6中QMainWindow组件的使用详解》QMainWindow是PyQt6中用于构建桌面应用程序的基础组件,本文主要介绍了PyQt6中QMainWindow组件的使用,具有一定的参考价值,... 目录1. QMainWindow 组php件概述2. 使用 QMainWindow3. QMainW

使用Python自动化生成PPT并结合LLM生成内容的代码解析

《使用Python自动化生成PPT并结合LLM生成内容的代码解析》PowerPoint是常用的文档工具,但手动设计和排版耗时耗力,本文将展示如何通过Python自动化提取PPT样式并生成新PPT,同时... 目录核心代码解析1. 提取 PPT 样式到 jsON关键步骤:代码片段:2. 应用 JSON 样式到

java变量内存中存储的使用方式

《java变量内存中存储的使用方式》:本文主要介绍java变量内存中存储的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍2、变量的定义3、 变量的类型4、 变量的作用域5、 内存中的存储方式总结1、介绍在 Java 中,变量是用于存储程序中数据

关于Mybatis和JDBC的使用及区别

《关于Mybatis和JDBC的使用及区别》:本文主要介绍关于Mybatis和JDBC的使用及区别,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、JDBC1.1、流程1.2、优缺点2、MyBATis2.1、执行流程2.2、使用2.3、实现方式1、XML配置文件

Maven 插件配置分层架构深度解析

《Maven插件配置分层架构深度解析》:本文主要介绍Maven插件配置分层架构深度解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录Maven 插件配置分层架构深度解析引言:当构建逻辑遇上复杂配置第一章 Maven插件配置的三重境界1.1 插件配置的拓扑

macOS Sequoia 15.5 发布: 改进邮件和屏幕使用时间功能

《macOSSequoia15.5发布:改进邮件和屏幕使用时间功能》经过常规Beta测试后,新的macOSSequoia15.5现已公开发布,但重要的新功能将被保留到WWDC和... MACOS Sequoia 15.5 正式发布!本次更新为 Mac 用户带来了一系列功能强化、错误修复和安全性提升,进一步增