当前位置: 首页 > article >正文

Hugging Face预训练GPT微调ChatGPT(微调入门!新手友好!)

Hugging Face预训练GPT微调ChatGPT(微调入门!新手友好!)

在实战中,⼤多数情况下都不需要从0开始训练模型,⽽是使⽤“⼤⼚”或者其他研究者开源的已经训练好的⼤模型。

在各种⼤模型开源库中,最具代表性的就是Hugging FaceHugging Face是⼀家专注于NLP领域的AI公司,开发了⼀个名为Transformers的开源库,该开源库拥有许多预训练后的深度学习模型,如BERT、GPT-2、T5等。Hugging FaceTransformers开源库使研究⼈员和开发⼈员能够更轻松地使⽤这些模型进⾏各种NLP任务,例如⽂本分类、问答、⽂本⽣成等。这个库也提供了简洁、⾼效的API,有助于快速实现⾃然语⾔处理应⽤。

从Hugging Face下载⼀个GPT-2并微调成ChatGPT,需要遵循的步骤如下。

image-20250319144401279

1.安装Hugging Face Transformers库

pip install transformers

2.载入预训练GPT-2模型和分词器

import torch # 导⼊torch
from transformers import GPT2Tokenizer # 导⼊GPT-2分词器
from transformers import GPT2LMHeadModel # 导⼊GPT-2语⾔模型
model_name = "gpt2" # 也可以选择其他模型,如"gpt2-medium" "gpt2-large"等
tokenizer = GPT2Tokenizer.from_pretrained(model_name) # 加载分词器
tokenizer.pad_token = '' # 为分词器添加pad token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('')
device = "cuda" if torch.cuda.is_available() else "cpu" # 判断是否有可⽤的GPU
model = GPT2LMHeadModel.from_pretrained(model_name).to(device) # 将模型加载到设备上(CPU或GPU)
vocab = tokenizer.get_vocab() # 获取词汇表
print("模型信息:", model)
print("分词器信息:",tokenizer)
print("词汇表⼤⼩:", len(vocab))
print("部分词汇示例:", (list(vocab.keys())[8000:8005]))

3.准备微调数据集

from torch.utils.data import Dataset  # 导入PyTorch的Dataset# 自定义ChatDataset类,继承自PyTorch的Dataset类
class ChatDataset(Dataset):def __init__(self, file_path, tokenizer, vocab):self.tokenizer = tokenizer  # 分词器self.vocab = vocab  # 词汇表# 加载数据并处理,将处理后的输入数据和目标数据赋值给input_data和target_dataself.input_data, self.target_data = self.load_and_process_data(file_path)# 定义加载和处理数据的方法def load_and_process_data(self, file_path):with open(file_path, "r") as f:  # 读取文件内容lines = f.readlines()input_data, target_data = [], []for i, line in enumerate(lines):  # 遍历文件的每一行if line.startswith("User:"):  # 如以"User:"开头,移除"User: "前缀,并将张量转换为列表tokens = self.tokenizer(line.strip()[6:], return_tensors="pt")["input_ids"].tolist()[0]tokens = tokens + [self.tokenizer.eos_token_id]  # 添加结束符input_data.append(torch.tensor(tokens, dtype=torch.long))  # 添加到input_dataelif line.startswith("AI:"):  # 如以"AI:"开头,移除"AI: "前缀,并将张量转换为列表tokens = self.tokenizer(line.strip()[4:], return_tensors="pt")["input_ids"].tolist()[0]tokens = tokens + [self.tokenizer.eos_token_id]  # 添加结束符target_data.append(torch.tensor(tokens, dtype=torch.long))  # 添加到target_datareturn input_data, target_data# 定义数据集的长度,即input_data的长度def __len__(self):return len(self.input_data)# 定义获取数据集中指定索引的数据的方法def __getitem__(self, idx):return self.input_data[idx], self.target_data[idx]file_path = "/kaggle/input/hugging-face-chatgpt-chat-data/chat.txt"  # 加载chat.txt数据集
chat_dataset = ChatDataset(file_path, tokenizer, vocab)  # 创建ChatDataset对象,传入文件、分词器和词汇表# 打印数据集中前2个数据示例
for i in range(2):input_example, target_example = chat_dataset[i]print(f"示例 {i + 1}:")print("输入:", tokenizer.decode(input_example))print("输出:", tokenizer.decode(target_example))

4.准备微调数据加载器

