当前位置: 首页 > news >正文

于BERT的中文问答系统12

主要改进点

日志配置:

确保日志文件按日期和时间生成,便于追踪不同运行的记录。
数据处理:

增加了对数据加载过程中错误的捕获和日志记录,确保程序能够跳过无效数据并继续运行。
模型训练:

增加了重新训练模型的功能,用户可以选择重新训练现有模型或从头开始训练。
用户交互:

增加了输入验证,确保用户输入的问题不为空。
增加了模糊匹配功能,支持部分输入问题的匹配。
错误处理:

在关键步骤增加了异常捕获和日志记录,提高了程序的健壮性。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox
import logging
from difflib import SequenceMatcher
from datetime import datetime# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))os.makedirs(os.path.dirname(log_file), exist_ok=True)logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device):model.train()total_loss = 0.0for batch in data_loader:try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 5for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加载训练数据集以便在获取答案时使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))self.create_widgets()def create_widgets(self):self.question_label = tk.Label(self.root, text="问题:")self.question_label.pack()self.question_entry = tk.Entry(self.root, width=50)self.question_entry.pack()self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)self.answer_button.pack()self.answer_label = tk.Label(self.root, text="回答:")self.answer_label.pack()self.answer_text = tk.Text(self.root, height=10, width=50)self.answer_text.pack()self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)self.train_button.pack()self.retrain_button = tk.Button(self.root, text="重新训练模型", command=lambda: self.train_model(retrain=True))self.retrain_button.pack()def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "人类回答"else:answer_type = "ChatGPT回答"specific_answer = self.get_specific_answer(question, answer_type)self.answer_text.delete(1.0, tk.END)self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "人类回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "未找到具体答案"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 5for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")messagebox.showerror("训练失败", f"模型训练失败: {e}")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

相关文章:

于BERT的中文问答系统12

主要改进点 日志配置: 确保日志文件按日期和时间生成,便于追踪不同运行的记录。 数据处理: 增加了对数据加载过程中错误的捕获和日志记录,确保程序能够跳过无效数据并继续运行。 模型训练: 增加了重新训练模型的功…...

基于SpringBoot“花开富贵”花园管理系统【附源码】

效果如下: 系统注册页面 系统首页界面 植物信息详细页面 后台登录界面 管理员主界面 植物分类管理界面 植物信息管理界面 园艺记录管理界面 研究背景 随着城市化进程的加快和人们生活质量的提升,越来越多的人开始追求与自然和谐共生的生活方式&#xf…...

MySQL连接查询:自连接

先看我的表结构 emp表 自连接也就是把一个表看作是两个作用的表就好,也就是说我把emp看作员工表,也看做领导表 自连接 基本语法 select 字段列表 FROM 表A 别名A JOIN 表A 别名B ON 条件;例子1:查询员工 及其 所属领导的名字 select a.n…...

Prometheus+Grafana备忘

Grafana安装 官网 https://grafana.com/grafana/download 官网提供了几种安装方式,我用最简单的 yum install -y https://dl.grafana.com/enterprise/release/grafana-enterprise-11.2.2-1.x86_64.rpm启动 //如果需要在系统启动时自动启动Grafana,可以…...

基于ssm实现的建筑装修图纸管理平台(源码+文档)

