【Tensor'flow】第一个FCN网络

2024-02-11 21:38
文章标签 第一个 网络 tensor flow fcn

本文主要是介绍【Tensor'flow】第一个FCN网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

学习Tensorflow,写一个超级简单的全卷积,效果没有,只是能跑通,没有dropout。


#!/usr/bin/env python
#coding:utf-8
from __future__ import absolute_import
from __future__ import divisionimport os,cv2
import numpy as np
import time
import tensorflow as tf
def weight_variable(shape):# 使用截断的正态分布初始权重initial = tf.truncated_normal(shape, stddev = 0.01)return tf.Variable(initial)def bias_variable(shape):return tf.Variable(tf.constant(0.0, shape = shape))def conv_layer(x, W, b):# W的尺寸是[ksize, ksize, input, output]conv = tf.nn.conv2d(x, W, strides = [1, 1, 1, 1], padding = 'SAME')conv_b = tf.nn.bias_add(conv, b)conv_relu = tf.nn.relu(conv_b)return conv_reludef max_pool_layer(x):return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')def deconv_layer(x, W, output_shape, b):# strides = 2 两倍上卷积# output_shape = [batch_size, output_width, output_height, output_channel],注意第一个是batch_size# 权重W = [ksize, ksize, output, input]后两位和卷积相反deconv = tf.nn.conv2d_transpose(x, W, output_shape,  strides = [1, 2, 2, 1], padding = 'SAME')return tf.nn.bias_add(deconv, b)# 获取数据
def get_data(image_path, label_path):image_list = os.listdir(image_path)label_list = os.listdir(label_path)image_list_arr = []label_list_arr = []for file in image_list:if file[-3:] == 'png':# cv2.imread('', -1)保持原始数据读入;如果没有-1会以图片形式读入,变成三通道image = cv2.imread(os.path.join(image_path,file),-1)#image = transform.resize(image, (512,512))image_list_arr.append(image)for file in label_list:if file[-3:] == 'png':label = cv2.imread(os.path.join(label_path,file), -1)label_list_arr.append(label)return (image_list_arr, label_list_arr)# 读取下一个batch数据
def next_batch(images, labels, batch_size, shuffle = False):assert len(images) == len(labels)if shuffle:indices = np.arange(len(images))np.random.shuffle(indices)for start_idx in range(0, len(images) - batch_size + 1, batch_size):if shuffle:exceprt = indices[start_idx : start_idx + batch_size]else:exceprt = slice(start_idx, start_idx + batch_size)yield np.array(images)[exceprt], np.array(labels)[exceprt]def main():# 尽量写相对路径image_path = './data/mri'label_path =  './data/labels'# 如果内存耗尽可以考虑将batch减小batch_size = 4n_epoch = 2lr = 0.01images, labels = get_data(image_path, label_path)ratio = 0.8length = len(images)s = np.int(length * ratio)x_train = images[: s]y_train = labels[: s]x_val = images[s: ]y_val = labels[s:]keep_prob = tf.placeholder(tf.float32)# None代表样本数量不固定x = tf.placeholder(tf.float32, shape = [None, 256, 256, 3])y = tf.placeholder(tf.float32, shape = [None, 256, 256, 3])# input 256*256# weight([ksize, ksize, input, output])weight1 = weight_variable([3, 3, 3, 64])bias1 = bias_variable([64])conv1 = conv_layer(x, weight1, bias1)# input 256*256# output 128*128weight2 = weight_variable([3, 3, 64, 128])bias2 = bias_variable([128])conv2 = conv_layer(conv1, weight2, bias2)pool1 = max_pool_layer(conv2)# input 128*128# output 64*64weight3 = weight_variable([3, 3, 128, 256])bias3 = bias_variable([256])conv3 = conv_layer(pool1, weight3, bias3)pool2 = max_pool_layer(conv3)# deconv1# weight([ksize, ksize, output, input])# 64*64->128*128(pool1)deconv_weight1 = weight_variable([3, 3, 128, 256])deconv_b1 = bias_variable([128])deconv1 = deconv_layer(pool2, deconv_weight1, [batch_size, 128, 128, 128], deconv_b1)# 与pool1融合,使用add的话deconv和pool的output channel要一致fuse_pool1 = tf.add(deconv1, pool1)# deconv2# 128*128->256*256(input)deconv_weight2 = weight_variable([3, 3, 64, 128])deconv_b2 = bias_variable([64])deconv2 = deconv_layer(fuse_pool1, deconv_weight2, [batch_size, 256, 256, 64], deconv_b2)# 转换成与输入标签相同的size,获得最后结果weight16 = weight_variable([3, 3, 64, 3])bias16 = bias_variable([3])conv16 = tf.nn.conv2d(deconv2, weight16, strides = [1, 1, 1, 1], padding = 'SAME')conv16_b = tf.nn.bias_add(conv16, bias16)logits16 = conv16_b# lossloss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits16, labels=y))opt = tf.train.AdamOptimizer(1e-4).minimize(loss)sess = tf.Session()sess.run(tf.global_variables_initializer())for epoch in range(n_epoch):# trainfor x_train_batch, y_train_batch in next_batch(x_train, y_train, batch_size, shuffle = True):_, train_loss = sess.run([opt, loss], feed_dict = {x: x_train_batch, y: y_train_batch})print ("------trian loss: %f" % train_loss)# valval_loss = 0for x_val_batch, y_val_batch in next_batch(x_val, y_val, batch_size, shuffle = True):val_loss = sess.run([loss], feed_dict={x: x_val_batch, y: y_val_batch})print("------val loss : %f" % val_loss)sess.close()if __name__ == '__main__':main()


