【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py

2024-03-18 02:12

本文主要是介绍【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文件位置:CenterFusion/src/lib/trainer.py
run_epoch作用:CenterFusion 项目训练一轮epoch过程

  • 在 main.py 函数中,生成了训练器,然后再使用训练器训练一个 epoch
  • run_epoch()函数的定义在src\lib\trainer.py150行左右,它的主要过程如下所示:
  def run_epoch(self, phase, epoch, data_loader):model_with_loss = self.model_with_loss'''self.model_with_loss 是 ModelWithLoss 类,这个类又继承 torch.nn.Module 类'''if phase == 'train':model_with_loss.train()'''启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和 Dropout需要在训练时添加 model.train()model.train()是保证 BN 层能够用到每一批数据的均值和方差对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数'''else:if len(self.opt.gpus) > 1:model_with_loss = self.model_with_loss.modulemodel_with_loss.eval()'''不启用 Batch Normalization 和 Dropout如果模型中有 BN 层 (Batch Normalization) 和Dropout在测试时添加 model.eval()model.eval() 是保证 BN 层能够用全部训练数据的均值和方差即测试过程中要保证 BN 层的均值和方差不变对于 Dropout,model.eval() 是利用到了所有网络连接,即不进行随机舍弃神经元。'''torch.cuda.empty_cache()'''释放空间'''opt = self.optresults = {}data_time, batch_time = AverageMeter(), AverageMeter()'''新建两个 AverageMeter 对象'''avg_loss_stats = {l: AverageMeter() for l in self.loss_stats \if l == 'tot' or opt.weights[l] > 0}'''为 loss 列表的每个属性赋值一个 AverageMeter 对象'''num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters'''获取数据长度'''bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)end = time.time()'''设置进度条'''for iter_id, batch in enumerate(data_loader):if iter_id >= num_iters:break'''遍历完'''data_time.update(time.time() - end)'''更新 data_time 的值'''for k in batch:if k != 'meta':batch[k] = batch[k].to(device=opt.device, non_blocking=True)'''这里的 batch 是一个 Tensor 对象将其配置到 gpu 上'''output, loss, loss_stats = model_with_loss(batch, phase)'''运行第一阶段(模型训练)'''# backpropagate and step optimizer 反向传播和步进优化器loss = loss.mean()'''求每一层损失值的平均值'''if phase == 'train':self.optimizer.zero_grad()'''将模型的参数梯度初始化为0'''loss.backward()'''反向传播计算梯度'''self.optimizer.step()'''更新所有参数''''''根据 pytorch 中 backward() 函数的计算当网络参量进行反馈时,梯度是累积计算而不是被替换但在处理每一个 batch 时并不需要与其他 batch的梯度混合起来累积计算因此需要对每个 batch 调用一遍 zero_grad() 将参数梯度置 0.'''batch_time.update(time.time() - end)'''更新 batch_time 的值'''end = time.time()Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(epoch, iter_id, num_iters, phase=phase,total=bar.elapsed_td, eta=bar.eta_td)'''bar.elapsed_td : 经过的时间增量eta=bar.eta_td : 时间间隔'''for l in avg_loss_stats:avg_loss_stats[l].update(loss_stats[l].mean().item(), batch['image'].size(0))Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)'''更新平均损失'''Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \'|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)if opt.print_iter > 0:if iter_id % opt.print_iter == 0:print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix)) else:bar.next()'''opt.print_iter = 0 执行 else 语句,显示进度条'''if opt.debug > 0:self.debug(batch, output, iter_id, dataset=data_loader.dataset)'''debug 默认为 0,没有执行 if 语句'''if (phase == 'val' and (opt.run_dataset_eval or opt.eval)):meta = batch['meta']dets = fusion_decode(output, K=opt.K, opt=opt)'''解码器和雷达点云融合调用的这个函数位于 CenterFusion\src\lib\model\decode.py 中这个函数具体实现的功能就是将前面模型训练得到的结果,也就是一些特征图,这些特征图为多维矩阵将特征图与毫米波雷达点云进行映射,映射过程就是将特征图进行维度转换、升维等操作,然后再点乘旋转矩阵'''for k in dets:dets[k] = dets[k].detach().cpu().numpy()'''detach() 阻断反向传播,返回值仍为 tensorcpu() 将变量放在 cpu 上,仍为 tensornumpy() 将 tensor 转换为 numpy'''calib = meta['calib'].detach().numpy() if 'calib' in meta else Nonedets = generic_post_process(opt, dets, meta['c'].cpu().numpy(), meta['s'].cpu().numpy(),output['hm'].shape[2], output['hm'].shape[3], self.opt.num_classes,calib)result = []for i in range(len(dets[0])):if dets[0][i]['score'] > self.opt.out_thresh and all(dets[0][i]['dim'] > 0):result.append(dets[0][i])'''筛选结果'''img_id = batch['meta']['img_id'].numpy().astype(np.int32)[0]'''强制类型转换图片 id'''results[img_id] = resultdel output, loss, loss_statsbar.finish()ret = {k: v.avg for k, v in avg_loss_stats.items()}'''平均损失结果'''ret['time'] = bar.elapsed_td.total_seconds() / 60.return ret, results

