使用TensorFlow Object Detection API进行红绿灯检测

2024-02-12 04:32

本文主要是介绍使用TensorFlow Object Detection API进行红绿灯检测,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

项目目录结构

本文中未明确说明的情况下,所使用的路径均在./research目录下。

  • research
    • object detection
    • datasets
      • my_traffic_light (参照Pascal VOC目录结构)
        • Annotations
        • ImageSets
        • JPEGImages
        • SegmentationClass
        • SegmentationObject
        • tfrecord
          ***.tfrecord
          ***.pbtxt
    • ssd_traffic_light_detection
      • ssd_traffic_light_detection_model
        • saved_model
          • variables
            saved_model.pb
            pipeline.config
            model.ckpt.meta / index / data-00000-of-00001
            frozen_inference_graph.pb
            checkpoint
      • train (主要存放用于启动训练的一些文件,和训练中间文件)
        • export
        • eval_0
          train_cmd.sh (存放一些会用到的训练命令等)
          model.ckpt-*****.meta
          model.ckpt-*****.index
          model.ckpt-*****.data-00000-of-00001
          graph.pbtxt
          events.out.tfevents.*****
          model_name_datasets.config
          pipeline.config
          checkpoint

数据集制作

图像采集

使用华为手机拍摄视频,存为*.mp4文件。

提取图像

extract_images_from_video
测试读取视频文件,查看文件的FPS/H/W和总帧数。

import cv2
import os
video_path = './JPEGImages/VID_20200419_122755.mp4'
output_dir = './JPEGImages/VID_20200419_122755'
if not os.path.exists(output_dir):os.mkdir(output_dir)cap = cv2.VideoCapture(video_path)
success, frame = cap.read()
fps = cap.get(cv2.CAP_PROP_FPS)
n_frame = cap.get(cv2.CAP_PROP_FRAME_COUNT)
h_frame = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
w_frame = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
print('The video propertities is: fps={}, height={}, width={}, and has {} frames.'.format(fps, h_frame, w_frame, n_frame))

提取图片到视频文件夹下,提取的图片存放到以视频文件名为名的文件夹下。

def extract_images_from_video(video_path):video_name = os.path.basename(video_path).split('.')[0] # 得到视频名字,不含后缀output_dir = os.path.join(os.path.dirname(video_path), video_name)if not os.path.exists(output_dir):os.mkdir(output_dir)cameraCapture = cv2.VideoCapture(video_path)success, frame = cameraCapture.read()idx = 0n_sels = 0while success:idx += 1if idx%45==0: # 每45张图片选取一张n_sels += 1frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)frame_name = "{0}_{1:0>5d}.jpg".format(video_name, n_sels)frame_saved_path = os.path.join(output_dir, frame_name)cv2.imwrite(frame_saved_path, frame)success, frame = cameraCapture.read()cameraCapture.release()print("Finished extract images from {}".format(video_name))import glob
video_files = "./JPEGImages/VID_20200419_*.mp4"
video_filepaths = glob.glob(video_files)
print(video_filepaths)
for path in video_filepaths:extract_images_from_video(path)

图像标注

训练

模型导出

进行推理

推理文件

导入包

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfilefrom distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("../../")
from object_detection.utils import ops as utils_opsif StrictVersion(tf.__version__) < StrictVersion('1.9.0'):raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')# This is needed to display the images.
%matplotlib inline
from utils import label_map_util
from utils import visualization_utils as vis_util
# What model to download.
MODEL_NAME = 'my_traffic_light_detection'
# MODEL_FILE = MODEL_NAME + '.tar.gz'
MODEL_DIR = './model'
# DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_FROZEN_GRAPH = MODEL_DIR + '/frozen_inference_graph.pb'# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('./dataset', 'traffic_light_label_map.pbtxt')

导入计算图

