计算上理解LayerNorm,为何泄露信息,知识追踪

2023-10-17 01:30

本文主要是介绍计算上理解LayerNorm,为何泄露信息,知识追踪,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch 中layernorm 的使用

首先给出官网的解释,
在这里插入图片描述

torch.nn.LayerNorm(normalized_shape: Union[int, List[int], torch.Size],eps: float = 1e-05,elementwise_affine: bool = True)

其中注意:LayerNorm中不会像BatchNorm那样跟踪统计全局的均值方差,因此train()和eval()对LayerNorm没有影响。


如何计算:训练样本a:batch=2,seq_len=2,dims=3
pytorch

a = torch.tensor([[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[1.0,2.0,3.0],[4.0,5.0,6.0]]])
ln = torch.nn.LayerNorm([2,3],elementwise_affine=False)
ln_out = ln(a)

numpy

mean = np.mean(a.numpy(), axis=(1,2))
var = np.var(a.numpy(), axis=(1,2))
div = np.sqrt(var+1e-05)
ln_out = (a-mean[:,None,None])/div[:,None,None] # None的作用是增加维度

所以layernorm中的normalized_shape是算矩阵中的后面几维,这里的[2,3] 表示倒数第二维和倒数第一维,最后针对每个样本都有只有一个均值和方差。

带参数的layernorm,由于是面向最后两个维度,那么weigth和bias与最后两维形状一样[2,3]。那么每一个样本都会重复使用,进行仿射变换,(仿射变换即乘以weight中对应的数字后,然后加bias中对应的数字),并会在反向传播时得到学习。

ln=torch.nn.LayerNorm([2,3],elementwise_affine=True)
ln.state_dict()
#OrderedDict([('weight', tensor([[1., 1., 1.],[1., 1., 1.]])),('bias', tensor([[0., 0., 0.],[0., 0., 0.]]))])

pytorch LayerNorm参数详解,计算过程

tensor = torch.FloatTensor([[1, 2, 4, 1],[6, 3, 2, 4],[2, 4, 6, 1]])

在这里插入图片描述
在使用LayerNorm时,通常只需要指定normalized_shape就可以了。

pytorch常用normalization函数

与batch normalization和instance normalization不同,batch normalization使用affine选项为每个通道/平面应用标量尺度γ和偏差β,而layer normalization使用elementwise_affine参数为每个元素应用尺度和偏差。
在这里插入图片描述

知识追踪领域

数据形状:【batch,sentence,feature】
SAKT中,或者Transformer中,我们的LayerNorm定义为:
self.layer_norm = nn.LayerNorm(d_model)
所以是对最后一维,特征维度进行归一化。.
而某一篇文字,(在没证实、没充分证据说明它的结果有问题,计算过程有泄露信息前,先不透漏文章)
self.layer_norm = nn.LayerNorm(normalized_shape = [length, d_model])
他是对最后两维,因为知识追踪,第t个时间步,是不能看到第t+1个时间步的信息的。问题是归一化只涉及到数值上面的放大缩小(scale),如何泄露还确实不知道模型怎么做到的。只能确定的是,第t个时间步看到了后面的信息。

这篇关于计算上理解LayerNorm,为何泄露信息,知识追踪的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/222007

相关文章

Go语言开发实现查询IP信息的MCP服务器

《Go语言开发实现查询IP信息的MCP服务器》随着MCP的快速普及和广泛应用,MCP服务器也层出不穷,本文将详细介绍如何在Go语言中使用go-mcp库来开发一个查询IP信息的MCP... 目录前言mcp-ip-geo 服务器目录结构说明查询 IP 信息功能实现工具实现工具管理查询单个 IP 信息工具的实现服

一文详解Java异常处理你都了解哪些知识

《一文详解Java异常处理你都了解哪些知识》:本文主要介绍Java异常处理的相关资料,包括异常的分类、捕获和处理异常的语法、常见的异常类型以及自定义异常的实现,文中通过代码介绍的非常详细,需要的朋... 目录前言一、什么是异常二、异常的分类2.1 受检异常2.2 非受检异常三、异常处理的语法3.1 try-

使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)

《使用Python从PPT文档中提取图片和图片信息(如坐标、宽度和高度等)》PPT是一种高效的信息展示工具,广泛应用于教育、商务和设计等多个领域,PPT文档中常常包含丰富的图片内容,这些图片不仅提升了... 目录一、引言二、环境与工具三、python 提取PPT背景图片3.1 提取幻灯片背景图片3.2 提取

Linux下如何使用C++获取硬件信息

《Linux下如何使用C++获取硬件信息》这篇文章主要为大家详细介绍了如何使用C++实现获取CPU,主板,磁盘,BIOS信息等硬件信息,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下... 目录方法获取CPU信息:读取"/proc/cpuinfo"文件获取磁盘信息:读取"/proc/diskstats"文

深入理解Apache Kafka(分布式流处理平台)

《深入理解ApacheKafka(分布式流处理平台)》ApacheKafka作为现代分布式系统中的核心中间件,为构建高吞吐量、低延迟的数据管道提供了强大支持,本文将深入探讨Kafka的核心概念、架构... 目录引言一、Apache Kafka概述1.1 什么是Kafka?1.2 Kafka的核心概念二、Ka

一文详解SQL Server如何跟踪自动统计信息更新

《一文详解SQLServer如何跟踪自动统计信息更新》SQLServer数据库中,我们都清楚统计信息对于优化器来说非常重要,所以本文就来和大家简单聊一聊SQLServer如何跟踪自动统计信息更新吧... SQL Server数据库中,我们都清楚统计信息对于优化器来说非常重要。一般情况下,我们会开启"自动更新

Python如何获取域名的SSL证书信息和到期时间

《Python如何获取域名的SSL证书信息和到期时间》在当今互联网时代,SSL证书的重要性不言而喻,它不仅为用户提供了安全的连接,还能提高网站的搜索引擎排名,那我们怎么才能通过Python获取域名的S... 目录了解SSL证书的基本概念使用python库来抓取SSL证书信息安装必要的库编写获取SSL证书信息

Win32下C++实现快速获取硬盘分区信息

《Win32下C++实现快速获取硬盘分区信息》这篇文章主要为大家详细介绍了Win32下C++如何实现快速获取硬盘分区信息,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 实现代码CDiskDriveUtils.h#pragma once #include <wtypesbase

国内环境搭建私有知识问答库踩坑记录(ollama+deepseek+ragflow)

《国内环境搭建私有知识问答库踩坑记录(ollama+deepseek+ragflow)》本文给大家利用deepseek模型搭建私有知识问答库的详细步骤和遇到的问题及解决办法,感兴趣的朋友一起看看吧... 目录1. 第1步大家在安装完ollama后,需要到系统环境变量中添加两个变量2. 第3步 “在cmd中

SpringBoot项目注入 traceId 追踪整个请求的日志链路(过程详解)

《SpringBoot项目注入traceId追踪整个请求的日志链路(过程详解)》本文介绍了如何在单体SpringBoot项目中通过手动实现过滤器或拦截器来注入traceId,以追踪整个请求的日志链... SpringBoot项目注入 traceId 来追踪整个请求的日志链路,有了 traceId, 我们在排