图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py

2024-01-20 20:04

本文主要是介绍图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

在这里插入图片描述

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

10、main.py的main()函数

def main():opts = get_argparser().parse_args()if opts.dataset.lower() == 'voc':opts.num_classes = 21elif opts.dataset.lower() == 'cityscapes':opts.num_classes = 19# Setup visualizationvis = Visualizer(port=opts.vis_port,env=opts.vis_env) if opts.enable_vis else Noneif vis is not None:  # display optionsvis.vis_table("Options", vars(opts))os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_iddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print("Device: %s" % device)# Setup random seedtorch.manual_seed(opts.random_seed)np.random.seed(opts.random_seed)random.seed(opts.random_seed)# Setup dataloaderif opts.dataset=='voc' and not opts.crop_val:opts.val_batch_size = 1train_dst, val_dst = get_dataset(opts)train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=0)val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=0)print("Dataset: %s, Train set: %d, Val set: %d" %(opts.dataset, len(train_dst), len(val_dst)))# Set up modelmodel_map = {'deeplabv3_resnet50': network.deeplabv3_resnet50,'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,'deeplabv3_resnet101': network.deeplabv3_resnet101,'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,'deeplabv3_mobilenet': network.deeplabv3_mobilenet,'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet}model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)if opts.separable_conv and 'plus' in opts.model:network.convert_to_separable_conv(model.classifier)utils.set_bn_momentum(model.backbone, momentum=0.01)# Set up metricsmetrics = StreamSegMetrics(opts.num_classes)# Set up optimizeroptimizer = torch.optim.SGD(params=[{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},{'params': model.classifier.parameters(), 'lr': opts.lr},], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)#optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)#torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)if opts.lr_policy=='poly':scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)elif opts.lr_policy=='step':scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)# Set up criterion# criterion = utils.get_loss(opts.loss_type)if opts.loss_type == 'focal_loss':criterion = utils.FocalLoss(ignore_index=255, size_average=True)elif opts.loss_type == 'cross_entropy':criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')def save_ckpt(path):""" save current model"""torch.save({"cur_itrs": cur_itrs,"model_state": model.module.state_dict(),"optimizer_state": optimizer.state_dict(),"scheduler_state": scheduler.state_dict(),"best_score": best_score,}, path)print("Model saved as %s" % path)utils.mkdir('checkpoints')# Restorebest_score = 0.0cur_itrs = 0cur_epochs = 0if opts.ckpt is not None and os.path.isfile(opts.ckpt):# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdancheckpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))model.load_state_dict(checkpoint["model_state"])model = nn.DataParallel(model)model.to(device)if opts.continue_training:optimizer.load_state_dict(checkpoint["optimizer_state"])scheduler.load_state_dict(checkpoint["scheduler_state"])cur_itrs = checkpoint["cur_itrs"]best_score = checkpoint['best_score']print("Training state restored from %s" % opts.ckpt)print("Model restored from %s" % opts.ckpt)del checkpoint  # free memoryelse:print("[!] Retrain")model = nn.DataParallel(model)model.to(device)#==========   Train Loop   ==========#vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,np.int32) if opts.enable_vis else None  # sample idxs for visualizationdenorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori imagesif opts.test_only:model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))returninterval_loss = 0while True: #cur_itrs < opts.total_itrs:# =====  Train  =====model.train()cur_epochs += 1for (images, labels) in train_loader:cur_itrs += 1images = images.to(device, dtype=torch.float32)labels = labels.to(device, dtype=torch.long)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()np_loss = loss.detach().cpu().numpy()interval_loss += np_lossif vis is not None:vis.vis_scalar('Loss', cur_itrs, np_loss)if (cur_itrs) % 10 == 0:interval_loss = interval_loss/10print("Epoch %d, Itrs %d/%d, Loss=%f" %(cur_epochs, cur_itrs, opts.total_itrs, interval_loss))interval_loss = 0.0if (cur_itrs) % opts.val_interval == 0:save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %(opts.model, opts.dataset, opts.output_stride))print("validation...")model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))if val_score['Mean IoU'] > best_score:  # save best modelbest_score = val_score['Mean IoU']save_ckpt('checkpoints/best_%s_%s_os%d.pth' %(opts.model, opts.dataset,opts.output_stride))if vis is not None:  # visualize validation score and samplesvis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])vis.vis_table("[Val] Class IoU", val_score['Class IoU'])for k, (img, target, lbl) in enumerate(ret_samples):img = (denorm(img) * 255).astype(np.uint8)target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along widthvis.vis_image('Sample %d' % k, concat_img)model.train()scheduler.step()  if cur_itrs >=  opts.total_itrs:returnif __name__ == '__main__':main()

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