项目简介 基于ssm实现的建筑装修图纸管理平台,主要功能如下: 技术栈 后端框框:spring/springmvc/mybatis 前端框架:html/JavaScript/Css/vue/elementui 运行环境:JDK1.8/MySQL5.7/idea(可选&#xff09…...

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07目录1. Evaluation of Large Language Models for Summarization Tasks in the Medical Domain: A Narrative Review摘要研究…...

Mahalanobis distance 马哈拉诺比斯距离

马哈拉诺比斯距离(Mahalanobis Distance)是一种衡量点与分布之间距离的度量,尤其适用于多维数据。与欧几里得距离不同,马哈拉诺比斯距离考虑了数据的协方差结构,因此在统计分析和异常值检测中非常有用。 定义 给定一…...

R语言绘制直方图

直方图是一种统计图表。它将数据分成若干区间,统计每个区间内数据的数量或频率,用矩形条高度表示。能直观展现数据分布特征,如集中趋势、离散程度等。在数据分析、质量控制、市场调研等领域广泛应用,可帮助人们快速了解数据整体形…...

论文阅读笔记-LogME: Practical Assessment of Pre-trained Models for Transfer Learning

前言 在NLP领域,预训练模型(准确的说应该是预训练语言模型)似乎已经成为各大任务必备的模块了,经常有看到文章称后BERT时代或后XXX时代,分析对比了许多主流模型的优缺点,这些相对而言有些停留在理论层面,可是有时候对于手上正在解决的任务,要用到预训练语言模型时,面…...

求二叉树的带权路径长度

二叉树的带权路径长度(WPL)是二叉树中所有叶结点的带权路径长度之和。给定一棵二叉树T,采用二叉链表存储。结点结构为: 其中叶结点的weight域保存该结点的非负权值。设root为指向T的根结点的指针,请设计求T的WPL的算法…...

Hive数仓操作(十五)

Hive 开窗函数 Hive窗口函数是一种特殊的函数,允许用户在查询中对一组行进行计算,而不仅仅是单独的行。窗口函数可以在 SQL 查询中进行聚合、排名、累积计算等。这使得窗口函数在数据分析和报告生成中非常有用。 窗口函数的基本组成部分 函数类型&…...

No.12 笔记 | 网络基础:ARP DNS TCP/IP与OSI模型

一、计算机网络:安全的基石 1. 网络的本质:数字世界的神经系统 定义:计算机的互联互通,实现资源共享和信息交换组成要素:发送者、接收者、介质、数据、协议(五大要素) 2. 网络架构&#xff1…...

OpenHarmony(鸿蒙南向开发)——轻量系统STM32F407芯片移植案例

往期知识点记录: 鸿蒙(HarmonyOS)应用层开发(北向)知识点汇总 鸿蒙(OpenHarmony)南向开发保姆级知识点汇总~ 持续更新中…… 介绍基于STM32F407IGT6芯片在拓维信息 Niobe407 开发板上移植OpenH…...

简单易懂的springboot整合Camunda 7工作流入门教程

简单易懂的Spring Boot整合Camunda7入门教程 因为关于Spring Boot结合Camunda7的教程在网上比较少,而且很多都写得有点乱,很多概念写得太散乱,讲解不清晰,导致看不懂,本人通过研究学习之后就写出了这篇教学文档。 介…...

LabVIEW提高开发效率技巧----点阵图(XY Graph)

在LabVIEW开发中,点阵图(XY Graph) 是一种强大的工具,尤其适用于需要实时展示大量数据的场景。通过使用点阵图,开发人员能够将实时数据可视化,帮助用户更直观地分析数据变化。 1. 点阵图的优势 点阵图&…...

C++-匿名空间

匿名命名空间(anonymous namespace)是 C 中的一种特性,用于将符号(如变量、函数或类)限制在定义它们的源文件的作用域内。这意味着在该源文件外部,这些符号不可见,从而避免了命名冲突。 1. 定义…...

jdk的安装和环境变量配置

1.将从官网下载好的jdk放在自己想要放的位置,这里的位置是:E:\develop 2.新建一个文件夹用来放安装的jdk,将jdk安装的此目录,这里的位置是:E:\develop\jdk17 3.jdk安装好之后,点击jdk17目录,点…...

继承、Lambda、Objective-C和Swift

继承 东风系列导弹是镇国神器。东风41不是突然就造出来的,之前有很多种东风xx导弹,每种导弹都有自己的独特之处,相同之处都具备导弹基本特点。很多工厂有量产磨具的生产线,盖房子就图纸,建筑设计建设都有参考&#xff…...

设置服务器走本地代理

勾选: 然后: git clone https://github.com/rofl0r/proxychains-ng.git./configure --prefix/home/wangguisen/usr --sysconfdir/home/wangguisen/etcmakemake install# 在最后配置成本地代理地址 vim /home/wangguisen/etc/proxychains.confsocks4 17…...

刷题 -哈希

面试面试经典 150 题 - 哈希 383. 赎金信 - 一个哈希表搞定 class Solution { public:bool canConstruct(string ransomNote, string magazine) {int hash[26] {0};for (auto& ch : magazine) {hash[ch - a];}for (auto& ch : ransomNote) {if (--hash[ch - a] < …...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

Appium+python自动化(十六)- ADB命令

简介 Android 调试桥(adb)是多种用途的工具&#xff0c;该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具&#xff0c;其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利&#xff0c;如安装和调试…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程&#xff0c;并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令&#xff0c;把数据流转换成Message&#xff0c;状态转变流程是&#xff1a;State::Created 》 St…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架&#xff0c;相比 MapReduce 具有以下核心优势&#xff1a; 内存计算&#xff1a;数据可常驻内存&#xff0c;迭代计算性能提升 10-100 倍&#xff08;文档段落&#xff1a;3-79…...

React19源码系列之 事件插件系统

事件类别 事件类型 定义 文档 Event Event 接口表示在 EventTarget 上出现的事件。 Event - Web API | MDN UIEvent UIEvent 接口表示简单的用户界面事件。 UIEvent - Web API | MDN KeyboardEvent KeyboardEvent 对象描述了用户与键盘的交互。 KeyboardEvent - Web…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

微信小程序云开发平台MySQL的连接方式

注&#xff1a;微信小程序云开发平台指的是腾讯云开发 先给结论&#xff1a;微信小程序云开发平台的MySQL&#xff0c;无法通过获取数据库连接信息的方式进行连接&#xff0c;连接只能通过云开发的SDK连接&#xff0c;具体要参考官方文档&#xff1a; 为什么&#xff1f; 因为…...

Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?

Redis 的发布订阅&#xff08;Pub/Sub&#xff09;模式与专业的 MQ&#xff08;Message Queue&#xff09;如 Kafka、RabbitMQ 进行比较&#xff0c;核心的权衡点在于&#xff1a;简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...