hk.LayerNorm 模块介绍

2024-01-08 04:52
文章标签 模块 介绍 layernorm hk

本文主要是介绍hk.LayerNorm 模块介绍,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

hk.LayerNorm 是 Haiku 库中用于实现 Layer Normalization(层归一化)的模块。Layer Normalization 是一种神经网络归一化的技术,旨在提高神经网络的训练稳定性和泛化性。

主要参数:

  • axis(默认为-1): 沿着哪个轴进行归一化。通常选择最后一个轴,对输入的特征进行归一化。

  • create_scale(默认为True): 是否创建可学习的缩放参数。如果为 True,则会创建一个可学习的缩放参数,用于调整归一化后的值的幅度。

  • create_offset(默认为True): 是否创建可学习的偏置参数。如果为 True,则会创建一个可学习的偏置参数,用于调整归一化后的值的偏移。

  • epsilon(默认为1e-5): 一个小的正数,用于防止除以零的情况。

import haiku as hk
import jax
import jax.numpy as jnp
import pickle### 自定义LayerNorm模块
class LayerNorm(hk.LayerNorm):"""LayerNorm module.Equivalent to hk.LayerNorm but with different parameter shapes: they arealways vectors rather than possibly higher-rank tensors. This makes it easierto change the layout whilst keep the model weight-compatible."""def __init__(self,axis,create_scale: bool,create_offset: bool,eps: float = 1e-5,scale_init=None,offset_init=None,use_fast_variance: bool = False,name=None,param_axis=None):super().__init__(axis=axis,create_scale=False,create_offset=False,eps=eps,scale_init=None,offset_init=None,use_fast_variance=use_fast_variance,name=name,param_axis=param_axis)self._temp_create_scale = create_scaleself._temp_create_offset = create_offset#self.scale_init = hk.initializers.Constant(1)#self.offset_init = hk.initializers.Constant(0)def __call__(self, x: jnp.ndarray) -> jnp.ndarray:is_bf16 = (x.dtype == jnp.bfloat16)if is_bf16:x = x.astype(jnp.float32)param_axis = self.param_axis[0] if self.param_axis else -1param_shape = (x.shape[param_axis],)param_broadcast_shape = [1] * x.ndimparam_broadcast_shape[param_axis] = x.shape[param_axis]scale = Noneoffset = None# scale,offset张量的形状必须可扩展到输入数据的形状。# 没有显式指定 self.scale_init,self.offset_init参数,# 则默认使用 Haiku 库中的默认初始化方法。同 def __init__()中注释的显式指定if self._temp_create_scale:scale = hk.get_parameter('scale', param_shape, x.dtype, init=self.scale_init)scale = scale.reshape(param_broadcast_shape)if self._temp_create_offset:offset = hk.get_parameter('offset', param_shape, x.dtype, init=self.offset_init)offset = offset.reshape(param_broadcast_shape)out = super().__call__(x, scale=scale, offset=offset)if is_bf16:out = out.astype(jnp.bfloat16)return outwith open("Human_HBB_tensor_dict_ensembled.pkl",'rb') as f:Human_HBB_tensor_dict = pickle.load(f)input_data = jnp.array(Human_HBB_tensor_dict['msa_feat'])
print(input_data.shape)# 转换为Haiku模块
# LayerNorm层,在数据最后一个维度/轴(特征)做归一化,并创建可学习的缩放参数和偏置参数
model = hk.transform(lambda x: LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model)## 获取初始化的参数,参数的形状需要输入数据的形状以及模型的结构
rng = jax.random.PRNGKey(42)
params = model.init(rng, input_data)
print(params) 
print("params scale shape:") 
#print(params['msa_feat_norm']['scale'].shape)
#print("params offset bias:")
#print(params['msa_feat_norm']['offset'].shape)output_data = model.apply(params, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data)### 使用原始的hk.LayerNorm模块
model2 = hk.transform(lambda x: hk.LayerNorm(axis=[-1], create_scale=True,create_offset=True,name='msa_feat_norm')(x))print(model2)params2 = model2.init(rng, input_data)
print(params2) 
print("params2 scale shape:") 
print(params2['msa_feat_norm']['scale'].shape)
print("params2 offset bias:")
print(params2['msa_feat_norm']['offset'].shape)output_data2 = model2.apply(params2, rng, input_data)
print("input_data shape:", input_data.shape) 
print("Output Data shape:", output_data2.shape)
#print("原始数据:", input_data)
print("经过LayerNorm后:", output_data2)

