基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码

本文主要是介绍基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  第一步:准备数据

头发分割数据,总共有5711张图片,里面的像素值为0和1,所以看起来全部是黑的,不影响使用

第二步:搭建模型

计算机视觉领域有很多任务是位置敏感的,比如目标检测、语义分割、实例分割等等。为了这些任务位置信息更加精准,很容易想到的做法就是维持高分辨率的feature map,事实上HRNet之前几乎所有的网络都是这么做的,通过下采样得到强语义信息,然后再上采样恢复高分辨率恢复位置信息(如下图所示),然而这种做法,会导致大量的有效信息在不断的上下采样过程中丢失。而HRNet通过并行多个分辨率的分支,加上不断进行不同分支之间的信息交互,同时达到强语义信息和精准位置信息的目的。

recover high resolution

思路在当时来讲,不同分支的信息交互属于很老套的思路(如FPN等),我觉得最大的创新点还是能够从头到尾保持高分辨率,而不同分支的信息交互是为了补充通道数减少带来的信息损耗,这种网络架构设计对于位置敏感的任务会有奇效。

第三步:代码

1)损失函数为:交叉熵损失函数

2)网络代码:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Ffrom .backbone import BN_MOMENTUM, hrnet_classificationclass HRnet_Backbone(nn.Module):def __init__(self, backbone = 'hrnetv2_w18', pretrained = False):super(HRnet_Backbone, self).__init__()self.model    = hrnet_classification(backbone = backbone, pretrained = pretrained)del self.model.incre_modulesdel self.model.downsamp_modulesdel self.model.final_layerdel self.model.classifierdef forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.conv2(x)x = self.model.bn2(x)x = self.model.relu(x)x = self.model.layer1(x)x_list = []for i in range(2):if self.model.transition1[i] is not None:x_list.append(self.model.transition1[i](x))else:x_list.append(x)y_list = self.model.stage2(x_list)x_list = []for i in range(3):if self.model.transition2[i] is not None:if i < 2:x_list.append(self.model.transition2[i](y_list[i]))else:x_list.append(self.model.transition2[i](y_list[-1]))else:x_list.append(y_list[i])y_list = self.model.stage3(x_list)x_list = []for i in range(4):if self.model.transition3[i] is not None:if i < 3:x_list.append(self.model.transition3[i](y_list[i]))else:x_list.append(self.model.transition3[i](y_list[-1]))else:x_list.append(y_list[i])y_list = self.model.stage4(x_list)return y_listclass HRnet(nn.Module):def __init__(self, num_classes = 21, backbone = 'hrnetv2_w18', pretrained = False):super(HRnet, self).__init__()self.backbone       = HRnet_Backbone(backbone = backbone, pretrained = pretrained)last_inp_channels   = np.int(np.sum(self.backbone.model.pre_stage_channels))self.last_layer = nn.Sequential(nn.Conv2d(in_channels=last_inp_channels, out_channels=last_inp_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM),nn.ReLU(inplace=True),nn.Conv2d(in_channels=last_inp_channels, out_channels=num_classes, kernel_size=1, stride=1, padding=0))def forward(self, inputs):H, W = inputs.size(2), inputs.size(3)x = self.backbone(inputs)# Upsamplingx0_h, x0_w = x[0].size(2), x[0].size(3)x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)x = torch.cat([x[0], x1, x2, x3], 1)x = self.last_layer(x)x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)return x

第四步:统计一些指标(训练过程中的loss和miou)

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码见:基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码

有问题可以私信或者留言,有问必答

这篇关于基于Pytorch框架的深度学习HRnet网络人像语义分割系统源码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1124401

相关文章

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

SpringBoot基础框架详解

《SpringBoot基础框架详解》SpringBoot开发目的是为了简化Spring应用的创建、运行、调试和部署等,使用SpringBoot可以不用或者只需要很少的Spring配置就可以让企业项目快... 目录SpringBoot基础 – 框架介绍1.SpringBoot介绍1.1 概述1.2 核心功能2

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实

MyBatis分页插件PageHelper深度解析与实践指南

《MyBatis分页插件PageHelper深度解析与实践指南》在数据库操作中,分页查询是最常见的需求之一,传统的分页方式通常有两种内存分页和SQL分页,MyBatis作为优秀的ORM框架,本身并未提... 目录1. 为什么需要分页插件?2. PageHelper简介3. PageHelper集成与配置3.

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

Maven 插件配置分层架构深度解析

《Maven插件配置分层架构深度解析》:本文主要介绍Maven插件配置分层架构深度解析,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录Maven 插件配置分层架构深度解析引言:当构建逻辑遇上复杂配置第一章 Maven插件配置的三重境界1.1 插件配置的拓扑

ubuntu20.0.4系统中安装Anaconda的超详细图文教程

《ubuntu20.0.4系统中安装Anaconda的超详细图文教程》:本文主要介绍了在Ubuntu系统中如何下载和安装Anaconda,提供了两种方法,详细内容请阅读本文,希望能对你有所帮助... 本文介绍了在Ubuntu系统中如何下载和安装Anaconda。提供了两种方法,包括通过网页手动下载和使用wg

重新对Java的类加载器的学习方式

《重新对Java的类加载器的学习方式》:本文主要介绍重新对Java的类加载器的学习方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍1.1、简介1.2、符号引用和直接引用1、符号引用2、直接引用3、符号转直接的过程2、加载流程3、类加载的分类3.1、显示