智能聊天机器人:使用PyTorch构建多轮对话系统
使用PyTorch构建多轮对话系统的示例代码。这个示例项目包括一个简单的Seq2Seq模型用于对话生成,并使用GRU作为RNN的变体。以下是代码的主要部分,包括数据预处理、模型定义和训练循环。
数据预处理
首先,准备数据并进行预处理。这部分代码假定你有一个对话数据集,格式为成对的问答句子。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random# 假设数据集是一个成对的问答列表
pairs = [["Hi, how are you?", "I'm good, thank you! How about you?"],["What is your name?", "My name is Chatbot."],# 添加更多对话数据
]# 简单的词汇表和索引映射
word2index = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
index2word = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
vocab_size = len(word2index)def tokenize(sentence):return sentence.lower().split()def build_vocab(pairs):global word2index, index2word, vocab_sizefor pair in pairs:for sentence in pair:for word in tokenize(sentence):if word not in word2index:word2index[word] = vocab_sizeindex2word[vocab_size] = wordvocab_size += 1def sentence_to_tensor(sentence):tokens = tokenize(sentence)indices = [word2index.get(word, word2index["<UNK>"]) for word in tokens]return torch.tensor(indices + [word2index["<EOS>"]], dtype=torch.long)build_vocab(pairs)
数据集和数据加载
定义一个Dataset类和DataLoader来加载数据。
class ChatDataset(Dataset):def __init__(self, pairs):self.pairs = pairsdef __len__(self):return len(self.pairs)def __getitem__(self, idx):input_tensor = sentence_to_tensor(self.pairs[idx][0])target_tensor = sentence_to_tensor(self.pairs[idx][1])return input_tensor, target_tensordef collate_fn(batch):inputs, targets = zip(*batch)input_lengths = [len(seq) for seq in inputs]target_lengths = [len(seq) for seq in targets]inputs = nn.utils.rnn.pad_sequence(inputs, padding_value=word2index["<PAD>"])targets = nn.utils.rnn.pad_sequence(targets, padding_value=word2index["<PAD>"])return inputs, targets, input_lengths, target_lengthsdataset = ChatDataset(pairs)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)
模型定义
定义一个简单的Seq2Seq模型,包括编码器和解码器。
class Encoder(nn.Module):def __init__(self, input_size, hidden_size, num_layers=1):super(Encoder, self).__init__()self.embedding = nn.Embedding(input_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, enforce_sorted=False)outputs, hidden = self.gru(packed, hidden)outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)return outputs, hiddenclass Decoder(nn.Module):def __init__(self, output_size, hidden_size, num_layers=1):super(Decoder, self).__init__()self.embedding = nn.Embedding(output_size, hidden_size)self.gru = nn.GRU(hidden_size, hidden_size, num_layers)self.out = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input_step, hidden, encoder_outputs):embedded = self.embedding(input_step)gru_output, hidden = self.gru(embedded, hidden)output = self.softmax(self.out(gru_output.squeeze(0)))return output, hiddenclass Seq2Seq(nn.Module):def __init__(self, encoder, decoder, device):super(Seq2Seq, self).__init__()self.encoder = encoderself.decoder = decoderself.device = devicedef forward(self, input_tensor, target_tensor, input_lengths, target_lengths, teacher_forcing_ratio=0.5):batch_size = input_tensor.size(1)max_target_len = max(target_lengths)vocab_size = self.decoder.out.out_featuresoutputs = torch.zeros(max_target_len, batch_size, vocab_size).to(self.device)encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)decoder_input = torch.tensor([[word2index["<SOS>"]] * batch_size]).to(self.device)decoder_hidden = encoder_hiddenfor t in range(max_target_len):decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)outputs[t] = decoder_outputtop1 = decoder_output.argmax(1)decoder_input = target_tensor[t].unsqueeze(0) if random.random() < teacher_forcing_ratio else top1.unsqueeze(0)return outputsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(vocab_size, hidden_size=256).to(device)
decoder = Decoder(vocab_size, hidden_size=256).to(device)
model = Seq2Seq(encoder, decoder, device).to(device)
训练循环
定义训练循环并进行模型训练。
def train(model, dataloader, num_epochs, learning_rate=0.001):criterion = nn.CrossEntropyLoss(ignore_index=word2index["<PAD>"])optimizer = optim.Adam(model.parameters(), lr=learning_rate)for epoch in range(num_epochs):model.train()total_loss = 0for inputs, targets, input_lengths, target_lengths in dataloader:inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs, targets, input_lengths, target_lengths)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader)}")train(model, dataloader, num_epochs=10)
测试与推理
定义一个简单的推理函数来进行对话生成。
def evaluate(model, sentence, max_length=10):model.eval()with torch.no_grad():input_tensor = sentence_to_tensor(sentence).unsqueeze(1).to(device)input_length = [input_tensor.size(0)]encoder_outputs, encoder_hidden = model.encoder(input_tensor, input_length)decoder_input = torch.tensor([[word2index["<SOS>"]]]).to(device)decoder_hidden = encoder_hiddendecoded_words = []for _ in range(max_length):decoder_output, decoder_hidden = model.decoder(decoder_input, decoder_hidden, encoder_outputs)top1 = decoder_output.argmax(1).item()if top1 == word2index["<EOS>"]:breakelse:decoded_words.append(index2word[top1])decoder_input = torch.tensor([[top1]]).to(device)return ' '.join(decoded_words)print(evaluate(model, "Hi, how are you?"))
总结
这只是一个简单的示例,用于展示如何使用PyTorch构建一个基本的多轮对话系统。实际应用中,可能需要更多的数据预处理、更复杂的模型(如Transformer)、更细致的训练策略和优化技术,以及更丰富的对话数据集。希望这个示例对你有所帮助!
相关文章:
智能聊天机器人:使用PyTorch构建多轮对话系统
使用PyTorch构建多轮对话系统的示例代码。这个示例项目包括一个简单的Seq2Seq模型用于对话生成,并使用GRU作为RNN的变体。以下是代码的主要部分,包括数据预处理、模型定义和训练循环。 数据预处理 首先,准备数据并进行预处理。这部分代码假…...

