caffe源码解析:卷积乘法中用到的im2col及col2im

2023-11-07 05:59

本文主要是介绍caffe源码解析:卷积乘法中用到的im2col及col2im,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这两个函数其实完成的功能比较简单,im2col就是把矩阵按卷积乘法所需,变换成列向量,col2im是一个逆过程

从下面这张图你一眼就能看明白im2col的操作(caffe中卷积计算都是Matrix_Kernel * Matrix_Col),因为都列出来太长了,我只列出了前4个,注意这是四周围完全没有填充0的情况,

 

col2im是一个反过来的过程,那么你可能会好奇,这两个操作能完全可逆吗?

事实上,结构是可逆的,结果不是,下面这个图很好地说明了展开的计算过程(图片比较大,可下载到电脑上看),

下面是一个可单独运行的测试源码,你可以随便编译跑一跑

#include <iostream>
using namespace std;inline bool is_a_ge_zero_and_a_lt_b(int a, int b) {return static_cast<unsigned>(a) < static_cast<unsigned>(b);
}template <typename Dtype>
void caffe_set(const int N, const Dtype alpha, Dtype* Y) {if (alpha == 0) {memset(Y, 0, sizeof(Dtype) * N);  // NOLINT(caffe/alt_fn)return;}for (int i = 0; i < N; ++i) {Y[i] = alpha;}
}template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,const int height, const int width, const int kernel_h, const int kernel_w,const int pad_h, const int pad_w,const int stride_h, const int stride_w,const int dilation_h, const int dilation_w,Dtype* data_col) {const int output_h = (height + 2 * pad_h -(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;const int output_w = (width + 2 * pad_w -(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;const int channel_size = height * width;for (int channel = channels; channel--; data_im += channel_size) {for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {int input_row = -pad_h + kernel_row * dilation_h;for (int output_rows = output_h; output_rows; output_rows--) {if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {for (int output_cols = output_w; output_cols; output_cols--) {*(data_col++) = 0;}}else {int input_col = -pad_w + kernel_col * dilation_w;for (int output_col = output_w; output_col; output_col--) {if (is_a_ge_zero_and_a_lt_b(input_col, width)) {*(data_col++) = data_im[input_row * width + input_col];}else {*(data_col++) = 0;}input_col += stride_w;}}input_row += stride_h;}}}}
}template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,const int height, const int width, const int kernel_h, const int kernel_w,const int pad_h, const int pad_w,const int stride_h, const int stride_w,const int dilation_h, const int dilation_w,Dtype* data_im) {caffe_set(height * width * channels, Dtype(0), data_im);const int output_h = (height + 2 * pad_h -(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;const int output_w = (width + 2 * pad_w -(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;const int channel_size = height * width;for (int channel = channels; channel--; data_im += channel_size) {for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {int input_row = -pad_h + kernel_row * dilation_h;for (int output_rows = output_h; output_rows; output_rows--) {if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {data_col += output_w;}else {int input_col = -pad_w + kernel_col * dilation_w;for (int output_col = output_w; output_col; output_col--) {if (is_a_ge_zero_and_a_lt_b(input_col, width)) {data_im[input_row * width + input_col] += *data_col;}data_col++;input_col += stride_w;}}input_row += stride_h;}}}}
}// 如果想运行6x6的矩阵,请取消下面的注释,并把5X5那段注释掉
int dataim[] = {1,2,3,4,5,6,5,6,7,8,9,10,6,5,4,3,2,1,10,9,8,7,6,5,4,3,2,1,5,6,3,2,1,6,5,4,
};int datacol[1000];
int outim[50];int main()
{im2col_cpu(dataim, 1, 6, 6, 3, 3, 0, 0, 1, 1, 1, 1, datacol);col2im_cpu(datacol, 1, 6, 6, 3, 3, 0, 0, 1, 1, 1, 1, outim);return 0;
}// 如果想运行5x5的矩阵,请取消下面的注释, 并把上面那段注释掉
/* 
int dataim[] = {1,2,3,4,5,6,7,8,9,10,5,4,3,2,1,10,9,8,7,6,4,3,2,1,5,
};int datacol[1000];
int outim[50];int main()
{im2col_cpu(dataim, 1, 5, 5, 3, 3, 0, 0, 1, 1, 1, 1, datacol);col2im_cpu(datacol, 1, 5, 5, 3, 3, 0, 0, 1, 1, 1, 1, outim);return 0;
}*/

按上面源码的操作,先运行im2col,再运行col2im,结果就很有意思了,相当于每个元素都乘了一个放大系数,只是不同的位置的放大系数是不一样的,看下面的图

仔细看那个放大系数矩阵,非常有规律,有木有?

AI视像算法学习群:824991413

这篇关于caffe源码解析:卷积乘法中用到的im2col及col2im的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/361692

相关文章

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

全面解析Golang 中的 Gorilla CORS 中间件正确用法

《全面解析Golang中的GorillaCORS中间件正确用法》Golang中使用gorilla/mux路由器配合rs/cors中间件库可以优雅地解决这个问题,然而,很多人刚开始使用时会遇到配... 目录如何让 golang 中的 Gorilla CORS 中间件正确工作一、基础依赖二、错误用法(很多人一开

Mysql中设计数据表的过程解析

《Mysql中设计数据表的过程解析》数据库约束通过NOTNULL、UNIQUE、DEFAULT、主键和外键等规则保障数据完整性,自动校验数据,减少人工错误,提升数据一致性和业务逻辑严谨性,本文介绍My... 目录1.引言2.NOT NULL——制定某列不可以存储NULL值2.UNIQUE——保证某一列的每一

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

MySQL CTE (Common Table Expressions)示例全解析

《MySQLCTE(CommonTableExpressions)示例全解析》MySQL8.0引入CTE,支持递归查询,可创建临时命名结果集,提升复杂查询的可读性与维护性,适用于层次结构数据处... 目录基本语法CTE 主要特点非递归 CTE简单 CTE 示例多 CTE 示例递归 CTE基本递归 CTE 结

Spring Boot 3.x 中 WebClient 示例详解析

《SpringBoot3.x中WebClient示例详解析》SpringBoot3.x中WebClient是响应式HTTP客户端,替代RestTemplate,支持异步非阻塞请求,涵盖GET... 目录Spring Boot 3.x 中 WebClient 全面详解及示例1. WebClient 简介2.

在MySQL中实现冷热数据分离的方法及使用场景底层原理解析

《在MySQL中实现冷热数据分离的方法及使用场景底层原理解析》MySQL冷热数据分离通过分表/分区策略、数据归档和索引优化,将频繁访问的热数据与冷数据分开存储,提升查询效率并降低存储成本,适用于高并发... 目录实现冷热数据分离1. 分表策略2. 使用分区表3. 数据归档与迁移在mysql中实现冷热数据分

C#解析JSON数据全攻略指南

《C#解析JSON数据全攻略指南》这篇文章主要为大家详细介绍了使用C#解析JSON数据全攻略指南,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录一、为什么jsON是C#开发必修课?二、四步搞定网络JSON数据1. 获取数据 - HttpClient最佳实践2. 动态解析 - 快速

Spring Boot3.0新特性全面解析与应用实战

《SpringBoot3.0新特性全面解析与应用实战》SpringBoot3.0作为Spring生态系统的一个重要里程碑,带来了众多令人兴奋的新特性和改进,本文将深入解析SpringBoot3.0的... 目录核心变化概览Java版本要求提升迁移至Jakarta EE重要新特性详解1. Native Ima

spring中的@MapperScan注解属性解析

《spring中的@MapperScan注解属性解析》@MapperScan是Spring集成MyBatis时自动扫描Mapper接口的注解,简化配置并支持多数据源,通过属性控制扫描路径和过滤条件,利... 目录一、核心功能与作用二、注解属性解析三、底层实现原理四、使用场景与最佳实践五、注意事项与常见问题六