【深度学习】使用tensorflow实现VGG19网络

2024-02-17 21:12

本文主要是介绍【深度学习】使用tensorflow实现VGG19网络,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

【深度学习】使用tensorflow实现VGG19网络

 

 

 

本文章向大家介绍【深度学习】使用tensorflow实现VGG19网络,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

 

 

 

 

VGG网络与AlexNet类似,也是一种CNN,VGG在2014年的 ILSVRC localization and classification 两个问题上分别取得了第一名和第二名。VGG网络非常深,通常有16-19层,卷积核大小为 3 x 3,16和19层的区别主要在于后面三个卷积部分卷积层的数量。第二个用tensorflow独立完成的小玩意儿......

 

 

模型结构

可以看到VGG的前几层为卷积和maxpool的交替,每个卷积包含多个卷积层,后面紧跟三个全连接层。激活函数采用Relu,训练采用了dropout,但并没有像AlexNet一样采用LRN(论文给出的理由是加LRN实验效果不好)。

模型定义

def maxPoolLayer(x, kHeight, kWidth, strideX, strideY, name, padding = "SAME"):"""max-pooling"""return tf.nn.max_pool(x, ksize = [1, kHeight, kWidth, 1],strides = [1, strideX, strideY, 1], padding = padding, name = name)def dropout(x, keepPro, name = None):"""dropout"""return tf.nn.dropout(x, keepPro, name)def fcLayer(x, inputD, outputD, reluFlag, name):"""fully-connect"""with tf.variable_scope(name) as scope:w = tf.get_variable("w", shape = [inputD, outputD], dtype = "float")b = tf.get_variable("b", [outputD], dtype = "float")out = tf.nn.xw_plus_b(x, w, b, name = scope.name)if reluFlag:return tf.nn.relu(out)else:return outdef convLayer(x, kHeight, kWidth, strideX, strideY,featureNum, name, padding = "SAME"):"""convlutional"""channel = int(x.get_shape()[-1]) #获取channel数with tf.variable_scope(name) as scope:w = tf.get_variable("w", shape = [kHeight, kWidth, channel, featureNum])b = tf.get_variable("b", shape = [featureNum])featureMap = tf.nn.conv2d(x, w, strides = [1, strideY, strideX, 1], padding = padding)out = tf.nn.bias_add(featureMap, b)return tf.nn.relu(tf.reshape(out, featureMap.get_shape().as_list()), name = scope.name)

定义了卷积、pooling、dropout、全连接五个模块,使用了上一篇AlexNet中的代码,其中卷积模块去除了group参数,因为网络没有像AlexNet一样分成两部分。接下来定义VGG19。

class VGG19(object):"""VGG model"""def __init__(self, x, keepPro, classNum, skip, modelPath = "vgg19.npy"):self.X = xself.KEEPPRO = keepProself.CLASSNUM = classNumself.SKIP = skipself.MODELPATH = modelPath#build CNNself.buildCNN()def buildCNN(self):"""build model"""conv1_1 = convLayer(self.X, 3, 3, 1, 1, 64, "conv1_1" )conv1_2 = convLayer(conv1_1, 3, 3, 1, 1, 64, "conv1_2")pool1 = maxPoolLayer(conv1_2, 2, 2, 2, 2, "pool1")conv2_1 = convLayer(pool1, 3, 3, 1, 1, 128, "conv2_1")conv2_2 = convLayer(conv2_1, 3, 3, 1, 1, 128, "conv2_2")pool2 = maxPoolLayer(conv2_2, 2, 2, 2, 2, "pool2")conv3_1 = convLayer(pool2, 3, 3, 1, 1, 256, "conv3_1")conv3_2 = convLayer(conv3_1, 3, 3, 1, 1, 256, "conv3_2")conv3_3 = convLayer(conv3_2, 3, 3, 1, 1, 256, "conv3_3")conv3_4 = convLayer(conv3_3, 3, 3, 1, 1, 256, "conv3_4")pool3 = maxPoolLayer(conv3_4, 2, 2, 2, 2, "pool3")conv4_1 = convLayer(pool3, 3, 3, 1, 1, 512, "conv4_1")conv4_2 = convLayer(conv4_1, 3, 3, 1, 1, 512, "conv4_2")conv4_3 = convLayer(conv4_2, 3, 3, 1, 1, 512, "conv4_3")conv4_4 = convLayer(conv4_3, 3, 3, 1, 1, 512, "conv4_4")pool4 = maxPoolLayer(conv4_4, 2, 2, 2, 2, "pool4")conv5_1 = convLayer(pool4, 3, 3, 1, 1, 512, "conv5_1")conv5_2 = convLayer(conv5_1, 3, 3, 1, 1, 512, "conv5_2")conv5_3 = convLayer(conv5_2, 3, 3, 1, 1, 512, "conv5_3")conv5_4 = convLayer(conv5_3, 3, 3, 1, 1, 512, "conv5_4")pool5 = maxPoolLayer(conv5_4, 2, 2, 2, 2, "pool5")fcIn = tf.reshape(pool5, [-1, 7*7*512])fc6 = fcLayer(fcIn, 7*7*512, 4096, True, "fc6")dropout1 = dropout(fc6, self.KEEPPRO)fc7 = fcLayer(dropout1, 4096, 4096, True, "fc7")dropout2 = dropout(fc7, self.KEEPPRO)self.fc8 = fcLayer(dropout2, 4096, self.CLASSNUM, True, "fc8")def loadModel(self, sess):"""load model"""wDict = np.load(self.MODELPATH, encoding = "bytes").item()#for layers in modelfor name in wDict:if name not in self.SKIP:with tf.variable_scope(name, reuse = True):for p in wDict[name]:if len(p.shape) == 1:#bias 只有一维sess.run(tf.get_variable('b', trainable = False).assign(p))else:#weightssess.run(tf.get_variable('w', trainable = False).assign(p)) 