from torch.utils.data import DataLoader  # 导入DataLoadertokenizer.pad_token = ''  # 为分词器添加pad token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids('')# 定义pad_sequence函数,用于将一批序列补齐到相同长度
def pad_sequence(sequences, padding_value=0, length=None):# 计算最大序列长度,如果length参数未提供,则使用输入序列中的最大长度max_length = max(len(seq) for seq in sequences) if length is None else length# 创建一个具有适当形状的全零张量,用于存储补齐后的序列result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)# 遍历序列,将每个序列的内容复制到张量result中for i, seq in enumerate(sequences):end = len(seq)result[i, :end] = seq[:end]return result# 定义collate_fn函数,用于将一个批次的数据整理成适当的形状
def collate_fn(batch):# 从批次中分离源序列和目标序列sources, targets = zip(*batch)# 计算批次中的最大序列长度max_length = max(max(len(s) for s in sources), max(len(t) for t in targets))# 使用pad_sequence函数补齐源序列和目标序列sources = pad_sequence(sources, padding_value=tokenizer.pad_token_id, length=max_length)targets = pad_sequence(targets, padding_value=tokenizer.pad_token_id, length=max_length)# 返回补齐后的源序列和目标序列return sources, targets# 创建DataLoader
chat_dataloader = DataLoader(chat_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)# 检查Dataloader输出
for input_batch, target_batch in chat_dataloader:print("Input batch tensor size:", input_batch.size())print("Target batch tensor size:", target_batch.size())breakfor input_batch, target_batch in chat_dataloader:print("Input batch tensor:")print(input_batch)print("Target batch tensor:")print(target_batch)break

5.对GPT-2进行微调

import torch.nn as nn
import torch.optim as optim# 定义损失函数,忽略pad_token_id对应的损失值
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.0001)# 进行500个epoch的训练
for epoch in range(500):for batch_idx, (input_batch, target_batch) in enumerate(chat_dataloader):  # 遍历数据加载器中的批次optimizer.zero_grad()  # 梯度清零input_batch, target_batch = input_batch.to(device), target_batch.to(device)  # 输入和目标批次移至设备outputs = model(input_batch)  # 前向传播logits = outputs.logits  # 获取logits# 计算损失loss = criterion(logits.view(-1, len(vocab)), target_batch.view(-1))loss.backward()  # 反向传播optimizer.step()  # 更新参数if (epoch + 1) % 100 == 0:  # 每100个epoch打印一次损失值print(f'Epoch: {epoch + 1:04d}, cost = {loss:.6f}')

6.用约束解码函数生成回答

# 定义集束解码函数
def generate_text_beam_search(model, input_str, max_len=50, beam_width=5):model.eval()  # 将模型设置为评估模式(不计算梯度)# 对输入字符串进行编码,并将其转换为张量,然后将其移动到相应的设备上input_tokens = tokenizer.encode(input_str, return_tensors="pt").to(device)# 初始化候选序列列表,包含当前输入序列和其对数概率得分(我们从0开始)candidates = [(input_tokens, 0.0)]# 禁用梯度计算,以加速预测过程with torch.no_grad():# 迭代生成最大长度的序列for _ in range(max_len):new_candidates = []# 对于每个候选序列for candidate, candidate_score in candidates:# 使用模型进行预测outputs = model(candidate)# 获取输出logitslogits = outputs.logits[:, -1, :]# 获取对数概率得分的top-k值(即beam_width)及其对应的tokenscores, next_tokens = torch.topk(logits, beam_width, dim=-1)final_results = []# 遍历top-k token及其对应的得分for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):# 在当前候选序列中添加新的tokennew_candidate = torch.cat((candidate, next_token.unsqueeze(0).unsqueeze(0)), dim=-1)# 更新候选序列的得分new_score = candidate_score - score.item()# 如果新的token是结束符(eos_token),则将该候选序列添加到最终结果中if next_token.item() == tokenizer.eos_token_id:final_results.append((new_candidate, new_score))# 否则,将新的候选序列添加到新候选序列列表中else:new_candidates.append((new_candidate, new_score))# 从新候选序列列表中选择得分最⾼的top-k个序列candidates = sorted(new_candidates, key=lambda x: x[1])[:beam_width]# 选择得分最⾼的候选序列best_candidate, _ = sorted(candidates, key=lambda x: x[1])[0]# 将输出token转换回文本字符串output_str = tokenizer.decode(best_candidate[0])# 移除输入字符串并修复空格问题input_len = len(tokenizer.encode(input_str))output_str = tokenizer.decode(best_candidate.squeeze()[input_len:])return output_str# 测试模型
test_inputs = ["what is the weather like today?","can you recommend a good book?"
]# 输出测试结果
for i, input_str in enumerate(test_inputs, start=1):generated_text = generate_text_beam_search(model, input_str)print(f"测试 {i}:")print(f"User: {input_str}")print(f"AI: {generated_text}")
测试1:
User: what is the weather like today?<|endoftext|>
AI: you need an current time for now app with app app app app
测试2:
User: Can you recommend a good book?<|endoftext|>
AI: ockingbird Lee Harper Harper Taylor

