easy-Fpn源码解读(二):train

2024-04-02 05:08
文章标签 源码 解读 train easy fpn

本文主要是介绍easy-Fpn源码解读(二):train,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • easy-Fpn源码解读(二):train
    • train.py完整代码
    • 代码解析

easy-Fpn源码解读(二):train

train.py完整代码

import argparse
import os
import time
import uuid
from collections import deque
from typing import Optionalfrom tensorboardX import SummaryWriter
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoaderfrom backbone.base import Base as BackboneBase
from config.train_config import TrainConfig as Config
from dataset.base import Base as DatasetBase
from logger import Logger as Log
from model import Model
from roi.wrapper import Wrapper as ROIWrapperdef _train(dataset_name: str, backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str, path_to_resuming_checkpoint: Optional[str]):dataset = DatasetBase.from_name(dataset_name)(path_to_data_dir, DatasetBase.Mode.TRAIN, Config.IMAGE_MIN_SIDE, Config.IMAGE_MAX_SIDE)dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)Log.i('Found {:d} samples'.format(len(dataset)))backbone = BackboneBase.from_name(backbone_name)(pretrained=True)model = Model(backbone, dataset.num_classes(), pooling_mode=Config.POOLING_MODE,anchor_ratios=Config.ANCHOR_RATIOS, anchor_scales=Config.ANCHOR_SCALES,rpn_pre_nms_top_n=Config.RPN_PRE_NMS_TOP_N, rpn_post_nms_top_n=Config.RPN_POST_NMS_TOP_N).cuda()optimizer = optim.SGD(model.parameters(), lr=Config.LEARNING_RATE,momentum=Config.MOMENTUM, weight_decay=Config.WEIGHT_DECAY)scheduler = MultiStepLR(optimizer, milestones=Config.STEP_LR_SIZES, gamma=Config.STEP_LR_GAMMA)step = 0time_checkpoint = time.time()losses = deque(maxlen=100)summary_writer = SummaryWriter(os.path.join(path_to_checkpoints_dir, 'summaries'))should_stop = Falsenum_steps_to_display = Config.NUM_STEPS_TO_DISPLAYnum_steps_to_snapshot = Config.NUM_STEPS_TO_SNAPSHOTnum_steps_to_finish = Config.NUM_STEPS_TO_FINISHif path_to_resuming_checkpoint is not None:step = model.load(path_to_resuming_checkpoint, optimizer, scheduler)Log.i(f'Model has been restored from file: {path_to_resuming_checkpoint}')Log.i('Start training')while not should_stop:for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'image = image_batch[0].cuda()bboxes = bboxes_batch[0].cuda()labels = labels_batch[0].cuda()forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss = forward_outputloss = anchor_objectness_loss + anchor_transformer_loss + proposal_class_loss + proposal_transformer_lossoptimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()losses.append(loss.item())summary_writer.add_scalar('train/anchor_objectness_loss'

这篇关于easy-Fpn源码解读(二):train的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/869042

相关文章

C++类和对象之默认成员函数的使用解读

《C++类和对象之默认成员函数的使用解读》:本文主要介绍C++类和对象之默认成员函数的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、默认成员函数有哪些二、各默认成员函数详解默认构造函数析构函数拷贝构造函数拷贝赋值运算符三、默认成员函数的注意事项总结一

MySQL的ALTER TABLE命令的使用解读

《MySQL的ALTERTABLE命令的使用解读》:本文主要介绍MySQL的ALTERTABLE命令的使用,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、查看所建表的编China编程码格式2、修改表的编码格式3、修改列队数据类型4、添加列5、修改列的位置5.1、把列

Linux CPU飙升排查五步法解读

《LinuxCPU飙升排查五步法解读》:本文主要介绍LinuxCPU飙升排查五步法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录排查思路-五步法1. top命令定位应用进程pid2.php top-Hp[pid]定位应用进程对应的线程tid3. printf"%

解读@ConfigurationProperties和@value的区别

《解读@ConfigurationProperties和@value的区别》:本文主要介绍@ConfigurationProperties和@value的区别及说明,具有很好的参考价值,希望对大家... 目录1. 功能对比2. 使用场景对比@ConfigurationProperties@Value3. 核

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

Jupyter notebook安装步骤解读

《Jupyternotebook安装步骤解读》:本文主要介绍Jupyternotebook安装步骤,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、开始安装二、更改打开文件位置和快捷启动方式总结在安装Jupyter notebook 之前,确认您已安装pytho

Java中的StringUtils.isBlank()方法解读

《Java中的StringUtils.isBlank()方法解读》:本文主要介绍Java中的StringUtils.isBlank()方法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑... 目录所在库及依赖引入方法签名方法功能示例代码代码解释与其他方法的对比总结StringUtils.isBl

对Django中时区的解读

《对Django中时区的解读》:本文主要介绍对Django中时区的解读方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录背景前端数据库中存储接口返回AI的解释问题:这样设置的作用答案获取当前时间(自动带时区)转换为北京时间显示总结背景设置时区为北京时间 TIM

Java中的内部类和常用类用法解读

《Java中的内部类和常用类用法解读》:本文主要介绍Java中的内部类和常用类用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录内部类和常用类内部类成员内部类静态内部类局部内部类匿名内部类常用类Object类包装类String类StringBuffer和Stri

使用easy connect之后,maven无法使用,原来需要配置-Djava.net.preferIPv4Stack=true问题

《使用easyconnect之后,maven无法使用,原来需要配置-Djava.net.preferIPv4Stack=true问题》:本文主要介绍使用easyconnect之后,maven无法... 目录使用easGWowCy connect之后,maven无法使用,原来需要配置-DJava.net.pr