基于lora技术微调Gemma(2B)代码实践

2024-04-04 01:44

本文主要是介绍基于lora技术微调Gemma(2B)代码实践,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一、前置条件

获得模型访问权,选择Colab运行时,配置训练环境。

先在Kaggle上注册,然后获得Gemma 2B 的访问权;

然后在Google colab 配置环境,主要是GPU的选择,免费的是T4,建议采用付费的A100(为了节省时间,微调训练耗时T4需要30分钟左右,A100只需要2分钟左右)

最后 在Kaggle 上的account上生成令牌文件(主要是usename 和 API Key),并将令牌文件配置到colab环境。

二、微调步骤

三、源码

# -*- coding: utf-8 -*-
"""gemma-lora微调.ipynbAutomatically generated by Colaboratory.Original file is located athttps://colab.research.google.com/drive/1_uEbuYP-vk0tCO0EA7IQ1t85jJ6ne7Qi
"""from google.colab import files
uploaded = files.upload()!mkdir ~/.kaggle
!mv kaggle.json ~/.kaggle/!chmod 600 ~/.kaggle/kaggle.json# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3import osos.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"import keras
import keras_nlp!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonlimport json
data = []
with open("databricks-dolly-15k.jsonl") as file:for line in file:features = json.loads(line)# Filter out examples with context, to keep it simple.if features["context"]:continue# Format the entire example as a single string.template = "Instruction:\n{instruction}\n\nResponse:\n{response}"data.append(template.format(**features))# Only use 1000 training examples, to keep it fast.
data = data[:1000]gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()prompt = template.format(instruction="What should I do on a trip to Europe?",response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))prompt = template.format(instruction="Explain the process of photosynthesis in a child could understand.",response="",
)
print(gemma_lm.generate(prompt,max_length=256))# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(learning_rate=5e-5,weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])gemma_lm.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=optimizer,weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)prompt = template.format(instruction="What should I do on a trip to Europe?",response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))prompt = template.format(instruction="Explain the process of photosynthesis in a way that a child could understand.",response="",
)
print(gemma_lm.generate(prompt, max_length=256))

微调关键代码--A100

这篇关于基于lora技术微调Gemma(2B)代码实践的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/874502

相关文章

Spring Boot集成/输出/日志级别控制/持久化开发实践

《SpringBoot集成/输出/日志级别控制/持久化开发实践》SpringBoot默认集成Logback,支持灵活日志级别配置(INFO/DEBUG等),输出包含时间戳、级别、类名等信息,并可通过... 目录一、日志概述1.1、Spring Boot日志简介1.2、日志框架与默认配置1.3、日志的核心作用

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

破茧 JDBC:MyBatis 在 Spring Boot 中的轻量实践指南

《破茧JDBC:MyBatis在SpringBoot中的轻量实践指南》MyBatis是持久层框架,简化JDBC开发,通过接口+XML/注解实现数据访问,动态代理生成实现类,支持增删改查及参数... 目录一、什么是 MyBATis二、 MyBatis 入门2.1、创建项目2.2、配置数据库连接字符串2.3、入

Android Paging 分页加载库使用实践

《AndroidPaging分页加载库使用实践》AndroidPaging库是Jetpack组件的一部分,它提供了一套完整的解决方案来处理大型数据集的分页加载,本文将深入探讨Paging库... 目录前言一、Paging 库概述二、Paging 3 核心组件1. PagingSource2. Pager3.

springboot自定义注解RateLimiter限流注解技术文档详解

《springboot自定义注解RateLimiter限流注解技术文档详解》文章介绍了限流技术的概念、作用及实现方式,通过SpringAOP拦截方法、缓存存储计数器,结合注解、枚举、异常类等核心组件,... 目录什么是限流系统架构核心组件详解1. 限流注解 (@RateLimiter)2. 限流类型枚举 (

在Java中使用OpenCV实践

《在Java中使用OpenCV实践》用户分享了在Java项目中集成OpenCV4.10.0的实践经验,涵盖库简介、Windows安装、依赖配置及灰度图测试,强调其在图像处理领域的多功能性,并计划后续探... 目录前言一 、OpenCV1.简介2.下载与安装3.目录说明二、在Java项目中使用三 、测试1.测

MyBatis-Plus 自动赋值实体字段最佳实践指南

《MyBatis-Plus自动赋值实体字段最佳实践指南》MyBatis-Plus通过@TableField注解与填充策略,实现时间戳、用户信息、逻辑删除等字段的自动填充,减少手动赋值,提升开发效率与... 目录1. MyBATis-Plus 自动赋值概述1.1 适用场景1.2 自动填充的原理1.3 填充策略

Python实现PDF按页分割的技术指南

《Python实现PDF按页分割的技术指南》PDF文件处理是日常工作中的常见需求,特别是当我们需要将大型PDF文档拆分为多个部分时,下面我们就来看看如何使用Python创建一个灵活的PDF分割工具吧... 目录需求分析技术方案工具选择安装依赖完整代码实现使用说明基本用法示例命令输出示例技术亮点实际应用场景扩

Olingo分析和实践之EDM 辅助序列化器详解(最佳实践)

《Olingo分析和实践之EDM辅助序列化器详解(最佳实践)》EDM辅助序列化器是ApacheOlingoOData框架中无需完整EDM模型的智能序列化工具,通过运行时类型推断实现灵活数据转换,适用... 目录概念与定义什么是 EDM 辅助序列化器?核心概念设计目标核心特点1. EDM 信息可选2. 智能类

Olingo分析和实践之OData框架核心组件初始化(关键步骤)

《Olingo分析和实践之OData框架核心组件初始化(关键步骤)》ODataSpringBootService通过初始化OData实例和服务元数据,构建框架核心能力与数据模型结构,实现序列化、URI... 目录概述第一步:OData实例创建1.1 OData.newInstance() 详细分析1.1.1