当前位置: 首页 > news >正文

对话模型Demo解读(使用代码解读原理)

文章目录

  • 前言
  • 一、数据加工
  • 二、模型搭建
  • 三、模型训练
    • 1、构建模型
    • 2、优化器与损失函数定义
    • 3、模型训练
  • 四、模型推理
  • 五、所有Demo源码


前言

对话模型是一种人工智能技术,旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术,能够理解自然语言并做出相应的回应。然而现有博客很少介绍对话模型内容,也很少用一个简单代码带领大家理解其原理。因此,我创建一个简单的对话模型,在不适用Hugging Face或LSTM结构,旨在使用一个简单的全连接神经网络来实现这个模型,且代码基于PyTorch框架搭建,意在帮助读者构建对话模型知识。当然,模型仅是一个简单模型,旨在帮助理解原理,不具备很好效果能力。


一、数据加工

文本数据最终都是转为对应字典索引代表其文本内容,输入模型加工,实现nlp任务,对话模型也不列外。因此,我们需要构建一个字典映射(可参考:点击这里)与文本数据,并按照字典映射转换为对应索引id,其代码如下:

# 定义一个简单的对话数据集
data = [("hi", "hello"),("how are you?", "I'm fine, thank you."),("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]

其字典内容如下:
在这里插入图片描述

二、模型搭建

这里,也是最重要内容,如何搭建对话模型,我是使用transformer结构搭建(之前模型使用LSTM模型搭建),创建一个简单的对话生成模型,也使用一个基于全连接层的神经网络来实现这个模型,其代码如下:


# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output

从代码上看,我们需要输入src为前面提问数据,而答案生成输入,是每个tgt字输入,按照顺序输出预测。

三、模型训练

1、构建模型

我大概试了一下,使用更多层对于简单数据反而效果不佳,我的数据又比较简单,我构建了较少的层来预测模型。其模型构建代码如下:

# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)

2、优化器与损失函数定义

优化器定义我将不考虑介绍,只说下文本预测是交叉熵方式,实际文本基本都采用该方法作为loss计算。其代码如下:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

3、模型训练

接下来,我们解读如何训练对话模型,我们获得对应输入数据与生成预测数据,我们开始训练模型,其代码如下:

# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))

假设input_seq =tensor([17, 6, 0, 15, 1, 8, 9, 15, 7, 6, 11, 4]),target_seq=tensor([20, 12, 18, 13, 15, 5, 16, 10, 9, 2, 15, 3, 17, 1, 10, 19, 15, 7, 6, 11, 14, 21]),vocab为字典映射,共22个映射。我将试图模型连续2轮解释一下,训练时候相关变化。
第二轮输出结果:
在这里插入图片描述
第三轮输出结果:
在这里插入图片描述
后面以此类推迭代。

很明显,提问每次都是全部输入,而输出则是第一个20开始与输入共同进模型,分别是模型src与tgt,不断重复与预测完成训练。而loss计算都是往后取一个target文本,也发现并不会计算20索引""文本。

四、模型推理

最后,让我们使用训练好的模型进行推理,实际和上面训练讲到方法类似,我们开始""文本开始,不断给出生成对应文本,也就是对话内容。

# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("hi")

其结果如下:
在这里插入图片描述


五、所有Demo源码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
# 定义一个简单的对话数据集
data = [# ("hi", "hello"),("how are you?", "I'm fine, thank you."),# ("what's your name?", "I'm a chatbot.")
]# 构建词汇表
vocab = list(set(" ".join([x[0] + " " + x[1] for x in data])))
vocab.append("<SOS>")
vocab.append("<EOS>")word_to_idx = {word: i for i, word in enumerate(vocab)}
idx_to_word = {i: word for i, word in enumerate(vocab)}# 将对话数据集转换为索引序列
def to_idx_seq(sentence):return [word_to_idx[word] for word in sentence]data_x = [to_idx_seq(x[0]) for x in data]
data_y = [to_idx_seq(x[1]) for x in data]data_y =[[word_to_idx["<SOS>"]]+list(x)+[word_to_idx["<EOS>"]] for x in data_y]# 定义一个简单的Transformer生成式对话模型
class ChatbotTransformer(nn.Module):def __init__(self, input_dim, output_dim, nhead, num_encoder_layers, num_decoder_layers):super(ChatbotTransformer, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.embedding = nn.Embedding(len(vocab), input_dim)self.transformer = nn.Transformer(d_model=input_dim, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)self.fc = nn.Linear(input_dim, output_dim)def forward(self, src, tgt):src = self.embedding(src).permute(1, 0, 2)  # 调整输入张量维度tgt = self.embedding(tgt).permute(1, 0, 2)  # 调整输入张量维度output = self.transformer(src, tgt)output = self.fc(output)return output# 创建模型实例
# model = ChatbotTransformer(input_dim=256, output_dim=len(vocab), nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = ChatbotTransformer(input_dim=16, output_dim=len(vocab), nhead=8, num_encoder_layers=1, num_decoder_layers=1)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)# 模型训练
epochs = 800
for epoch in range(epochs):total_loss = 0for i in range(len(data_x)):optimizer.zero_grad()input_seq = torch.tensor(data_x[i])target_seq = torch.tensor(data_y[i])for j in range(len(target_seq)-1):if random.random() < 0.5:j = random.randint(0, len(target_seq)-2)output = model(input_seq.unsqueeze(0), target_seq[:j+1].unsqueeze(0))  # 使用目标序列的前n-1个词预测后n-1个词loss = criterion(output.view(-1, len(vocab)), target_seq[1:j+2].view(-1))  # 计算损失,需要将输出形状转换成二维loss.backward()optimizer.step()total_loss += loss.item()if (epoch+1) % 10 == 0:print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, epochs, total_loss / len(data_x)))# 进行推理
def generate_response(input_sentence):model.eval()input_seq = torch.tensor(to_idx_seq(input_sentence))target_seq = torch.tensor([word_to_idx["<SOS>"]])  # 在开始时使用特殊的起始标记with torch.no_grad():for i in range(20):  # 限制生成的句子长度为20个词output = model(input_seq.unsqueeze(0), target_seq.unsqueeze(0))output_token = output.argmax(2)[-1].item()print(idx_to_word[output_token], end=" ")target_seq = torch.cat((target_seq, torch.tensor([output_token])), dim=0)res = [idx_to_word[int(k)] for k in target_seq]print(res[1:-1])return res[1:-1]# 进行对话生成
generate_response("how are you?")

相关文章:

对话模型Demo解读(使用代码解读原理)

文章目录 前言一、数据加工二、模型搭建三、模型训练1、构建模型2、优化器与损失函数定义3、模型训练 四、模型推理五、所有Demo源码 前言 对话模型是一种人工智能技术&#xff0c;旨在使计算机能够像人类一样进行对话和交流。这种模型通常基于深度学习和自然语言处理技术&…...

Android 自定义BaseFragment

直接上代码&#xff1a; BaseFragment代码&#xff1a; package com.example.custom.fragment;import android.content.Context; import android.os.Bundle; import android.view.LayoutInflater; import android.view.View; import android.view.ViewGroup; import androidx…...

[C#] 如何对列表,字典等进行排序?

对列表进行排序 下面是一个基于C#的列表排序的案例&#xff1a; using System; using System.Collections.Generic;class Program {static void Main(string[] args){// 创建一个列表List<int> numbers new List<int>() { 5, 2, 8, 1, 10 };// 使用Sort方法对列…...

Mac 下载安装Java、maven并配置环境变量

下载Java8 下载地址&#xff1a;https://www.oracle.com/java/technologies/downloads/ 根据操作系统选择版本 没有oracle账号需要注册、激活登录 mac直接选择.dmg文件进行下载&#xff0c;下载后安装。 默认安装路径&#xff1a;/Library/Java/JavaVirtualMachines/jdk-1…...

【多模态】27、Vary | 通过扩充图像词汇来提升多模态模型在细粒度感知任务(OCR等)上的效果

文章目录 一、背景二、方法2.1 生成 new vision vocabulary2.1.1 new vocabulary network2.1.2 Data engine in the generating phrase2.1.3 输入的格式 2.2 扩大 vision vocabulary2.2.1 Vary-base 的结构2.2.2 Data engine2.2.3 对话格式 三、效果3.1 数据集3.2 图像细粒度感…...

|Python新手小白低级教程|第二十章:函数(2)【包括石头剪刀布判断程序(模拟版)】

文章目录 前言一、复习一、函数实战之——if语句特殊系统1.判断等第分数&#xff08;函数名为mark&#xff08;参数num&#xff09;&#xff09;2.石头剪刀布判断程序 二、练习总结 前言 Hello&#xff0c;大家好&#xff0c;我是你们的BoBo仔&#xff0c;感谢你们来阅读我的文…...

vue3 之 商城项目—home

home—整体结构搭建 根据上面五个模块建目录图如下&#xff1a; home/index.vue <script setup> import HomeCategory from ./components/HomeCategory.vue import HomeBanner from ./components/HomeBanner.vue import HomeNew from ./components/HomeNew.vue import…...

git flow与分支管理

git flow与分支管理 一、git flow是什么二、分支管理1、主分支Master2、开发分支Develop3、临时性分支功能分支预发布分支修补bug分支 三、分支管理最佳实践1、分支名义规划2、环境与分支3、分支图 四、git flow缺点 一、git flow是什么 Git 作为一个源码管理系统&#xff0c;…...

【Linux】学习-进程信号

进程信号 信号入门 生活角度的信号 你在网上买了很多件商品,再等待不同商品快递的到来。但即便快递没有到来,你也知道快递来临时,你该怎么处理快递。也就是你能“识别快递”,也就是你意识里是知道如果这时候快递员送来了你的包裹,你知道该如何处理这些包裹当快递员到了你…...

webgis后端安卓系统部署攻略

目录 前言 一、将后端项目编译ARM64 二、安卓手机安装termux 1.更换为国内源 2.安装ssh远程访问 3.安装文件远程访问 三、安装postgis数据库 1.安装数据库 2.数据库配置 3.数据导入 四、后端项目部署 五、自启动设置 总结 前言 因为之前一直做的H5APP开发&#xf…...

【数据分享】1929-2023年全球站点的逐日平均风速数据(Shp\Excel\免费获取)

气象数据是在各项研究中都经常使用的数据&#xff0c;气象指标包括气温、风速、降水、能见度等指标&#xff0c;说到气象数据&#xff0c;最详细的气象数据是具体到气象监测站点的数据&#xff01; 有关气象指标的监测站点数据&#xff0c;之前我们分享过1929-2023年全球气象站…...

【多模态大模型】视觉大模型SAM:如何使模型能够处理任意图像的分割任务?

SAM&#xff1a;如何使模型能够处理任意图像的分割任务&#xff1f; 核心思想起始问题: 如何使模型能够处理任意图像的分割任务&#xff1f;5why分析5so分析 总结子问题1: 如何编码输入图像以适应分割任务&#xff1f;子问题2: 如何处理各种形式的分割提示&#xff1f;子问题3:…...

Shell之sed

sed是什么 Linux sed 命令是利用脚本来处理文本文件。 可依照脚本的指令来处理、编辑文本文件。主要用来自动编辑一个或多个文件、简化对文件的反复操作、编写转换程序等。 sed命令详解 语法 sed [-hnV][-e <script>][-f<script文件>][文本文件] sed [-nefr] [动作…...

AJAX——认识URL

1 什么是URL&#xff1f; 统一资源定位符&#xff08;英语&#xff1a;Uniform Resource Locator&#xff0c;缩写&#xff1a;URL&#xff0c;或称统一资源定位器、定位地址、URL地址&#xff09;俗称网页地址&#xff0c;简称网址&#xff0c;是因特网上标准的资源的地址&…...

《Docker极简教程》--Docker环境的搭建--在Linux上搭建Docker环境

更新系统&#xff1a;首先确保所有的包管理器都是最新的。对于基于Debian的系统&#xff08;如Ubuntu&#xff09;&#xff0c;可以使用以下命令&#xff1a;sudo apt-get update sudo apt-get upgrade安装必要的依赖项&#xff1a;安装一些必要的工具&#xff0c;比如ca-certi…...

开源微服务平台框架的特点是什么?

借助什么平台的力量&#xff0c;可以让企业实现高效率的流程化办公&#xff1f;低代码技术平台是近些年来较为流行的平台产品&#xff0c;可以帮助很多行业进入流程化办公新时代&#xff0c;做好数据管理工作&#xff0c;从而提升企业市场竞争力。流辰信息专业研发低代码技术平…...

C#系列-C#操作UDP发送接收数据(10)

在C#中&#xff0c;发送UDP数据并接收响应通常涉及创建两个UdpClient实例&#xff1a;一个用于发送数据&#xff0c;另一个用于接收响应。以下是发送UDP数据并接收响应的示例代码&#xff1a; 首先&#xff0c;我们需要定义一个方法来发送UDP数据&#xff0c;并等待接收服务器…...

突破编程_C++_面试(基础知识(10))

面试题29&#xff1a;什么是嵌套类&#xff0c;它有什么作用 嵌套类指的是在一个类的内部定义的另一个类。嵌套类可以作为外部类的一个成员&#xff0c;但它与其声明类型紧密关联&#xff0c;不应被用作通用类型。嵌套类可以访问外部类的所有成员&#xff0c;包括私有成员&…...

初步探索Pyglet库:打造轻量级多媒体与游戏开发利器

目录 pyglet库 功能特点 安装和导入 安装 导入 基本代码框架 导入模块 创建窗口 创建控件 定义事件 运行应用 程序界面 运行结果 完整代码 标签控件 常用事件 窗口事件 鼠标事件 键盘事件 文本事件 其它场景 网页标签 音乐播放 图片显示 祝大家新…...

【npm】安装全局包,使用时提示:不是内部或外部命令,也不是可运行的程序或批处理文件

问题 如图&#xff0c;明明安装Vue是全局包&#xff0c;但是使用时却提示&#xff1a; 解决办法 使用以下命令任意一种命令查看全局包的配置路径 npm root -g 然后将此路径&#xff08;不包括node_modules&#xff09;添加到环境变量中去&#xff0c;这里注意&#xff0c;原…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

【JVM】- 内存结构

引言 JVM&#xff1a;Java Virtual Machine 定义&#xff1a;Java虚拟机&#xff0c;Java二进制字节码的运行环境好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收的功能数组下标越界检查&#xff08;会抛异常&#xff0c;不会覆盖到其他代码…...

《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》

在注意力分散、内容高度同质化的时代&#xff0c;情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现&#xff0c;消费者对内容的“有感”程度&#xff0c;正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中&#xff0…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

提升移动端网页调试效率:WebDebugX 与常见工具组合实践

在日常移动端开发中&#xff0c;网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时&#xff0c;开发者迫切需要一套高效、可靠且跨平台的调试方案。过去&#xff0c;我们或多或少使用过 Chrome DevTools、Remote Debug…...

【C++】纯虚函数类外可以写实现吗?

1. 答案 先说答案&#xff0c;可以。 2.代码测试 .h头文件 #include <iostream> #include <string>// 抽象基类 class AbstractBase { public:AbstractBase() default;virtual ~AbstractBase() default; // 默认析构函数public:virtual int PureVirtualFunct…...

区块链技术概述

区块链技术是一种去中心化、分布式账本技术&#xff0c;通过密码学、共识机制和智能合约等核心组件&#xff0c;实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点&#xff1a;数据存储在网络中的多个节点&#xff08;计算机&#xff09;&#xff0c;而非…...

基于鸿蒙(HarmonyOS5)的打车小程序

1. 开发环境准备 安装DevEco Studio (鸿蒙官方IDE)配置HarmonyOS SDK申请开发者账号和必要的API密钥 2. 项目结构设计 ├── entry │ ├── src │ │ ├── main │ │ │ ├── ets │ │ │ │ ├── pages │ │ │ │ │ ├── H…...