Retinaface训练超参数调优

2024-06-22 01:20

本文主要是介绍Retinaface训练超参数调优,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

训练20遍数据集跑出的效果

from __future__ import print_functionimport argparse
import math
import osimport optuna
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as datafrom data import WiderFaceDetection, detection_collate, preproc, cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from layers.modules import MultiBoxLoss
from models.retinaface import RetinaFace# 解析命令行参数
parser = argparse.ArgumentParser(description='Retinaface Training')
parser.add_argument('--training_dataset', default='./data/lst/train/label.txt', help='训练数据集目录')
parser.add_argument('--network', default='resnet50', help='Backbone 网络选择: mobile0.25 或 resnet50')
parser.add_argument('--num_workers', default=4, type=int, help='数据加载时的工作线程数')
parser.add_argument('--resume_net', default=None, help='重新训练时的已保存模型路径')
parser.add_argument('--resume_epoch', default=0, type=int, help='重新训练时的迭代轮数')
parser.add_argument('--save_folder', default='./weights/', help='保存检查点模型的目录')# 解析参数
args = parser.parse_args()# 如果 save_folder 目录不存在,则创建它
if not os.path.exists(args.save_folder):os.mkdir(args.save_folder)# 根据选择的网络初始化配置
cfg = None
if args.network == "mobile0.25":cfg = cfg_mnet
elif args.network == "resnet50":cfg = cfg_re50# 设置 RGB 平均值、类别数、图像维度等
rgb_mean = (104, 117, 123)  # BGR 顺序
num_classes = 2
img_dim = cfg['image_size']
num_gpu = cfg['ngpu']
batch_size = cfg['batch_size']
max_epoch = cfg['epoch']
gpu_train = cfg['gpu_train']num_workers = args.num_workers
training_dataset = args.training_dataset
save_folder = args.save_folder# 超参数优化目标函数
def objective(trial):# 超参数搜索空间initial_lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)momentum = trial.suggest_float('momentum', 0.7, 0.99)weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-2, log=True)gamma = trial.suggest_float('gamma', 0.1, 0.5, log=True)# 初始化 RetinaFace 模型net = RetinaFace(cfg=cfg)# 如果指定了 resume_net,加载预训练权重if args.resume_net is not None:state_dict = torch.load(args.resume_net)from collections import OrderedDictnew_state_dict = OrderedDict()for k, v in state_dict.items():head = k[:7]if head == 'module.':name = k[7:]  # 移除 `module.`else:name = knew_state_dict[name] = vnet.load_state_dict(new_state_dict)# 如果有多个 GPU 可用,使用 DataParallel 进行并行训练if num_gpu > 1 and gpu_train:net = torch.nn.DataParallel(net).cuda()else:net = net.cuda()cudnn.benchmark = True# 定义优化器、损失函数和先验框optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=momentum, weight_decay=weight_decay)criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))with torch.no_grad():priors = priorbox.forward()priors = priors.cuda()# 训练函数def train():net.train()epoch = 0 + args.resume_epochdataset = WiderFaceDetection(training_dataset, preproc(img_dim, rgb_mean))epoch_size = math.ceil(len(dataset) / batch_size)max_iter = max_epoch * epoch_sizestepvalues = (cfg['decay1'] * epoch_size, cfg['decay2'] * epoch_size)step_index = 0start_iter = args.resume_epoch * epoch_size if args.resume_epoch > 0 else 0for iteration in range(start_iter, max_iter):if iteration % epoch_size == 0:batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers,collate_fn=detection_collate))epoch += 1if iteration in stepvalues:step_index += 1lr = adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size)images, targets = next(batch_iterator)images = images.cuda()targets = [anno.cuda() for anno in targets]out = net(images)optimizer.zero_grad()loss_l, loss_c, loss_landm = criterion(out, priors, targets)loss = cfg['loc_weight'] * loss_l + loss_c + loss_landmloss.backward()optimizer.step()return loss.item()# 学习率调整函数def adjust_learning_rate(optimizer, gamma, epoch, step_index, iteration, epoch_size):warmup_epoch = 5if epoch < warmup_epoch:lr = initial_lr * (iteration + 1) / (epoch_size * warmup_epoch)else:lr = initial_lr * 0.5 * (1 + math.cos(math.pi * (epoch - warmup_epoch) / (max_epoch - warmup_epoch)))for param_group in optimizer.param_groups:param_group['lr'] = lrreturn lr# 训练并返回损失final_loss = train()# 将超参数和对应的损失写入文件with open('best.txt', 'a') as f:f.write(f"Trial {trial.number} - Loss: {final_loss}\n")f.write(f"lr: {initial_lr}, momentum: {momentum}, weight_decay: {weight_decay}, gamma: {gamma}\n\n")return final_lossif __name__ == '__main__':# 使用 Optuna 进行超参数优化study = optuna.create_study(direction='minimize')study.optimize(objective, n_trials=20)print('最佳超参数:')print(study.best_params)# 将最佳超参数写入文件with open('best.txt', 'a') as f:f.write('最佳超参数:\n')f.write(str(study.best_params))f.write('\n')

