当前位置: 首页 > news >正文

Gemma

Gemma

  • 1.使用
  • 2.RAG
  • 3.LoRA
    • 3.1LoRA分类任务
    • 3.2LoRA中文建模任务

1.使用

首先是去HF下载模型,但一直下载不了,所以去了HF镜像网站,下载gemma需要HF的Token,按照步骤就可以下载。代码主要是Kaggle论坛里面的分享内容。

huggingface-cli download --token hf_XXX --resume-download google/gemma-7b --local-dir gemma-7b-mirror

这里我有时是2b有时是7b,换着用。

from transformers import AutoTokenizer, AutoModelForCausalLM  
tokenizer = AutoTokenizer.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")
Gemma = AutoModelForCausalLM.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")
def answer_the_question(question):input_ids = tokenizer(question, return_tensors="pt")generated_text = Gemma.generate(**input_ids,max_length=256)answer = tokenizer.decode(generated_text[0], skip_special_tokens=True)return answer
question = "给我写一首优美的诗歌?"
answer = answer_the_question(question)
print(answer)

2.RAG

参考

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
##2.1 根据question检索sentence chunk
import os
def get_all_pdfs(directory):pdf_files = []for root, dirs, files in os.walk(directory):for file in files:if file.endswith(".pdf"):pdf_files.append(os.path.join(root, file))return pdf_filesclass RAG:def __init__(self, num_retrieved_docs=5, pdf_folder_path='D:/Gemma/PDF'):pdf_files = get_all_pdfs(pdf_folder_path)print("Documents used", pdf_files)loaders = [PyPDFLoader(pdf_file) for pdf_file in pdf_files]all_documents = []for loader in loaders:raw_documents = loader.load()text_splitter = CharacterTextSplitter(separator="\n\n",chunk_size=10,chunk_overlap=1,# length_function=len,)documents = text_splitter.split_documents(raw_documents)all_documents.extend(documents)embeddings = HuggingFaceEmbeddings(model_name="D:/Projects/model/m3e-base")    self.db = FAISS.from_documents(all_documents, embeddings)self.retriever = self.db.as_retriever(search_kwargs={"k": num_retrieved_docs})def search(self, query):docs = self.retriever.get_relevant_documents(query)return docs
retriever = RAG()
##2.2根据sentence chunk和question去回答
class Assistant:def __init__(self):self.tokenizer = AutoTokenizer.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")self.Gemma = AutoModelForCausalLM.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")def create_prompt(self, query, retrieved_info):prompt = f"""你是人工智能助手,需要根据Relevant information里面的相关内容回答用户的Instruction,其中相关信息如下:Instruction: {query}Relevant information: {retrieved_info}Output:"""print(prompt)return promptdef reply(self, query, retrieved_info):prompt = self.create_prompt(query, retrieved_info)input_ids = self.tokenizer(query, return_tensors="pt").input_ids# Generate text with a focus on factual responsesgenerated_text = self.Gemma.generate(input_ids,do_sample=True,max_length=500,temperature=0.7, # Adjust temperature according to the task, for code generation it can be 0.9)# Decode and return the answeranswer = self.tokenizer.decode(generated_text[0], skip_special_tokens=True)return answer
chatbot = Assistant()
## 2.3开始使用RAG
def generate_reply(query):related_docs = retriever.search(query)#print('related docs', related_docs)reply = chatbot.reply(query, related_docs)return reply
reply = generate_reply("存在的不足及后续的优化工作")
for s in reply.split('\n'):print(s)

3.LoRA

3.1LoRA分类任务

参考
使用nlp-getting-started数据集训练模型做二分类任务。首先拿到源model

from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments,pipeline
from peft import prepare_model_for_int8_training,LoraConfig, TaskType, get_peft_model
import numpy as np
NUM_CLASSES = 2#模型输出分类的类别数
BATCH_SIZE,EPOCHS,R,LORA_ALPHA,LORA_DROPOUT = 8,5,64,32,0.1#LoRA训练的参数
MODEL_PATH="D:/Gemma/gemma-2b-int-mirror2"#模型地址
# 1.源model,设置输出二分类
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH,num_labels=NUM_CLASSES)
print(model)

处理csv数据,将输入文字经过tokenizer编码处理