detection_graph = tf.Graph()
with detection_graph.as_default():od_graph_def = tf.GraphDef()with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')ops = tf.get_default_graph().get_operations()all_tensor_names = {output.name for op in ops for output in op.outputs}print(all_tensor_names)category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
print(category_index)def load_image_into_numpy_array(image):(im_width, im_height) = image.sizereturn np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
import glob
PATH_TO_TEST_IMAGES_DIR = './test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(0, 10) ]
# TEST_IMAGE_PATHS = glob.glob("./test_images/*.jpg")
print(TEST_IMAGE_PATHS)
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)
def run_inference_for_single_image(image, graph):with graph.as_default():with tf.Session() as sess:# Get handles to input and output tensorsops = tf.get_default_graph().get_operations()all_tensor_names = {output.name for op in ops for output in op.outputs}tensor_dict = {}for key in ['num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks']:tensor_name = key + ':0'if tensor_name in all_tensor_names:tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)if 'detection_masks' in tensor_dict:# The following processing is only for single imagedetection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])# Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(detection_masks, detection_boxes, image.shape[0], image.shape[1])detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)# Follow the convention by adding back the batch dimensiontensor_dict['detection_masks'] = tf.expand_dims(detection_masks_reframed, 0)image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')# Run inferenceoutput_dict = sess.run(tensor_dict,feed_dict={image_tensor: np.expand_dims(image, 0)})# all outputs are float32 numpy arrays, so convert types as appropriateoutput_dict['num_detections'] = int(output_dict['num_detections'][0])output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)output_dict['detection_boxes'] = output_dict['detection_boxes'][0]output_dict['detection_scores'] = output_dict['detection_scores'][0]if 'detection_masks' in output_dict:output_dict['detection_masks'] = output_dict['detection_masks'][0]return output_dict
import cv2
for image_path in TEST_IMAGE_PATHS:image = Image.open(image_path)# the array based representation of the image will be used later in order to prepare the# result image with boxes and labels on it.image_np = load_image_into_numpy_array(image)# Expand dimensions since the model expects images to have shape: [1, None, None, 3]image_np_expanded = np.expand_dims(image_np, axis=0)# Actual detection.output_dict = run_inference_for_single_image(image_np, detection_graph)print(output_dict)# Visualization of the results of a detection.vis_util.visualize_boxes_and_labels_on_image_array(image_np,output_dict['detection_boxes'],output_dict['detection_classes'],output_dict['detection_scores'],category_index,instance_masks=output_dict.get('detection_masks'),use_normalized_coordinates=True,line_thickness=4)image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
#   cv2.imshow('image',image_np)
#   cv2.waitKey(10)
#   cv2.destroyAllWindows()
#   if cv2.waitKey(1000)&0xff == 113:
# cv2.destroyAllWindows()
#   plt.figure(figsize=IMAGE_SIZE)
#   plt.imshow(image_np)
# plt.show()

这篇关于使用TensorFlow Object Detection API进行红绿灯检测的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/701676

相关文章

Spring IoC 容器的使用详解(最新整理)

《SpringIoC容器的使用详解(最新整理)》文章介绍了Spring框架中的应用分层思想与IoC容器原理,通过分层解耦业务逻辑、数据访问等模块,IoC容器利用@Component注解管理Bean... 目录1. 应用分层2. IoC 的介绍3. IoC 容器的使用3.1. bean 的存储3.2. 方法注

Python内置函数之classmethod函数使用详解

《Python内置函数之classmethod函数使用详解》:本文主要介绍Python内置函数之classmethod函数使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录1. 类方法定义与基本语法2. 类方法 vs 实例方法 vs 静态方法3. 核心特性与用法(1编程客

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

Go语言数据库编程GORM 的基本使用详解

《Go语言数据库编程GORM的基本使用详解》GORM是Go语言流行的ORM框架,封装database/sql,支持自动迁移、关联、事务等,提供CRUD、条件查询、钩子函数、日志等功能,简化数据库操作... 目录一、安装与初始化1. 安装 GORM 及数据库驱动2. 建立数据库连接二、定义模型结构体三、自动迁

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹

Spring 框架之Springfox使用详解

《Spring框架之Springfox使用详解》Springfox是Spring框架的API文档工具,集成Swagger规范,自动生成文档并支持多语言/版本,模块化设计便于扩展,但存在版本兼容性、性... 目录核心功能工作原理模块化设计使用示例注意事项优缺点优点缺点总结适用场景建议总结Springfox 是

嵌入式数据库SQLite 3配置使用讲解

《嵌入式数据库SQLite3配置使用讲解》本文强调嵌入式项目中SQLite3数据库的重要性,因其零配置、轻量级、跨平台及事务处理特性,可保障数据溯源与责任明确,详细讲解安装配置、基础语法及SQLit... 目录0、惨痛教训1、SQLite3环境配置(1)、下载安装SQLite库(2)、解压下载的文件(3)、

Golang如何对cron进行二次封装实现指定时间执行定时任务

《Golang如何对cron进行二次封装实现指定时间执行定时任务》:本文主要介绍Golang如何对cron进行二次封装实现指定时间执行定时任务问题,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录背景cron库下载代码示例【1】结构体定义【2】定时任务开启【3】使用示例【4】控制台输出总结背景