Object_Detection_API之Inference

2024-02-12 04:32
文章标签 object detection api inference

本文主要是介绍Object_Detection_API之Inference,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

导入包

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfilefrom distutils.version import StrictVersion
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("../../")
from object_detection.utils import ops as utils_opsif StrictVersion(tf.__version__) < StrictVersion('1.9.0'):raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')%matplotlib inline
from utils import label_map_util
from utils import visualization_utils as vis_util

设置模型目录

MODEL_DIR = './model_from_200000_steps'
PATH_TO_FROZEN_GRAPH = MODEL_DIR + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('./dataset', 'traffic_light_label_map.pbtxt')

恢复计算图

detection_graph = tf.Graph()
with detection_graph.as_default():od_graph_def=tf.GraphDef()with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')ops = tf.get_default_graph().get_operations()all_tensor_names = {output.name for op in ops for output in op.outputs}
#         print(all_tensor_names)print(len(all_tensor_names))
for name in all_tensor_names:
#     if name.startswith('SecondStage'):
#         print(name)if name.find('bottleneck')==-1:passelse:print(name)

读取labelmap

category_index=label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS)
print(category_index)
def load_image_into_numpy_array(image):(im_width, im_height) = image.sizereturn np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)

定义推理函数

def run_inference_for_single_image(image, graph):with graph.as_default():with tf.Session() as sess:# Get handles to input and output tensorsops = tf.get_default_graph().get_operations()all_tensor_names = {output.name for op in ops for output in op.outputs}tensor_dict = {}for key in ['num_detections', 'detection_boxes', 'detection_scores','detection_classes', 'detection_masks']:tensor_name = key + ':0'if tensor_name in all_tensor_names:tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)if 'detection_masks' in tensor_dict:# The following processing is only for single imagedetection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])# Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(detection_masks, detection_boxes, image.shape[0], image.shape[1])detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)# Follow the convention by adding back the batch dimensiontensor_dict['detection_masks'] = tf.expand_dims(detection_masks_reframed, 0)image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')# Run inferenceoutput_dict = sess.run(tensor_dict,feed_dict={image_tensor: np.expand_dims(image, 0)})# all outputs are float32 numpy arrays, so convert types as appropriateoutput_dict['num_detections'] = int(output_dict['num_detections'][0])output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)output_dict['detection_boxes'] = output_dict['detection_boxes'][0]output_dict['detection_scores'] = output_dict['detection_scores'][0]if 'detection_masks' in output_dict:output_dict['detection_masks'] = output_dict['detection_masks'][0]return output_dict

进行推理

import random
import cv2
import glob
PATH_TO_TEST_IMAGES_DIR = './images_for_test'
TEST_IMAGE_PATHS = glob.glob(os.path.join(PATH_TO_TEST_IMAGES_DIR, "*.jpg"))
print(len(TEST_IMAGE_PATHS))
IMAGE_SIZE=(12,8)

推理

count=0
saved_dir_for_predicted_images="./images_predicted"
for count in range(10):chx = random.randint(0, len(TEST_IMAGE_PATHS)-1)image_path = TEST_IMAGE_PATHS[chx]image = Image.open(image_path)# the array based representation of the image will be used later in order to prepare the# result image with boxes and labels on it.image_np = load_image_into_numpy_array(image)# Expand dimensions since the model expects images to have shape: [1, None, None, 3]image_np_expanded = np.expand_dims(image_np, axis=0)# Actual detection.output_dict = run_inference_for_single_image(image_np, detection_graph)# print(output_dict)# Visualization of the results of a detection.vis_util.visualize_boxes_and_labels_on_image_array(image_np,output_dict['detection_boxes'],output_dict['detection_classes'],output_dict['detection_scores'],category_index,instance_masks=output_dict.get('detection_masks'),use_normalized_coordinates=True,line_thickness=4)im = Image.fromarray(image_np.astype('uint8'))saved_name = os.path.basename(image_path).split('.')[0]im.save(os.path.join(saved_dir_for_predicted_images, saved_name+'.jpg'))
#   image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
#   cv2.imshow('image',image_np)
#   cv2.waitKey(10)
#   cv2.destroyAllWindows()
#   if cv2.waitKey(1000)&0xff == 113:
#     cv2.destroyAllWindows()
#   plt.figure(figsize=IMAGE_SIZE)
#   plt.imshow(image_np)
#   plt.show()