#2.处理dataset,输入过长进行truncation(tokenizer处理后)
dataset = load_dataset('csv', data_files='D:/Gemma/nlp-getting-started/train.csv')
dataset['test'] = dataset['train']
dataset = dataset.remove_columns(['id', 'keyword', 'location'])
dataset = dataset.rename_column("target", "label")#csv最后只保留了text列和label列
tokenized_dataset = {}#train和test
for split in dataset.keys():tokenized_dataset[split] = dataset[split].map(lambda x: tokenizer(x["text"], truncation=True), batched=True)
print(tokenized_dataset["train"])
print(tokenized_dataset["train"][1])

在源model基础上配置LoRA的参数,形成lora_model

#3.LoRA模型参数设置
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(r=R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,task_type=TaskType.SEQ_CLS,#SEQ_CLS:序列分类任务;TOKEN_CLS命名实体识别;SEQ2SEQ机器翻译;LM语言建模任务target_modules='all-linear'#all-linear所有线性层;embeddings嵌入层;convs卷积层
)
lora_model = get_peft_model(model, lora_config)
print(lora_model)
print(lora_model.print_trainable_parameters())#LoRA模型要训练的参数

配置lora_model的训练参数

#4.LoRA训练参数设置(损失计算等)
def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return {"accuracy": (predictions == labels).mean()}trainer = Trainer(model=lora_model,args=TrainingArguments(output_dir="./LoAR_data/",learning_rate=2e-5,per_device_train_batch_size=BATCH_SIZE,per_device_eval_batch_size=BATCH_SIZE,evaluation_strategy="epoch",save_strategy="epoch",num_train_epochs=EPOCHS,weight_decay=0.01,load_best_model_at_end=True,logging_steps=10,report_to="none"),train_dataset=tokenized_dataset["train"],eval_dataset=tokenized_dataset["test"],tokenizer=tokenizer,data_collator=DataCollatorWithPadding(tokenizer=tokenizer),compute_metrics=compute_metrics,
)

开始训练并保存使用模型

#5.训练并评估
print("Evaluating the Model Before Training!")
trainer.evaluate()
print("Training the Model")
trainer.train()
print("Evaluating the trained model")
trainer.evaluate()
#6.保存并使用
lora_model.save_pretrained('fine-tuned-model')
clf = pipeline("text-classification", lora_model, tokenizer=MODEL_PATH)#LoRA训练后的模型

3.2LoRA中文建模任务

参考
首先拿到源model和config

from transformers import AutoConfig,AutoTokenizer,AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model,prepare_model_for_kbit_training,PeftModel
import torch
import  datasets
from tqdm import tqdm
import json
BATCH_SIZE,EPOCHS,R,LORA_ALPHA,LORA_DROPOUT = 8,5,64,32,0.1#LoRA训练的参数
MODEL_PATH="D:/Gemma/gemma-2b-int-mirror2"#模型地址
device = torch.device('cuda:0')
# 1.源model和model的config
config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True)
config.is_causal = True  #确保模型在生成文本时只能看到左侧的上下文
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,device_map="auto", config=config,trust_remote_code=True)

根据模型和config处理json数据

#2.根据model的config处理dataset(tokenizer处理后),并保存加载
def preprocess(tokenizer: PreTrainedTokenizer, config, file_path, max_seq_length, prompt_key, target_key, skip_overlength=False):  # 数据预处理  pad_token_id = tokenizer.pad_token_id  # 获取填充标记的ID  with open(file_path, "r", encoding="utf8") as f:  for line in tqdm(f.readlines()):  example = json.loads(line)  prompt_ids = tokenizer.encode(example[prompt_key], max_length=max_seq_length, truncation=True)  target_ids = tokenizer.encode(example[target_key], max_length=max_seq_length, truncation=True)  # 检查prompt和target连接后是否超出最大长度,并在需要时跳过  total_length = len(prompt_ids) + len(target_ids) + (1 if config.eos_token_id is not None else 0)  if skip_overlength and total_length > max_seq_length:  continue  # 连接prompt和target,并添加EOS标记(如果提供)  input_ids = prompt_ids + target_ids  if config.eos_token_id is not None:  input_ids.append(config.eos_token_id)  # 截断序列到最大长度  input_ids = input_ids[:max_seq_length]  # 填充序列到最大长度  input_ids.extend([pad_token_id] * (max_seq_length - len(input_ids)))  assert len(input_ids) == max_seq_length, "序列长度必须等于max_seq_length"  yield {  "input_ids": input_ids,  "seq_len": len(prompt_ids)  # 注意:这里提供的seq_len是原始prompt的长度,不包括填充  }
dataset = datasets.Dataset.from_generator(lambda: preprocess(tokenizer, config, "D:/Gemma/try/hc3_chatgpt_zh_specific_qa.json", max_seq_length=2000, prompt_key="q",target_key="a",))dataset.save_to_disk("h3c-chinese")  # 保存处理后的数据集
train_set = datasets.load_from_disk("h3c-chinese")#加载处理后的数据集

