tensorRT C++使用pt转engine模型进行推理

2024-06-23 17:12

本文主要是介绍tensorRT C++使用pt转engine模型进行推理,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 1. 前言
  • 2. 模型转换
  • 3. 修改Binding
  • 4. 修改后处理

1. 前言

本文不讲tensorRT的推理流程,因为这种文章很多,这里着重讲从标准yolov5的tensort推理代码(模型转pt->wts->engine)改造成TPH-yolov5(pt->onnx->engine)的过程。

2. 模型转换

请查看上一篇文章https://blog.csdn.net/wyw0000/article/details/139737473?spm=1001.2014.3001.5502

3. 修改Binding

如果不修改Binding,会报下图中的错误。
在这里插入图片描述
该问题是由于Binding有多个,而代码中只申请了input和output,那么如何查看engine模型有几个Bingding呢?代码如下:

int get_model_info(const string& model_path) {// 创建 loggerLogger gLogger;// 从文件中读取 enginestd::ifstream engineFile(model_path, std::ios::binary);if (!engineFile) {std::cerr << "Failed to open engine file." << std::endl;return -1;}engineFile.seekg(0, engineFile.end);long int fsize = engineFile.tellg();engineFile.seekg(0, engineFile.beg);std::vector<char> engineData(fsize);engineFile.read(engineData.data(), fsize);if (!engineFile) {std::cerr << "Failed to read engine file." << std::endl;return -1;}// 反序列化 engineauto runtime = nvinfer1::createInferRuntime(gLogger);auto engine = runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);// 获取并打印输入和输出绑定信息for (int i = 0; i < engine->getNbBindings(); ++i) {nvinfer1::Dims dims = engine->getBindingDimensions(i);nvinfer1::DataType type = engine->getBindingDataType(i);std::cout << "Binding " << i << " (" << engine->getBindingName(i) << "):" << std::endl;std::cout << "  Type: " << (int)type << std::endl;std::cout << "  Dimensions: ";for (int j = 0; j < dims.nbDims; ++j) {std::cout << (j ? "x" : "") << dims.d[j];}std::cout << std::endl;std::cout << "  Is Input: " << (engine->bindingIsInput(i) ? "Yes" : "No") << std::endl;}// 清理资源engine->destroy();runtime->destroy();return 0;
}

下图是我的tph-yolov5的Binding,可以看到有5个Binding,因此在doInference推理之前,要给5个Binding都申请空间,同时要注意获取BindingIndex时,名称和dimension与查询出来的对应。
在这里插入图片描述

//for tph-yolov5int Sigmoid_921_index = trt->engine->getBindingIndex("onnx::Sigmoid_921");int Sigmoid_1183_index = trt->engine->getBindingIndex("onnx::Sigmoid_1183");int Sigmoid_1367_index = trt->engine->getBindingIndex("onnx::Sigmoid_1367");CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_921_index], BATCH_SIZE * 3 * 192 * 192 * 7 * sizeof(float)));CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1183_index], BATCH_SIZE * 3 * 96 * 96 * 7 * sizeof(float)));CUDA_CHECK(cudaMalloc(&trt->buffers[Sigmoid_1367_index], BATCH_SIZE * 3 * 48 * 48 * 7 * sizeof(float)));trt->data = new float[BATCH_SIZE * 3 * INPUT_H * INPUT_W];trt->prob = new float[BATCH_SIZE * OUTPUT_SIZE];trt->inputIndex = trt->engine->getBindingIndex(INPUT_BLOB_NAME);trt->outputIndex = trt->engine->getBindingIndex(OUTPUT_BLOB_NAME);

还有推理的部分也要做修改,原来只有input和output两个Binding时,那么输出是buffers[1],而目前是有5个Binding那么输出就变成了buffers[4]

void doInference(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output, int batchSize) {// infer on the batch asynchronously, and DMA output back to hostcontext.enqueueV2(buffers, stream, nullptr);//CUDA_CHECK(cudaMemcpyAsync(output, buffers[1], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));CUDA_CHECK(cudaMemcpyAsync(output, buffers[4], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));cudaStreamSynchronize(stream);
}

4. 修改后处理

之前的yolov5推理代码是将pt模型转为wts再转为engine的,输出维度只有一维,而TPH输出维度为145152*7,因此要对原来的后处理代码进行修改。

