pytorch 计算点集内或矩阵内两两元素之间的距离矩阵

2024-02-07 14:08

本文主要是介绍pytorch 计算点集内或矩阵内两两元素之间的距离矩阵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言:时间紧可直接看4,5两条。


1. 这一功能在python内或numpy内有现成的工具包

from scipy.spatial import distance
# 以下两种方式视情况选择
scipy.spatial.distance.pdist()
scipy.spatial.distance.cdist()

在神经网络的训练过程中,应用以上工具包需要把torch.tensor转变成numpy格式再计算,存在两个缺点:一是耗时,格式变来变去,而且从GPU迁移到CPU再返回到GPU;二是会造成梯度丢失。

2. pytorch中自带pdist()函数,但是这个函数输出结果为距离向量,而不是距离矩阵。距离向量是距离矩阵中上三角的元素。

import torch
import torch.tensor as tensor
import torch.nn.functional as F
a = tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]])  #建立tensord=F.pdist(a, p=2)
print(d)
"""
输出:tensor([1.7321, 3.4641, 5.1962, 1.7321, 3.4641, 1.7321])
"""

3. 自定义pdist()函数计算欧氏距离,如下所示,但是该函数只能用来计算欧式距离(L2范数),而且对角线上的元素不是0,而是一个极小的数1e-4。

import torch
import torch.tensor as tensor
"""
自定义的距离矩阵函数
"""
def pdists(A, squared = False, eps = 1e-8):prod = torch.mm(A, A.t())norm = prod.diag().unsqueeze(1).expand_as(prod)res = (norm + norm.t() - 2 * prod).clamp(min = 0)if squared:return reselse:res = res.clamp(min = eps).sqrt()return res"""应用示例"""
a = tensor([[1., 2., 3.],[4., 5., 6.],[7., 8., 9.],[10., 11., 12.]])c=pdists(a, squared = False)
print(c)
"""打印结果
tensor([[1.0000e-04, 5.1962e+00, 1.0392e+01, 1.5588e+01],[5.1962e+00, 1.0000e-04, 5.1962e+00, 1.0392e+01],[1.0392e+01, 5.1962e+00, 1.0000e-04, 5.1962e+00],[1.5588e+01, 1.0392e+01, 5.1962e+00, 1.0000e-04]])
"""

4. pytorch中的torch.norm(input[:, None] - input, dim=2, p=p)函数可以实现该功能

    在torch.nn.functional.pdist的文档介绍中有这么一句话:

 简单翻译:计算输入中每​​对行向量之间的p范数距离。 这与torch.norm(input[:, None] - input, dim=2, p=p)的对角线以外的上部三角形部分相同。 如果行是连续的,此功能将更快。

这句话暗示:torch.norm函数可用于计算距离矩阵,而且可以选择L1、L2范数或者其他范数。
应用示例:

import torch
import torch.tensor as tensor
a = tensor([[1., 1., 1.],[2., 2., 2.],[3., 3., 3.],[4., 4., 4.]])  #建立tensor
b=torch.norm(a[:, None]-a, dim=2, p=2)
print(b)
"""
tensor([[0.0000, 1.7321, 3.4641, 5.1962],[1.7321, 0.0000, 1.7321, 3.4641],[3.4641, 1.7321, 0.0000, 1.7321],[5.1962, 3.4641, 1.7321, 0.0000]])
"""

对应的,可以把torch.norm封装成新的pdist函数:

import torch
import torch.tensor as tensor
"""函数封装"""
def pdist(a,dim=2, p=2):dist_matrix = torch.norm(a[:, None]-a, dim, p)return dist_matrix 

5. 自定义余弦距离矩阵

import torch
def cosinematrix(A):prod = torch.mm(A, A.t())#分子norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母cos = prod.div(torch.mm(norm.t(),norm))return cos# 使用
d_matrix=cosinematrix(inputs)

文章参考:pytorch不用for循环计算一个矩阵各行之间的L1 、L2范数距离和余弦距离_小鱼的代码世界-CSDN博客_pytorch计算距离矩阵

这篇关于pytorch 计算点集内或矩阵内两两元素之间的距离矩阵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/687969

相关文章

C# 比较两个list 之间元素差异的常用方法

《C#比较两个list之间元素差异的常用方法》:本文主要介绍C#比较两个list之间元素差异,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. 使用Except方法2. 使用Except的逆操作3. 使用LINQ的Join,GroupJoin

深度解析Java项目中包和包之间的联系

《深度解析Java项目中包和包之间的联系》文章浏览阅读850次,点赞13次,收藏8次。本文详细介绍了Java分层架构中的几个关键包:DTO、Controller、Service和Mapper。_jav... 目录前言一、各大包1.DTO1.1、DTO的核心用途1.2. DTO与实体类(Entity)的区别1

python3如何找到字典的下标index、获取list中指定元素的位置索引

《python3如何找到字典的下标index、获取list中指定元素的位置索引》:本文主要介绍python3如何找到字典的下标index、获取list中指定元素的位置索引问题,具有很好的参考价值,... 目录enumerate()找到字典的下标 index获取list中指定元素的位置索引总结enumerat

CSS实现元素撑满剩余空间的五种方法

《CSS实现元素撑满剩余空间的五种方法》在日常开发中,我们经常需要让某个元素占据容器的剩余空间,本文将介绍5种不同的方法来实现这个需求,并分析各种方法的优缺点,感兴趣的朋友一起看看吧... css实现元素撑满剩余空间的5种方法 在日常开发中,我们经常需要让某个元素占据容器的剩余空间。这是一个常见的布局需求

Python并行处理实战之如何使用ProcessPoolExecutor加速计算

《Python并行处理实战之如何使用ProcessPoolExecutor加速计算》Python提供了多种并行处理的方式,其中concurrent.futures模块的ProcessPoolExecu... 目录简介完整代码示例代码解释1. 导入必要的模块2. 定义处理函数3. 主函数4. 生成数字列表5.

java Long 与long之间的转换流程

《javaLong与long之间的转换流程》Long类提供了一些方法,用于在long和其他数据类型(如String)之间进行转换,本文将详细介绍如何在Java中实现Long和long之间的转换,感... 目录概述流程步骤1:将long转换为Long对象步骤2:将Longhttp://www.cppcns.c

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

C/C++中OpenCV 矩阵运算的实现

《C/C++中OpenCV矩阵运算的实现》本文主要介绍了C/C++中OpenCV矩阵运算的实现,包括基本算术运算(标量与矩阵)、矩阵乘法、转置、逆矩阵、行列式、迹、范数等操作,感兴趣的可以了解一下... 目录矩阵的创建与初始化创建矩阵访问矩阵元素基本的算术运算 ➕➖✖️➗矩阵与标量运算矩阵与矩阵运算 (逐元

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不