使用TF-slim在猫狗大战数据集训练分类器

2024-02-12 04:32

本文主要是介绍使用TF-slim在猫狗大战数据集训练分类器,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

目录

TF-slim的测试

数据集制作

数据集验证

数据集注册(引入)

注册数据集

修改slim/train_image_classifier.py文件,进行训练。


 

TF-slim的测试

运行以下命令,测试tf.contrib.slim模块是否已正确安装。从TF_MODELS/research/slim目录下运行:

python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"

数据集制作

创建目录slim/dataset/cvd,该目录下raw_data文件夹存放下载好的猫狗大战图片,而创建好的tfrecord文件直接存放在cvd目录下。

新建数据集制作文件download_and_convert_cvd_v1_0.py,该文件直接从download_and_convert_flowers.py文件修改而来。代码如下:

r"""Downloads and converts cat_vs_dog data to TFRecords of TF-Example protos.This module downloads the cat_vs_dog data, uncompresses it, reads the files
that make up the cat_vs_dog data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.The script should take about a minute to run."""from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport math
import os
import random
import sysimport tensorflow as tffrom datasets import dataset_utils# 数据集原始图片的下载网址,本代码使用预先下载好的图片,因此这个字段不用。
_DATA_URL = 'http://url/to/download/xxx.tgz'# 验证集的图片数,由于猫狗大战数据集的train文件夹共有25000张图片,这里取其中0.3作为验证集。
_NUM_VALIDATION = 7500# Seed for repeatability.
_RANDOM_SEED = 0# 指定数据集分成几个tfrecord文件存放
_NUM_SHARDS = 5# 这个ImageReader类,主要提供两个方法,从gfile读取的*.jpg文件数据解码成图片的tensor数据
class ImageReader(object):"""Helper class that provides TensorFlow image coding utilities."""def __init__(self):# Initializes function that decodes RGB JPEG data.self._decode_jpeg_data = tf.placeholder(dtype=tf.string)self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)def read_image_dims(self, sess, image_data):image = self.decode_jpeg(sess, image_data)return image.shape[0], image.shape[1]def decode_jpeg(self, sess, image_data):image = sess.run(self._decode_jpeg,feed_dict={self._decode_jpeg_data: image_data})assert len(image.shape) == 3assert image.shape[2] == 3return image# 从提供的dataset_dir数据集文件夹读取所有的文件和类别,返回文件的相对路径./dataset_dir/raw_data/*.jpg。
# 默认原始图片存放于dataset_dir/raw_data文件夹下
def _get_filenames_and_classes(dataset_dir):"""Returns a list of filenames and inferred class names.Args:dataset_dir: A directory containing a set of subdirectories representingclass names. Each subdirectory should contain PNG or JPG encoded images.Returns:A list of image file paths, relative to `dataset_dir` and the list ofsubdirectories, representing class names."""cvd_root = os.path.join(dataset_dir, 'raw_data')directories = []class_names = []photo_filenames = []for filename in os.listdir(cvd_root):path = os.path.join(cvd_root, filename)if os.path.isfile(path):photo_filenames.append(path)class_names = ['cat', 'dog']return photo_filenames, class_names# 生成tfrecord文件的存储名,返回其相对dataset_dir的路径
def _get_dataset_filename(dataset_dir, split_name, shard_id):output_filename = 'cat_vs_dog_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS)return os.path.join(dataset_dir, output_filename)# 传入分割名'train'/'validation',相应的文件相对路径的列表,类名到类id的映射字典,tfrecord存放的目录dataset_dir
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):"""Converts the given filenames to a TFRecord dataset.Args:split_name: The name of the dataset, either 'train' or 'validation'.filenames: A list of absolute paths to png or jpg images.class_names_to_ids: A dictionary from class names (strings) to ids(integers).dataset_dir: The directory where the converted datasets are stored."""assert split_name in ['train', 'validation']num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))with tf.Graph().as_default():image_reader = ImageReader()with tf.Session('') as sess:for shard_id in range(_NUM_SHARDS):output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:start_ndx = shard_id * num_per_shardend_ndx = min((shard_id+1) * num_per_shard, len(filenames))for i in range(start_ndx, end_ndx):sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id))sys.stdout.flush()# Read the filename:image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()height, width = image_reader.read_image_dims(sess, image_data)# 根据文件名中是以cat还是dog开始,判断其类别名# class_name = os.path.basename(os.path.dirname(filenames[i]))fname = os.path.basename(filenames[i])class_name = fname.split('.')[0]class_id = class_names_to_ids[class_name]print('\r>> File %s of class_name: %s and class_id: %d will be write into tfrecord.'%(filenames[i], class_name, class_id))example = dataset_utils.image_to_tfexample(image_data, b'jpg', height, width, class_id)tfrecord_writer.write(example.SerializeToString())sys.stdout.write('\n')sys.stdout.flush()def _clean_up_temporary_files(dataset_dir):"""Removes temporary files used to create the dataset.Args:dataset_dir: The directory where the temporary files are stored."""filename = _DATA_URL.split('/')[-1]filepath = os.path.join(dataset_dir, filename)tf.gfile.Remove(filepath)tmp_dir = os.path.join(dataset_dir, 'flower_photos')tf.gfile.DeleteRecursively(tmp_dir)def _dataset_exists(dataset_dir):for split_name in ['train', 'validation']:for shard_id in range(_NUM_SHARDS):output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id)if not tf.gfile.Exists(output_filename):return Falsereturn Truedef run(dataset_dir):"""Runs the download and conversion operation.Args:dataset_dir: The dataset directory where the dataset is stored."""if not tf.gfile.Exists(dataset_dir):tf.gfile.MakeDirs(dataset_dir)if _dataset_exists(dataset_dir):print('Dataset files already exist. Exiting without re-creating them.')return# dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)class_names_to_ids = dict(zip(class_names, range(len(class_names))))# Divide into train and test:random.seed(_RANDOM_SEED)random.shuffle(photo_filenames)_NUM_VALIDATION = int(0.3*(len(photo_filenames))) # 这里按0.3分配训练和验证数training_filenames = photo_filenames[_NUM_VALIDATION:]validation_filenames = photo_filenames[:_NUM_VALIDATION]# First, convert the training and validation sets._convert_dataset('train', training_filenames, class_names_to_ids,dataset_dir)_convert_dataset('validation', validation_filenames, class_names_to_ids,dataset_dir)# Finally, write the labels file:labels_to_class_names = dict(zip(range(len(class_names)), class_names))dataset_utils.write_label_file(labels_to_class_names, dataset_dir)# _clean_up_temporary_files(dataset_dir)print('\nFinished converting the cat_vs_dog dataset!')if __name__ == "__main__":run('cvd')