昇思25天学习打卡营第16天 | 文本解码原理-以MindNLP为例
基于 MindSpore 实现 BERT 对话情绪识别 上几章我们学习过了基于MindSpore来实现计算机视觉的一些应用,那么从这期开始要开始一个新的领域——LLM 首先了解一下什么是LLM LLM 是 “大型语言模型”(Large Language Model)的缩写。LLM 是一种…...

Unity之Text组件换行\n没有实现+动态中英互换
前因:文本中的换行 \n没有换行而是打印出来了,解决方式 因为unity会默认把\n替换成\\n 面板中使用富文本这个选项啊 没有用 m_text.text m_text.text.Replace("\\n", "\n"); ###动态中英文互译 using System.Collections; using…...

vue3+ el-tree 展开和折叠,默认展开第一项
默认第一项展开: 展开所有项: 折叠所有项: <template><el-treestyle"max-width: 600px":data"treeData"node-key"id":default-expanded-keys"defaultExpandedKey":props"defaultProps"…...

ProFormList --复杂数据联动ProFormDependency
需求: (1)数据联动:测试数据1、2互相依赖,测试数据1<测试数据2,测试数据2>测试数据1。 (2)点击添加按钮,添加一行。 (3)自定义操作按钮。 ࿰…...
Git、Github、tortoiseGit下载安装调试全套教程
一、Git 1.下载安装Git 编辑器可默认Vim,可换成别的,此处换成VScode,换成VScode或别的都需要单独下载和调用 (1)Git安装:https://www.cnblogs.com/xiuxingzhe/p/9300905.html 超级完整的 Git的下载、安…...

老师怎么快速发布成绩?
期末考试的钟声刚刚敲响,成绩单的发放却成了老师们的一大难题。每当期末成绩揭晓,老师们便要开始一项繁琐的任务——将每一份成绩单逐一私信给家长。这不仅耗费了大量的时间和精力,也让本就忙碌的期末工作变得更加繁重。然而,随着…...

央视揭露:上百元的AI填报高考志愿真的靠谱吗?阿里云新增两位AI圈“代言人”!|AI日报
文章推荐 MiniMax闫俊杰:国内模型远不及GPT-4;OpenAI隐瞒黑客曾入侵其内部系统|AI日报 今日热点 月之暗面、智联招聘成为阿里云新“代言人”,使用阿里云强大算力和大模型服务平台提升模型推理效率 7月8日,阿里云官…...

TPM管理咨询公司甄选指南
在竞争激烈的市场环境中,TPM(全面生产维护)管理咨询公司的重要性日益凸显。然而,如何在众多咨询公司中筛选出最适合自己企业的合作伙伴,成为了许多企业决策者面临的难题。本文将从专业度、行业经验、服务质量和性价比等…...
探索 Scikit-Learn:机器学习的强大工具库
Scikit-Learn 探索 Scikit-Learn:机器学习的强大工具库主要功能模块分类(Classification)回归(Regression)聚类(Clustering)降维(Dimensionality Reduction)模型选择&…...

音视频质量评判标准
一、实时通信延时指标 通过图中表格可以看到,如果端到端延迟在200ms以内,说明整个通话是优质的,通话效果就像大家在同一个房间里聊天一样;300ms以内,大多数人很满意,400ms以内,有小部分人可以感…...

