Strassen矩阵乘法简要解析(第4章:分治策略)

2024-06-23 15:08

本文主要是介绍Strassen矩阵乘法简要解析(第4章:分治策略),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Strassen矩阵乘法简要解析

Strassen矩阵乘法具体描述如下:

两个n×n 阶的矩阵A与B的乘积是另一个n×n 阶矩阵C,C可表示为假如每一个C(i, j) 都用此公式计算,则计算C所需要的操作次数为n3 m+n2 (n- 1) a,其中m表示一次乘法,a 表示一次加法或减法。

为了使讨论简便,假设n 是2的幂(也就是说, n是1,2,4,8,1 6,...)。

首先,假设n= 1时是一个小问题,n> 1时为一个大问题。后面将根据需要随时修改这个假设。对于1×1阶的小矩阵,可以通过将两矩阵中的两个元素直接相乘而得到结果。

考察一个n> 1的大问题。可以将这样的矩阵分成4个n/ 2×n/ 2阶的矩阵A1,A2,A3,和A4。当n 大于1且n 是2的幂时,n/ 2也是2的幂。因此较小矩阵也满足前面对矩阵大小的假设。矩阵Bi 和Ci 的定义与此类似.

假定strassen矩阵分割方案仅用于n≥8的矩阵乘法,而对于n<8的矩阵乘法则直接利用公式进行计算。则n= 8时,8×8矩阵相乘需要7次4×4矩阵乘法和1 8次4×4矩阵加/减法。每次矩阵乘法需花费6 4m+ 4 8a次操作,每次矩阵加法或减法需花费1 6a次操作。因此总的操作次数为7 ( 6 4m+ 4 8a) + 1 8 ( 1 6a) = 4 4 8m+ 6 2 4a。而使用直接计算方法,则需要5 1 2m+ 4 4 8a次操作。要使S t r a s s e n方法比直接计算方法快,至少要求5 1 2-4 4 8次乘法的开销比6 2 4-4 4 8次加/减法的开销大。或者说一次乘法的开销应该大于近似2 . 7 5次加/减法的开销。

假定n<1 6的矩阵是一个“小”问题,strassen的分解方案仅仅用于n≥1 6的情况,对于n<1 6的矩阵相乘,直接利用公式。则当n= 1 6时使用分而治之算法需要7 ( 5 1 2m+ 4 4 8a) +1 8 ( 6 4a) = 3 5 8 4m+ 4 2 8 8a次操作。直接计算时需要4 0 9 6m+ 3 8 4 0a次操作。若一次乘法的开销与一次加/减法的开销相同,则strassen方法需要7872次操作及用于问题分解的额外时间,而直接计算方法则需要7936次操作加上程序中执行for循环以及其他语句所花费的时间。即使直接计算方法所需要的操作次数比strassen方法少,但由于直接计算方法需要更多的额外开销,因此它也不见得会比strassen方法快。

n 的值越大,Strassen 方法与直接计算方法所用的操作次数的差异就越大,因此对于足够大的n,Strassen 方法将更快。设t (n) 表示使用Strassen 分而治之方法所需的时间。因为大的矩阵会被递归地分割成小矩阵直到每个矩阵的大小小于或等于k(k至少为8,也许更大,具体值由计算机的性能决定). 用迭代方法计算,可得t(n) = (nlog27)。因为log27 ≈2 . 8 1,所以与直接计算方法的复杂性(n3)相比,分而治之矩阵乘法算法有较大的改进。

再次说明:

矩阵C = AB,可写为
C11 = A11B11 + A12B21
C12 = A11B12 + A12B22
C21 = A21B11 + A22B21
C22 = A21B12 + A22B22
如果A、B、C都是二阶矩阵,则共需要8次乘法和4次加法。如果阶大于2,可以将矩阵分块进行计算。耗费的时间是O(nE3)。

要改进算法计算时间的复杂度,必须减少乘法运算次数。按分治法的思想,Strassen提出一种新的方法,用7次乘法完成2阶矩阵的乘法,算法如下:
M1 = A11(B12 - B12)
M2 = (A11 + A12)B22
M3 = (A21 + A22)B11
M4 = A22(B21 - B11)
M5 = (A11 + A22)(B11 + B22)
M6 = (A12 - A22)(B21 + B22)
M7 = (A11 - A21)(B11 + B12)
完成了7次乘法,再做如下加法:
C11 = M5 + M4 - M2 + M6
C12 = M1 + M2
C21 = M3 + M4
C22 = M5 + M1 - M3 - M7
全部计算使用了7次乘法和18次加减法,计算时间降低到O(nE2.81)。计算复杂性得到较大改进。

 

