当前位置: 首页 > news >正文

对话模型Demo解读(使用代码解读原理)

文章目录

  • 前言
  • 一、数据加工
  • 二、模型搭建
  • 三、模型训练
    • 1、构建模型
    • 2、优化器与损失函数定义
    • 3、模型训练
  • 四、模型推理
  • 五、所有Demo源码


前言

对话模型是一种人工智能技术,旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术,能够理解自然语言并做出相应的回应。然而现有博客很少介绍对话模型内容,也很少用一个简单代码带领大家理解其原理。因此,我创建一个简单的对话模型,在不适用Hugging Face或LSTM结构,旨在使用一个简单的全连接神经网络来实现这个模型,且代码基于PyTorch框架搭建,意在帮助读者构建对话模型知识。当然,模型仅是一个简单模型,旨在帮助理解原理,不具备很好效果能力。


一、数据加工

文本数据最终都是转为对应字典索引代表其文本内容,输入模型加工,实现nlp任务,对话模型也不列外。因此,我们需要构建一个字典映射(可参考:点击这里)与文本数据,并按照字典映射转换为对应索引id,其代码如下:

# 定义一个简单的对话数据集
data = [("hi", "hello"),("how are you?", "I'm fine, thank you."),("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]

其字典内容如下:
在这里插入图片描述

二、模型搭建

这里,也是最重要内容,如何搭建对话模型,我是使用transformer结构搭建(之前模型使用LSTM模型搭建),创建一个简单的对话生成模型,也使用一个基于全连接层的神经网络来实现这个模型,其代码如下:


# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output

从代码上看,我们需要输入src为前面提问数据,而答案生成输入,是每个tgt字输入,按照顺序输出预测。

三、模型训练

1、构建模型

我大概试了一下,使用更多层对于简单数据反而效果不佳,我的数据又比较简单,我构建了较少的层来预测模型。其模型构建代码如下:

# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)

2、优化器与损失函数定义

优化器定义我将不考虑介绍,只说下文本预测是交叉熵方式,实际文本基本都采用该方法作为loss计算。其代码如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

3、模型训练

接下来,我们解读如何训练对话模型,我们获得对应输入数据与生成预测数据,我们开始训练模型,其代码如下:

# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))

假设input_seq =tensor([17, 6, 0, 15, 1, 8, 9, 15, 7, 6, 11, 4]),target_seq=tensor([20, 12, 18, 13, 15, 5, 16, 10, 9, 2, 15, 3, 17, 1, 10, 19, 15, 7, 6, 11, 14, 21]),vocab为字典映射,共22个映射。我将试图模型连续2轮解释一下,训练时候相关变化。
第二轮输出结果:
在这里插入图片描述
第三轮输出结果:
在这里插入图片描述
后面以此类推迭代。

很明显,提问每次都是全部输入,而输出则是第一个20开始与输入共同进模型,分别是模型src与tgt,不断重复与预测完成训练。而loss计算都是往后取一个target文本,也发现并不会计算20索引""文本。

四、模型推理

最后,让我们使用训练好的模型进行推理,实际和上面训练讲到方法类似,我们开始""文本开始,不断给出生成对应文本,也就是对话内容。

# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("hi")

其结果如下:
在这里插入图片描述


五、所有Demo源码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 定义一个简单的对话数据集
data = [# ("hi", "hello"),("how are you?", "I'm fine, thank you."),# ("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("how are you?")

相关文章:

对话模型Demo解读(使用代码解读原理)

文章目录 前言一、数据加工二、模型搭建三、模型训练1、构建模型2、优化器与损失函数定义3、模型训练 四、模型推理五、所有Demo源码 前言 对话模型是一种人工智能技术&#xff0c;旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术&…...

Android 自定义BaseFragment

直接上代码&#xff1a; BaseFragment代码&#xff1a; package com.example.custom.fragment;import android.content.Context; import android.os.Bundle; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import androidx…...

[C#] 如何对列表,字典等进行排序?

对列表进行排序 下面是一个基于C#的列表排序的案例&#xff1a; using System; using System.Collections.Generic;class Program {static void Main(string[] args){// 创建一个列表List<int> numbers new List<int>() { 5, 2, 8, 1, 10 };// 使用Sort方法对列…...

Mac 下载安装Java、maven并配置环境变量

下载Java8 下载地址&#xff1a;https://www.oracle.com/java/technologies/downloads/ 根据操作系统选择版本 没有oracle账号需要注册、激活登录 mac直接选择.dmg文件进行下载&#xff0c;下载后安装。 默认安装路径&#xff1a;/Library/Java/JavaVirtualMachines/jdk-1…...

【多模态】27、Vary | 通过扩充图像词汇来提升多模态模型在细粒度感知任务(OCR等)上的效果

文章目录 一、背景二、方法2.1 生成 new vision vocabulary2.1.1 new vocabulary network2.1.2 Data engine in the generating phrase2.1.3 输入的格式 2.2 扩大 vision vocabulary2.2.1 Vary-base 的结构2.2.2 Data engine2.2.3 对话格式 三、效果3.1 数据集3.2 图像细粒度感…...

|Python新手小白低级教程|第二十章:函数(2)【包括石头剪刀布判断程序(模拟版)】

文章目录 前言一、复习一、函数实战之——if语句特殊系统1.判断等第分数&#xff08;函数名为mark&#xff08;参数num&#xff09;&#xff09;2.石头剪刀布判断程序 二、练习总结 前言 Hello&#xff0c;大家好&#xff0c;我是你们的BoBo仔&#xff0c;感谢你们来阅读我的文…...

vue3 之 商城项目—home

home—整体结构搭建 根据上面五个模块建目录图如下&#xff1a; home/index.vue <script setup> import HomeCategory from ./components/HomeCategory.vue import HomeBanner from ./components/HomeBanner.vue import HomeNew from ./components/HomeNew.vue import…...

git flow与分支管理

git flow与分支管理 一、git flow是什么二、分支管理1、主分支Master2、开发分支Develop3、临时性分支功能分支预发布分支修补bug分支 三、分支管理最佳实践1、分支名义规划2、环境与分支3、分支图 四、git flow缺点 一、git flow是什么 Git 作为一个源码管理系统&#xff0c;…...

【Linux】学习-进程信号

进程信号 信号入门 生活角度的信号 你在网上买了很多件商品,再等待不同商品快递的到来。但即便快递没有到来,你也知道快递来临时,你该怎么处理快递。也就是你能“识别快递”,也就是你意识里是知道如果这时候快递员送来了你的包裹,你知道该如何处理这些包裹当快递员到了你…...

webgis后端安卓系统部署攻略

目录 前言 一、将后端项目编译ARM64 二、安卓手机安装termux 1.更换为国内源 2.安装ssh远程访问 3.安装文件远程访问 三、安装postgis数据库 1.安装数据库 2.数据库配置 3.数据导入 四、后端项目部署 五、自启动设置 总结 前言 因为之前一直做的H5APP开发&#xf…...

【数据分享】1929-2023年全球站点的逐日平均风速数据(Shp\Excel\免费获取)

气象数据是在各项研究中都经常使用的数据&#xff0c;气象指标包括气温、风速、降水、能见度等指标&#xff0c;说到气象数据&#xff0c;最详细的气象数据是具体到气象监测站点的数据&#xff01; 有关气象指标的监测站点数据&#xff0c;之前我们分享过1929-2023年全球气象站…...

【多模态大模型】视觉大模型SAM:如何使模型能够处理任意图像的分割任务?

SAM&#xff1a;如何使模型能够处理任意图像的分割任务&#xff1f; 核心思想起始问题: 如何使模型能够处理任意图像的分割任务&#xff1f;5why分析5so分析 总结子问题1: 如何编码输入图像以适应分割任务&#xff1f;子问题2: 如何处理各种形式的分割提示&#xff1f;子问题3:…...

Shell之sed

sed是什么 Linux sed 命令是利用脚本来处理文本文件。 可依照脚本的指令来处理、编辑文本文件。主要用来自动编辑一个或多个文件、简化对文件的反复操作、编写转换程序等。 sed命令详解 语法 sed [-hnV][-e <script>][-f<script文件>][文本文件] sed [-nefr] [动作…...

AJAX——认识URL

1 什么是URL&#xff1f; 统一资源定位符&#xff08;英语&#xff1a;Uniform Resource Locator&#xff0c;缩写&#xff1a;URL&#xff0c;或称统一资源定位器、定位地址、URL地址&#xff09;俗称网页地址&#xff0c;简称网址&#xff0c;是因特网上标准的资源的地址&…...

《Docker极简教程》--Docker环境的搭建--在Linux上搭建Docker环境

更新系统&#xff1a;首先确保所有的包管理器都是最新的。对于基于Debian的系统&#xff08;如Ubuntu&#xff09;&#xff0c;可以使用以下命令&#xff1a;sudo apt-get update sudo apt-get upgrade安装必要的依赖项&#xff1a;安装一些必要的工具&#xff0c;比如ca-certi…...

开源微服务平台框架的特点是什么?

借助什么平台的力量&#xff0c;可以让企业实现高效率的流程化办公&#xff1f;低代码技术平台是近些年来较为流行的平台产品&#xff0c;可以帮助很多行业进入流程化办公新时代&#xff0c;做好数据管理工作&#xff0c;从而提升企业市场竞争力。流辰信息专业研发低代码技术平…...

C#系列-C#操作UDP发送接收数据(10)

在C#中&#xff0c;发送UDP数据并接收响应通常涉及创建两个UdpClient实例&#xff1a;一个用于发送数据&#xff0c;另一个用于接收响应。以下是发送UDP数据并接收响应的示例代码&#xff1a; 首先&#xff0c;我们需要定义一个方法来发送UDP数据&#xff0c;并等待接收服务器…...

突破编程_C++_面试(基础知识(10))

面试题29&#xff1a;什么是嵌套类&#xff0c;它有什么作用 嵌套类指的是在一个类的内部定义的另一个类。嵌套类可以作为外部类的一个成员&#xff0c;但它与其声明类型紧密关联&#xff0c;不应被用作通用类型。嵌套类可以访问外部类的所有成员&#xff0c;包括私有成员&…...

初步探索Pyglet库:打造轻量级多媒体与游戏开发利器

目录 pyglet库 功能特点 安装和导入 安装 导入 基本代码框架 导入模块 创建窗口 创建控件 定义事件 运行应用 程序界面 运行结果 完整代码 标签控件 常用事件 窗口事件 鼠标事件 键盘事件 文本事件 其它场景 网页标签 音乐播放 图片显示 祝大家新…...

【npm】安装全局包,使用时提示:不是内部或外部命令,也不是可运行的程序或批处理文件

问题 如图&#xff0c;明明安装Vue是全局包&#xff0c;但是使用时却提示&#xff1a; 解决办法 使用以下命令任意一种命令查看全局包的配置路径 npm root -g 然后将此路径&#xff08;不包括node_modules&#xff09;添加到环境变量中去&#xff0c;这里注意&#xff0c;原…...

Python爬虫实战:研究MechanicalSoup库相关技术

一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分&#xff1a; 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析&#xff1a; CTR…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

Android15默认授权浮窗权限

我们经常有那种需求&#xff0c;客户需要定制的apk集成在ROM中&#xff0c;并且默认授予其【显示在其他应用的上层】权限&#xff0c;也就是我们常说的浮窗权限&#xff0c;那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

k8s业务程序联调工具-KtConnect

概述 原理 工具作用是建立了一个从本地到集群的单向VPN&#xff0c;根据VPN原理&#xff0c;打通两个内网必然需要借助一个公共中继节点&#xff0c;ktconnect工具巧妙的利用k8s原生的portforward能力&#xff0c;简化了建立连接的过程&#xff0c;apiserver间接起到了中继节…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...