完整的端到端的中文聊天机器人
这段代码是一个完整的端到端的中文聊天机器人的实现,包括数据处理、模型训练、预测和图形用户界面(GUI),下面是对各个部分功能的详细说明:
1. 导入必要的库
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import random
import tkinter as tk
import jieba
import matplotlib.pyplot as plt
import os
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.amp import GradScaler, autocast
os: 用于设置环境变量和文件操作。
torch: PyTorch 库,用于构建和训练深度学习模型。
tkinter: 用于创建图形用户界面。
jieba: 用于中文分词。
matplotlib: 用于绘制损失曲线。
json: 用于读取 JSON 文件。
transformers: Hugging Face 的 Transformers 库,用于加载预训练模型和分词器。
torch.amp: 用于混合精度训练,提高训练速度和减少内存占用。
2. 定义特殊标记和词汇表
PAD_TOKEN = "<PAD>"
UNK_TOKEN = "<UNK>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"word2index = {PAD_TOKEN: 0, UNK_TOKEN: 1, SOS_TOKEN: 2, EOS_TOKEN: 3}
index2word = {0: PAD_TOKEN, 1: UNK_TOKEN, 2: SOS_TOKEN, 3: EOS_TOKEN}
特殊标记:定义了四个特殊标记,分别表示填充、未知词、句子开始和句子结束。
词汇表:初始化词汇表,将特殊标记映射到索引。
3. 中文分词
def tokenize_chinese(sentence):tokens = jieba.lcut(sentence)return tokens
功能:使用 jieba 对输入的中文句子进行分词,返回分词后的词汇列表。
4. 构建词汇表
def build_vocab(sentences):global word2index, index2wordvocab_size = len(word2index)for sentence in sentences:for token in tokenize_chinese(sentence):if token not in word2index:word2index[token] = vocab_sizeindex2word[vocab_size] = tokenvocab_size += 1return vocab_size
功能:遍历所有句子,构建词汇表,将每个词映射到一个唯一的索引。
5. 将句子转换为张量
def sentence_to_tensor(sentence, max_length=50):tokens = tokenize_chinese(sentence)indices = [word2index.get(token, word2index[UNK_TOKEN]) for token in tokens]indices = [word2index[SOS_TOKEN]] + indices + [word2index[EOS_TOKEN]]indices += [word2index[PAD_TOKEN]] * (max_length - len(indices))return torch.tensor(indices, dtype=torch.long), len(indices)
功能:将输入的句子转换为张量,并返回句子的实际长度。句子被加上 和 标记,并用 标记填充到指定的最大长度。
6. 读取数据
def load_data(file_path):if file_path.endswith('.jsonl'):with open(file_path, 'r', encoding='utf-8') as f:lines = [json.loads(line) for line in f.readlines()]elif file_path.endswith('.json'):with open(file_path, 'r', encoding='utf-8') as f:lines = json.load(f)else:raise ValueError("不支持的文件格式。请使用 .jsonl 或 .json。")questions = [line['question'] for line in lines]answers = [random.choice(line['human_answers'] + line['chatgpt_answers']) for line in lines]return questions, answers
功能:从指定的 JSON 或 JSONL 文件中读取数据,返回问题和答案列表。
7. 数据增强
def data_augmentation(sentence):tokens = tokenize_chinese(sentence)augmented_sentence = []if random.random() < 0.1:insert_token = random.choice(list(word2index.keys())[4:])insert_index = random.randint(0, len(tokens))tokens.insert(insert_index, insert_token)if random.random() < 0.1 and len(tokens) > 1:delete_index = random.randint(0, len(tokens) - 1)del tokens[delete_index]if len(tokens) > 1 and random.random() < 0.1:index1, index2 = random.sample(range(len(tokens)), 2)tokens[index1], tokens[index2] = tokens[index2], tokens[index1]augmented_sentence = ''.join(tokens)return augmented_sentence
功能:对输入的句子进行随机插入、删除和交换操作,以增加数据的多样性。
8. 定义数据集
class ChatDataset(Dataset):def __init__(self, questions, answers):self.questions = questionsself.answers = answersdef __len__(self):return len(self.questions)def __getitem__(self, idx):input_tensor, input_length = sentence_to_tensor(self.questions[idx])target_tensor, target_length = sentence_to_tensor(self.answers[idx])return input_tensor, target_tensor, input_length, target_length
功能:定义一个自定义的数据集类,用于存储问题和答案,并将它们转换为张量。
9. 自定义 collate 函数
def collate_fn(batch):inputs, targets, input_lengths, target_lengths = zip(*batch)inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=word2index[PAD_TOKEN])targets = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=word2index[PAD_TOKEN])return inputs, targets, torch.tensor(input_lengths), torch.tensor(target_lengths)
功能:将一批数据进行填充,使其具有相同的长度,并返回填充后的输入、目标、输入长度和目标长度。
10. 创建数据集和数据加载器
def create_dataset_and_dataloader(questions_file, answers_file, batch_size=10, shuffle=True, split_ratio=0.8):questions, answers = load_data(questions_file)vocab_size = build_vocab(questions + answers)dataset = ChatDataset(questions, answers)train_size = int(split_ratio * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)return train_dataset, train_dataloader, val_dataset, val_dataloader, vocab_size
功能:创建训练和验证数据集及数据加载器,并返回词汇表的大小。
11. 定义模型结构
class Encoder(nn.Module):def __init__(self, input_size, hidden_size, num_layers=1):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)outputs, hidden = self.gru(packed, hidden)outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)return outputs, hiddenclass Decoder(nn.Module):def __init__(self, output_size, hidden_size, num_layers=1):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_step, hidden, encoder_outputs):embedded = self.embedding(input_step)gru_output, hidden = self.gru(embedded, hidden)output = self.softmax(self.out(gru_output.squeeze(1)))return output, hiddenclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device, tokenizer):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = deviceself.tokenizer = tokenizerdef forward(self, input_tensor, target_tensor, input_lengths, target_lengths, teacher_forcing_ratio=0.5):batch_size = input_tensor.size(0)max_target_len = max(target_lengths)vocab_size = self.decoder.out.out_featuresoutputs = torch.zeros(batch_size, max_target_len, vocab_size).to(self.device)encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)decoder_input = torch.tensor([[word2index[SOS_TOKEN]] * batch_size], device=self.device).transpose(0, 1)decoder_hidden = encoder_hiddenfor t in range(max_target_len<
相关文章:

完整的端到端的中文聊天机器人
这段代码是一个完整的端到端的中文聊天机器人的实现,包括数据处理、模型训练、预测和图形用户界面(GUI),下面是对各个部分功能的详细说明: 1. 导入必要的库 import os os.environ[CUDA_LAUNCH_BLOCKING] = 1import torch import torch.nn as nn import torch.optim as o…...

【有啥问啥】Stackelberg博弈方法:概念、原理及其在AI中的应用
Stackelberg博弈方法:概念、原理及其在AI中的应用 1. 什么是Stackelberg博弈? Stackelberg博弈(Stackelberg Competition)是一种不对称的领导者-追随者(Leader-Follower)博弈模型,由德国经济学…...

【UI自动化】前言
系列文章目录 【UI自动化】前言 自动化不能代替手工测试,自动化都是以手工测试为基础,自动化测试实现的步骤要依赖手工; 文章目录 系列文章目录【UI自动化】前言 自动化测试的类型自动化解决的问题什么是UI测试测试分类一、使用UI自动化的…...

Unity对象池的高级写法 (Plus优化版)
唐老师关于对物体分类的OOD的写法确实十分好,代码也耦合度也低,但是我有个简单的写法同样能实现一样的效果,所以我就充分发挥了一下主观能动性 相较于基本功能,这一版做出了如下改动 1.限制了对象池最大数量,多出来的…...

vue3<script setup>中computed
在 Vue 3 中,<script setup> 语法糖是 Composition API 的一种简化写法,它允许你更简洁地编写组件逻辑。在 <script setup> 中使用 computed 与在普通 <script> 标签中使用 Composition API 的方式类似,但通常我们会借助 i…...

【已解决】使用JAVA语言实现递归调用-本关任务:用循环和递归算法求 n(小于 10 的正整数) 的阶乘 n!。
本关任务:用循环和递归算法求 n(小于 10 的正整数) 的阶乘 n!。 测试说明 平台会对你编写的代码进行测试,比对你输出的数值与实际正确数值,只有所有数据全部计算正确才能通过测试: 测试输入:1…...

BiRefNet 教程:基于 PyTorch 实现的双向精细化网络
BiRefNet 教程:基于 PyTorch 实现的双向精细化网络 BiRefNet 是一个图像分割网络,专注于复杂任务如背景移除、掩码生成、伪装物体检测、显著性目标检测等。该模型结合了编码器、解码器、多尺度特征提取、以及梯度监督机制,能够有效处理不同类…...

Oracle 数据库安装和配置指南(新)
目录 1. 什么是Oracle数据库? 2. 安装前的准备工作 2.1 硬件要求 2.2 软件要求 2.3 下载Oracle安装包 3. Oracle数据库的安装步骤 3.1 Windows系统安装步骤 3.2 Linux系统安装步骤 4. 配置Oracle数据库 4.1 设置环境变量(Linux) 4.…...

JavaScript的注释与常见输出方式
注释 源码中注释是不被引擎所解释的,它的作用是对代码进行解释。Javascript 提供两种注释的写法:一种是单行注释,用//起头;另一种是多行注释,放在/*和*/之间。 单行注释: //这是单行注释 多行注释: /*这是 多行 注…...

深入探索Android开发之Java核心技术学习大全
Android作为全球最流行的移动操作系统之一,其开发技能的需求日益增长。本文将为您介绍一套专为Android开发者设计的Java核心技术学习资料,包括详细的学习大纲、PDF文档、源代码以及配套视频教程,帮助您从Java基础到高级特性,再到A…...

vue3 选择字体的颜色,使用vue3-colorpicker来选择颜色
1、有的时候我们会用到颜色的选择器,像element-plus提供了,但是ant-design-vue并没有: 这个暂时没有看到: 但是Ant Design 5的版本有,应该不是vue的。 2、使用第三方提供的vue3-colorpicker:storybook/cli…...

windows C++ 并行编程-使用消息块筛选器
本文档演示了如何使用筛选器函数,使异步消息块能够根据消息的有效负载接受或拒绝消息。 创建消息块对象(例如 concurrency::unbounded_buffer、concurrency::call 或 concurrency::transformer)时,可以提供筛选器函数,用于确定消息块是接受还…...

【mysql技术内幕】
MySQL之技术内幕 1.MVCC模式2. 实现mvcc模式的基础点3.MySQL锁的类型4. 说下MySQL的索引有哪些吧?5. 谈谈分库分表6. 分表后的id咋么保证唯一性呢?7. 分表后非sharding key的查询咋么处理的? 1.MVCC模式 MVCC, 是multi-version concurrency c…...

快递物流单号识别API接口DEMO下载
单号识别API为用户提供单号识别快递公司服务,依托于快递鸟大数据平台,用户提供快递单号,即可实时返回可能的一个或多个快递公司,存在多个快递公司结果的,大数据平台根据可能性、单号量,进行智能排序。 应用…...

Jetpack——Room
概述 Room是谷歌公司推出的数据库处理框架,该框架同样基于SQLite,但它通过注解技术极大简化了数据库操作,减少了原来相当一部分编码工作量。在使用Room之前,要先修改模块的build.gradle文件,往dependencies节点添加下…...

Dynamic Connected Networks for Chinese Spelling Check(ACL2021)
Dynamic Connected Networks for Chinese Spelling Check(ACL2021) 一.概述 文中认为基于bert的非自回归语言模型依赖于输出独立性假设。不适当的独立性假设阻碍了基于bert的模型学习目标token之间的依赖关系,从而导致了不连贯的问题。为些,…...

前端vue-3种生命周期,只能在各自的领域使用
上面的表格可以简化为下面的两句话: setup是语法糖,下面的两个import导入是vue3和vue2的区别,现在的vue3直接导入,比之前vue2简单 还可以是导入两个生命周期函数...

el-upload如何自定展示上传的文件
Element UI 中,el-upload 组件支持通过插槽(slot)来自定义文件列表的展示方式。这通常是通过 file-list 插槽来实现的。下面是一个使用 el-upload 组件并通过 file-list 插槽来自定义文件列表展示的完整示例代码。 在这个示例中,…...

研1日记15
1. 文心一言生成: 在PyTorch中,nn.AdaptiveAvgPool1d(1)是一个一维自适应平均池化层。这个层的作用是将输入的特征图(或称为张量)在一维上进行自适应平均池化,使得输出特征图的大小在指定的维度上变为1。这意味着&…...

基于Nginx搭建点播直播服务器
实现直播和点播离不开服务器⽀持,可以使用开源的NGINX服务器搭建直播和点播服务。 当然,NGINX本身是不⽀持视频的,需要为NGINX增加相应的RTMP模块进行支持。 1、下载nginx和rtmp模块 # nginx wget ht tp://nginx.org/download/nginx-1.18.…...

QT LineEdit显示模式
QT LineEdit显示模式 QLineEdit 显示模式: Normal 普通模式 NoEcho 不回写,即输入内容是有的,但是显示不出来,就是不在 QLineEdit 输入框中显示,但是触发例如 textChanged 信号会将所输入的文字写出来 Password 显示密码 Pa…...

IT技术在数字化转型中的关键作用
IT技术在数字化转型中的关键作用 在当今数字化浪潮中,IT技术无疑扮演着核心角色。无论是企业的数字化转型,还是政府公共服务的智能化提升,信息技术都在推动着整个社会向更高效、更智能的方向发展。本文将探讨IT技术在数字化转型中的关键作用…...

【C++指南】C++中nullptr的深入解析
💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:《C指南》 期待您的关注 目录 引言 一、nullptr的引入背景 二、nullptr的特点 1.类型安全 2.明确的空指针表示 3.函数重载支…...

解决启动docker desktop报The network name cannot be found的问题
现象 deploying WSL2 distributions ensuring main distro is deployed: checking if main distro is up to date: checking main distro bootstrap version: getting main distro bootstrap version: open \wsl$\docker-desktop\etc\wsl_bootstrap_version: The network name…...

Guava: 探索 Google 的 Java 核心库
Guava 是 Google 开发的一套 Java 核心库,它提供了一系列新的集合类型(例如多映射 multimap 和多集合 multiset)、不可变集合、图形库以及用于并发、I/O、哈希、原始类型、字符串等的实用工具。Guava 在 Google 的大多数 Java 项目中得到了广…...

Qt-qmake概述
概述 qmake工具为您提供了一个面向项目的系统,用于管理应用程序、库和其他组件的构建过程。这种方法使您能够控制使用的源文件,并允许简洁地描述过程中的每个步骤,通常在单个文件中。qmake将每个项目文件中的信息扩展为一个Makefile…...

【protobuf】ProtoBuf的学习与使用⸺C++
W...Y的主页 😊 代码仓库分享💕 前言:之前我们学习了Linux与windows的protobuf安装,知道protobuf是做序列化操作的应用,今天我们来学习一下protobuf。 目录 ⼀、初识ProtoBuf 步骤1:创建.proto文件 步…...

【iOS】MVC架构模式
文章目录 前言MVC架构模式基本概念通信方式简单应用 总结 前言 “MVC”,即Model(模型),View(视图),Controller(控制器),MVC模式是架构模式的一种。 关于“架构模式”&a…...

ML 系列:机器学习和深度学习的深层次总结(08)—欠拟合、过拟合,正确拟合
ML 系列赛:第 9 天 — Under、Over 和 Good Fit 文章目录 一、说明二、了解欠拟合、过拟合和实现正确的平衡三、关于泛化四、欠拟合五、过拟合六、适度拟合七、结论 一、说明 在有监督学习过程中,对于指定数据集进行训练,训练结果存在欠拟合…...

Unity-物理系统-刚体加力
一 刚体自带添加力的方法 给刚体加力的目标就是 让其有一个速度 朝向某一个方向移动 1.首先应该获取刚体组件 rigidBody this.GetComponent<Rigidbody>(); 2.添加力 //相对世界坐标 //世界坐标系 Z轴正方向加了一个里 //加力过后 对象是否停止…...