STRASSEN矩阵乘法算法如下:

#include <iostream.h>

const int N=8; //常量N用来定义矩阵的大小

void main()
{

    void STRASSEN(int n,float A[][N],float B[][N],float C[][N]); 
    void input(int n,float p[][N]);
    void output(int n,float C[][N]);                    //函数声明部分

    float A[N][N],B[N][N],C[N][N];  //定义三个矩阵A,B,C

    cout<<"现在录入矩阵A[N][N]:"<<endl<<endl;
    input(N,A);
    cout<<endl<<"现在录入矩阵B[N][N]:"<<endl<<endl;
    input(N,B);                         //录入数组

    STRASSEN(N,A,B,C);   //调用STRASSEN函数计算

    output(N,C);  //输出计算结果
}


void input(int n,float p[][N])  //矩阵输入函数
{
    int i,j;

    for(i=0;i<n;i++)
    {
        cout<<"请输入第"<<i+1<<"行"<<endl;
        for(j=0;j<n;j++)
            cin>>p[i][j];
    }
}

void output(int n,float C[][N]) //据矩阵输出函数
{
    int i,j;
    cout<<"输出矩阵:"<<endl;
    for(i=0;i<n;i++)
    {
        cout<<endl;
        for(j=0;j<n;j++)
            cout<<C[i][j]<<"  ";
    }
    cout<<endl<<endl;

}

void MATRIX_MULTIPLY(float A[][N],float B[][N],float C[][N])  //按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
{
    int i,j,t;
    for(i=0;i<2;i++)                     //计算A*B-->C
        for(j=0;j<2;j++)
        {    
            C[i][j]=0;                   //计算完一个C[i][j],C[i][j]应重新赋值为零
            for(t=0;t<2;t++)
            C[i][j]=C[i][j]+A[i][t]*B[t][j];
        }
}

void MATRIX_ADD(int n,float X[][N],float Y[][N],float Z[][N]) //矩阵加法函数X+Y—>Z
{
    int i,j;
    for(i=0;i<n;i++)
        for(j=0;j<n;j++)
            Z[i][j]=X[i][j]+Y[i][j];
}

void MATRIX_SUB(int n,float X[][N],float Y[][N],float Z[][N]) //矩阵减法函数X-Y—>Z
{
    int i,j;
    for(i=0;i<n;i++)
        for(j=0;j<n;j++)
            Z[i][j]=X[i][j]-Y[i][j];

}


void STRASSEN(int n,float A[][N],float B[][N],float C[][N])  //STRASSEN函数(递归)
{
    float A11[N][N],A12[N][N],A21[N][N],A22[N][N];
    float B11[N][N],B12[N][N],B21[N][N],B22[N][N];
    float C11[N][N],C12[N][N],C21[N][N],C22[N][N];
    float M1[N][N],M2[N][N],M3[N][N],M4[N][N],M5[N][N],M6[N][N],M7[N][N];
    float AA[N][N],BB[N][N],MM1[N][N],MM2[N][N];

    int i,j;//,x;


    if (n==2)
        MATRIX_MULTIPLY(A,B,C);//按通常的矩阵乘法计算C=AB的子算法(仅做2阶)
    else
    {
        for(i=0;i<n/2;i++)                         
            for(j=0;j<n/2;j++)

                {
                    A11[i][j]=A[i][j];
                    A12[i][j]=A[i][j+n/2];
                    A21[i][j]=A[i+n/2][j];
                    A22[i][j]=A[i+n/2][j+n/2];
                    B11[i][j]=B[i][j];
                    B12[i][j]=B[i][j+n/2];
                    B21[i][j]=B[i+n/2][j];
                    B22[i][j]=B[i+n/2][j+n/2];
                }                                   //将矩阵A和B式分为四块




    MATRIX_SUB(n/2,B12,B22,BB);                       

    STRASSEN(n/2,A11,BB,M1);//M1=A11(B12-B22)

    MATRIX_ADD(n/2,A11,A12,AA);
    STRASSEN(n/2,AA,B22,M2);//M2=(A11+A12)B22

    MATRIX_ADD(n/2,A21,A22,AA);
    STRASSEN(n/2,AA,B11,M3);//M3=(A21+A22)B11

    MATRIX_SUB(n/2,B21,B11,BB);
    STRASSEN(n/2,A22,BB,M4);//M4=A22(B21-B11)

    MATRIX_ADD(n/2,A11,A22,AA);
    MATRIX_ADD(n/2,B11,B22,BB);
    STRASSEN(n/2,AA,BB,M5);//M5=(A11+A22)(B11+B22)

    MATRIX_SUB(n/2,A12,A22,AA);
    MATRIX_SUB(n/2,B21,B22,BB);
    STRASSEN(n/2,AA,BB,M6);//M6=(A12-A22)(B21+B22)

    MATRIX_SUB(n/2,A11,A21,AA);
    MATRIX_SUB(n/2,B11,B12,BB);
    STRASSEN(n/2,AA,BB,M7);//M7=(A11-A21)(B11+B12)
                                                    //计算M1,M2,M3,M4,M5,M6,M7(递归部分)


    MATRIX_ADD(N/2,M5,M4,MM1);                        
    MATRIX_SUB(N/2,M2,M6,MM2);
    MATRIX_SUB(N/2,MM1,MM2,C11);//C11=M5+M4-M2+M6

    MATRIX_ADD(N/2,M1,M2,C12);//C12=M1+M2

    MATRIX_ADD(N/2,M3,M4,C21);//C21=M3+M4

    MATRIX_ADD(N/2,M5,M1,MM1);
    MATRIX_ADD(N/2,M3,M7,MM2);
    MATRIX_SUB(N/2,MM1,MM2,C22);//C22=M5+M1-M3-M7

    for(i=0;i<n/2;i++)
        for(j=0;j<n/2;j++)
        {
            C[i][j]=C11[i][j];
            C[i][j+n/2]=C12[i][j];
            C[i+n/2][j]=C21[i][j];
            C[i+n/2][j+n/2]=C22[i][j];
        }                                            //计算结果送回C[N][N]

    }

}

