hk.LayerNorm 模块介绍

2024-01-08 04:52
文章标签 模块 介绍 layernorm hk

本文主要是介绍hk.LayerNorm 模块介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

hk.LayerNorm 是 Haiku 库中用于实现 Layer Normalization(层归一化)的模块。Layer Normalization 是一种神经网络归一化的技术,旨在提高神经网络的训练稳定性和泛化性。

主要参数:

  • axis(默认为-1): 沿着哪个轴进行归一化。通常选择最后一个轴,对输入的特征进行归一化。

  • create_scale(默认为True): 是否创建可学习的缩放参数。如果为 True,则会创建一个可学习的缩放参数,用于调整归一化后的值的幅度。

  • create_offset(默认为True): 是否创建可学习的偏置参数。如果为 True,则会创建一个可学习的偏置参数,用于调整归一化后的值的偏移。

  • epsilon(默认为1e-5): 一个小的正数,用于防止除以零的情况。

import haiku as hk
import jax
import jax.numpy as jnp
import pickle### 自定义LayerNorm模块
class LayerNorm(hk.LayerNorm):"""LayerNorm module.Equivalent to hk.LayerNorm but with different parameter shapes: they arealways vectors rather than possibly higher-rank tensors. This makes it easierto change the layout whilst keep the model weight-compatible."""def __init__(self,axis,create_scale: bool,create_offset: bool,eps: float = 1e-5,scale_init=None,offset_init=None,use_fast_variance: bool = False,name=None,param_axis=None):super().__init__(axis=axis,create_scale=False,create_offset=False,eps=eps,scale_init=None,offset_init=None,use_fast_variance=use_fast_variance,name=name,param_axis=param_axis)self._temp_create_scale = create_scaleself._temp_create_offset = create_offset#self.scale_init = hk.initializers.Constant(1)#self.offset_init = hk.initializers.Constant(0)def __call__(self, x: jnp.ndarray) -> jnp.ndarray:is_bf16 = (x.dtype == jnp.bfloat16)if is_bf16:x = x.astype(jnp.float32)param_axis = self.param_axis[0] if self.param_axis else -1param_shape = (x.shape[param_axis],)param_broadcast_shape = [1] * x.ndimparam_broadcast_shape[param_axis] = x.shape[param_axis]scale = Noneoffset = None# scale,offset张量的形状必须可扩展到输入数据的形状。# 没有显式指定 self.scale_init,self.offset_init参数,# 则默认使用 Haiku 库中的默认初始化方法。同 def __init__()中注释的显式指定if self._temp_create_scale:scale = hk.get_parameter('scale', param_shape, x.dtype, init=self.scale_init)scale = scale.reshape(param_broadcast_shape)if self._temp_create_offset:offset = hk.get_parameter('offset', param_shape, x.dtype, init=self.offset_init)offset = offset.reshape(param_broadcast_shape)out = super().__call__(x, scale=scale, offset=offset)if is_bf16:out = out.astype(jnp.bfloat16)return outwith open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:Human_HBB_tensor_dict = pickle.load(f)input_data = jnp.array(Human_HBB_tensor_dict['msa_feat'])
print(input_data.shape)# 转换为Haiku模块
# LayerNorm层,在数据最后一个维度/轴(特征)做归一化,并创建可学习的缩放参数和偏置参数
model = hk.transform(lambda x: LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model)## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
rng = jax.random.PRNGKey(42)
params = model.init(rng, input_data)
print(params) 
print("params scale shape:") 
#print(params['msa_feat_norm']['scale'].shape)
#print("params offset bias:")
#print(params['msa_feat_norm']['offset'].shape)output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data)### 使用原始的hk.LayerNorm模块
model2 = hk.transform(lambda x: hk.LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model2)params2 = model2.init(rng, input_data)
print(params2) 
print("params2 scale shape:") 
print(params2['msa_feat_norm']['scale'].shape)
print("params2 offset bias:")
print(params2['msa_feat_norm']['offset'].shape)output_data2 = model2.apply(params2, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data2.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data2)

参考:

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=layernorm#layernorm

这篇关于hk.LayerNorm 模块介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/582385

相关文章

MySQL常用字符串函数示例和场景介绍

《MySQL常用字符串函数示例和场景介绍》MySQL提供了丰富的字符串函数帮助我们高效地对字符串进行处理、转换和分析,本文我将全面且深入地介绍MySQL常用的字符串函数,并结合具体示例和场景,帮你熟练... 目录一、字符串函数概述1.1 字符串函数的作用1.2 字符串函数分类二、字符串长度与统计函数2.1

Python通用唯一标识符模块uuid使用案例详解

《Python通用唯一标识符模块uuid使用案例详解》Pythonuuid模块用于生成128位全局唯一标识符,支持UUID1-5版本,适用于分布式系统、数据库主键等场景,需注意隐私、碰撞概率及存储优... 目录简介核心功能1. UUID版本2. UUID属性3. 命名空间使用场景1. 生成唯一标识符2. 数

zookeeper端口说明及介绍

《zookeeper端口说明及介绍》:本文主要介绍zookeeper端口说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、zookeeper有三个端口(可以修改)aVNMqvZ二、3个端口的作用三、部署时注意总China编程结一、zookeeper有三个端口(可以

Python中win32包的安装及常见用途介绍

《Python中win32包的安装及常见用途介绍》在Windows环境下,PythonWin32模块通常随Python安装包一起安装,:本文主要介绍Python中win32包的安装及常见用途的相关... 目录前言主要组件安装方法常见用途1. 操作Windows注册表2. 操作Windows服务3. 窗口操作

Python中re模块结合正则表达式的实际应用案例

《Python中re模块结合正则表达式的实际应用案例》Python中的re模块是用于处理正则表达式的强大工具,正则表达式是一种用来匹配字符串的模式,它可以在文本中搜索和匹配特定的字符串模式,这篇文章主... 目录前言re模块常用函数一、查看文本中是否包含 A 或 B 字符串二、替换多个关键词为统一格式三、提

c++中的set容器介绍及操作大全

《c++中的set容器介绍及操作大全》:本文主要介绍c++中的set容器介绍及操作大全,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录​​一、核心特性​​️ ​​二、基本操作​​​​1. 初始化与赋值​​​​2. 增删查操作​​​​3. 遍历方

HTML img标签和超链接标签详细介绍

《HTMLimg标签和超链接标签详细介绍》:本文主要介绍了HTML中img标签的使用,包括src属性(指定图片路径)、相对/绝对路径区别、alt替代文本、title提示、宽高控制及边框设置等,详细内容请阅读本文,希望能对你有所帮助... 目录img 标签src 属性alt 属性title 属性width/h

一文深入详解Python的secrets模块

《一文深入详解Python的secrets模块》在构建涉及用户身份认证、权限管理、加密通信等系统时,开发者最不能忽视的一个问题就是“安全性”,Python在3.6版本中引入了专门面向安全用途的secr... 目录引言一、背景与动机:为什么需要 secrets 模块?二、secrets 模块的核心功能1. 基

MybatisPlus service接口功能介绍

《MybatisPlusservice接口功能介绍》:本文主要介绍MybatisPlusservice接口功能介绍,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友... 目录Service接口基本用法进阶用法总结:Lambda方法Service接口基本用法MyBATisP

MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)

《MySQL复杂SQL之多表联查/子查询详细介绍(最新整理)》掌握多表联查(INNERJOIN,LEFTJOIN,RIGHTJOIN,FULLJOIN)和子查询(标量、列、行、表子查询、相关/非相关、... 目录第一部分:多表联查 (JOIN Operations)1. 连接的类型 (JOIN Types)