配置Lora参数

#3.LoRA模型参数设置
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(r=R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,task_type="CAUSAL_LM",target_modules='all-linear'
)
lora_model = get_peft_model(model, lora_config)
print(lora_model)
print(lora_model.print_trainable_parameters())#LoRA模型要训练的参数

配置lora的训练参数,包括损失计算compute_metrics,并对输入的input_ids构造输入样本列表批次处理。

trainer = Trainer(model=lora_model,args=TrainingArguments(output_dir="./LoAR_data2/",learning_rate=2e-5,per_device_train_batch_size=BATCH_SIZE,save_strategy="epoch",num_train_epochs=EPOCHS,weight_decay=0.01,logging_steps=10,report_to="none"),train_dataset=train_set,tokenizer=tokenizer,data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
#     compute_metrics=compute_metrics
)
trainer.train()

相关文章:

Gemma

Gemma 1.使用2.RAG3.LoRA3.1LoRA分类任务3.2LoRA中文建模任务 1.使用 首先是去HF下载模型,但一直下载不了,所以去了HF镜像网站,下载gemma需要HF的Token,按照步骤就可以下载。代码主要是Kaggle论坛里面的分享内容。 huggingface-…...

淘宝关键词搜索API、搜索商品接口、商品价格监控

淘宝搜索引擎的工作原理: 淘宝搜索引擎的工作原理是基于搜索引擎的核心技术——爬虫和索引,通过对海量数据的抓取、分析和存储,提供给用户最准确的搜索结果。 具体来说,淘宝搜索引擎的工作流程如下: 企业级api数据…...

vue实现水印功能

目录 一、应用场景 二、实现原理 三、详细开发 1.水印的实现方式 2.防止用户通过控制台修改样式去除水印效果(可跳过,有弊端) 3.水印的使用 (1)单页面/全局使用 (2)全局使用个别页面去掉…...

记录一下我的Ruby On Rails的systemd服务脚本

自己也是一个 ROR 框架的学习者,同时也是 Ruby 的新手。对于如何让 ROR 应用随系统自动启动并不是很了解。在尝试了各种方法之后,我最终找到了一条可行的途径。虽然不确定是否完全正确,但服务已经成功启动了。因此,我决定在这里保…...

【计算机网络】传输层——TCP和UDP详解

文章目录 一. TCP和UDP简介二. UDP 协议详解1. UDP报文格式2. UDP的使用场景 三. TCP 协议详解1. TCP报文格式2. TCP协议的重要机制确认应答(保证可靠传输的最核心机制)超时重传连接管理(三次握手、四次挥手)!&#xf…...

stm32和嵌入式linux可以同步学习吗?

在开始前我有一些资料,是我根据网友给的问题精心整理了一份「stm3的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!!如果需要使用STM32,建…...

maven--->maven中的<properties>属性有什么作用?

🙌🙌🙌🙌🙌🙌 在Maven中,元素用于定义项目中可重用的属性值。这些属性值可以在项目的POM文件中被引用,以便在整个项目中统一管理和使用。通过使用元素,可以避免在POM文件…...

android 网络请求总结

1 先看下基础部分: android okhttp网络访问是基于 tcp/ip 的 最上层是应用层的封装,有http,https(加密),ftp 下面是socket套接字的封装,就是将ip和端口的封装 在下面就是tcp/udp 在下面 ip协议…...

用 Python 自动化处理无聊的事情