这篇关于【Tensor'flow】第一个FCN网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/700870

相关文章

Linux网络配置之网桥和虚拟网络的配置指南

《Linux网络配置之网桥和虚拟网络的配置指南》这篇文章主要为大家详细介绍了Linux中配置网桥和虚拟网络的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 一、网桥的配置在linux系统中配置一个新的网桥主要涉及以下几个步骤:1.为yum仓库做准备,安装组件epel-re

python如何下载网络文件到本地指定文件夹

《python如何下载网络文件到本地指定文件夹》这篇文章主要为大家详细介绍了python如何实现下载网络文件到本地指定文件夹,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下...  在python中下载文件到本地指定文件夹可以通过以下步骤实现,使用requests库处理HTTP请求,并结合o

Linux高并发场景下的网络参数调优实战指南

《Linux高并发场景下的网络参数调优实战指南》在高并发网络服务场景中,Linux内核的默认网络参数往往无法满足需求,导致性能瓶颈、连接超时甚至服务崩溃,本文基于真实案例分析,从参数解读、问题诊断到优... 目录一、问题背景:当并发连接遇上性能瓶颈1.1 案例环境1.2 初始参数分析二、深度诊断:连接状态与

Qt实现网络数据解析的方法总结

《Qt实现网络数据解析的方法总结》在Qt中解析网络数据通常涉及接收原始字节流,并将其转换为有意义的应用层数据,这篇文章为大家介绍了详细步骤和示例,感兴趣的小伙伴可以了解下... 目录1. 网络数据接收2. 缓冲区管理(处理粘包/拆包)3. 常见数据格式解析3.1 jsON解析3.2 XML解析3.3 自定义

Linux系统配置NAT网络模式的详细步骤(附图文)

《Linux系统配置NAT网络模式的详细步骤(附图文)》本文详细指导如何在VMware环境下配置NAT网络模式,包括设置主机和虚拟机的IP地址、网关,以及针对Linux和Windows系统的具体步骤,... 目录一、配置NAT网络模式二、设置虚拟机交换机网关2.1 打开虚拟机2.2 管理员授权2.3 设置子

揭秘Python Socket网络编程的7种硬核用法

《揭秘PythonSocket网络编程的7种硬核用法》Socket不仅能做聊天室,还能干一大堆硬核操作,这篇文章就带大家看看Python网络编程的7种超实用玩法,感兴趣的小伙伴可以跟随小编一起... 目录1.端口扫描器:探测开放端口2.简易 HTTP 服务器:10 秒搭个网页3.局域网游戏:多人联机对战4.

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

Linux系统之主机网络配置方式

《Linux系统之主机网络配置方式》:本文主要介绍Linux系统之主机网络配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、查看主机的网络参数1、查看主机名2、查看IP地址3、查看网关4、查看DNS二、配置网卡1、修改网卡配置文件2、nmcli工具【通用

使用Python高效获取网络数据的操作指南

《使用Python高效获取网络数据的操作指南》网络爬虫是一种自动化程序,用于访问和提取网站上的数据,Python是进行网络爬虫开发的理想语言,拥有丰富的库和工具,使得编写和维护爬虫变得简单高效,本文将... 目录网络爬虫的基本概念常用库介绍安装库Requests和BeautifulSoup爬虫开发发送请求解

如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解

《如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别详解》:本文主要介绍如何通过海康威视设备网络SDK进行Java二次开发摄像头车牌识别的相关资料,描述了如何使用海康威视设备网络SD... 目录前言开发流程问题和解决方案dll库加载不到的问题老旧版本sdk不兼容的问题关键实现流程总结前言作为