【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py

2024-03-17 23:52

本文主要是介绍【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件内容:CenterFusion/src/lib/model/model.py
文件作用:模型的创建、导入、保存

model.py 具体内容如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torchvision.models as models
import torch
import torch.nn as nn
import osfrom .networks.dla import DLASeg
from .networks.resdcn import PoseResDCN
from .networks.resnet import PoseResNet
from .networks.dlav0 import DLASegv0
from .networks.generic_network import GenericNetwork_network_factory = {'resdcn': PoseResDCN,'dla': DLASeg,'res': PoseResNet,'dlav0': DLASegv0,'generic': GenericNetwork
}def create_model(arch, head, head_conv, opt=None):num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0'''处理字符串 arch = dla_34 ,将下划线后半部分取出最后 num_layers = 34'''arch = arch[:arch.find('_')] if '_' in arch else arch'''将 arch = dla_34 中下划线前半部分取出最后 arch = 'dla''''model_class = _network_factory[arch]'''根据 arch = 'dla' 获取 _network_factory 中的值最后 model_class = DLASegDLASeg 类定义在 CenterFusion/src/lib/model/networks/dla.py 第 594 行'''model = model_class(num_layers, heads=head, head_convs=head_conv, opt=opt)'''配置模型'''return modeldef load_model(model, model_path, opt, optimizer=None):start_epoch = 0'''设定初始轮次 = 0'''checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))'''torch.load() 函数:用来加载 torch.save() 保存的模型文件'''state_dict_ = checkpoint['state_dict']'''获取 checkpoint 模型文件中的 state_dict 属性这个属性存放训练过程中需要学习的权重和偏执系数state_dict 作为 python 的字典对象将每一层的参数映射成 tensor 张量需要注意的是 torch.nn.Module 模块中的 state_dict 只包含卷积层和全连接层的参数'''state_dict = {}for k in state_dict_:if k.startswith('module') and not k.startswith('module_list'):state_dict[k[7:]] = state_dict_[k]else:state_dict[k] = state_dict_[k]'''startswith(str) 函数:检测字符串 str,检测到返回 True,否则返回 False这里只执行了 else 语句,相当于保存导入模型的网络参数'''model_state_dict = model.state_dict()'''浅拷贝 main.py 中创建的新模型 DLA 的网络参数'''for k in state_dict:'''遍历导入的模型中的每层网络参数'''if k in model_state_dict:'''判断新模型的网络参数中是否有导入的模型的参数是有的,因为导入的模型也是 DLA 模型'''if (state_dict[k].shape != model_state_dict[k].shape) or \(opt.reset_hm and k.startswith('hm') and (state_dict[k].shape[0] in [80, 1])):'''第一个条件为 True其余条件全部为 False'''if opt.reuse_hm:'''不执行'''print('Reusing parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))# todo: bug in next line: both sides of < are the sameif state_dict[k].shape[0] < state_dict[k].shape[0]:model_state_dict[k][:state_dict[k].shape[0]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:model_state_dict[k].shape[0]]state_dict[k] = model_state_dict[k]elif opt.warm_start_weights:'''不执行'''try:print('Partially loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))if state_dict[k].shape[1] < model_state_dict[k].shape[1]:model_state_dict[k][:,:state_dict[k].shape[1]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:,:model_state_dict[k].shape[1]]state_dict[k] = model_state_dict[k]except:print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]else:'''执行该 else 中的语句'''print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]'''将新模型的网络参数赋值给导入的模型中'''else:print('Drop parameter {}.'.format(k))for k in model_state_dict:if not (k in state_dict):print('No param {}.'.format(k))state_dict[k] = model_state_dict[k]'''给导入的模型添加没有的参数'''model.load_state_dict(state_dict, strict=False)'''使用 state_dict 反序列化模型参数字字典,用来加载模型参数将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中简述:给模型对象加载训练好的模型参数,即加载模型参数'''#冻结骨干网,没有执行if opt.freeze_backbone:for (name, module) in model.named_children():if name in opt.layers_to_freeze:for (name, layer) in module.named_children():for param in layer.parameters():param.requires_grad = False# 恢复优化器参数,没有执行if optimizer is not None and opt.resume:if 'optimizer' in checkpoint:start_epoch = checkpoint['epoch']start_lr = opt.lrfor step in opt.lr_step:if start_epoch >= step:start_lr *= 0.1for param_group in optimizer.param_groups:param_group['lr'] = start_lrprint('Resumed optimizer with start lr', start_lr)else:print('No optimizer parameters in checkpoint.')if optimizer is not None:'''执行该 if 语句'''return model, optimizer, start_epochelse:return modeldef save_model(path, epoch, model, optimizer=None):if isinstance(model, torch.nn.DataParallel):'''isinstance(object, classinfo) 判断一个函数 object 是否是一个已知的类型 classinfo是则返回 True,反之返回 False'''state_dict = model.module.state_dict()else:state_dict = model.state_dict()'''获取模型的参数矩阵'''data = {'epoch': epoch,'state_dict': state_dict}if not (optimizer is None):data['optimizer'] = optimizer.state_dict()'''获取模型的优化器'''torch.save(data, path)'''保存模型'''

这篇关于【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820587

相关文章

Spring创建Bean的八种主要方式详解

《Spring创建Bean的八种主要方式详解》Spring(尤其是SpringBoot)提供了多种方式来让容器创建和管理Bean,@Component、@Configuration+@Bean、@En... 目录引言一、Spring 创建 Bean 的 8 种主要方式1. @Component 及其衍生注解

MySQL 数据库表操作完全指南:创建、读取、更新与删除实战

《MySQL数据库表操作完全指南:创建、读取、更新与删除实战》本文系统讲解MySQL表的增删查改(CURD)操作,涵盖创建、更新、查询、删除及插入查询结果,也是贯穿各类项目开发全流程的基础数据交互原... 目录mysql系列前言一、Create(创建)并插入数据1.1 单行数据 + 全列插入1.2 多行数据

Java实现TXT文件导入功能的详细步骤

《Java实现TXT文件导入功能的详细步骤》在实际开发中,很多应用场景需要将用户上传的TXT文件进行解析,并将文件中的数据导入到数据库或其他存储系统中,本文将演示如何用Java实现一个基本的TXT文件... 目录前言1. 项目需求分析2. 示例文件格式3. 实现步骤3.1. 准备数据库(假设使用 mysql

MySQL 临时表创建与使用详细说明

《MySQL临时表创建与使用详细说明》MySQL临时表是存储在内存或磁盘的临时数据表,会话结束时自动销毁,适合存储中间计算结果或临时数据集,其名称以#开头(如#TempTable),本文给大家介绍M... 目录mysql 临时表详细说明1.定义2.核心特性3.创建与使用4.典型应用场景5.生命周期管理6.注

MySQL的触发器全解析(创建、查看触发器)

《MySQL的触发器全解析(创建、查看触发器)》MySQL触发器是与表关联的存储程序,当INSERT/UPDATE/DELETE事件发生时自动执行,用于维护数据一致性、日志记录和校验,优点包括自动执行... 目录触发器的概念:创建触www.chinasem.cn发器:查看触发器:查看当前数据库的所有触发器的定

创建springBoot模块没有目录结构的解决方案

《创建springBoot模块没有目录结构的解决方案》2023版IntelliJIDEA创建模块时可能出现目录结构识别错误,导致文件显示异常,解决方法为选择模块后点击确认,重新校准项目结构设置,确保源... 目录创建spChina编程ringBoot模块没有目录结构解决方案总结创建springBoot模块没有目录

解决Failed to get nested archive for entry BOOT-INF/lib/xxx.jar问题

《解决FailedtogetnestedarchiveforentryBOOT-INF/lib/xxx.jar问题》解决BOOT-INF/lib/xxx.jar替换异常需确保路径正确:解... 目录Failed to get nested archive for entry BOOT-INF/lib/xxx

SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南

《SpringBoot集成EasyExcel实现百万级别的数据导入导出实践指南》本文将基于开源项目springboot-easyexcel-batch进行解析与扩展,手把手教大家如何在SpringBo... 目录项目结构概览核心依赖百万级导出实战场景核心代码效果百万级导入实战场景监听器和Service(核心

批量导入txt数据到的redis过程

《批量导入txt数据到的redis过程》用户通过将Redis命令逐行写入txt文件,利用管道模式运行客户端,成功执行批量删除以Product*匹配的Key操作,提高了数据清理效率... 目录批量导入txt数据到Redisjs把redis命令按一条 一行写到txt中管道命令运行redis客户端成功了批量删除k

flask库中sessions.py的使用小结

《flask库中sessions.py的使用小结》在Flask中Session是一种用于在不同请求之间存储用户数据的机制,Session默认是基于客户端Cookie的,但数据会经过加密签名,防止篡改,... 目录1. Flask Session 的基本使用(1) 启用 Session(2) 存储和读取 Se