这篇关于【CenterFusion】run_epoch()函数-训练一轮epoch-CenterFusion/src/lib/trainer.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/820868

相关文章

Python中help()和dir()函数的使用

《Python中help()和dir()函数的使用》我们经常需要查看某个对象(如模块、类、函数等)的属性和方法,Python提供了两个内置函数help()和dir(),它们可以帮助我们快速了解代... 目录1. 引言2. help() 函数2.1 作用2.2 使用方法2.3 示例(1) 查看内置函数的帮助(

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

Python中bisect_left 函数实现高效插入与有序列表管理

《Python中bisect_left函数实现高效插入与有序列表管理》Python的bisect_left函数通过二分查找高效定位有序列表插入位置,与bisect_right的区别在于处理重复元素时... 目录一、bisect_left 基本介绍1.1 函数定义1.2 核心功能二、bisect_left 与

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

C++/类与对象/默认成员函数@构造函数的用法

《C++/类与对象/默认成员函数@构造函数的用法》:本文主要介绍C++/类与对象/默认成员函数@构造函数的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录名词概念默认成员函数构造函数概念函数特征显示构造函数隐式构造函数总结名词概念默认构造函数:不用传参就可以

C++类和对象之默认成员函数的使用解读

《C++类和对象之默认成员函数的使用解读》:本文主要介绍C++类和对象之默认成员函数的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、默认成员函数有哪些二、各默认成员函数详解默认构造函数析构函数拷贝构造函数拷贝赋值运算符三、默认成员函数的注意事项总结一

Python函数返回多个值的多种方法小结

《Python函数返回多个值的多种方法小结》在Python中,函数通常用于封装一段代码,使其可以重复调用,有时,我们希望一个函数能够返回多个值,Python提供了几种不同的方法来实现这一点,需要的朋友... 目录一、使用元组(Tuple):二、使用列表(list)三、使用字典(Dictionary)四、 使

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

MySQL 字符串截取函数及用法详解

《MySQL字符串截取函数及用法详解》在MySQL中,字符串截取是常见的操作,主要用于从字符串中提取特定部分,MySQL提供了多种函数来实现这一功能,包括LEFT()、RIGHT()、SUBST... 目录mysql 字符串截取函数详解RIGHT(str, length):从右侧截取指定长度的字符SUBST

Kotlin运算符重载函数及作用场景

《Kotlin运算符重载函数及作用场景》在Kotlin里,运算符重载函数允许为自定义类型重新定义现有的运算符(如+-…)行为,从而让自定义类型能像内置类型那样使用运算符,本文给大家介绍Kotlin运算... 目录基本语法作用场景类对象数据类型接口注意事项在 Kotlin 里,运算符重载函数允许为自定义类型重