图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)

2024-03-26 12:20

本文主要是介绍图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

概述

DIS(Dichotomous Image Segmentation)是一种新的图像分割任务,旨在从自然图像中分割出高精度的物体。与传统的图像分割任务相比,DIS更侧重于具有单个或几个目标的图像,因此可以提供更丰富准确的细节。

为了研究DIS任务,研究人员创建了一个名为DIS5K的大规模、可扩展的数据集。DIS5K数据集包含了5,470张高分辨率图像,每张图像都配有高精度的二值分割掩码。这个数据集的建立有助于推动多个应用方向的发展,如图像去背景、艺术设计、模拟视图运动、基于图像的增强现实(AR)应用、基于视频的AR应用、3D视频制作等。

通过研究DIS任务和使用DIS5K数据集,研究人员可以探索新的图像分割方法,并为各种应用领域提供更精确、更可靠的图像分割技术,从而推动分割技术在更广泛的领域中的应用。

官网:https://xuebinqin.github.io/dis/index.html
Github:https://github.com/xuebinqin/DIS

数据集

图像二类分割是将图像分割成两个主要区域:前景和背景。在这种情况下,前景代表图像中的某个类别的物体,而背景则是除了该物体之外的所有内容。
官方公布了算所使用的数据集DIS5K, DIS5K数据集中的每张图像都经过了像素级别的手工标注,标注的真值掩码非常精确,每张图像的标记时间相当长。这种高精度的标注使得数据集中的每个像素都与其相应的类别关联起来,从而为模型提供了可靠的训练数据。这种高精度的标注是实现图像二类分割的关键,因为模型需要能够准确地识别和分割出前景物体。

在DIS5K数据集中,标注对象的类型多样,包括透明和半透明的物体,标注使用单个像素的二值掩码进行。这种精确的标注确保了模型训练的有效性和准确性,并且使得模型能够预测出高精度的物体分割结果。

DIS5K数据集网盘地址:https://pan.baidu.com/s/1umNk2AeBG5aB5kXlHTHdIg
提取码:7qfs

模型训练

模型训练可参考git上的官方的文档

模型推理

模型C++使用onnxruntime进行推理

#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>class DIS
{
public:DIS(std::string model_path);void inference(cv::Mat& cv_src, cv::Mat& cv_mask);
private:std::vector<float> input_image_;int inpWidth;int inpHeight;int outWidth;int outHeight;const float score_th = 0;Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "DIS");Ort::Session* ort_session = nullptr;Ort::SessionOptions sessionOptions = Ort::SessionOptions();std::vector<char*> input_names;std::vector<char*> output_names;std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputsstd::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};DIS::DIS(std::string model_path)
{std::wstring widestr = std::wstring(model_path.begin(), model_path.end());//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions);size_t numInputNodes = ort_session->GetInputCount();size_t numOutputNodes = ort_session->GetOutputCount();Ort::AllocatorWithDefaultOptions allocator;for (int i = 0; i < numInputNodes; i++){input_names.push_back(ort_session->GetInputName(i, allocator));Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();auto input_dims = input_tensor_info.GetShape();input_node_dims.push_back(input_dims);}for (int i = 0; i < numOutputNodes; i++){output_names.push_back(ort_session->GetOutputName(i, allocator));Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();auto output_dims = output_tensor_info.GetShape();output_node_dims.push_back(output_dims);}this->inpHeight = input_node_dims[0][2];this->inpWidth = input_node_dims[0][3];this->outHeight = output_node_dims[0][2];this->outWidth = output_node_dims[0][3];
}void DIS::inference(cv::Mat& cv_src, cv::Mat& cv_mask)
{cv::Mat cv_dst;cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight));this->input_image_.resize(this->inpWidth * this->inpHeight * cv_dst.channels());for (int c = 0; c < 3; c++){for (int i = 0; i < this->inpHeight; i++){for (int j = 0; j < this->inpWidth; j++){float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];this->input_image_[c * this->inpHeight * this->inpWidth + i * this->inpWidth + j] = pix / 255.0 - 0.5;}}}std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info,input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],&input_tensor_, 1, output_names.data(), output_names.size());   // 开始推理float* pred = ort_outputs[0].GetTensorMutableData<float>();cv::Mat mask(outHeight, outWidth, CV_32FC1, pred);double min_value, max_value;minMaxLoc(mask, &min_value, &max_value, 0, 0);mask = (mask - min_value) / (max_value - min_value);cv::resize(mask, cv_mask, cv::Size(cv_src.cols, cv_src.rows));
}void show_img(std::string name, const cv::Mat& img)
{cv::namedWindow(name, 0);int max_rows = 500;int max_cols = 600;if (img.rows >= img.cols && img.rows > max_rows) {cv::resizeWindow(name, cv::Size(img.cols * max_rows / img.rows, max_rows));}else if (img.cols >= img.rows && img.cols > max_cols) {cv::resizeWindow(name, cv::Size(max_cols, img.rows * max_cols / img.cols));}cv::imshow(name, img);
}cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{int width = cv_src.cols;int height = cv_src.rows;cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);float* alpha_data = (float*)alpha.data;for (int i = 0; i < height; i++){for (int j = 0; j < width; j++){float alpha_ = alpha_data[i * width + j];cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];}}return cv_matting;
}int main()
{DIS dis_net("isnet_general_use_720x1280.onnx");std::string path = "images";std::vector<std::string> filenames;cv::glob(path, filenames, false);for (auto file_name : filenames){cv::Mat cv_src = cv::imread(file_name);//std::vector<cv::Mat> cv_dsts;cv::Mat cv_dst, cv_mask;dis_net.inference(cv_src, cv_mask);std::vector<int> color{255, 0, 0};cv_dst=replaceBG(cv_src, cv_mask, color);show_img("src", cv_src);show_img("mask", cv_mask);show_img("dst", cv_dst);cv::waitKey(0);}
}

python推理代码也依赖onnxruntime

import argparse
import cv2
import numpy as np
import onnxruntime
### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']  inference failed
class DIS():def __init__(self, modelpath, score_th=None):so = onnxruntime.SessionOptions()so.log_severity_level = 3self.net = onnxruntime.InferenceSession(modelpath, so)self.input_height = self.net.get_inputs()[0].shape[2]self.input_width = self.net.get_inputs()[0].shape[3]self.input_name = self.net.get_inputs()[0].nameself.output_name = self.net.get_outputs()[0].nameself.score_th = score_thdef detect(self, srcimg):img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height))img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = img.astype(np.float32) / 255.0 - 0.5blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)outs = self.net.run([self.output_name], {self.input_name: blob})mask = np.array(outs[0]).squeeze()min_value = np.min(mask)max_value = np.max(mask)mask = (mask - min_value) / (max_value - min_value)if self.score_th is not None:mask = np.where(mask < self.score_th, 0, 1)mask *= 255mask = mask.astype('uint8')mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)return maskdef generate_overlay_image(srcimg, mask):overlay_image = np.zeros(srcimg.shape, dtype=np.uint8)overlay_image[:] = (255, 255, 255)mask = np.stack((mask,) * 3, axis=-1).astype('uint8') mask_image = np.where(mask, srcimg, overlay_image)return mask, mask_imageif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg')parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx')args = parser.parse_args()mynet = DIS(args.modelpath)srcimg = cv2.imread(args.imgpath)mask = mynet.detect(srcimg)mask, overlay_image = generate_overlay_image(srcimg, mask)winName = 'Deep learning object detection in onnxruntime'cv2.namedWindow(winName, cv2.WINDOW_NORMAL)cv2.imshow(winName, np.hstack((srcimg, mask)))cv2.waitKey(0)cv2.destroyAllWindows()