这篇关于Object_Detection_API之Inference的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/701677

相关文章

PHP应用中处理限流和API节流的最佳实践

《PHP应用中处理限流和API节流的最佳实践》限流和API节流对于确保Web应用程序的可靠性、安全性和可扩展性至关重要,本文将详细介绍PHP应用中处理限流和API节流的最佳实践,下面就来和小编一起学习... 目录限流的重要性在 php 中实施限流的最佳实践使用集中式存储进行状态管理(如 Redis)采用滑动

Go语言使用net/http构建一个RESTful API的示例代码

《Go语言使用net/http构建一个RESTfulAPI的示例代码》Go的标准库net/http提供了构建Web服务所需的强大功能,虽然众多第三方框架(如Gin、Echo)已经封装了很多功能,但... 目录引言一、什么是 RESTful API?二、实战目标:用户信息管理 API三、代码实现1. 用户数据

Python用Flask封装API及调用详解

《Python用Flask封装API及调用详解》本文介绍Flask的优势(轻量、灵活、易扩展),对比GET/POST表单/JSON请求方式,涵盖错误处理、开发建议及生产环境部署注意事项... 目录一、Flask的优势一、基础设置二、GET请求方式服务端代码客户端调用三、POST表单方式服务端代码客户端调用四

SpringBoot结合Knife4j进行API分组授权管理配置详解

《SpringBoot结合Knife4j进行API分组授权管理配置详解》在现代的微服务架构中,API文档和授权管理是不可或缺的一部分,本文将介绍如何在SpringBoot应用中集成Knife4j,并进... 目录环境准备配置 Swagger配置 Swagger OpenAPI自定义 Swagger UI 底

使用Python的requests库调用API接口的详细步骤

《使用Python的requests库调用API接口的详细步骤》使用Python的requests库调用API接口是开发中最常用的方式之一,它简化了HTTP请求的处理流程,以下是详细步骤和实战示例,涵... 目录一、准备工作:安装 requests 库二、基本调用流程(以 RESTful API 为例)1.

SpringBoot监控API请求耗时的6中解决解决方案

《SpringBoot监控API请求耗时的6中解决解决方案》本文介绍SpringBoot中记录API请求耗时的6种方案,包括手动埋点、AOP切面、拦截器、Filter、事件监听、Micrometer+... 目录1. 简介2.实战案例2.1 手动记录2.2 自定义AOP记录2.3 拦截器技术2.4 使用Fi

Python错误AttributeError: 'NoneType' object has no attribute问题的彻底解决方法

《Python错误AttributeError:NoneTypeobjecthasnoattribute问题的彻底解决方法》在Python项目开发和调试过程中,经常会碰到这样一个异常信息... 目录问题背景与概述错误解读:AttributeError: 'NoneType' object has no at

Knife4j+Axios+Redis前后端分离架构下的 API 管理与会话方案(最新推荐)

《Knife4j+Axios+Redis前后端分离架构下的API管理与会话方案(最新推荐)》本文主要介绍了Swagger与Knife4j的配置要点、前后端对接方法以及分布式Session实现原理,... 目录一、Swagger 与 Knife4j 的深度理解及配置要点Knife4j 配置关键要点1.Spri

HTML5 getUserMedia API网页录音实现指南示例小结

《HTML5getUserMediaAPI网页录音实现指南示例小结》本教程将指导你如何利用这一API,结合WebAudioAPI,实现网页录音功能,从获取音频流到处理和保存录音,整个过程将逐步... 目录1. html5 getUserMedia API简介1.1 API概念与历史1.2 功能与优势1.3

使用Python实现调用API获取图片存储到本地的方法

《使用Python实现调用API获取图片存储到本地的方法》开发一个自动化工具,用于从JSON数据源中提取图像ID,通过调用指定API获取未经压缩的原始图像文件,并确保下载结果与Postman等工具直接... 目录使用python实现调用API获取图片存储到本地1、项目概述2、核心功能3、环境准备4、代码实现