这篇关于图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/627168

相关文章

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

精选20个好玩又实用的的Python实战项目(有图文代码)

《精选20个好玩又实用的的Python实战项目(有图文代码)》文章介绍了20个实用Python项目,涵盖游戏开发、工具应用、图像处理、机器学习等,使用Tkinter、PIL、OpenCV、Kivy等库... 目录① 猜字游戏② 闹钟③ 骰子模拟器④ 二维码⑤ 语言检测⑥ 加密和解密⑦ URL缩短⑧ 音乐播放

Python pandas库自学超详细教程

《Pythonpandas库自学超详细教程》文章介绍了Pandas库的基本功能、安装方法及核心操作,涵盖数据导入(CSV/Excel等)、数据结构(Series、DataFrame)、数据清洗、转换... 目录一、什么是Pandas库(1)、Pandas 应用(2)、Pandas 功能(3)、数据结构二、安

SQL Server跟踪自动统计信息更新实战指南

《SQLServer跟踪自动统计信息更新实战指南》本文详解SQLServer自动统计信息更新的跟踪方法,推荐使用扩展事件实时捕获更新操作及详细信息,同时结合系统视图快速检查统计信息状态,重点强调修... 目录SQL Server 如何跟踪自动统计信息更新:深入解析与实战指南 核心跟踪方法1️⃣ 利用系统目录

java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)

《java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)》:本文主要介绍java中pdf模版填充表单踩坑的相关资料,OpenPDF、iText、PDFBox是三... 目录准备Pdf模版方法1:itextpdf7填充表单(1)加入依赖(2)代码(3)遇到的问题方法2:pd

flask库中sessions.py的使用小结

《flask库中sessions.py的使用小结》在Flask中Session是一种用于在不同请求之间存储用户数据的机制,Session默认是基于客户端Cookie的,但数据会经过加密签名,防止篡改,... 目录1. Flask Session 的基本使用(1) 启用 Session(2) 存储和读取 Se

2025版mysql8.0.41 winx64 手动安装详细教程

《2025版mysql8.0.41winx64手动安装详细教程》本文指导Windows系统下MySQL安装配置,包含解压、设置环境变量、my.ini配置、初始化密码获取、服务安装与手动启动等步骤,... 目录一、下载安装包二、配置环境变量三、安装配置四、启动 mysql 服务,修改密码一、下载安装包安装地

电脑提示d3dx11_43.dll缺失怎么办? DLL文件丢失的多种修复教程

《电脑提示d3dx11_43.dll缺失怎么办?DLL文件丢失的多种修复教程》在使用电脑玩游戏或运行某些图形处理软件时,有时会遇到系统提示“d3dx11_43.dll缺失”的错误,下面我们就来分享超... 在计算机使用过程中,我们可能会遇到一些错误提示,其中之一就是缺失某个dll文件。其中,d3dx11_4

Linux下在线安装启动VNC教程

《Linux下在线安装启动VNC教程》本文指导在CentOS7上在线安装VNC,包含安装、配置密码、启动/停止、清理重启步骤及注意事项,强调需安装VNC桌面以避免黑屏,并解决端口冲突和目录权限问题... 目录描述安装VNC安装 VNC 桌面可能遇到的问题总结描js述linux中的VNC就类似于Window

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em