buildCNN函数完全按照VGG的结构搭建网络。

loadModel函数从模型文件中读取参数,采用的模型文件见github上的readme说明。 至此,我们定义了完整的模型,下面开始测试模型。

模型测试

ImageNet训练的VGG有很多类,几乎包含所有常见的物体,因此我们随便从网上找几张图片测试。比如我直接用了之前做项目的图片,为了避免审美疲劳,我们不只用渣土车,还要用挖掘机、采沙船:

然后编写测试代码:

parser = argparse.ArgumentParser(description='Classify some images.')
parser.add_argument('mode', choices=['folder', 'url'], default='folder')
parser.add_argument('path', help='Specify a path [e.g. testModel]')
args = parser.parse_args(sys.argv[1:])if args.mode == 'folder': #测试方式为本地文件夹#get testImagewithPath = lambda f: '{}/{}'.format(args.path,f)testImg = dict((f,cv2.imread(withPath(f))) for f in os.listdir(args.path) if os.path.isfile(withPath(f)))
elif args.mode == 'url': #测试方式为URLdef url2img(url): #获取URL图像'''url to image'''resp = urllib.request.urlopen(url)image = np.asarray(bytearray(resp.read()), dtype="uint8")image = cv2.imdecode(image, cv2.IMREAD_COLOR)return imagetestImg = {args.path:url2img(args.path)}if testImg.values():#some paramsdropoutPro = 1classNum = 1000skip = []imgMean = np.array([104, 117, 124], np.float)x = tf.placeholder("float", [1, 224, 224, 3])model = vgg19.VGG19(x, dropoutPro, classNum, skip)score = model.fc8softmax = tf.nn.softmax(score)with tf.Session() as sess:sess.run(tf.global_variables_initializer())model.loadModel(sess) #加载模型for key,img in testImg.items():#img preprocessresized = cv2.resize(img.astype(np.float), (224, 224)) - imgMean #去均值maxx = np.argmax(sess.run(softmax, feed_dict = {x: resized.reshape((1, 224, 224, 3))})) #网络输入为224*224res = caffe_classes.class_names[maxx]font = cv2.FONT_HERSHEY_SIMPLEXcv2.putText(img, res, (int(img.shape[0]/3), int(img.shape[1]/3)), font, 1, (0, 255, 0), 2) #在图像上绘制结果print("{}: {}n----".format(key,res)) #输出测试结果cv2.imshow("demo", img)cv2.waitKey(0)

如果你看完了我AlexNet的博客,那么一定会发现我这里的测试代码做了一些小的修改,增加了URL测试的功能,可以测试网上的图像 ,测试结果如下:

 

这篇关于【深度学习】使用tensorflow实现VGG19网络的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/719001

相关文章

使用Python开发一个Ditto剪贴板数据导出工具

《使用Python开发一个Ditto剪贴板数据导出工具》在日常工作中,我们经常需要处理大量的剪贴板数据,下面将介绍如何使用Python的wxPython库开发一个图形化工具,实现从Ditto数据库中读... 目录前言运行结果项目需求分析技术选型核心功能实现1. Ditto数据库结构分析2. 数据库自动定位3

Python yield与yield from的简单使用方式

《Pythonyield与yieldfrom的简单使用方式》生成器通过yield定义,可在处理I/O时暂停执行并返回部分结果,待其他任务完成后继续,yieldfrom用于将一个生成器的值传递给另一... 目录python yield与yield from的使用代码结构总结Python yield与yield

Go语言使用select监听多个channel的示例详解

《Go语言使用select监听多个channel的示例详解》本文将聚焦Go并发中的一个强力工具,select,这篇文章将通过实际案例学习如何优雅地监听多个Channel,实现多任务处理、超时控制和非阻... 目录一、前言:为什么要使用select二、实战目标三、案例代码:监听两个任务结果和超时四、运行示例五

python使用Akshare与Streamlit实现股票估值分析教程(图文代码)

《python使用Akshare与Streamlit实现股票估值分析教程(图文代码)》入职测试中的一道题,要求:从Akshare下载某一个股票近十年的财务报表包括,资产负债表,利润表,现金流量表,保存... 目录一、前言二、核心知识点梳理1、Akshare数据获取2、Pandas数据处理3、Matplotl

分布式锁在Spring Boot应用中的实现过程

《分布式锁在SpringBoot应用中的实现过程》文章介绍在SpringBoot中通过自定义Lock注解、LockAspect切面和RedisLockUtils工具类实现分布式锁,确保多实例并发操作... 目录Lock注解LockASPect切面RedisLockUtils工具类总结在现代微服务架构中,分布

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

Redis客户端连接机制的实现方案

《Redis客户端连接机制的实现方案》本文主要介绍了Redis客户端连接机制的实现方案,包括事件驱动模型、非阻塞I/O处理、连接池应用及配置优化,具有一定的参考价值,感兴趣的可以了解一下... 目录1. Redis连接模型概述2. 连接建立过程详解2.1 连php接初始化流程2.2 关键配置参数3. 最大连