在datasets目录下运行download_and_convert_flowers.py文件,将在slim/datasets/cvd目录下生成5个train-*.tfrecord和5个validation-*.tfrecord文件。

数据集验证

注意,在上一步数据集制作过程中,

_convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir)

该函数调用了一个函数:

example = dataset_utils.image_to_tfexample(image_data, b'jpg', height, width, class_id)tfrecord_writer.write(example.SerializeToString())

该函数的定义为:

def image_to_tfexample(image_data, image_format, height, width, class_id):return tf.train.Example(features=tf.train.Features(feature={'image/encoded': bytes_feature(image_data),'image/format': bytes_feature(image_format),'image/class/label': int64_feature(class_id),'image/height': int64_feature(height),'image/width': int64_feature(width),}))

该函数明确了tfrecord中每一条tfexample记录的键值签名。后续在数据读取过程中,也应当按照此键值进行解码。

TODO:读取tfrecord,解码tfexample,并重新显示图片及其类别,以验证tfrecord文件的正确性。

数据集注册(引入)

在slim/datasets/目录下新建cvd_v1_0.py(可从flowers.py修改而来)。代码如下:

"""Provides data for the flowers dataset.The dataset scripts used to create the dataset can be found at:
tensorflow/models/research/slim/datasets/download_and_convert_flowers.py
"""from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport os
import tensorflow as tffrom datasets import dataset_utilsslim = tf.contrib.slim_FILE_PATTERN = 'cat_vs_dog_%s_*.tfrecord'SPLITS_TO_SIZES = {'train': 17500, 'validation': 7500}_NUM_CLASSES = 2_ITEMS_TO_DESCRIPTIONS = {'image': 'A color image of varying size.','label': 'A single integer between 0 and 4',
}def get_split(split_name, dataset_dir, file_pattern=None, reader=None):"""Gets a dataset tuple with instructions for reading flowers.Args:split_name: A train/validation split name.dataset_dir: The base directory of the dataset sources.file_pattern: The file pattern to use when matching the dataset sources.It is assumed that the pattern contains a '%s' string so that the splitname can be inserted.reader: The TensorFlow reader type.Returns:A `Dataset` namedtuple.Raises:ValueError: if `split_name` is not a valid train/validation split."""if split_name not in SPLITS_TO_SIZES:raise ValueError('split name %s was not recognized.' % split_name)# if not file_pattern:#   file_pattern = _FILE_PATTERNfile_pattern = _FILE_PATTERNfile_pattern = os.path.join(dataset_dir, file_pattern % split_name)# Allowing None in the signature so that dataset_factory can use the default.if reader is None:reader = tf.TFRecordReaderkeys_to_features = {'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),}items_to_handlers = {'image': slim.tfexample_decoder.Image(),'label': slim.tfexample_decoder.Tensor('image/class/label'),}decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)labels_to_names = Noneif dataset_utils.has_labels(dataset_dir):labels_to_names = dataset_utils.read_label_file(dataset_dir)return slim.dataset.Dataset(data_sources=file_pattern,reader=reader,decoder=decoder,num_samples=SPLITS_TO_SIZES[split_name],items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,num_classes=_NUM_CLASSES,labels_to_names=labels_to_names)

