当前位置: 首页 > news >正文

于BERT的中文问答系统12

主要改进点

日志配置:

确保日志文件按日期和时间生成,便于追踪不同运行的记录。
数据处理:

增加了对数据加载过程中错误的捕获和日志记录,确保程序能够跳过无效数据并继续运行。
模型训练:

增加了重新训练模型的功能,用户可以选择重新训练现有模型或从头开始训练。
用户交互:

增加了输入验证,确保用户输入的问题不为空。
增加了模糊匹配功能,支持部分输入问题的匹配。
错误处理:

在关键步骤增加了异常捕获和日志记录,提高了程序的健壮性。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox
import logging
from difflib import SequenceMatcher
from datetime import datetime# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))os.makedirs(os.path.dirname(log_file), exist_ok=True)logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device):model.train()total_loss = 0.0for batch in data_loader:try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 5for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加载训练数据集以便在获取答案时使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))self.create_widgets()def create_widgets(self):self.question_label = tk.Label(self.root, text="问题:")self.question_label.pack()self.question_entry = tk.Entry(self.root, width=50)self.question_entry.pack()self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)self.answer_button.pack()self.answer_label = tk.Label(self.root, text="回答:")self.answer_label.pack()self.answer_text = tk.Text(self.root, height=10, width=50)self.answer_text.pack()self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)self.train_button.pack()self.retrain_button = tk.Button(self.root, text="重新训练模型", command=lambda: self.train_model(retrain=True))self.retrain_button.pack()def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "人类回答"else:answer_type = "ChatGPT回答"specific_answer = self.get_specific_answer(question, answer_type)self.answer_text.delete(1.0, tk.END)self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "人类回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "未找到具体答案"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 5for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")messagebox.showerror("训练失败", f"模型训练失败: {e}")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

相关文章:

于BERT的中文问答系统12

主要改进点 日志配置: 确保日志文件按日期和时间生成,便于追踪不同运行的记录。 数据处理: 增加了对数据加载过程中错误的捕获和日志记录,确保程序能够跳过无效数据并继续运行。 模型训练: 增加了重新训练模型的功…...

基于SpringBoot“花开富贵”花园管理系统【附源码】

效果如下: 系统注册页面 系统首页界面 植物信息详细页面 后台登录界面 管理员主界面 植物分类管理界面 植物信息管理界面 园艺记录管理界面 研究背景 随着城市化进程的加快和人们生活质量的提升,越来越多的人开始追求与自然和谐共生的生活方式&#xf…...

MySQL连接查询:自连接

先看我的表结构 emp表 自连接也就是把一个表看作是两个作用的表就好,也就是说我把emp看作员工表,也看做领导表 自连接 基本语法 select 字段列表 FROM 表A 别名A JOIN 表A 别名B ON 条件;例子1:查询员工 及其 所属领导的名字 select a.n…...

Prometheus+Grafana备忘

Grafana安装 官网 https://grafana.com/grafana/download 官网提供了几种安装方式,我用最简单的 yum install -y https://dl.grafana.com/enterprise/release/grafana-enterprise-11.2.2-1.x86_64.rpm启动 //如果需要在系统启动时自动启动Grafana,可以…...

基于ssm实现的建筑装修图纸管理平台(源码+文档)