在这里插入图片描述

Trial	Loss	Learning Rate (lr)
0	1.474661	0.006130
1	2.118352	0.009536
2	0.720860	0.001478
3	1.219791	0.000690
4	2.139611	0.000137
5	2.485054	0.000155
6	3.654128	0.000102
7	1.276526	0.002037
8	3.393638	0.000207
9	1.449489	0.001304
10	1.539526	0.000506
11	1.774117	0.000576
12	2.403089	0.002708
13	1.850937	0.000673
14	1.411137	0.000331
15	0.963467	0.001294
16	1.245503	0.002975
17	1.674727	0.001095
18	1.468113	0.001567
19	0.801777	0.004197
20	11.193202	0.003866
21	1.638056	0.004286
22	1.181070	0.002046
23	1.436263	0.005183
24	1.868375	0.000939
25	1.384036	0.007968
26	1.327896	0.001810
27	0.900618	0.002786
28	1.587448	0.002961
29	1.414236	0.006639
30	1.640772	0.003894
31	1.167393	0.000975
32	1.327571	0.002453
33	1.163059	0.001341
34	1.491638	0.005048
35	1.831493	0.000386
36	3.975567	0.000802
37	1.390656	0.009954
38	1.485421	0.001599
39	2.045896	0.002233
40	1.480700	0.003282
41	1.251927	0.001231
42	1.286666	0.001444
43	1.157723	0.001841
44	1.000185	0.001921
45	1.868337	0.003445
46	1.291534	0.006600
47	1.486465	0.002422
48	1.743561	0.005034
49	1.316136	0.000792
import pandas as pd
import matplotlib.pyplot as plt# Creating a DataFrame with the extracted data
data = {"Trial": list(range(50)),"Loss": [1.474661, 2.118352, 0.720860, 1.219791, 2.139611, 2.485054, 3.654128, 1.276526, 3.393638, 1.449489, 1.539526, 1.774117, 2.403089, 1.850937, 1.411137, 0.963467, 1.245503, 1.674727, 1.468113, 0.801777, 11.193202, 1.638056, 1.181070, 1.436263, 1.868375, 1.384036, 1.327896, 0.900618, 1.587448, 1.414236, 1.640772, 1.167393, 1.327571, 1.163059, 1.491638, 1.831493, 3.975567, 1.390656, 1.485421, 2.045896, 1.480700, 1.251927, 1.286666, 1.157723, 1.000185, 1.868337, 1.291534, 1.486465, 1.743561, 1.316136],"Learning Rate (lr)": [0.006130, 0.009536, 0.001478, 0.000690, 0.000137, 0.000155, 0.000102, 0.002037, 0.000207, 0.001304,0.000506, 0.000576, 0.002708, 0.000673, 0.000331, 0.001294, 0.002975, 0.001095, 0.001567, 0.004197,0.003866, 0.004286, 0.002046, 0.005183, 0.000939, 0.007968, 0.001810, 0.002786, 0.002961, 0.006639,0.003894, 0.000975, 0.002453, 0.001341, 0.005048, 0.000386, 0.000802, 0.009954, 0.001599, 0.002233,0.003282, 0.001231, 0.001444, 0.001841, 0.001921, 0.003445, 0.006600, 0.002422, 0.005034, 0.000792]
}df = pd.DataFrame(data)# Sorting the dataframe by learning rate for a smooth line plot
df_sorted = df.sort_values(by="Learning Rate (lr)")# Extracting sorted data
sorted_learning_rates = df_sorted["Learning Rate (lr)"]
sorted_losses = df_sorted["Loss"]# Plotting line chart
plt.figure(figsize=(10, 6))
plt.plot(sorted_learning_rates, sorted_losses, marker='o', linestyle='-', color='blue')
plt.xscale('log')  # Using logarithmic scale for learning rate to better visualize the range
plt.yscale('log')  # Using logarithmic scale for loss to better visualize the range
plt.title('Loss vs Learning Rate')
plt.xlabel('Learning Rate (log scale)')
plt.ylabel('Loss (log scale)')
plt.grid(True)
plt.show()