struct BoundingBox {//bbox[0],bbox[1],bbox[2],bbox[3],conf, class_idfloat x1, y1, x2, y2, score, index;
};float iou(const BoundingBox&  box1, const BoundingBox& box2) {float max_x = max(box1.x1, box2.x1);  // 找出左上角坐标哪个大float min_x = min(box1.x2, box2.x2);  // 找出右上角坐标哪个小float max_y = max(box1.y1, box2.y1);float min_y = min(box1.y2, box2.y2);if (min_x <= max_x || min_y <= max_y) // 如果没有重叠return 0;float over_area = (min_x - max_x) * (min_y - max_y);  // 计算重叠面积float area_a = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);float area_b = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);float iou = over_area / (area_a + area_b - over_area);return iou;
}std::vector<BoundingBox> nonMaximumSuppression(std::vector<std::vector<float>>& boxes, float overlapThreshold) {std::vector<BoundingBox> convertedBoxes;// 将数据转换为BoundingBox结构体for (const auto&  box: boxes) {if (box.size() == 6) { // Assuming [x1, y1, x2, y2, score]BoundingBox bbox;bbox.x1 = box[0];bbox.y1 = box[1];bbox.x2 = box[2];bbox.y2 = box[3];bbox.score = box[4];bbox.index = box[5];convertedBoxes.push_back(bbox);}else {std::cerr << "Invalid box format!" << std::endl;}}// 对框按照分数降序排序std::sort(convertedBoxes.begin(), convertedBoxes.end(), [](const BoundingBox& a, const BoundingBox&  b) {return a.score > b.score;});// 非最大抑制std::vector<BoundingBox> result;std::vector<bool> isSuppressed(convertedBoxes.size(), false);for (size_t i = 0; i < convertedBoxes.size(); ++i) {if (!isSuppressed[i]) {result.push_back(convertedBoxes[i]);for (size_t j = i + 1; j < convertedBoxes.size(); ++j) {if (!isSuppressed[j]) {float overlap = iou(convertedBoxes[i], convertedBoxes[j]);if (overlap > overlapThreshold) {isSuppressed[j] = true;}}}}}
#if 0// 输出结果std::cout << "NMS Result:" << std::endl;for (const auto& box: result) {std::cout << "x1: " << box.x1 << ", y1: " << box.y1<< ", x2: " << box.x2 << ", y2: " << box.y2<< ", score: " << box.score << ",index:" << box.index << std::endl;}
#endif return result;
}void post_process(float *prob_model, float conf_thres, float overlapThreshold, std::vector<Yolo::Detection> & detResult)
{int cols = 7, rows = 145152;//  ========== 8. 获取推理结果 =========std::vector<std::vector<float>> prediction(rows, std::vector<float>(cols));int index = 0;for (int i = 0; i < rows; ++i) {for (int j = 0; j < cols; ++j) {prediction[i][j] = prob_model[index++];}}//  ========== 9. 大于conf_thres加入xc =========std::vector<std::vector<float>> xc;for (const auto& row : prediction) {if (row[4] > conf_thres) {xc.push_back(row);}}//  ========== 10. 置信度 = obj_conf * cls_conf =========//std::cout << xc[0].size() << std::endl;for (auto& row: xc) {for (int i = 5; i < xc[0].size(); i++) {row[i] *= row[4];}}// ========== 11. 切片取出xywh 转为xyxy=========std::vector<std::vector<float>> xywh;for (const auto& row: xc) {std::vector<float> sliced_row(row.begin(), row.begin() + 4);xywh.push_back(sliced_row);}std::vector<std::vector<float>> box(xywh.size(), std::vector<float>(4, 0.0));xywhtoxxyy(xywh, box);// ========== 12. 获取置信度最高的类别和索引=========std::size_t mi = xc[0].size();std::vector<float> conf(xc.size(), 0.0);std::vector<float> j(xc.size(), 0.0);for (std::size_t i = 0; i < xc.size(); ++i) {// 模拟切片操作 x[:, 5:mi]auto sliced_x = std::vector<float>(xc[i].begin() + 5, xc[i].begin() + mi);// 计算 maxauto max_it = std::max_element(sliced_x.begin(), sliced_x.end());// 获取 max 的索引std::size_t max_index = std::distance(sliced_x.begin(), max_it);// 将 max 的值和索引存储到相应的向量中conf[i] = *max_it;j[i] = max_index;  // 加上切片的起始索引}// ========== 13. concat x1, y1, x2, y2, score, index;======== =for (int i = 0; i < xc.size(); i++) {box[i].push_back(conf[i]);box[i].push_back(j[i]);}std::vector<std::vector<float>> output;for (int i = 0; i < xc.size(); i++) {output.push_back(box[i]); // 创建一个空的 float 向量并}// ==========14 应用非最大抑制 ==========std::vector<BoundingBox>  result = nonMaximumSuppression(output, overlapThreshold);for (const auto& r : result){Yolo::Detection det;det.bbox[0] = r.x1;det.bbox[1] = r.y1;det.bbox[2] = r.x2;det.bbox[3] = r.y2;det.conf = r.score;det.class_id = r.index;detResult.push_back(det);}}

代码参考:
https://blog.csdn.net/rooftopstars/article/details/136771496
https://blog.csdn.net/qq_73794703/article/details/132147879

这篇关于tensorRT C++使用pt转engine模型进行推理的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1087768

相关文章

Linux中压缩、网络传输与系统监控工具的使用完整指南

《Linux中压缩、网络传输与系统监控工具的使用完整指南》在Linux系统管理中,压缩与传输工具是数据备份和远程协作的桥梁,而系统监控工具则是保障服务器稳定运行的眼睛,下面小编就来和大家详细介绍一下它... 目录引言一、压缩与解压:数据存储与传输的优化核心1. zip/unzip:通用压缩格式的便捷操作2.

使用Python实现可恢复式多线程下载器

《使用Python实现可恢复式多线程下载器》在数字时代,大文件下载已成为日常操作,本文将手把手教你用Python打造专业级下载器,实现断点续传,多线程加速,速度限制等功能,感兴趣的小伙伴可以了解下... 目录一、智能续传:从崩溃边缘抢救进度二、多线程加速:榨干网络带宽三、速度控制:做网络的好邻居四、终端交互

Python中注释使用方法举例详解

《Python中注释使用方法举例详解》在Python编程语言中注释是必不可少的一部分,它有助于提高代码的可读性和维护性,:本文主要介绍Python中注释使用方法的相关资料,需要的朋友可以参考下... 目录一、前言二、什么是注释?示例:三、单行注释语法:以 China编程# 开头,后面的内容为注释内容示例:示例:四

从入门到精通C++11 <chrono> 库特性

《从入门到精通C++11<chrono>库特性》chrono库是C++11中一个非常强大和实用的库,它为时间处理提供了丰富的功能和类型安全的接口,通过本文的介绍,我们了解了chrono库的基本概念... 目录一、引言1.1 为什么需要<chrono>库1.2<chrono>库的基本概念二、时间段(Durat

C++20管道运算符的实现示例

《C++20管道运算符的实现示例》本文简要介绍C++20管道运算符的使用与实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧... 目录标准库的管道运算符使用自己实现类似的管道运算符我们不打算介绍太多,因为它实际属于c++20最为重要的

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

Go语言数据库编程GORM 的基本使用详解

《Go语言数据库编程GORM的基本使用详解》GORM是Go语言流行的ORM框架,封装database/sql,支持自动迁移、关联、事务等,提供CRUD、条件查询、钩子函数、日志等功能,简化数据库操作... 目录一、安装与初始化1. 安装 GORM 及数据库驱动2. 建立数据库连接二、定义模型结构体三、自动迁

ModelMapper基本使用和常见场景示例详解

《ModelMapper基本使用和常见场景示例详解》ModelMapper是Java对象映射库,支持自动映射、自定义规则、集合转换及高级配置(如匹配策略、转换器),可集成SpringBoot,减少样板... 目录1. 添加依赖2. 基本用法示例:简单对象映射3. 自定义映射规则4. 集合映射5. 高级配置匹

Spring 框架之Springfox使用详解

《Spring框架之Springfox使用详解》Springfox是Spring框架的API文档工具,集成Swagger规范,自动生成文档并支持多语言/版本,模块化设计便于扩展,但存在版本兼容性、性... 目录核心功能工作原理模块化设计使用示例注意事项优缺点优点缺点总结适用场景建议总结Springfox 是

嵌入式数据库SQLite 3配置使用讲解

《嵌入式数据库SQLite3配置使用讲解》本文强调嵌入式项目中SQLite3数据库的重要性,因其零配置、轻量级、跨平台及事务处理特性,可保障数据溯源与责任明确,详细讲解安装配置、基础语法及SQLit... 目录0、惨痛教训1、SQLite3环境配置(1)、下载安装SQLite库(2)、解压下载的文件(3)、