参考:

https://dm-haiku.readthedocs.io/en/latest/api.html?highlight=layernorm#layernorm

这篇关于hk.LayerNorm 模块介绍的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/582385

相关文章

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

Python的time模块一些常用功能(各种与时间相关的函数)

《Python的time模块一些常用功能(各种与时间相关的函数)》Python的time模块提供了各种与时间相关的函数,包括获取当前时间、处理时间间隔、执行时间测量等,:本文主要介绍Python的... 目录1. 获取当前时间2. 时间格式化3. 延时执行4. 时间戳运算5. 计算代码执行时间6. 转换为指

Python正则表达式语法及re模块中的常用函数详解

《Python正则表达式语法及re模块中的常用函数详解》这篇文章主要给大家介绍了关于Python正则表达式语法及re模块中常用函数的相关资料,正则表达式是一种强大的字符串处理工具,可以用于匹配、切分、... 目录概念、作用和步骤语法re模块中的常用函数总结 概念、作用和步骤概念: 本身也是一个字符串,其中

Python中的getopt模块用法小结

《Python中的getopt模块用法小结》getopt.getopt()函数是Python中用于解析命令行参数的标准库函数,该函数可以从命令行中提取选项和参数,并对它们进行处理,本文详细介绍了Pyt... 目录getopt模块介绍getopt.getopt函数的介绍getopt模块的常用用法getopt模

redis过期key的删除策略介绍

《redis过期key的删除策略介绍》:本文主要介绍redis过期key的删除策略,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录第一种策略:被动删除第二种策略:定期删除第三种策略:强制删除关于big key的清理UNLINK命令FLUSHALL/FLUSHDB命

python logging模块详解及其日志定时清理方式

《pythonlogging模块详解及其日志定时清理方式》:本文主要介绍pythonlogging模块详解及其日志定时清理方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地... 目录python logging模块及日志定时清理1.创建logger对象2.logging.basicCo

Qt spdlog日志模块的使用详解

《Qtspdlog日志模块的使用详解》在Qt应用程序开发中,良好的日志系统至关重要,本文将介绍如何使用spdlog1.5.0创建满足以下要求的日志系统,感兴趣的朋友一起看看吧... 目录版本摘要例子logmanager.cpp文件main.cpp文件版本spdlog版本:1.5.0采用1.5.0版本主要

Pytest多环境切换的常见方法介绍

《Pytest多环境切换的常见方法介绍》Pytest作为自动化测试的主力框架,如何实现本地、测试、预发、生产环境的灵活切换,本文总结了通过pytest框架实现自由环境切换的几种方法,大家可以根据需要进... 目录1.pytest-base-url2.hooks函数3.yml和fixture结论你是否也遇到过

Python使用date模块进行日期处理的终极指南

《Python使用date模块进行日期处理的终极指南》在处理与时间相关的数据时,Python的date模块是开发者最趁手的工具之一,本文将用通俗的语言,结合真实案例,带您掌握date模块的六大核心功能... 目录引言一、date模块的核心功能1.1 日期表示1.2 日期计算1.3 日期比较二、六大常用方法详

MySQL中慢SQL优化的不同方式介绍

《MySQL中慢SQL优化的不同方式介绍》慢SQL的优化,主要从两个方面考虑,SQL语句本身的优化,以及数据库设计的优化,下面小编就来给大家介绍一下有哪些方式可以优化慢SQL吧... 目录避免不必要的列分页优化索引优化JOIN 的优化排序优化UNION 优化慢 SQL 的优化,主要从两个方面考虑,SQL 语