这篇关于Strassen矩阵乘法简要解析(第4章:分治策略)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1087499

相关文章

深度解析Java DTO(最新推荐)

《深度解析JavaDTO(最新推荐)》DTO(DataTransferObject)是一种用于在不同层(如Controller层、Service层)之间传输数据的对象设计模式,其核心目的是封装数据,... 目录一、什么是DTO?DTO的核心特点:二、为什么需要DTO?(对比Entity)三、实际应用场景解析

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

Java中的雪花算法Snowflake解析与实践技巧

《Java中的雪花算法Snowflake解析与实践技巧》本文解析了雪花算法的原理、Java实现及生产实践,涵盖ID结构、位运算技巧、时钟回拨处理、WorkerId分配等关键点,并探讨了百度UidGen... 目录一、雪花算法核心原理1.1 算法起源1.2 ID结构详解1.3 核心特性二、Java实现解析2.

使用Python绘制3D堆叠条形图全解析

《使用Python绘制3D堆叠条形图全解析》在数据可视化的工具箱里,3D图表总能带来眼前一亮的效果,本文就来和大家聊聊如何使用Python实现绘制3D堆叠条形图,感兴趣的小伙伴可以了解下... 目录为什么选择 3D 堆叠条形图代码实现:从数据到 3D 世界的搭建核心代码逐行解析细节优化应用场景:3D 堆叠图

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

解析C++11 static_assert及与Boost库的关联从入门到精通

《解析C++11static_assert及与Boost库的关联从入门到精通》static_assert是C++中强大的编译时验证工具,它能够在编译阶段拦截不符合预期的类型或值,增强代码的健壮性,通... 目录一、背景知识:传统断言方法的局限性1.1 assert宏1.2 #error指令1.3 第三方解决

全面解析MySQL索引长度限制问题与解决方案

《全面解析MySQL索引长度限制问题与解决方案》MySQL对索引长度设限是为了保持高效的数据检索性能,这个限制不是MySQL的缺陷,而是数据库设计中的权衡结果,下面我们就来看看如何解决这一问题吧... 目录引言:为什么会有索引键长度问题?一、问题根源深度解析mysql索引长度限制原理实际场景示例二、五大解决

深度解析Spring Boot拦截器Interceptor与过滤器Filter的区别与实战指南

《深度解析SpringBoot拦截器Interceptor与过滤器Filter的区别与实战指南》本文深度解析SpringBoot中拦截器与过滤器的区别,涵盖执行顺序、依赖关系、异常处理等核心差异,并... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现

深度解析Spring AOP @Aspect 原理、实战与最佳实践教程

《深度解析SpringAOP@Aspect原理、实战与最佳实践教程》文章系统讲解了SpringAOP核心概念、实现方式及原理,涵盖横切关注点分离、代理机制(JDK/CGLIB)、切入点类型、性能... 目录1. @ASPect 核心概念1.1 AOP 编程范式1.2 @Aspect 关键特性2. 完整代码实

解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘问题

《解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘问题》:本文主要介绍解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4... 目录未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘打开pom.XM