PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出

2024-02-29 03:20

本文主要是介绍PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

获取TensorRT(TRT)模型输入和输出,用于创建TRT的模型服务使用,具体参考脚本check_trt_script.py,如下:

  • 脚本输入:TRT的模型路径和输入图像尺寸
  • 脚本输出:模型的输入和输出结点信息,同时验证TRT模型是否可用
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2021. All rights reserved.
Created by C. L. Wang on 16.9.21
"""import argparseimport numpy as npdef check_trt(model_path, image_size):"""检查TRT模型"""import pycuda.driver as cudaimport tensorrt as trt# 必须导入包,import pycuda.autoinit,否则报错import pycuda.autoinitprint('[Info] model_path: {}'.format(model_path))img_shape = (1, 3, image_size, image_size)print('[Info] img_shape: {}'.format(img_shape))trt_logger = trt.Logger(trt.Logger.WARNING)trt_path = model_path  # TRT模型路径with open(trt_path, 'rb') as f, trt.Runtime(trt_logger) as runtime:engine = runtime.deserialize_cuda_engine(f.read())for binding in engine:binding_idx = engine.get_binding_index(binding)size = engine.get_binding_shape(binding_idx)dtype = trt.nptype(engine.get_binding_dtype(binding))print("[Info] binding: {}, binding_idx: {}, size: {}, dtype: {}".format(binding, binding_idx, size, dtype))input_image = np.random.randn(*img_shape).astype(np.float32)  # 图像尺寸input_image = np.ascontiguousarray(input_image)print('[Info] input_image: {}'.format(input_image.shape))with engine.create_execution_context() as context:stream = cuda.Stream()bindings = [0] * len(engine)for binding in engine:idx = engine.get_binding_index(binding)if engine.binding_is_input(idx):input_memory = cuda.mem_alloc(input_image.nbytes)bindings[idx] = int(input_memory)cuda.memcpy_htod_async(input_memory, input_image, stream)else:dtype = trt.nptype(engine.get_binding_dtype(binding))shape = context.get_binding_shape(idx)output_buffer = np.empty(shape, dtype=dtype)output_buffer = np.ascontiguousarray(output_buffer)output_memory = cuda.mem_alloc(output_buffer.nbytes)bindings[idx] = int(output_memory)context.execute_async_v2(bindings, stream.handle)stream.synchronize()cuda.memcpy_dtoh(output_buffer, output_memory)print("[Info] output_buffer: {}".format(output_buffer))def parse_args():"""处理脚本参数"""parser = argparse.ArgumentParser(description='检查TRT模型')parser.add_argument('-m', dest='model_path', required=True, help='TRT模型路径', type=str)parser.add_argument('-s', dest='image_size', required=False, help='图像尺寸,如336', type=int, default=336)args = parser.parse_args()arg_model_path = args.model_pathprint("[Info] 模型路径: {}".format(arg_model_path))arg_image_size = args.image_sizeprint("[Info] image_size: {}".format(arg_image_size))return arg_model_path, arg_image_sizedef main():arg_model_path, arg_image_size = parse_args()check_trt(arg_model_path, arg_image_size)  # 检查TRT模型if __name__ == '__main__':main()

注意:必须导入包,import pycuda.autoinit,否则cuda.Stream()报错,如下:
image-20210916162952425

输出信息如下:

[Info] 模型路径: ../mydata/trt_models/model_best_c2_20210915_cuda.trt
[Info] image_size: 336
[Info] model_path: ../mydata/trt_models/model_best_c2_20210915_cuda.trt
[Info] img_shape: (1, 3, 336, 336)
[Info] binding: input_0, binding_idx: 0, size: (1, 3, 336, 336), dtype: <class 'numpy.float32'>
[Info] binding: output_0, binding_idx: 1, size: (1, 2), dtype: <class 'numpy.float32'>
[Info] input_image: (1, 3, 336, 336)
[Info] output_buffer: [[ 0.23275298 -0.2184143 ]]

有效信息为:

  • 输入结点binding: input_0,输入尺寸size: (1, 3, 336, 336),输入类型dtype: <class 'numpy.float32'>
  • 输出结果binding: output_0,输出尺寸size: (1, 2),输出类型dtype: <class 'numpy.float32'>

相应的json文件如下:

{"model_path": "model_best_c2_20210915_cuda.trt","model_format": "trt","quant_type": "FP32","gpu_index": 0,"inputs": {"input_0": {"shapes": [1,3,336,336],"type": "FP32"}},"outputs": {"output_0": {"shapes": [1,2],"type": "FP32"}}
}

这篇关于PyTorch随笔 - 获取TensorRT(TRT)模型输入和输出的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/757429

相关文章

python获取cmd环境变量值的实现代码

《python获取cmd环境变量值的实现代码》:本文主要介绍在Python中获取命令行(cmd)环境变量的值,可以使用标准库中的os模块,需要的朋友可以参考下... 前言全局说明在执行py过程中,总要使用到系统环境变量一、说明1.1 环境:Windows 11 家庭版 24H2 26100.4061

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Python使用pynput模拟实现键盘自动输入工具

《Python使用pynput模拟实现键盘自动输入工具》在日常办公和软件开发中,我们经常需要处理大量重复的文本输入工作,所以本文就来和大家介绍一款使用Python的PyQt5库结合pynput键盘控制... 目录概述:当自动化遇上可视化功能全景图核心功能矩阵技术栈深度效果展示使用教程四步操作指南核心代码解析

使用Python获取JS加载的数据的多种实现方法

《使用Python获取JS加载的数据的多种实现方法》在当今的互联网时代,网页数据的动态加载已经成为一种常见的技术手段,许多现代网站通过JavaScript(JS)动态加载内容,这使得传统的静态网页爬取... 目录引言一、动态 网页与js加载数据的原理二、python爬取JS加载数据的方法(一)分析网络请求1

通过cmd获取网卡速率的代码

《通过cmd获取网卡速率的代码》今天从群里看到通过bat获取网卡速率两段代码,感觉还不错,学习bat的朋友可以参考一下... 1、本机有线网卡支持的最高速度:%v%@echo off & setlocal enabledelayedexpansionecho 代码开始echo 65001编码获取: >

使用Python实现调用API获取图片存储到本地的方法

《使用Python实现调用API获取图片存储到本地的方法》开发一个自动化工具,用于从JSON数据源中提取图像ID,通过调用指定API获取未经压缩的原始图像文件,并确保下载结果与Postman等工具直接... 目录使用python实现调用API获取图片存储到本地1、项目概述2、核心功能3、环境准备4、代码实现

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不