该文件主要实现了一个函数,该函数返回一个slim.dataset.Dataset()类,供调用。

def get_split(split_name, dataset_dir, file_pattern=None, reader=None)

注册数据集

主要修改dataset_factory.py文件,添加cvd数据集从代码到文件的路线。dataset_name='cvd',并映射到相应的程序数据集slim.dataset.Dataset()生成文件cvd_v1_0.py。

"""A factory-pattern class which returns classification image/label pairs."""from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionfrom datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import cvd_v1_0datasets_map = {'cifar10': cifar10,'flowers': flowers,'imagenet': imagenet,'mnist': mnist,'cvd': cvd_v1_0,
}def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):"""Given a dataset name and a split_name returns a Dataset.Args:name: String, the name of the dataset.split_name: A train/test split name.dataset_dir: The directory where the dataset files are stored.file_pattern: The file pattern to use for matching the dataset source files.reader: The subclass of tf.ReaderBase. If left as `None`, then the defaultreader defined by each dataset is used.Returns:A `Dataset` class.Raises:ValueError: If the dataset `name` is unknown."""if name not in datasets_map:raise ValueError('Name of dataset unknown %s' % name)return datasets_map[name].get_split(split_name,dataset_dir,file_pattern,reader)

该文件主要实现了函数:

def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None)

实现从dataset_name和split_name到slim.dataset.Dataset()的映射。至此,数据集准备完毕。

修改slim/train_image_classifier.py文件,进行训练。

主要修改

这篇关于使用TF-slim在猫狗大战数据集训练分类器的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/701673

相关文章

Conda与Python venv虚拟环境的区别与使用方法详解

《Conda与Pythonvenv虚拟环境的区别与使用方法详解》随着Python社区的成长,虚拟环境的概念和技术也在不断发展,:本文主要介绍Conda与Pythonvenv虚拟环境的区别与使用... 目录前言一、Conda 与 python venv 的核心区别1. Conda 的特点2. Python v

Spring Boot中WebSocket常用使用方法详解

《SpringBoot中WebSocket常用使用方法详解》本文从WebSocket的基础概念出发,详细介绍了SpringBoot集成WebSocket的步骤,并重点讲解了常用的使用方法,包括简单消... 目录一、WebSocket基础概念1.1 什么是WebSocket1.2 WebSocket与HTTP

C#中Guid类使用小结

《C#中Guid类使用小结》本文主要介绍了C#中Guid类用于生成和操作128位的唯一标识符,用于数据库主键及分布式系统,支持通过NewGuid、Parse等方法生成,感兴趣的可以了解一下... 目录前言一、什么是 Guid二、生成 Guid1. 使用 Guid.NewGuid() 方法2. 从字符串创建

Python使用python-can实现合并BLF文件

《Python使用python-can实现合并BLF文件》python-can库是Python生态中专注于CAN总线通信与数据处理的强大工具,本文将使用python-can为BLF文件合并提供高效灵活... 目录一、python-can 库:CAN 数据处理的利器二、BLF 文件合并核心代码解析1. 基础合

Python使用OpenCV实现获取视频时长的小工具

《Python使用OpenCV实现获取视频时长的小工具》在处理视频数据时,获取视频的时长是一项常见且基础的需求,本文将详细介绍如何使用Python和OpenCV获取视频时长,并对每一行代码进行深入解析... 目录一、代码实现二、代码解析1. 导入 OpenCV 库2. 定义获取视频时长的函数3. 打开视频文

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互