nlp培训重点-5
1. LoRA微调
loader:
# -*- coding: utf-8 -*-import json
import re
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
"""
数据加载
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.index_to_label = {0: '家居', 1: '房产', 2: '股票', 3: '社会', 4: '文化',5: '国际', 6: '教育', 7: '军事', 8: '彩票', 9: '旅游',10: '体育', 11: '科技', 12: '汽车', 13: '健康',14: '娱乐', 15: '财经', 16: '时尚', 17: '游戏'}self.label_to_index = dict((y, x) for x, y in self.index_to_label.items())self.config["class_num"] = len(self.index_to_label)if self.config["model_type"] == "bert":self.tokenizer = BertTokenizer.from_pretrained(config["pretrain_model_path"])self.vocab = load_vocab(config["vocab_path"])self.config["vocab_size"] = len(self.vocab)self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:for line in f:line = json.loads(line)tag = line["tag"]label = self.label_to_index[tag]title = line["title"]if self.config["model_type"] == "bert":input_id = self.tokenizer.encode(title, max_length=self.config["max_length"], pad_to_max_length=True)else:input_id = self.encode_sentence(title)input_id = torch.LongTensor(input_id)label_index = torch.LongTensor([label])self.data.append([input_id, label_index])returndef encode_sentence(self, text):input_id = []for char in text:input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))input_id = self.padding(input_id)return input_id#补齐或截断输入的序列,使其可以在一个batch内运算def padding(self, input_id):input_id = input_id[:self.config["max_length"]]input_id += [0] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]def load_vocab(vocab_path):token_dict = {}with open(vocab_path, encoding="utf8") as f:for index, line in enumerate(f):token = line.strip()token_dict[token] = index + 1 #0留给padding位置,所以从1开始return token_dict#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("valid_tag_news.json", Config)print(dg[1])
model:
import torch.nn as nn
from config import Config
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim import Adam, SGDTorchModel = AutoModelForSequenceClassification.from_pretrained(Config["pretrain_model_path"])def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)
evaluate:
# -*- coding: utf-8 -*-
import torch
from loader import load_data"""
模型效果测试
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)self.stats_dict = {"correct":0, "wrong":0} #用于存储测试结果def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)self.model.eval()self.stats_dict = {"correct": 0, "wrong": 0} # 清空上一轮结果for index, batch_data in enumerate(self.valid_data):if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_ids, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况with torch.no_grad():pred_results = self.model(input_ids)[0]self.write_stats(labels, pred_results)acc = self.show_stats()return accdef write_stats(self, labels, pred_results):# assert len(labels) == len(pred_results)for true_label, pred_label in zip(labels, pred_results):pred_label = torch.argmax(pred_label)# print(true_label, pred_label)if int(true_label) == int(pred_label):self.stats_dict["correct"] += 1else:self.stats_dict["wrong"] += 1returndef show_stats(self):correct = self.stats_dict["correct"]wrong = self.stats_dict["wrong"]self.logger.info("预测集合条目总量:%d" % (correct +wrong))self.logger.info("预测正确条目:%d,预测错误条目:%d" % (correct, wrong))self.logger.info("预测准确率:%f" % (correct / (correct + wrong)))self.logger.info("--------------------")return correct / (correct + wrong)
main:
# -*- coding: utf-8 -*-import torch
import os
import random
import os
import numpy as np
import torch.nn as nn
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
from peft import get_peft_model, LoraConfig, \PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig #[DEBUG, INFO, WARNING, ERROR, CRITICAL]
logging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""seed = Config["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)def main(config):#创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加载训练数据train_data = load_data(config["train_data_path"], config)#加载模型model = TorchModel#大模型微调策略tuning_tactics = config["tuning_tactics"]if tuning_tactics == "lora_tuning":peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "key", "value"])elif tuning_tactics == "p_tuning":peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)elif tuning_tactics == "prompt_tuning":peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)elif tuning_tactics == "prefix_tuning":peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)model = get_peft_model(model, peft_config)# print(model.state_dict().keys())if tuning_tactics == "lora_tuning":# lora配置会冻结原始模型中的所有层的权重,不允许其反传梯度# 但是事实上我们希望最后一个线性层照常训练,只是bert部分被冻结,所以需要手动设置for param in model.get_submodule("model").get_submodule("classifier").parameters():param.requires_grad = True# 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()#加载优化器optimizer = choose_optimizer(config, model)#加载效果测试类evaluator = Evaluator(config, model, logger)#训练for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):if cuda_flag:batch_data = [d.cuda() for d in batch_data]optimizer.zero_grad()input_ids, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况output = model(input_ids)[0]loss = nn.CrossEntropyLoss()(output, labels.view(-1))loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))acc = evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "%s.pth" % tuning_tactics)save_tunable_parameters(model, model_path) #保存模型权重return accdef save_tunable_parameters(model, path):saved_params = {k: v.to("cpu")for k, v in model.named_parameters()if v.requires_grad}torch.save(saved_params, path)if __name__ == "__main__":main(Config)
pred:
import torch
import logging
from model import TorchModel
from peft import get_peft_model, LoraConfig, PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfigfrom evaluate import Evaluator
from config import Configlogging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)#大模型微调策略
tuning_tactics = Config["tuning_tactics"]print("正在使用 %s"%tuning_tactics)if tuning_tactics == "lora_tuning":peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "key", "value"])
elif tuning_tactics == "p_tuning":peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prompt_tuning":peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prefix_tuning":peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)#重建模型
model = TorchModel
# print(model.state_dict().keys())
# print("====================")model = get_peft_model(model, peft_config)
# print(model.state_dict().keys())
# print("====================")state_dict = model.state_dict()#将微调部分权重加载
if tuning_tactics == "lora_tuning":loaded_weight = torch.load('output/lora_tuning.pth')
elif tuning_tactics == "p_tuning":loaded_weight = torch.load('output/p_tuning.pth')
elif tuning_tactics == "prompt_tuning":loaded_weight = torch.load('output/prompt_tuning.pth')
elif tuning_tactics == "prefix_tuning":loaded_weight = torch.load('output/prefix_tuning.pth')print(loaded_weight.keys())
state_dict.update(loaded_weight)#权重更新后重新加载到模型
model.load_state_dict(state_dict)#进行一次测试
model = model.cuda()
evaluator = Evaluator(Config, model, logger)
evaluator.eval(0)
相关文章:
nlp培训重点-5
1. LoRA微调 loader: # -*- coding: utf-8 -*-import json import re import os import torch import numpy as np from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizer """ 数据加载 """cl…...
电子学会—2024年月6青少年软件编程(图形化)四级等级考试真题——水仙花数
水仙花数 如果一个三位数等于它各个数位上的数字的立方和,那么这个数就是水仙花数,例如:153 111 555 333,153就是一个水仙花数。 1.准备工作 (1)保留默认角色小猫; (2)白色背景。 2.功能实现 (1)使用循环遍历所有三位数,把所…...
若依分页的逻辑分析
看了一些网上的感觉都是 听君一席话, 如听一席话. 下面开始简单的分析一下, 随便找一个接口, 看一下前端的请求地址: 请求方式: GET 请求地址: http://localhost/dev-api/system/role/list?pageNum1&pageSize10 后端接口: PreAuthorize("ss.hasPermi(system:role:li…...
JetBrains学生申请
目录 JetBrains学生免费授权申请 IDEA安装与使用 第一个JAVA代码 1.利用txt文件和cmd命令运行 2.使用IDEA新建项目 JetBrains学生免费授权申请 本教程采用学生校园邮箱申请,所以要先去自己的学校申请校园邮箱。 进入JetBrains官网 点击立即申请,然…...
【算法方法总结·五】链表操作的一些技巧和注意事项
【算法方法总结五】链表操作的一些技巧和注意事项 【算法方法总结一】二分法的一些技巧和注意事项【算法方法总结二】双指针的一些技巧和注意事项【算法方法总结三】滑动窗口的一些技巧和注意事项【算法方法总结四】字符串操作的一些技巧和注意事项【算法方法总结五】链表操作…...
langchain系列(终)- LangGraph 多智能体详解
目录 一、导读 二、概念原理 1、智能体 2、多智能体 3、智能体弊端 4、多智能体优点 5、多智能体架构 6、交接(Handoffs) 7、架构说明 (1)网络 (2)监督者 (3)监督者&…...
侯捷 C++ 课程学习笔记:深入理解智能指针
文章目录 每日一句正能量一、引言二、智能指针的核心概念(一)std::unique_ptr(二)std::shared_ptr(三)std::weak_ptr 三、学习心得四、实际应用案例五、总结 每日一句正能量 如果说幸福是一个悖论ÿ…...
访问不了 https://raw.githubusercontent.com 怎么办?
修改 Hosts 文件(推荐) 原理:通过手动指定域名对应的 IP 地址,绕过 DNS 污染。 步骤: 1、访问 IPAddress.com,搜索 raw.githubusercontent.com,获取当前最新的 IPv4 地址(例如 1…...
大模型工程师学习日记(十五):Hugging Face 模型微调训练(基于 BERT 的中文评价情感分析)
1. datasets 库核心方法 1.1. 列出数据集 使用 d atasets 库,你可以轻松列出所有 Hugging Face 平台上的数据集: from datasets import list_datasets# 列出所有数据集 all_datasets list_datasets()print(all_datasets)1.2. 加载数据集 你可以通过 l…...
Codeforces Round 258 (Div. 2) E. Devu and Flowers 生成函数
题目链接 题目大意 有 n n n ( 1 ≤ n ≤ 20 ) (1\leq n \leq 20) (1≤n≤20) 个花瓶,第 i i i 个花瓶里有 f i f_i fi ( 1 ≤ f i ≤ 1 0 12 ) (1\leq f_i \leq 10^{12}) (1≤fi≤1012) 朵花。现在要选择 s s s ( 1 ≤ s ≤ 1 0 14 ) (1\leq s \leq 1…...
MySQL-----SELECT语句-查询
目录 SELECT语句-查询 1.格式 2.操作 3.算数表达式 SELECT语句-查询 1.格式 📖简单查询: 格式: select 字段1,字段n from 表名; 起别名: 通过在字段后添加 as 别名 as可以省略 改变表头 eg: select username "用户名",password as "…...
子数组、子串系列(典型算法思想)—— OJ例题算法解析思路
一、53. 最大子数组和 - 力扣(LeetCode) 算法代码: class Solution { public:int maxSubArray(vector<int>& nums) {// 1. 创建 dp 表// dp[i] 表示以第 i 个元素结尾的子数组的最大和int n nums.size();vector<int> dp(n…...
Windows编程----进程的当前目录
进程的当前目录 Windows Api中有大量的函数在调用的时候,需要传递路径。比如创建文件,创建目录,删除目录,删除文件等等。拿创建文件的CreateFile函数做比喻,如果我们要创建的文件路径不是全路径,那么wind…...
AVL树的介绍及实现
文章目录 (一)AVL的概念(二)AVL树的实现1.AVL树的结构2.AVL树的插入3.AVL树的查找 (三)检查一棵树是否是AVL树 (一)AVL的概念 AVL树是一棵高度平衡的二叉搜索树,通过控制…...
hadoop第3课(hdfs shell常用命令)
一、Hadoop FS 基础操作命令 1. 查看帮助 hadoop fs -help [命令名] # 查看具体命令的帮助文档 # 示例: hadoop fs -help mkdir2. 目录操作 hadoop fs -mkdir /path # 创建目录 hadoop fs -mkdir -p /path/a/b # 递归创建多级目录 hadoop fs -rmdir …...
为什么Java不采用引用传递方式
Java不采用引用传递方式,而是统一采用值传递机制,这一设计决策背后有多种原因。 1. 语言设计的简洁性与一致性 Java的设计目标之一是保持语言的简洁性和一致性。如果同时支持值传递和引用传递,可能会导致语言复杂度增加,使得开发者难以理解和使用。通过统一采用值传递机制…...
【RAG】文本分割的粒度
文本分隔 可能存在的问题 粒度太大可能导致检索不精准粒度太小可能导致信息不全面问题的答案可能跨越两个片段 # 创建一个向量数据库对象 vector_db MyVectorDBConnector("demo_text_split", get_embeddings) # 向向量数据库中添加文档 vector_db.add_documents(p…...
Qt信号与槽机制实现原理
Qt 的信号和槽机制是其核心特性之一,用于实现对象间的松耦合通信。以下是对其实现原理的详细分析: 1. 元对象系统(Meta-Object System) Q_OBJECT 宏与 moc Qt 通过元对象系统实现反射能力。声明 Q_OBJECT 宏的类会由 moc…...
Vue3 中 Computed 用法
Computed 又被称作计算属性,用于动态的根据某个值或某些值的变化,来产生对应的变化,computed 具有缓存性,当无关值变化时,不会引起 computed 声明值的变化。 产生一个新的变量并挂载到 vue 实例上去。 vue3 中 的 com…...
《今日AI-人工智能-编程日报》
一、AI行业动态 AI模型作弊行为引发担忧 最新研究表明,AI在国际象棋对弈中表现出作弊倾向,尤其是高级推理模型如OpenAI的o1-preview和DeepSeek的R1模型。这些模型通过篡改代码、窃取棋路等手段试图扭转战局,且作弊行为与其智能水平正相关。研…...
快速生成viso流程图图片形式
我们在写详细设计文档的过程中总会不可避免的涉及到时序图或者流程图的绘制,viso这个软件大部分技术人员都会使用,但是想要画的好看,画的科学还是比较难的,现在我总结一套比较好的方法可以生成好看科学的viso图(图片格式)。主要思…...
centos7关闭与开启图形界面
centos7关闭图形界面 systemctl set-default multi-user.target rebootcentos7开启图形界面 systemctl set-default graphical.target reboot...
linux学习(十)(磁盘和文件系统(索引节点,文件系统,添加磁盘,交换,LVM公司,挂载))
Linux 磁盘文件系统 Linux 使用各种文件系统来允许我们从计算机系统的硬件(例如磁盘)存储和检索数据。文件系统定义了如何在这些存储设备上组织、存储和检索数据。流行的 Linux 文件系统示例包括 EXT4、FAT32、NTFS 和 Btrfs。 每个文件系统都有自己的…...
vulkanscenegraph显示倾斜模型(5.2)-交换链
前言 在 VulkanSceneGraph(VSG)中,vsg::Window 类对窗口进行了高层次的封装,为开发者提供了便捷的窗口管理接口。在上一篇文章中,我们探讨了 VkInstance、VkSurfaceKHR、VkPhysicalDevice 和 VkDevice 的创建过程&…...
【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯
【极光 Orbit•STC8H】03. 小刀初试:点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤,LED明灭见真章。 端口配置定方向,寄存器值细推敲。 高低电平随心控,循环闪烁展锋芒。 嵌入式门初开启,从此代码手中扬。 摘要 …...
实现一键不同环境迁移ES模板
实现概述: 1、查询环境A模板信息 2、获取模板信息值转换 3、同步保存至环境B package com.jayce.boot.route.common.util;import com.fasterxml.jackson.databind.JsonNode; import com.google.common.collect.Lists; import com.jayce.boot.route.common.util.…...
Nacos学习笔记-占位符读取其他命名空间内容
Nacos当前命名空间下的配置文件需要跨命名空间读取其他配置文件的内容。可以先通过Nacos提供的API接口获取配置文件内容,然后解析数据将其放入环境的PropertySource中。 相关依赖包 <!-- Nacos依赖包 --> <dependency><groupId>com.alibaba.clo…...
OSPF报文分析
OSPF报文分析 组播地址 224.0.0.0~224.0.0.255为预留的组播地址(永久组地址),地址224.0.0.0保留不做分配,其它地址供路由协议使用; 224.0.1.0~238.255.255.255为用户可用的组播地址(…...
MySql性能(9)- mysql的order by的工作原理
全字段排序rowid排序全字段排序和rowid排序 3.1 联合索引优化 3.2 覆盖索引优化优先队列算法优化建议 5.1 修改系统参数 5.2 优化sql 1. 全字段排序 CREATE TABLE t ( id int(11) NOT NULL,city varchar(16) NOT NULL, name varchar(16) NOT NULL, age int(11) NOT NULL,addr v…...
死锁问题分析工具
使用 gdb 调试 gdb ./your_program (gdb) run (gdb) thread apply all bt还可以分析pthread_mutex内部,查看owen字段分析哪个线程占用的锁,一个可能的 pthread_mutex 内部结构可以大致表示为: typedef struct pthread_mutex_t {int state; …...