推理结果
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
资源和模型下载地址:https://download.csdn.net/download/matt45m/89024664

这篇关于图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/848500

相关文章

python获取指定名字的程序的文件路径的两种方法

《python获取指定名字的程序的文件路径的两种方法》本文主要介绍了python获取指定名字的程序的文件路径的两种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要... 最近在做项目,需要用到给定一个程序名字就可以自动获取到这个程序在Windows系统下的绝对路径,以下

JavaScript中的高级调试方法全攻略指南

《JavaScript中的高级调试方法全攻略指南》什么是高级JavaScript调试技巧,它比console.log有何优势,如何使用断点调试定位问题,通过本文,我们将深入解答这些问题,带您从理论到实... 目录观点与案例结合观点1观点2观点3观点4观点5高级调试技巧详解实战案例断点调试:定位变量错误性能分

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

Python实现批量CSV转Excel的高性能处理方案

《Python实现批量CSV转Excel的高性能处理方案》在日常办公中,我们经常需要将CSV格式的数据转换为Excel文件,本文将介绍一个基于Python的高性能解决方案,感兴趣的小伙伴可以跟随小编一... 目录一、场景需求二、技术方案三、核心代码四、批量处理方案五、性能优化六、使用示例完整代码七、小结一、

Python中 try / except / else / finally 异常处理方法详解

《Python中try/except/else/finally异常处理方法详解》:本文主要介绍Python中try/except/else/finally异常处理方法的相关资料,涵... 目录1. 基本结构2. 各部分的作用tryexceptelsefinally3. 执行流程总结4. 常见用法(1)多个e

C++统计函数执行时间的最佳实践

《C++统计函数执行时间的最佳实践》在软件开发过程中,性能分析是优化程序的重要环节,了解函数的执行时间分布对于识别性能瓶颈至关重要,本文将分享一个C++函数执行时间统计工具,希望对大家有所帮助... 目录前言工具特性核心设计1. 数据结构设计2. 单例模式管理器3. RAII自动计时使用方法基本用法高级用法

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

Python实现精确小数计算的完全指南

《Python实现精确小数计算的完全指南》在金融计算、科学实验和工程领域,浮点数精度问题一直是开发者面临的重大挑战,本文将深入解析Python精确小数计算技术体系,感兴趣的小伙伴可以了解一下... 目录引言:小数精度问题的核心挑战一、浮点数精度问题分析1.1 浮点数精度陷阱1.2 浮点数误差来源二、基础解决

使用Python实现Word文档的自动化对比方案

《使用Python实现Word文档的自动化对比方案》我们经常需要比较两个Word文档的版本差异,无论是合同修订、论文修改还是代码文档更新,人工比对不仅效率低下,还容易遗漏关键改动,下面通过一个实际案例... 目录引言一、使用python-docx库解析文档结构二、使用difflib进行差异比对三、高级对比方

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达