easy-Fpn源码解读(二):train

2024-04-02 05:08
文章标签 源码 解读 train easy fpn

本文主要是介绍easy-Fpn源码解读(二):train,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • easy-Fpn源码解读(二):train
    • train.py完整代码
    • 代码解析

easy-Fpn源码解读(二):train

train.py完整代码

import argparse
import os
import time
import uuid
from collections import deque
from typing import Optionalfrom tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoaderfrom backbone.base import Base as BackboneBase
from config.train_config import TrainConfig as Config
from dataset.base import Base as DatasetBase
from logger import Logger as Log
from model import Model
from roi.wrapper import Wrapper as ROIWrapperdef _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)Log.i('Found {:d} samples'.format(len(dataset)))backbone = BackboneBase.from_name(backbone_name)(pretrained=True)model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)step = 0time_checkpoint = time.time()losses = deque(maxlen=100)summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))should_stop = Falsenum_steps_to_display = Config.NUM_STEPS_TO_DISPLAYnum_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOTnum_steps_to_finish = Config.NUM_STEPS_TO_FINISHif path_to_resuming_checkpoint is not None:step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')Log.i('Start training')while not should_stop:for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'image = image_batch[0].cuda()bboxes = bboxes_batch[0].cuda()labels = labels_batch[0].cuda()forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_outputloss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_lossoptimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()losses.append(loss.item())summary_writer.add_scalar('train/anchor_objectness_loss'

这篇关于easy-Fpn源码解读(二):train的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/869042

相关文章

Linux jq命令的使用解读

《Linuxjq命令的使用解读》jq是一个强大的命令行工具,用于处理JSON数据,它可以用来查看、过滤、修改、格式化JSON数据,通过使用各种选项和过滤器,可以实现复杂的JSON处理任务... 目录一. 简介二. 选项2.1.2.2-c2.3-r2.4-R三. 字段提取3.1 普通字段3.2 数组字段四.

MySQL之搜索引擎使用解读

《MySQL之搜索引擎使用解读》MySQL存储引擎是数据存储和管理的核心组件,不同引擎(如InnoDB、MyISAM)采用不同机制,InnoDB支持事务与行锁,适合高并发场景;MyISAM不支持事务,... 目录mysql的存储引擎是什么MySQL存储引擎的功能MySQL的存储引擎的分类查看存储引擎1.命令

Spring的基础事务注解@Transactional作用解读

《Spring的基础事务注解@Transactional作用解读》文章介绍了Spring框架中的事务管理,核心注解@Transactional用于声明事务,支持传播机制、隔离级别等配置,结合@Tran... 目录一、事务管理基础1.1 Spring事务的核心注解1.2 注解属性详解1.3 实现原理二、事务事

Linux五种IO模型的使用解读

《Linux五种IO模型的使用解读》文章系统解析了Linux的五种IO模型(阻塞、非阻塞、IO复用、信号驱动、异步),重点区分同步与异步IO的本质差异,强调同步由用户发起,异步由内核触发,通过对比各模... 目录1.IO模型简介2.五种IO模型2.1 IO模型分析方法2.2 阻塞IO2.3 非阻塞IO2.4

MySQL8.0临时表空间的使用及解读

《MySQL8.0临时表空间的使用及解读》MySQL8.0+引入会话级(temp_N.ibt)和全局(ibtmp1)InnoDB临时表空间,用于存储临时数据及事务日志,自动创建与回收,重启释放,管理高... 目录一、核心概念:为什么需要“临时表空间”?二、InnoDB 临时表空间的两种类型1. 会话级临时表

java 恺撒加密/解密实现原理(附带源码)

《java恺撒加密/解密实现原理(附带源码)》本文介绍Java实现恺撒加密与解密,通过固定位移量对字母进行循环替换,保留大小写及非字母字符,由于其实现简单、易于理解,恺撒加密常被用作学习加密算法的入... 目录Java 恺撒加密/解密实现1. 项目背景与介绍2. 相关知识2.1 恺撒加密算法原理2.2 Ja

Nginx屏蔽服务器名称与版本信息方式(源码级修改)

《Nginx屏蔽服务器名称与版本信息方式(源码级修改)》本文详解如何通过源码修改Nginx1.25.4,移除Server响应头中的服务类型和版本信息,以增强安全性,需重新配置、编译、安装,升级时需重复... 目录一、背景与目的二、适用版本三、操作步骤修改源码文件四、后续操作提示五、注意事项六、总结一、背景与

Android实现图片浏览功能的示例详解(附带源码)

《Android实现图片浏览功能的示例详解(附带源码)》在许多应用中,都需要展示图片并支持用户进行浏览,本文主要为大家介绍了如何通过Android实现图片浏览功能,感兴趣的小伙伴可以跟随小编一起学习一... 目录一、项目背景详细介绍二、项目需求详细介绍三、相关技术详细介绍四、实现思路详细介绍五、完整实现代码

C语言自定义类型之联合和枚举解读

《C语言自定义类型之联合和枚举解读》联合体共享内存,大小由最大成员决定,遵循对齐规则;枚举类型列举可能值,提升可读性和类型安全性,两者在C语言中用于优化内存和程序效率... 目录一、联合体1.1 联合体类型的声明1.2 联合体的特点1.2.1 特点11.2.2 特点21.2.3 特点31.3 联合体的大小1

Python标准库datetime模块日期和时间数据类型解读

《Python标准库datetime模块日期和时间数据类型解读》文章介绍Python中datetime模块的date、time、datetime类,用于处理日期、时间及日期时间结合体,通过属性获取时间... 目录Datetime常用类日期date类型使用时间 time 类型使用日期和时间的结合体–日期时间(