当前位置: 首页 > news >正文

于BERT的中文问答系统12

主要改进点

日志配置:

确保日志文件按日期和时间生成,便于追踪不同运行的记录。
数据处理:

增加了对数据加载过程中错误的捕获和日志记录,确保程序能够跳过无效数据并继续运行。
模型训练:

增加了重新训练模型的功能,用户可以选择重新训练现有模型或从头开始训练。
用户交互:

增加了输入验证,确保用户输入的问题不为空。
增加了模糊匹配功能,支持部分输入问题的匹配。
错误处理:

在关键步骤增加了异常捕获和日志记录,提高了程序的健壮性。

import os
import json
import jsonlines
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
import tkinter as tk
from tkinter import filedialog, messagebox
import logging
from difflib import SequenceMatcher
from datetime import datetime# 配置日志
LOGS_DIR = os.path.join(PROJECT_ROOT, 'logs')
os.makedirs(LOGS_DIR, exist_ok=True)def setup_logging():log_file = os.path.join(LOGS_DIR, datetime.now().strftime('%Y-%m-%d/%H-%M-%S/羲和.txt'))os.makedirs(os.path.dirname(log_file), exist_ok=True)logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler(log_file),logging.StreamHandler()])# 获取项目根目录
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
setup_logging()# 数据集类
class XihuaDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.tokenizer = tokenizerself.max_length = max_lengthself.data = self.load_data(file_path)def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef __len__(self):return len(self.data)def __getitem__(self, idx):item = self.data[idx]question = item['question']human_answer = item['human_answers'][0]chatgpt_answer = item['chatgpt_answers'][0]try:inputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)human_inputs = self.tokenizer(human_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)chatgpt_inputs = self.tokenizer(chatgpt_answer, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)except Exception as e:logging.warning(f"跳过无效项 {idx}: {e}")return self.__getitem__((idx + 1) % len(self.data))return {'input_ids': inputs['input_ids'].squeeze(),'attention_mask': inputs['attention_mask'].squeeze(),'human_input_ids': human_inputs['input_ids'].squeeze(),'human_attention_mask': human_inputs['attention_mask'].squeeze(),'chatgpt_input_ids': chatgpt_inputs['input_ids'].squeeze(),'chatgpt_attention_mask': chatgpt_inputs['attention_mask'].squeeze(),'human_answer': human_answer,'chatgpt_answer': chatgpt_answer}# 获取数据加载器
def get_data_loader(file_path, tokenizer, batch_size=8, max_length=128):dataset = XihuaDataset(file_path, tokenizer, max_length)return DataLoader(dataset, batch_size=batch_size, shuffle=True)# 模型定义
class XihuaModel(torch.nn.Module):def __init__(self, pretrained_model_name='F:/models/bert-base-chinese'):super(XihuaModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model_name)self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = outputs.pooler_outputlogits = self.classifier(pooled_output)return logits# 训练函数
def train(model, data_loader, optimizer, criterion, device):model.train()total_loss = 0.0for batch in data_loader:try:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)human_input_ids = batch['human_input_ids'].to(device)human_attention_mask = batch['human_attention_mask'].to(device)chatgpt_input_ids = batch['chatgpt_input_ids'].to(device)chatgpt_attention_mask = batch['chatgpt_attention_mask'].to(device)optimizer.zero_grad()human_logits = model(human_input_ids, human_attention_mask)chatgpt_logits = model(chatgpt_input_ids, chatgpt_attention_mask)human_labels = torch.ones(human_logits.size(0), 1).to(device)chatgpt_labels = torch.zeros(chatgpt_logits.size(0), 1).to(device)loss = criterion(human_logits, human_labels) + criterion(chatgpt_logits, chatgpt_labels)loss.backward()optimizer.step()total_loss += loss.item()except Exception as e:logging.warning(f"跳过无效批次: {e}")return total_loss / len(data_loader)# 主训练函数
def main_train(retrain=False):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device: {device}')tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(device)if retrain:model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=device, weights_only=True))optimizer = optim.Adam(model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()train_data_loader = get_data_loader(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'), tokenizer, batch_size=8, max_length=128)num_epochs = 5for epoch in range(num_epochs):train_loss = train(model, train_data_loader, optimizer, criterion, device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')torch.save(model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")# GUI界面
class XihuaChatbotGUI:def __init__(self, root):self.root = rootself.root.title("羲和聊天机器人")self.tokenizer = BertTokenizer.from_pretrained('F:/models/bert-base-chinese')self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')self.model = XihuaModel(pretrained_model_name='F:/models/bert-base-chinese').to(self.device)self.load_model()self.model.eval()# 加载训练数据集以便在获取答案时使用self.data = self.load_data(os.path.join(PROJECT_ROOT, 'data/train_data.jsonl'))self.create_widgets()def create_widgets(self):self.question_label = tk.Label(self.root, text="问题:")self.question_label.pack()self.question_entry = tk.Entry(self.root, width=50)self.question_entry.pack()self.answer_button = tk.Button(self.root, text="获取回答", command=self.get_answer)self.answer_button.pack()self.answer_label = tk.Label(self.root, text="回答:")self.answer_label.pack()self.answer_text = tk.Text(self.root, height=10, width=50)self.answer_text.pack()self.train_button = tk.Button(self.root, text="训练模型", command=self.train_model)self.train_button.pack()self.retrain_button = tk.Button(self.root, text="重新训练模型", command=lambda: self.train_model(retrain=True))self.retrain_button.pack()def get_answer(self):question = self.question_entry.get()if not question:messagebox.showwarning("输入错误", "请输入问题")returninputs = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation=True, max_length=128)with torch.no_grad():input_ids = inputs['input_ids'].to(self.device)attention_mask = inputs['attention_mask'].to(self.device)logits = self.model(input_ids, attention_mask)if logits.item() > 0:answer_type = "人类回答"else:answer_type = "ChatGPT回答"specific_answer = self.get_specific_answer(question, answer_type)self.answer_text.delete(1.0, tk.END)self.answer_text.insert(tk.END, f"{answer_type}\n{specific_answer}")def get_specific_answer(self, question, answer_type):# 使用模糊匹配查找最相似的问题best_match = Nonebest_ratio = 0.0for item in self.data:ratio = SequenceMatcher(None, question, item['question']).ratio()if ratio > best_ratio:best_ratio = ratiobest_match = itemif best_match:if answer_type == "人类回答":return best_match['human_answers'][0]else:return best_match['chatgpt_answers'][0]return "未找到具体答案"def load_data(self, file_path):data = []if file_path.endswith('.jsonl'):with jsonlines.open(file_path) as reader:for i, item in enumerate(reader):try:data.append(item)except jsonlines.jsonlines.InvalidLineError as e:logging.warning(f"跳过无效行 {i + 1}: {e}")elif file_path.endswith('.json'):with open(file_path, 'r') as f:try:data = json.load(f)except json.JSONDecodeError as e:logging.warning(f"跳过无效文件 {file_path}: {e}")return datadef load_model(self):model_path = os.path.join(PROJECT_ROOT, 'models/xihua_model.pth')if os.path.exists(model_path):self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))logging.info("加载现有模型")else:logging.info("没有找到现有模型,将使用预训练模型")def train_model(self, retrain=False):file_path = filedialog.askopenfilename(filetypes=[("JSONL files", "*.jsonl"), ("JSON files", "*.json")])if not file_path:messagebox.showwarning("文件选择错误", "请选择一个有效的数据文件")returntry:dataset = XihuaDataset(file_path, self.tokenizer)data_loader = DataLoader(dataset, batch_size=8, shuffle=True)# 加载已训练的模型权重if retrain:self.model.load_state_dict(torch.load(os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'), map_location=self.device, weights_only=True))self.model.to(self.device)self.model.train()optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)criterion = torch.nn.BCEWithLogitsLoss()num_epochs = 5for epoch in range(num_epochs):train_loss = train(self.model, data_loader, optimizer, criterion, self.device)logging.info(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')torch.save(self.model.state_dict(), os.path.join(PROJECT_ROOT, 'models/xihua_model.pth'))logging.info("模型训练完成并保存")messagebox.showinfo("训练完成", "模型训练完成并保存")except Exception as e:logging.error(f"模型训练失败: {e}")messagebox.showerror("训练失败", f"模型训练失败: {e}")# 主函数
if __name__ == "__main__":# 启动GUIroot = tk.Tk()app = XihuaChatbotGUI(root)root.mainloop()

相关文章:

于BERT的中文问答系统12

主要改进点 日志配置: 确保日志文件按日期和时间生成,便于追踪不同运行的记录。 数据处理: 增加了对数据加载过程中错误的捕获和日志记录,确保程序能够跳过无效数据并继续运行。 模型训练: 增加了重新训练模型的功…...

基于SpringBoot“花开富贵”花园管理系统【附源码】

效果如下: 系统注册页面 系统首页界面 植物信息详细页面 后台登录界面 管理员主界面 植物分类管理界面 植物信息管理界面 园艺记录管理界面 研究背景 随着城市化进程的加快和人们生活质量的提升,越来越多的人开始追求与自然和谐共生的生活方式&#xf…...

MySQL连接查询:自连接

先看我的表结构 emp表 自连接也就是把一个表看作是两个作用的表就好,也就是说我把emp看作员工表,也看做领导表 自连接 基本语法 select 字段列表 FROM 表A 别名A JOIN 表A 别名B ON 条件;例子1:查询员工 及其 所属领导的名字 select a.n…...

Prometheus+Grafana备忘

Grafana安装 官网 https://grafana.com/grafana/download 官网提供了几种安装方式,我用最简单的 yum install -y https://dl.grafana.com/enterprise/release/grafana-enterprise-11.2.2-1.x86_64.rpm启动 //如果需要在系统启动时自动启动Grafana,可以…...

基于ssm实现的建筑装修图纸管理平台(源码+文档)

项目简介 基于ssm实现的建筑装修图纸管理平台,主要功能如下: 技术栈 后端框框:spring/springmvc/mybatis 前端框架:html/JavaScript/Css/vue/elementui 运行环境:JDK1.8/MySQL5.7/idea(可选&#xff09…...

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-07目录1. Evaluation of Large Language Models for Summarization Tasks in the Medical Domain: A Narrative Review摘要研究…...

Mahalanobis distance 马哈拉诺比斯距离

马哈拉诺比斯距离(Mahalanobis Distance)是一种衡量点与分布之间距离的度量,尤其适用于多维数据。与欧几里得距离不同,马哈拉诺比斯距离考虑了数据的协方差结构,因此在统计分析和异常值检测中非常有用。 定义 给定一…...

R语言绘制直方图

直方图是一种统计图表。它将数据分成若干区间,统计每个区间内数据的数量或频率,用矩形条高度表示。能直观展现数据分布特征,如集中趋势、离散程度等。在数据分析、质量控制、市场调研等领域广泛应用,可帮助人们快速了解数据整体形…...

论文阅读笔记-LogME: Practical Assessment of Pre-trained Models for Transfer Learning

前言 在NLP领域,预训练模型(准确的说应该是预训练语言模型)似乎已经成为各大任务必备的模块了,经常有看到文章称后BERT时代或后XXX时代,分析对比了许多主流模型的优缺点,这些相对而言有些停留在理论层面,可是有时候对于手上正在解决的任务,要用到预训练语言模型时,面…...

求二叉树的带权路径长度

二叉树的带权路径长度(WPL)是二叉树中所有叶结点的带权路径长度之和。给定一棵二叉树T,采用二叉链表存储。结点结构为: 其中叶结点的weight域保存该结点的非负权值。设root为指向T的根结点的指针,请设计求T的WPL的算法…...

Hive数仓操作(十五)

Hive 开窗函数 Hive窗口函数是一种特殊的函数,允许用户在查询中对一组行进行计算,而不仅仅是单独的行。窗口函数可以在 SQL 查询中进行聚合、排名、累积计算等。这使得窗口函数在数据分析和报告生成中非常有用。 窗口函数的基本组成部分 函数类型&…...

No.12 笔记 | 网络基础:ARP DNS TCP/IP与OSI模型

一、计算机网络:安全的基石 1. 网络的本质:数字世界的神经系统 定义:计算机的互联互通,实现资源共享和信息交换组成要素:发送者、接收者、介质、数据、协议(五大要素) 2. 网络架构&#xff1…...

OpenHarmony(鸿蒙南向开发)——轻量系统STM32F407芯片移植案例

往期知识点记录: 鸿蒙(HarmonyOS)应用层开发(北向)知识点汇总 鸿蒙(OpenHarmony)南向开发保姆级知识点汇总~ 持续更新中…… 介绍基于STM32F407IGT6芯片在拓维信息 Niobe407 开发板上移植OpenH…...

简单易懂的springboot整合Camunda 7工作流入门教程

简单易懂的Spring Boot整合Camunda7入门教程 因为关于Spring Boot结合Camunda7的教程在网上比较少,而且很多都写得有点乱,很多概念写得太散乱,讲解不清晰,导致看不懂,本人通过研究学习之后就写出了这篇教学文档。 介…...

LabVIEW提高开发效率技巧----点阵图(XY Graph)

在LabVIEW开发中,点阵图(XY Graph) 是一种强大的工具,尤其适用于需要实时展示大量数据的场景。通过使用点阵图,开发人员能够将实时数据可视化,帮助用户更直观地分析数据变化。 1. 点阵图的优势 点阵图&…...

C++-匿名空间

匿名命名空间(anonymous namespace)是 C 中的一种特性,用于将符号(如变量、函数或类)限制在定义它们的源文件的作用域内。这意味着在该源文件外部,这些符号不可见,从而避免了命名冲突。 1. 定义…...

jdk的安装和环境变量配置

1.将从官网下载好的jdk放在自己想要放的位置,这里的位置是:E:\develop 2.新建一个文件夹用来放安装的jdk,将jdk安装的此目录,这里的位置是:E:\develop\jdk17 3.jdk安装好之后,点击jdk17目录,点…...

继承、Lambda、Objective-C和Swift

继承 东风系列导弹是镇国神器。东风41不是突然就造出来的,之前有很多种东风xx导弹,每种导弹都有自己的独特之处,相同之处都具备导弹基本特点。很多工厂有量产磨具的生产线,盖房子就图纸,建筑设计建设都有参考&#xff…...

设置服务器走本地代理

勾选: 然后: git clone https://github.com/rofl0r/proxychains-ng.git./configure --prefix/home/wangguisen/usr --sysconfdir/home/wangguisen/etcmakemake install# 在最后配置成本地代理地址 vim /home/wangguisen/etc/proxychains.confsocks4 17…...

刷题 -哈希

面试面试经典 150 题 - 哈希 383. 赎金信 - 一个哈希表搞定 class Solution { public:bool canConstruct(string ransomNote, string magazine) {int hash[26] {0};for (auto& ch : magazine) {hash[ch - a];}for (auto& ch : ransomNote) {if (--hash[ch - a] < …...

深入解析NAND Flash基础操作与系统集成——从阵列结构到多Die协同

1. NAND Flash基础结构与工作原理 NAND Flash存储器是现代存储系统的核心组件&#xff0c;从U盘到企业级SSD都依赖这项技术。要理解它的强大之处&#xff0c;得先从它的物理结构说起——想象一个巨大的立体停车场&#xff0c;每个停车位就是一个存储单元&#xff0c;而控制电路…...

FanControl风扇控制软件:从噪音困扰到静音享受的完整指南

FanControl风扇控制软件&#xff1a;从噪音困扰到静音享受的完整指南 【免费下载链接】FanControl.Releases This is the release repository for Fan Control, a highly customizable fan controlling software for Windows. 项目地址: https://gitcode.com/GitHub_Trending…...

SketchUp STL插件:从数字设计到3D打印的无缝桥梁

SketchUp STL插件&#xff1a;从数字设计到3D打印的无缝桥梁 【免费下载链接】sketchup-stl A SketchUp Ruby Extension that adds STL (STereoLithography) file format import and export. 项目地址: https://gitcode.com/gh_mirrors/sk/sketchup-stl SketchUp STL插件…...

Unity URDF导入终极指南:3步快速实现机器人仿真

Unity URDF导入终极指南&#xff1a;3步快速实现机器人仿真 【免费下载链接】URDF-Importer URDF importer 项目地址: https://gitcode.com/gh_mirrors/ur/URDF-Importer Unity URDF Importer是Unity Robotics官方推出的机器人模型导入工具&#xff0c;它能够让你在Unit…...

Miniconda环境迁移实战:如何将CentOS装好的Python环境打包到其他服务器?

Miniconda环境迁移实战&#xff1a;跨服务器Python环境无缝转移指南 当你在CentOS服务器上精心配置了一个完美的Python数据分析环境&#xff0c;却需要在另一台服务器上复现时&#xff0c;难道要重新经历一遍繁琐的安装过程&#xff1f;本文将揭示两种高效可靠的Miniconda环境迁…...

SWF逆向工程标准化文档:JPEXS Free Flash Decompiler实施指南

SWF逆向工程标准化文档&#xff1a;JPEXS Free Flash Decompiler实施指南 【免费下载链接】jpexs-decompiler JPEXS Free Flash Decompiler 项目地址: https://gitcode.com/gh_mirrors/jp/jpexs-decompiler JPEXS Free Flash Decompiler是一款强大的SWF逆向工程工具&…...

LTI系统设计避坑指南:因果性与稳定性在实际工程中的5个关键检查点

LTI系统设计避坑指南&#xff1a;因果性与稳定性在实际工程中的5个关键检查点 在数字信号处理领域&#xff0c;线性时不变&#xff08;LTI&#xff09;系统的设计是工程师日常工作的核心。然而&#xff0c;理论推导与工程实践之间往往存在一道鸿沟——许多在数学上完美的系统模…...

Delphi 终极实战:将自定义控件打包成 BPL,安装到 Delphi 工具栏(组件库实战)

前面我们手写了专属 UI 组件库&#xff08;MyUIClass.pas&#xff09;&#xff0c;但如果你想在以后的项目中一键调用这些控件&#xff0c;而不是每次都复制粘贴代码&#xff0c;那就必须将它们打包成 Delphi 组件包&#xff08;BPL 文件&#xff09;。学会这篇&#xff0c;你将…...

5分钟搞定局域网IP扫描:OpUtils保姆级配置教程(附常见问题排查)

5分钟搞定局域网IP扫描&#xff1a;OpUtils保姆级配置教程&#xff08;附常见问题排查&#xff09; 办公室里突然断网了&#xff1f;打印机死活连不上&#xff1f;新同事的电脑无法接入内网&#xff1f;作为中小企业IT运维人员&#xff0c;这些场景你一定不陌生。别急着打电话求…...

KMS_VL_ALL_AIO激活工具完全指南:从问题诊断到长效管理

KMS_VL_ALL_AIO激活工具完全指南&#xff1a;从问题诊断到长效管理 【免费下载链接】KMS_VL_ALL_AIO Smart Activation Script 项目地址: https://gitcode.com/gh_mirrors/km/KMS_VL_ALL_AIO 如何诊断Windows/Office激活失败的核心原因&#xff1f; 1.1 激活失败的三大…...