如何微调Meta Llama-3 8B
发布时间:2024年09月12日
Meta
推出了 Meta
Llama 3 系列 LLM,包括 8 和 70B 大小的预训练和指令调整的生成文本模型。这些指令调整模型针对对话进行了优化,在行业基准测试中优于许多开源聊天模型。在开发过程中,我们特别注意优化实用性和安全性。
目录概览:
微调微调是机器学习中使用的一种技术,尤其是大型语言模型 (LLM)。这是一种利用现有模型的知识并针对特定任务进行定制的方法。需要资源情况T4-14.7/16GB。
开始微调Llama-3 8B
第 1 步:安装库
o
pip
:这将安装两个库:用于从 Hugging Face Hub 访问模型和用于交互式编码。
install huggingface_hub ipythonhuggingface_hubipython
o
"unsloth[colab]
:这将从 GitHub 安装 Unsloth 库,为 Google Colab() 和 conda 环境()指定不同的选项。
@ git+https://github.com/unslothai/unsloth.git" "unsloth[conda]
@git+https://github.com/unslothai/unsloth.git"[colab][conda]
o
export
:可能用于为 Hugging Face
HF_TOKEN=xxxxxxxxxxxxx
Hub 设置身份验证令牌,但出于安全原因,实际令牌值是隐藏的。
o
1. pip install huggingface_hub ipython "unsloth[colab] @ git+https://github.com/unslothai/unsloth.git" "unsloth[conda] @ git+https://github.com/unslothai/unsloth.git"
2. export HF_TOKEN=xxxxxxxxxxxxx
安装Wandb
1、 安装 Wandb 库:安装与 Wandb 交互所需的库。pip install wandb
2、 登录:提示您输入
Wandb 凭据(可能是 API 密钥),以便您可以使用该服务。wandb login
·
1. pip install
2. wandbwandb logio
导入库
·
1. import os
2. from unsloth import FastLanguageModel
3. import torch
4. from trl import SFTTrainer
5. from transformers import TrainingArguments
6. from datasets import load_dataset
加载数据集
1.
设置最大序列长度:定义每个训练示例中允许的最大标记数。这有助于在训练期间管理内存和计算资源。max_seq_length = 2048
2.
定义数据 URL:以 JSONL 格式存储数据集的 Web 地址,可能包含文本数据。url
- 加载数据集:使用库从提供的 URL 加载数据。
dataset
= load_dataset("json", data_files = {"train" : url},
split = "train")datasets
load_dataset("json")
将数据格式指定为 JSON。
data_files
dictionary 使用键“train”和 URL 作为其值定义训练数据位置。
split="train"
表示我们正在加载数据集的训练部分。
·
1. max_seq_length=2048
2. url="https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
3. dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
加载 Llama-3-8B
·
1. # 2. Load Llama3 model
2. model, tokenizer = FastLanguageModel.from_pretrained(
3. model_name = "unsloth/llama-3-8b-bnb-4bit", # 指定 Unsloth 库中的确切模型。“Llama3”可能是型号名称,“8b”表示 80 亿个参数,“bnb”可能是指特定的架构,“4bit”表示使用内存效率高的格式。
4. max_seq_length = max_seq_length, # 设置最大序列长度(前面定义)以限制模型可以处理的输入长度。
5. dtype = None, # (假设它设置为 None)允许库选择最合适的数据类型
6. load_in_4bit = True, # 允许以内存高效的 4 位格式加载模型(如果模型和硬件支持)
7. )
generate_text
·
·
1. def generate_text(text):
2. inputs = tokenizer(
3. [
4. text
5. ], return_tensors="pt").to("cuda")
6. outputs = model.generate(**inputs, max_new_tokens=20, use_cache=True)
7. tokenizer.batch_decode(outputs)
8. print("Before training\n")
进行模型参数设置和快速 LoRA 权重和训练
·
1. model = FastLanguageModel.get_peft_model(
2. model,
3. r = 16,
4. target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
5. "gate_proj", "up_proj", "down_proj",],
6. lora_alpha = 16,
7. lora_dropout = 0, # Supports any, but = 0 is optimized
8. bias = "none", # Supports any, but = "none" is optimized
9. use_gradient_checkpointing = True,
10. random_state = 3407,
11. max_seq_length = max_seq_length,
12. use_rslora = False, # Rank stabilized LoRA
13. loftq_config = None, # LoftQ
14. )
开始训练
·
trainer = SFTTrainer( model = model, train_dataset = dataset, dataset_text_field = "text", max_seq_length = max_seq_length, tokenizer = tokenizer, args = TrainingArguments( per_device_train_batch_size = 2, gradient_accumulation_steps = 4, warmup_steps = 10, max_steps = 60, fp16 = not torch.cuda.is_bf16_supported(), bf16 = torch.cuda.is_bf16_supported(), logging_steps = 1, output_dir = "outputs", optim = "adamw_8bit", weight_decay = 0.01, lr_scheduler_type = "linear", seed = 3407, ),)trainer.train()
测试模型
·
1. print("\n ######## \nAfter training\n")
2. generate_text("<human>: List the top 5 most popular movies of all time.\n<bot>: ")
保存模型
·
1. model.save_pretrained("lora_model")
2. model.save_pretrained_merged("outputs", tokenizer, save_method = "merged_16bit",)
3. model.push_to_hub_merged("YOURUSERNAME/llama3-8b-oig-unsloth-merged", tokenizer, save_method = "merged_16bit", token = os.environ.get("HF_TOKEN"))
4. model.push_to_hub("YOURUSERNAME/llama3-8b-oig-unsloth", tokenizer, save_method = "lora", token = os.environ.get("HF_TOKEN"))
出自:https://mp.weixin.qq.com/s/mwaCtibKkFjQzPhDRKtCOw
如果你想要了解关于智能工具类的内容,可以查看 智汇宝库,这是一个提供智能工具的网站。
在这你可以找到各种智能工具的相关信息,了解智能工具的用法以及最新动态。
B12是一个基于人工智能的建站工具,帮助企业提供专业网站建设服务,简化业务运营并提升客户参与度。