“编程最棒的部分就是看到机器做一些有用的事情而获得的胜利。用 Python 将无聊的事情自动化将所有编程视为这些小小的胜利;它让无聊变得有趣。” Hilary Mason,数据科学家兼 Fast Forward Labs 创始人 “我很享受打破东西然后把它们重新组合起来的乐趣…...

稀疏计算、彩票假说、MoE、SparseGPT

稀疏计算可能是未来10年内最有潜力的深度学习方向之一,稀疏计算模拟了对人脑的观察,人脑在处理信息的时候只有少数神经元在活动,多数神经元是不工作的。而稀疏计算的基本思想是:在计算过程中,将一些不重要的参数设置为…...

Git Windows安装教程

Git简介 Git是目前世界上最先进的分布式版本控制系统。它的工作原理 / 流程如下: [ Workspace:工作区 Index / Stage:暂存区 Repository:仓库区(或本地仓库) Remote:远程仓库 ] Git的下载 去 Git 官网下载对应系统的软件了,下…...

iOS高级理论:Runtime应用

一、遍历类的属性,快速归档 在 iOS 中,可以使用 Runtime 遍历类的属性来实现快速的归档(Archiving)操作。归档是将对象转换为数据流以便存储或传输的过程。下面是一个简单的示例,展示如何使用 Runtime 遍历类的属性进…...

php判断和过滤get或者post的html标签,防止跨站点脚本(XSS),链接注入,框架注入等攻击

大部分网站都包含搜索功能,根据用户搜索的词去执行服务端的业务逻辑。如果一些黑客在搜索参数包含链接(a)、嵌入其他网页(iframe)、前端代码(script)等html字符,再加上服务端php不加…...

PySide6实现课堂点名程序

目录 一:实现思路 二:实现代码 三:完整代码和界面 一:实现思路 为了创建一点名程序,并编写一个基本的 GUI 应用程序。新建一个窗口,展在窗口界面添加开始和停止按钮的QPushButton,和展示正在显示的人名QLabel,点击开始时随机显示人名列表中的一个名字并且展示在QLab…...

瑞_Redis_Redis命令

文章目录 1 Redis命令Redis数据结构Redis 的 key 的层级结构1.0 Redis通用命令1.0.1 KEYS1.0.2 DEL1.0.3 EXISTS1.0.4 EXPIRE1.0.5 TTL 1.1 String类型1.1.0 String类型的常见命令1.1.1 SET 和 GET1.1.2 MSET 和 MGET1.1.3 INCR和INCRBY和DECY1.1.4 SETNX1.1.5 SETEX 1.2 Hash类…...

js 算法题 在数组中找出和为目标值 target 的那 两个 整数,并返回它们的数组下标

题目:给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你可以…...

基于springboot接口的编写

目录 1、模糊分页查询 2、批量删除 3、新增 4、编辑 此接口非彼接口。此接口是MVC的设计模式中的Controller层,一般我们会叫Controller层里的方法为接口。他们是负责接收前端或者其它服务的传来的请求,并对请求进行相应的处理,最终再将处…...

【HarmonyOS】鸿蒙开发之Video组件——第3.7章

Video组件内VideoOptions属性简介 src:设置视频地址。currentProgressRate:设置视频播放倍速,参数说明如下: number|string:只支持 0.75 , 1.0 , 1.25 , 1.75 , 2.0 。P…...

React引入css的几种方式以及应用

1.直接引入css文件 import "./parent.css" 2.引入css模块,定义文件名[组件名.module.css];该方式可避免类名的重复,每个组件都有独立的作用域,避免了全局污染,保证了类名的唯一性 import styles from &qu…...

[算法沉淀记录] 排序算法 —— 冒泡排序

排序算法 —— 冒泡排序 基本概念 冒泡排序是一种简单的排序算法。它重复地遍历要排序的列表,一次比较两个元素,并交换它们的位置,如果它们不是按照升序排列的。这步遍历是重复进行的,直到没有再需要交换,也就是说该…...

synchronized 学习

学习源: https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.应用场景 不超卖,也要考虑性能问题(场景) 2.常见面试问题: sync出…...

Linux链表操作全解析

Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表?1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...

ESP32读取DHT11温湿度数据

芯片:ESP32 环境:Arduino 一、安装DHT11传感器库 红框的库,别安装错了 二、代码 注意,DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

HTML 列表、表格、表单

1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...

12.找到字符串中所有字母异位词

🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...