详解矩阵乘法中的Strassen算法

2024-06-02 16:38

本文主要是介绍详解矩阵乘法中的Strassen算法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 

机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。

一、矩阵乘法

如下图所示:

Figure 1 Matrix Multiplication

二、Strassen算法

Figure 2 x^3 vs. x^2.807

三、Strassen原理详解

Strassen算法正是从这个角度出发,实现了降低算法复杂度!

实现步骤可以分为以下4步:

3.1 Strassen实现步骤

 

四、Strassen算法的代码实现

我们以MNN中关于Strassen算法源码实现来学习:https://github.com/alibaba/MNN/blob/master/source/backend/cpu/compute/StrassenMatmulComputor.cpp。

类StrassenMatrixComputor提供了3个API供调用:

_generateTrivalMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT);

普通矩阵乘法计算

_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法

_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth);

Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用)

我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:

第一步:使用Strassen算法收益判断

在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法

    /*Compute the memory read / write cost for expandMatrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write)*/float saveCost =(eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3);if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f) {return _generateTrivialMatMul(AT, BT, CT);}

第二步:分块

    auto aStride = AT->stride(0);auto a11     = AT->host<float>() + 0 * aUnit * eSub + 0 * aStride * lSub;auto a12     = AT->host<float>() + 0 * aUnit * eSub + 1 * aStride * lSub;auto a21     = AT->host<float>() + 1 * aUnit * eSub + 0 * aStride * lSub;auto a22     = AT->host<float>() + 1 * aUnit * eSub + 1 * aStride * lSub;auto bStride = BT->stride(0);auto b11     = BT->host<float>() + 0 * bUnit * lSub + 0 * bStride * hSub;auto b12     = BT->host<float>() + 0 * bUnit * lSub + 1 * bStride * hSub;auto b21     = BT->host<float>() + 1 * bUnit * lSub + 0 * bStride * hSub;auto b22     = BT->host<float>() + 1 * bUnit * lSub + 1 * bStride * hSub;auto cStride = CT->stride(0);auto c11     = CT->host<float>() + 0 * aUnit * eSub + 0 * cStride * hSub;auto c12     = CT->host<float>() + 0 * aUnit * eSub + 1 * cStride * hSub;auto c21     = CT->host<float>() + 1 * aUnit * eSub + 0 * cStride * hSub;auto c22     = CT->host<float>() + 1 * aUnit * eSub + 1 * cStride * hSub;

第三步:分治和递归

Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:

1. If n = 1 Output A × B
2. Else
3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2
4. P1   Strassen(A11,B12 − B22)
5. P2   Strassen(A11 + A12,B22)
6. P3   Strassen(A21 + A22,B11)
7. P4   Strassen(A22,B21 − B11)
8. P5   Strassen(A11 + A22,B11 + B22)
9. P6   Strassen(A12 − A22,B21 + B22)
10. P7   Strassen(A11 − A21,B11 + B12)
11. C11   P5 + P4 − P2 + P6
12. C12   P1 + P2
13. C21   P3 + P4
14. C22   P1 + P5 − P3 − P7
15. Output C
16. End If

例如其中的一步代码如下所示:

   {// S1=A21+A22, T1=B12-B11, P5=S1T1auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() {MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub);MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub);};mFunctions.emplace_back(f);auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth);if (code != NO_ERROR) {return code;}}

递归执行,得到最终结果!

这篇关于详解矩阵乘法中的Strassen算法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1024517

相关文章

Spring StateMachine实现状态机使用示例详解

《SpringStateMachine实现状态机使用示例详解》本文介绍SpringStateMachine实现状态机的步骤,包括依赖导入、枚举定义、状态转移规则配置、上下文管理及服务调用示例,重点解... 目录什么是状态机使用示例什么是状态机状态机是计算机科学中的​​核心建模工具​​,用于描述对象在其生命

Java JDK1.8 安装和环境配置教程详解

《JavaJDK1.8安装和环境配置教程详解》文章简要介绍了JDK1.8的安装流程,包括官网下载对应系统版本、安装时选择非系统盘路径、配置JAVA_HOME、CLASSPATH和Path环境变量,... 目录1.下载JDK2.安装JDK3.配置环境变量4.检验JDK官网下载地址:Java Downloads

使用Python删除Excel中的行列和单元格示例详解

《使用Python删除Excel中的行列和单元格示例详解》在处理Excel数据时,删除不需要的行、列或单元格是一项常见且必要的操作,本文将使用Python脚本实现对Excel表格的高效自动化处理,感兴... 目录开发环境准备使用 python 删除 Excphpel 表格中的行删除特定行删除空白行删除含指定

MySQL中的LENGTH()函数用法详解与实例分析

《MySQL中的LENGTH()函数用法详解与实例分析》MySQLLENGTH()函数用于计算字符串的字节长度,区别于CHAR_LENGTH()的字符长度,适用于多字节字符集(如UTF-8)的数据验证... 目录1. LENGTH()函数的基本语法2. LENGTH()函数的返回值2.1 示例1:计算字符串

Spring Boot spring-boot-maven-plugin 参数配置详解(最新推荐)

《SpringBootspring-boot-maven-plugin参数配置详解(最新推荐)》文章介绍了SpringBootMaven插件的5个核心目标(repackage、run、start... 目录一 spring-boot-maven-plugin 插件的5个Goals二 应用场景1 重新打包应用

mybatis执行insert返回id实现详解

《mybatis执行insert返回id实现详解》MyBatis插入操作默认返回受影响行数,需通过useGeneratedKeys+keyProperty或selectKey获取主键ID,确保主键为自... 目录 两种方式获取自增 ID:1. ​​useGeneratedKeys+keyProperty(推

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

Linux系统性能检测命令详解

《Linux系统性能检测命令详解》本文介绍了Linux系统常用的监控命令(如top、vmstat、iostat、htop等)及其参数功能,涵盖进程状态、内存使用、磁盘I/O、系统负载等多维度资源监控,... 目录toppsuptimevmstatIOStatiotopslabtophtopdstatnmon

java使用protobuf-maven-plugin的插件编译proto文件详解

《java使用protobuf-maven-plugin的插件编译proto文件详解》:本文主要介绍java使用protobuf-maven-plugin的插件编译proto文件,具有很好的参考价... 目录protobuf文件作为数据传输和存储的协议主要介绍在Java使用maven编译proto文件的插件

Android ClassLoader加载机制详解

《AndroidClassLoader加载机制详解》Android的ClassLoader负责加载.dex文件,基于双亲委派模型,支持热修复和插件化,需注意类冲突、内存泄漏和兼容性问题,本文给大家介... 目录一、ClassLoader概述1.1 类加载的基本概念1.2 android与Java Class