如何在vue3中使用scss
一 要使用scss首先需要下载相关的包 可以在终端使用下面的命令下载相关包 npm install -D sass 二 在src文件下新建一个文件夹叫做styles 在文件夹下创建三个文件 index.scss主要用来引用其他文件 reset.scss用来清除默认的样式 variable.scss用来配置全局属性 三 需要在v…...

Gartner发布采用美国防部模型实施零信任的方法指南:七大支柱落地方法
零信任是网络安全计划的关键要素,但制定策略可能会很困难。安全和风险管理领导者应使用美国国防部模型的七大支柱以及 Gartner 研究来设计零信任策略。 战略规划假设 到 2026 年,10% 的大型企业将拥有全面、成熟且可衡量的零信任计划,而 202…...

Flutter——最详细(Badge)使用教程
背景 主要常用于组件叠加上圆点提示; 使用场景,消息数量提示,消息红点提示 属性作用backgroundColor红点背景色smallSize设置红点大小isLabelVisible是否显示offset设置红点位置alignment设置红点位置child设置底部组件 代码块 class Badge…...

SQLServer的系统数据库用别的服务器上的系统数据库替换后做跨服务器连接时出现凭证、非对称金钥或私密金钥的资料无效
出错作业背景: 公司的某个sqlserver服务器要做迁移,由于该sqlserver服务器上数据库很多,并且做了很多的job和维护计划,重新安装的sqlserver这些都是空的,于是就想到了把系统4个系统数据库进行替换,然后也把…...
vue前端面试
一 .v-if和v-show的区别 v-if 和 v-show 是 Vue.js 中两个常用的条件渲染指令,它们都可以根据条件决定是否渲染某个元素。但是它们之间存在一些区别。 语法:v-if 和 v-show 的语法相同,都接收一个布尔值作为参数。 <div v-if"show…...

【网络安全】Host碰撞漏洞原理+工具+脚本
文章目录 漏洞原理虚拟主机配置Host头部字段Host碰撞漏洞漏洞场景工具漏洞原理 Host 碰撞漏洞,也称为主机名冲突漏洞,是一种网络攻击手段。常见危害有:绕过访问控制,通过公网访问一些未经授权的资源等。 虚拟主机配置 在Web服务器(如Nginx或Apache)上,多个网站可以共…...
unattended-upgrade进程介绍
unattended-upgrade 是一个用于自动更新 Debian 和 Ubuntu 系统的软件包。这个进程通常用于定期下载并安装安全更新,以保持系统的安全性和稳定性。 具体来说,这个命令 /usr/bin/python3 /usr/bin/unattended-upgrade --download-only 表示运行 unattend…...

SpringBoot 中多例模式的神秘世界:用法区别以及应用场景,最后的灵魂拷问会吗?- 第519篇
历史文章(文章累计500) 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 《…...
基于STM32设计的智能婴儿床(ESP8266局域网)_2024升级版_180
基于STM32设计的智能婴儿床(采用STM32F103C8T6)(180) 文章目录 一、设计需求【1】项目功能介绍【2】程序最终的运行逻辑【3】硬件模块组成【4】ESP8266模块配置【5】上位机开发思路【6】系统功能模块划分1.2 项目开发背景1.3 开发工具的选择1.4 系统框架图1.5 系统原理图1.6 硬…...
Cursor实现用excel数据填充word模版的方法
cursor主页:https://www.cursor.com/ 任务目标:把excel格式的数据里的单元格,按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例,…...

label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件
在选煤厂、化工厂、钢铁厂等过程生产型企业,其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进,需提前预防假检、错检、漏检,推动智慧生产运维系统数据的流动和现场赋能应用。同时,…...

Linux-07 ubuntu 的 chrome 启动不了
文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了,报错如下四、启动不了,解决如下 总结 问题原因 在应用中可以看到chrome,但是打不开(说明:原来的ubuntu系统出问题了,这个是备用的硬盘&a…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...
uniapp中使用aixos 报错
问题: 在uniapp中使用aixos,运行后报如下错误: AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...

企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...
Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?
Redis 的发布订阅(Pub/Sub)模式与专业的 MQ(Message Queue)如 Kafka、RabbitMQ 进行比较,核心的权衡点在于:简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)
本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...

c++第七天 继承与派生2
这一篇文章主要内容是 派生类构造函数与析构函数 在派生类中重写基类成员 以及多继承 第一部分:派生类构造函数与析构函数 当创建一个派生类对象时,基类成员是如何初始化的? 1.当派生类对象创建的时候,基类成员的初始化顺序 …...