模型的回答虽然称不上完美,但是,我们⾄少能够看出,微调数据集中的信息起到了⼀定的作⽤。第⼀个问题问及天⽓,模型敏锐地指向“app”(应⽤)这个存在于训练语料库中的信息,⽽查看“应⽤”确实是我们希望模型给出的答案。回答第⼆个问题时,模型给出了语料库中所推荐图书的作者的名字“Lee Harper”,⽽书名“To kill a Mockingbird”中的mockingbird是⼀个未知token,模型把它拆解成了三个token。具体信息如下。

tokenizer.encode('Mockingbird')[44/76, 8629, 16944]
tokenizer.decode(44)'M'
tokenizer.decode(8629)'ocking'
tokenizer.decode(16944)'bird'

因此,在解码时,出现了ockingbird这样的不完整信息,但是其中也的确包含了⼀定的语料库内部的知识。

⽽微调则针对特定任务进⾏优化。这⼀模式的优势在于,微调过程通常需要较少的训练数据和计算资源,同时仍能获得良好的性能。

相关文章:

Hugging Face预训练GPT微调ChatGPT(微调入门!新手友好!)

Hugging Face预训练GPT微调ChatGPT&#xff08;微调入门&#xff01;新手友好&#xff01;&#xff09; 在实战中&#xff0c;⼤多数情况下都不需要从0开始训练模型&#xff0c;⽽是使⽤“⼤⼚”或者其他研究者开源的已经训练好的⼤模型。 在各种⼤模型开源库中&#xff0c;最…...

【CSS3】化神篇

目录 平面转换平移旋转改变旋转原点多重转换缩放倾斜 渐变线性渐变径向渐变 空间转换平移视距旋转立体呈现缩放 动画使现步骤animation 复合属性animation 属性拆分逐帧动画多组动画 平面转换 作用&#xff1a;为元素添加动态效果&#xff0c;一般与过渡配合使用 概念&#x…...

Unity音频混合器如何暴露参数

音频混合器是Unity推荐管理音效混音的工具&#xff0c;那么如何使用代码对它进行管理呢&#xff1f; 首先我在AudioMixer的Master组中创建了BGM和SFX的分组&#xff0c;你也可以直接用Master没有问题。 这里我以BGM为例&#xff0c;如果要在代码中进行使用就需要将参数暴露出去…...

Vue keepalive学习用法

在Vue中&#xff0c;<keep-alive>的include属性用于指定需要缓存的组件&#xff0c;其实现方式如下&#xff1a; 1. 基本用法 • 字符串形式&#xff1a;通过逗号分隔组件名称&#xff0c;匹配到的组件会被缓存。 <keep-alive include"ComponentA,ComponentB&…...

5-1 使用ECharts将MySQL数据库中的数据可视化

方法一&#xff1a;使用Python Flask框架搭建API 对于技术小白来说&#xff0c;使用ECharts将MySQL数据库中的数据可视化需要分步骤完成。以下是详细的实现流程&#xff1a; 一、技术架构‌ 后端服务‌&#xff1a;使用Python Flask框架搭建API&#xff08;简单易学&#xff…...

构建下一代AI Agent:自动化开发与行业落地全解析

1. 下一代AI Agent&#xff1a;概念与核心能力 核心能力描述技术支撑应用价值自主性独立规划与执行任务&#xff0c;无需持续人工干预决策树、强化学习、目标导向规划减少人工干预&#xff0c;提高任务执行效率决策能力评估多种方案并选择最优解决方案贝叶斯决策、多目标优化、…...

如何理解分布式光纤传感器?

关键词&#xff1a;OFDR、分布式光纤传感、光纤传感器 分布式光纤传感器是近年来备受关注的前沿技术&#xff0c;其核心在于将光纤本身作为传感介质和信号传输介质&#xff0c;通过解析光信号在光纤中的散射效应&#xff0c;实现对温度、应变、振动等物理量的连续、无盲区、高…...

四、小白学JAVA-石头剪刀布游戏