要确定进一步训练的最佳学习率范围,我们可以通过以下几点分析学习率与损失之间的关系:

低损失区域:找出在较低损失值对应的学习率范围。
稳定性:选择一个损失值较低且稳定的学习率范围。
从已提供的数据中,我们可以观察哪些学习率对应较低的损失值:

最低损失值为0.720860,出现在学习率为0.001478。
其他较低损失值(低于1.0)出现在以下学习率:
0.004197 对应损失 0.801777
0.002786 对应损失 0.900618
0.001294 对应损失 0.963467
综合考虑损失较低和稳定性,学习率范围可以集中在0.001到0.005之间。具体建议如下:

初步范围:0.001到0.005
更精细的范围:由于在0.001478、0.004197和0.002786处有较低的损失,可以进一步缩小到0.001到0.002和0.004到0.005之间。
进一步训练建议
在0.001到0.002之间试验更多的学习率值,例如0.0011, 0.0015, 0.0018等。
在0.004到0.005之间试验,例如0.0042, 0.0045, 0.0048等。
通过这些步骤,您可以更准确地找到最佳学习率,从而进一步降低损失,提高模型的性能

这篇关于Retinaface训练超参数调优的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1082892

相关文章

JVisualVM之Java性能监控与调优利器详解

《JVisualVM之Java性能监控与调优利器详解》本文将详细介绍JVisualVM的使用方法,并结合实际案例展示如何利用它进行性能调优,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全... 目录1. JVisualVM简介2. JVisualVM的安装与启动2.1 启动JVisualVM2

一文详解PostgreSQL复制参数

《一文详解PostgreSQL复制参数》PostgreSQL作为一款功能强大的开源关系型数据库,其复制功能对于构建高可用性系统至关重要,本文给大家详细介绍了PostgreSQL的复制参数,需要的朋友可... 目录一、复制参数基础概念二、核心复制参数深度解析1. max_wal_seChina编程nders:WAL

Linux高并发场景下的网络参数调优实战指南

《Linux高并发场景下的网络参数调优实战指南》在高并发网络服务场景中,Linux内核的默认网络参数往往无法满足需求,导致性能瓶颈、连接超时甚至服务崩溃,本文基于真实案例分析,从参数解读、问题诊断到优... 目录一、问题背景:当并发连接遇上性能瓶颈1.1 案例环境1.2 初始参数分析二、深度诊断:连接状态与

史上最全nginx详细参数配置

《史上最全nginx详细参数配置》Nginx是一个轻量级高性能的HTTP和反向代理服务器,同时也是一个通用代理服务器(TCP/UDP/IMAP/POP3/SMTP),最初由俄罗斯人IgorSyso... 目录基本命令默认配置搭建站点根据文件类型设置过期时间禁止文件缓存防盗链静态文件压缩指定定错误页面跨域问题

SpringBoot请求参数接收控制指南分享

《SpringBoot请求参数接收控制指南分享》:本文主要介绍SpringBoot请求参数接收控制指南,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring Boot 请求参数接收控制指南1. 概述2. 有注解时参数接收方式对比3. 无注解时接收参数默认位置

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

SpringMVC获取请求参数的方法

《SpringMVC获取请求参数的方法》:本文主要介绍SpringMVC获取请求参数的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下... 目录1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、@RequestParam4、@

Spring Boot项目部署命令java -jar的各种参数及作用详解

《SpringBoot项目部署命令java-jar的各种参数及作用详解》:本文主要介绍SpringBoot项目部署命令java-jar的各种参数及作用的相关资料,包括设置内存大小、垃圾回收... 目录前言一、基础命令结构二、常见的 Java 命令参数1. 设置内存大小2. 配置垃圾回收器3. 配置线程栈大小

SpringBoot利用@Validated注解优雅实现参数校验

《SpringBoot利用@Validated注解优雅实现参数校验》在开发Web应用时,用户输入的合法性校验是保障系统稳定性的基础,​SpringBoot的@Validated注解提供了一种更优雅的解... 目录​一、为什么需要参数校验二、Validated 的核心用法​1. 基础校验2. php分组校验3