项目简介 基于ssm实现的建筑装修图纸管理平台,主要功能如下: 技术栈 后端框框:spring/springmvc/mybatis 前端框架:html/JavaScript/Css/vue/elementui 运行环境:JDK1.8/MySQL5.7/idea(可选&#xff09…...

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07目录1. Evaluation of Large Language Models for Summarization Tasks in the Medical Domain: A Narrative Review摘要研究…...

Mahalanobis distance 马哈拉诺比斯距离

马哈拉诺比斯距离(Mahalanobis Distance)是一种衡量点与分布之间距离的度量,尤其适用于多维数据。与欧几里得距离不同,马哈拉诺比斯距离考虑了数据的协方差结构,因此在统计分析和异常值检测中非常有用。 定义 给定一…...

R语言绘制直方图

直方图是一种统计图表。它将数据分成若干区间,统计每个区间内数据的数量或频率,用矩形条高度表示。能直观展现数据分布特征,如集中趋势、离散程度等。在数据分析、质量控制、市场调研等领域广泛应用,可帮助人们快速了解数据整体形…...

论文阅读笔记-LogME: Practical Assessment of Pre-trained Models for Transfer Learning

前言 在NLP领域,预训练模型(准确的说应该是预训练语言模型)似乎已经成为各大任务必备的模块了,经常有看到文章称后BERT时代或后XXX时代,分析对比了许多主流模型的优缺点,这些相对而言有些停留在理论层面,可是有时候对于手上正在解决的任务,要用到预训练语言模型时,面…...

求二叉树的带权路径长度

二叉树的带权路径长度(WPL)是二叉树中所有叶结点的带权路径长度之和。给定一棵二叉树T,采用二叉链表存储。结点结构为: 其中叶结点的weight域保存该结点的非负权值。设root为指向T的根结点的指针,请设计求T的WPL的算法…...

Hive数仓操作(十五)

Hive 开窗函数 Hive窗口函数是一种特殊的函数,允许用户在查询中对一组行进行计算,而不仅仅是单独的行。窗口函数可以在 SQL 查询中进行聚合、排名、累积计算等。这使得窗口函数在数据分析和报告生成中非常有用。 窗口函数的基本组成部分 函数类型&…...

No.12 笔记 | 网络基础:ARP DNS TCP/IP与OSI模型

一、计算机网络:安全的基石 1. 网络的本质:数字世界的神经系统 定义:计算机的互联互通,实现资源共享和信息交换组成要素:发送者、接收者、介质、数据、协议(五大要素) 2. 网络架构&#xff1…...

OpenHarmony(鸿蒙南向开发)——轻量系统STM32F407芯片移植案例

往期知识点记录: 鸿蒙(HarmonyOS)应用层开发(北向)知识点汇总 鸿蒙(OpenHarmony)南向开发保姆级知识点汇总~ 持续更新中…… 介绍基于STM32F407IGT6芯片在拓维信息 Niobe407 开发板上移植OpenH…...

简单易懂的springboot整合Camunda 7工作流入门教程

简单易懂的Spring Boot整合Camunda7入门教程 因为关于Spring Boot结合Camunda7的教程在网上比较少,而且很多都写得有点乱,很多概念写得太散乱,讲解不清晰,导致看不懂,本人通过研究学习之后就写出了这篇教学文档。 介…...

LabVIEW提高开发效率技巧----点阵图(XY Graph)

在LabVIEW开发中,点阵图(XY Graph) 是一种强大的工具,尤其适用于需要实时展示大量数据的场景。通过使用点阵图,开发人员能够将实时数据可视化,帮助用户更直观地分析数据变化。 1. 点阵图的优势 点阵图&…...

C++-匿名空间

匿名命名空间(anonymous namespace)是 C 中的一种特性,用于将符号(如变量、函数或类)限制在定义它们的源文件的作用域内。这意味着在该源文件外部,这些符号不可见,从而避免了命名冲突。 1. 定义…...

jdk的安装和环境变量配置

1.将从官网下载好的jdk放在自己想要放的位置,这里的位置是:E:\develop 2.新建一个文件夹用来放安装的jdk,将jdk安装的此目录,这里的位置是:E:\develop\jdk17 3.jdk安装好之后,点击jdk17目录,点…...

继承、Lambda、Objective-C和Swift

继承 东风系列导弹是镇国神器。东风41不是突然就造出来的,之前有很多种东风xx导弹,每种导弹都有自己的独特之处,相同之处都具备导弹基本特点。很多工厂有量产磨具的生产线,盖房子就图纸,建筑设计建设都有参考&#xff…...

设置服务器走本地代理

勾选: 然后: git clone https://github.com/rofl0r/proxychains-ng.git./configure --prefix/home/wangguisen/usr --sysconfdir/home/wangguisen/etcmakemake install# 在最后配置成本地代理地址 vim /home/wangguisen/etc/proxychains.confsocks4 17…...

刷题 -哈希

面试面试经典 150 题 - 哈希 383. 赎金信 - 一个哈希表搞定 class Solution { public:bool canConstruct(string ransomNote, string magazine) {int hash[26] {0};for (auto& ch : magazine) {hash[ch - a];}for (auto& ch : ransomNote) {if (--hash[ch - a] < …...

云计算——弹性云计算器(ECS)

弹性云服务器&#xff1a;ECS 概述 云计算重构了ICT系统&#xff0c;云计算平台厂商推出使得厂家能够主要关注应用管理而非平台管理的云平台&#xff0c;包含如下主要概念。 ECS&#xff08;Elastic Cloud Server&#xff09;&#xff1a;即弹性云服务器&#xff0c;是云计算…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始&#xff0c;我们会探讨数据链路层的差错控制功能&#xff0c;差错控制功能的主要目标是要发现并且解决一个帧内部的位错误&#xff0c;我们需要使用特殊的编码技术去发现帧内部的位错误&#xff0c;当我们发现位错误之后&#xff0c;通常来说有两种解决方案。第一…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

如何在看板中有效管理突发紧急任务

在看板中有效管理突发紧急任务需要&#xff1a;设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP&#xff08;Work-in-Progress&#xff09;弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中&#xff0c;设立专门的紧急任务通道尤为重要&#xff0c;这能…...

Cinnamon修改面板小工具图标

Cinnamon开始菜单-CSDN博客 设置模块都是做好的&#xff0c;比GNOME简单得多&#xff01; 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...

镜像里切换为普通用户

如果你登录远程虚拟机默认就是 root 用户&#xff0c;但你不希望用 root 权限运行 ns-3&#xff08;这是对的&#xff0c;ns3 工具会拒绝 root&#xff09;&#xff0c;你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案&#xff1a;创建非 roo…...

linux 下常用变更-8

1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行&#xff0c;YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID&#xff1a; YW3…...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...