1、如何从控制台获取用户输入 import java.util.Scanner;public class Main {public static void main(String[] args) {// 石头剪刀布的思路// 1 2 3 石头 剪刀 布Scanner scanner new Scanner(System.in);System.out.println("请出拳&#xff1a;1.石头 2.剪刀 3.布【…...

【一起来学kubernetes】21、Secret使用详解

Secret 的详细介绍 Secret 是 Kubernetes 中用于存储和管理敏感信息&#xff08;如密码、令牌、密钥等&#xff09;的资源对象。Secret的设计目的是为了安全地存储和传输敏感信息&#xff0c;如密码、API密钥、证书等。这些信息通常不应该直接硬编码在配置文件或镜像中&#x…...

css重点知识汇总(一)

css重点知识汇总&#xff08;一&#xff09; 引入css的不同方式 link 通过src来获取相应的css资源。除了获取css之外还可以获取其他资源&#xff0c;例如js在页面载入是同步下载可以通过js对dom操作来改变css import css3引入的新方法只能引入css资源需要页面完全载入后才…...

PMP-项目运行环境

你好&#xff01;我是 Lydia-穎穎 ♥感谢你的陪伴与支持 ~~~ 欢迎一起探索未知的知识和未来&#xff0c;现在lets go go go!!! 1. 影响项目的要素 项目存在在不同的环境下&#xff0c;环境对于项目的交付产生不同的影响。需了解环境对于项目的影响&#xff0c;采取相应措施应对…...

shell 脚本搭建apache

#!/bin/bash # Set Apache version to install ## author: yuan# 检查外网连接 echo "检查外网连接..." ping www.baidu.com -c 3 > /dev/null 2>&1 if [ $? -eq 0 ]; thenecho "外网通讯良好&#xff01;" elseecho "网络连接失败&#x…...

Huawei 鲲鹏(ARM/Aarch64)服务器安装KVM虚拟机(非桌面视图)

提出问题 因需要进行ARM架构适配&#xff0c;需要在Huawei Taishan 200k&#xff08;CPU&#xff1a; Kunpeng 920 5231K&#xff09;上&#xff0c;创建几台虚拟机做为开发测试环境。 无奈好久没搞了&#xff0c;看了一下自己多年前写的文章&#xff1a;Huawei 鲲鹏&#xf…...

《Python实战进阶》No28: 使用 Paramiko 实现远程服务器管理

No28: 使用 Paramiko 实现远程服务器管理 摘要 在现代开发与运维中&#xff0c;远程服务器管理是必不可少的一环。通过 SSH 协议&#xff0c;我们可以安全地连接到远程服务器并执行各种操作。Python 的 Paramiko 模块是一个强大的工具&#xff0c;能够帮助我们实现自动化任务&…...

备赛蓝桥杯之第十六届模拟赛3期职业院校组

提示&#xff1a;本篇文章仅仅是作者自己目前在备赛蓝桥杯中&#xff0c;自己学习与刷题的学习笔记&#xff0c;写的不好&#xff0c;欢迎大家批评与建议 由于个别题目代码量与题目量偏大&#xff0c;请大家自己去蓝桥杯官网【连接高校和企业 - 蓝桥云课】去寻找原题&#xff0…...

【Kafka】深入了解Kafka

集群的成员关系 Kafka使用Zookeeper维护集群的成员信息。 每一个broker都有一个唯一的标识&#xff0c;这个标识可以在配置文件中指定&#xff0c;也可以自动生成。当broker在启动时通过创建Zookeeper的临时节点把自己的ID注册到Zookeeper中。broker、控制器和其他一些动态系…...

C++特性——RAII、智能指针

RAII 就像new一个需要delete&#xff0c;fopen之后需要fclose&#xff0c;但这样会有隐形问题&#xff08;忘记释放&#xff09;。RAII即用对象把这个过程给包起来&#xff0c;对象构造的时候&#xff0c;new或者fopen&#xff0c;析构的时候delete. 为什么需要智能指针 对于…...

C++异常处理时的异常类型抛出选择

在 C 中选择抛出哪种异常类型&#xff0c;主要取决于错误的性质以及希望传达的语义信息。以下是一些指导原则&#xff0c;帮助在可能发生异常的地方选择合适的异常类型进行抛出&#xff1a; 1. std::exception 适用场景&#xff1a;作为所有标准异常的基类&#xff0c;std::e…...

elsticsearch 通过reindex修改shards

elasticsearch reindex 索引。 背景&#xff1a; 索引test1 reindex到test2 修改sharding数量 程序是通过别名test1_alias访问索引 1、创建目标索引test2 索引需要手动提前创建自动创建可能会有mapping 不一致性的风险。 The destination should be configured as wanted …...

CentOS系类普通挂载磁盘挂载命令

检查磁盘是否有分区 lsblk如果 vdb 下面没有分区&#xff08;比如 vdb1&#xff09;&#xff0c;你需要先创建分区。 创建分区&#xff08;如果需要&#xff09; fdisk /dev/vdb然后在 fdisk 交互界面&#xff1a; 输入 n 创建新分区 选择 p 创建主分区 默认分区号和大小 输…...

Kafka自定义分区机制

文章目录 1.如何自定义分区机制2.示例 1.如何自定义分区机制 若需要使用自定义分区机制&#xff0c;需要完成两件事&#xff1a; 1)在 producer 程序中创建一个类&#xff0c;实现 org.apache.kafka.clients.producer.Partitioner 接口主要分区逻辑在 Partitioner.partition中…...

【HarmonyOS NEXT】关键资产存储开发案例

在 iOS 开发中 Keychain 是一个非常安全的存储系统&#xff0c;用于保存敏感信息&#xff0c;如密码、证书、密钥等。与文件系统不同&#xff0c;Keychain 提供了更高的安全性&#xff0c;因为它对数据进行了加密&#xff0c;并且只有经过授权的应用程序才能访问存储的数据。那…...

强化学习(赵世钰版)-学习笔记(9.策略梯度法)

本章是课程的导数第二章&#xff0c;旨在讲解策略的函数化形式。 之前的方法&#xff0c;描述一个策略都是用表格的形式&#xff0c;每一行代表一个状态&#xff0c;每一列代表一个行为&#xff0c;表格中的元素对应相关状态下执行相关行为的概率。 函数化的策略表征形式是指&a…...

ModuleNotFoundError: No module named ‘flask‘ 错误

要解决 ModuleNotFoundError: No module named ‘flask’ 错误&#xff0c;需确保已正确安装 Flask 库。以下是详细步骤&#xff1a; ‌1. 安装 Flask‌ 在终端或命令行中执行以下命令&#xff08;注意权限问题&#xff09;&#xff1a; 使用 pip 安装 pip install flask 若…...

【c++】【STL】unordered_set 底层实现(简略版)

【c】【STL】unordered_set 底层实现&#xff08;简略版&#xff09; ps:这个是我自己看的不保证正确&#xff0c;觉得太长的后面会总结整个调用逻辑 unordered_set 内部实现 template <class _Kty, class _Hasher hash<_Kty>, class _Keyeq equal_to<_Kty>…...

【Zephyr】【一】学习笔记

Zephyr RTOS 示例代码集 1. 基础示例 1.0 基础配置 每个示例都需要一个 prj.conf 文件来配置项目。以下是各个示例所需的配置&#xff1a; 基础示例 prj.conf # 控制台输出 CONFIG_PRINTKy CONFIG_SERIALy CONFIG_UART_CONSOLEy# 日志系统 CONFIG_LOGy CONFIG_LOG_DEFAULT…...

网络安全设备配置与管理-实验4-防火墙AAA服务配置

实验4-p118防火墙AAA服务配置 从这个实验开始&#xff0c;每一个实验都是长篇大论&#x1f613; 不过有好兄弟会替我出手 注意&#xff1a;1. gns3.exe必须以管理员身份打开&#xff0c;否则ping不通虚拟机。 win10虚拟机无法做本次实验&#xff0c;必须用学校给的虚拟机。首…...

后端框架模块化

后端框架的模块化设计旨在简化开发流程、提高可维护性&#xff0c;并通过分层解耦降低复杂性。以下是常见的后端模块及其在不同语言&#xff08;Node.js、Java、Python&#xff09;中的实现方式&#xff1a; 目录 1. 路由&#xff08;Routing&#xff09;2. 中间件&#xff08;…...

【论文阅读】Contrastive Clustering Learning for Multi-Behavior Recommendation

论文地址&#xff1a;Contrastive Clustering Learning for Multi-Behavior Recommendation | ACM Transactions on Information Systems 摘要 近年来&#xff0c;多行为推荐模型取得了显著成功。然而&#xff0c;许多模型未充分考虑不同行为之间的共性与差异性&#xff0c;以…...

视频转音频, 音频转文字

Ubuntu 24 环境准备 # 系统级依赖 sudo apt update && sudo apt install -y ffmpeg python3-venv git build-essential python3-dev# Python虚拟环境 python3 -m venv ~/ai_summary source ~/ai_summary/bin/activate核心工具链 工具用途安装命令Whisper语音识别pip …...