【HuggingFace Transformers】LlamaMLP源码解析

2024-09-02 18:20

本文主要是介绍【HuggingFace Transformers】LlamaMLP源码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

LlamaMLP源码解析

  • 1. LlamaMLP 介绍
  • 2. LlamaMLP类 源码解析

1. LlamaMLP 介绍

LlamaMLPLLaMA 模型中的 MLP 层,主要用于对输入特征进行非线性变换。在分片预训练模式下,线性层的权重被切分,分步处理后再进行拼接和求和,而在常规模式下,直接应用线性变换和激活函数处理输入数据。其计算公式为:
o u t p u t = W d o w n ⋅ ( σ ( W g a t e ⋅ x + b g a t e ) ⊙ ( W u p ⋅ x + b u p ) ) + b d o w n output = W_{down}\cdot(\sigma(W_{gate}\cdot x+b_{gate})\odot (W_{up}\cdot x+b_{up})) +b_{down} output=Wdown(σ(Wgatex+bgate)(Wupx+bup))+bdown

2. LlamaMLP类 源码解析

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 15:16import torch
import torch.nn.functional as Ffrom torch import nn
from transformers.activations import ACT2FNclass LlamaMLP(nn.Module):def __init__(self, config):super().__init__()self.config = config  # 配置参数self.hidden_size = config.hidden_size  # 隐藏层的维度self.intermediate_size = config.intermediate_size  # 中间层的维度self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)  # 定义第一个线性变换层,将隐藏层映射到中间层self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)  # 定义第二个线性变换层,将隐藏层映射到中间层self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)  # 定义第三个线性变换层,将中间层的输出映射回隐藏层self.act_fn = ACT2FN[config.hidden_act]  # 根据配置选择激活函数def forward(self, x):# 如果是分片预训练if self.config.pretraining_tp > 1:slice = self.intermediate_size // self.config.pretraining_tp  # 计算每个切片的大小gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)  # 将 gate_proj 层的权重按行切分成多个切片up_proj_slices = self.up_proj.weight.split(slice, dim=0)  # 将 up_proj 层的权重按行切分成多个切片down_proj_slices = self.down_proj.weight.split(slice, dim=1)  # 将 down_proj 层的权重按列切分成多个切片gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)  # 对输入 x 应用每个 gate_proj 切片的线性变换,并沿列拼接up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)  # 对输入 x 应用每个 up_proj 切片的线性变换,并沿列拼接intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)  # 应用激活函数后,与 up_proj 结果逐元素相乘,并沿通道切分成多个张量down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)]  # 对每个 intermediate_states 切片应用对应的 down_proj 切片的线性变换down_proj = sum(down_proj)  # 将所有 down_proj 切片的结果相加else:# 如果不是分片预训练,直接进行线性变换和激活函数处理down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))# 返回最终的输出结果return down_proj

这篇关于【HuggingFace Transformers】LlamaMLP源码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/1130717

相关文章

Mybatis Plus JSqlParser解析sql语句及JSqlParser安装步骤

《MybatisPlusJSqlParser解析sql语句及JSqlParser安装步骤》JSqlParser是一个用于解析SQL语句的Java库,它可以将SQL语句解析为一个Java对象树,允许... 目录【一】jsqlParser 是什么【二】JSqlParser 的安装步骤【三】使用场景【1】sql语

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

Java 关键字transient与注解@Transient的区别用途解析

《Java关键字transient与注解@Transient的区别用途解析》在Java中,transient是一个关键字,用于声明一个字段不会被序列化,这篇文章给大家介绍了Java关键字transi... 在Java中,transient 是一个关键字,用于声明一个字段不会被序列化。当一个对象被序列化时,被

Java JSQLParser解析SQL的使用指南

《JavaJSQLParser解析SQL的使用指南》JSQLParser是一个Java语言的SQL语句解析工具,可以将SQL语句解析成为Java类的层次结构,还支持改写SQL,下面我们就来看看它的具... 目录一、引言二、jsQLParser常见类2.1 Class Diagram2.2 Statement

python进行while遍历的常见错误解析

《python进行while遍历的常见错误解析》在Python中选择合适的遍历方式需要综合考虑可读性、性能和具体需求,本文就来和大家讲解一下python中while遍历常见错误以及所有遍历方法的优缺点... 目录一、超出数组范围问题分析错误复现解决方法关键区别二、continue使用问题分析正确写法关键点三

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

使用Java实现Navicat密码的加密与解密的代码解析

《使用Java实现Navicat密码的加密与解密的代码解析》:本文主要介绍使用Java实现Navicat密码的加密与解密,通过本文,我们了解了如何利用Java语言实现对Navicat保存的数据库密... 目录一、背景介绍二、环境准备三、代码解析四、核心代码展示五、总结在日常开发过程中,我们有时需要处理各种软

Python多进程、多线程、协程典型示例解析(最新推荐)

《Python多进程、多线程、协程典型示例解析(最新推荐)》:本文主要介绍Python多进程、多线程、协程典型示例解析(最新推荐),本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定... 目录一、multiprocessing(多进程)1. 模块简介2. 案例详解:并行计算平方和3. 实现逻

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实

MyBatis分页插件PageHelper深度解析与实践指南

《MyBatis分页插件PageHelper深度解析与实践指南》在数据库操作中,分页查询是最常见的需求之一,传统的分页方式通常有两种内存分页和SQL分页,MyBatis作为优秀的ORM框架,本身并未提... 目录1. 为什么需要分页插件?2. PageHelper简介3. PageHelper集成与配置3.