当前位置: 首页 > article >正文

Spider 数据集上实现nlp2sql训练任务

NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中,例如数据分析、业务智能和问答系统。

任务目标
  • 理解自然语言: 理解用户输入的自然语言问题,包括意图、实体和上下文。
  • 生成 SQL 查询: 将理解后的信息转换为正确的 SQL 查询,以从数据库中检索所需的数据。

例如

输入: 用户的自然语言问题,“获取 Gelderland 区的总人口。”

输出: 对应的 SQL 查询,SELECT population FROM districts WHERE name = 'Gelderland';

Spider 是一个难度最大数据集

耶鲁大学在2018年新提出的一个大规模的NL2SQL(Text-to-SQL)数据集。
该数据集包含了10,181条自然语言问句、分布在200个独立数据库中的5,693条SQL,内容覆盖了138个不同的领域。
涉及的SQL语法最全面,是目前难度最大的NL2SQL数据集。

下载查看spider数据集内容

Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer

Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer

Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc

...

首先需要转换为Spider的标准格式(参考tables.jsontrain.json):

{"db_id": "concert_singer","question": "Show name, country, age...","query": "SELECT name, country, age FROM singer ORDER BY age DESC","schema": {"table_names": ["singer"],"column_names": [[0, "name", "text"],[0, "country", "text"],[0, "age", "int"]]}
}

拆分为table.json的原因可能涉及到数据组织和重用。每个数据库的结构(表、列、外键)在多个问题中都会被重复使用。如果每个问题都附带完整的schema信息,会导致数据冗余,增加存储和处理的开销。所以,将schema单独存储为table.json,可以让不同的数据条目引用同一个数据库模式,减少重复数据。拆分后的结构需要更高效的数据管理,例如在训练模型时,根据每个问题的db_id去table.json中查找对应的schema信息。这样做的好处是当多个问题属于同一个数据库时,不需要每次都重复加载schema提高了效率。

column_names 表示数据库表中每一列的详细信息。具体来说,column_names 是一个列表,其中每个元素都是一个包含三个部分的子列表:

  1. 表索引(0):表示该列属于哪个表。在这个例子中,所有列都属于第一个表(索引为 0)。
  2. 列名("name"、"country"、"age"):表示列的名称。
  3. 数据类型("text"、"int"):表示该列的数据类型,例如文本(text)或整数(int)。

实现下面逻辑转换原始数据

def extract_columns_from_sql(sql):# 使用正则表达式匹配 SELECT 语句中的列名match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)if match:# 提取列名columns = match.group(1).split(",")# 构建 column_names 列表column_names = []for index, column in enumerate(columns):column = column.strip()  # 去除多余的空格data_type = "text"  # 默认数据类型为 text,可以根据需要修改# 添加到 column_names 列表,假设所有列类型为 textcolumn_names.append([0, column, data_type])return column_namesreturn []# 从 dev.sql 文件读取数据
def load_sql_data(file_path):data_list = []with open(file_path, 'r', encoding='utf-8') as f:  # 指定编码为 UTF-8lines = f.readlines()for i in range(0, len(lines), 3):  # 每三行一组question_line = lines[i].strip()sql_line = lines[i + 1].strip()if not question_line or not sql_line:continue# 提取问题和 SQLquestion = question_line.split(': ', 1)[1].strip()  # 获取问题内容sql = sql_line.split(': ', 1)[1].strip()  # 获取 SQL 查询# 提取表名db_id = question_line.split('|||')[-1].strip()  # 从问题行获取表名question = question.split('|||')[0].strip()target_sql = preprocess(question, db_id, sql)data_list.append({"input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}","target_sql": json.dumps(target_sql)  # 将目标 SQL 转换为 JSON 格式字符串})return data_list

选择Tokenizer.from_pretrained("t5-base") 是用于加载 T5(Text-to-Text Transfer Transformer)模型的分词器。T5 是一个强大的自然语言处理模型,能够处理各种文本任务(如翻译、摘要、问答等),并且将所有任务视为文本到文本的转换。

from transformers import T5Tokenizertokenizer = T5Tokenizer.from_pretrained("t5-base")def preprocess(question, db_id, sql):# 提取列名column_names = extract_columns_from_sql(sql)# 构建目标格式target_sql = {"db_id": db_id,"question": question,"query": sql,"schema": {"table_names": [db_id],"column_names": column_names}}return target_sql# 示例数据
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)

所有nlp任务都涉及的需要token化,使用t5-base 做tokenize

def tokenize_function(examples):model_inputs = tokenizer(examples["input_text"],max_length=512,truncation=True,padding="max_length")with tokenizer.as_target_tokenizer():labels = tokenizer(examples["target_sql"],max_length=512,truncation=True,padding="max_length")model_inputs["labels"] = labels["input_ids"]return model_inputs

使用 tokenizer.as_target_tokenizer() 上下文管理器,确保目标文本(即 SQL 查询)被正确处理。目标文本也经过编码,转换为 token IDs,并同样进行填充和截断。将目标文本的编码结果(token IDs)存储在 model_inputs["labels"] 中。这是模型在训练时需要的输出,用于计算损失。最终返回一个字典 model_inputs,它包含了模型的输入和对应的标签。这种结构使得模型在训练时可以直接使用。

最后组织下训练代码

tokenized_datasets = dataset.map(tokenize_function, batched=True)# 加载模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")# 训练参数
training_args = Seq2SeqTrainingArguments(output_dir="./results",evaluation_strategy="epoch",learning_rate=3e-5,per_device_train_batch_size=8,per_device_eval_batch_size=8,num_train_epochs=100,predict_with_generate=True,run_name="spider"
)# 开始训练
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,data_collator=DataCollatorForSeq2Seq(tokenizer)
)trainer.train()

这里使用的是Seq2SeqTrainer, 它是 Hugging Face 的 transformers 库中用于序列到序列(Seq2Seq)任务的训练器。它为处理诸如翻译、文本生成和问答等任务提供了一个高层次的接口,简化了训练过程。以下是 Seq2SeqTrainer 的主要功能和特点:

  1. 简化训练流程Seq2SeqTrainer 封装了许多常见的训练步骤,如数据加载、模型训练、评估和预测,使得用户可以更专注于模型和数据,而不必处理繁琐的训练细节。

  2. 支持多种训练参数: 通过 Seq2SeqTrainingArguments 类,可以灵活配置训练参数,如学习率、批量大小、训练轮数、评估策略等。

  3. 自动处理填充和截断: 在处理输入和输出序列时,Seq2SeqTrainer 可以自动填充和截断序列,以确保它们适应模型的输入要求。

  4. 集成评估和监控: 支持在训练过程中进行模型评估,并可以根据评估指标(如损失)监控训练进度。用户可以设置评估频率和评估数据集

开始训练,进行100次epoch

训练监控在 Weights & Biases ,Seq2SeqTrainer 能够向 Weights & Biases (wandb) 传输训练监控数据,主要是因为它内置了与 wandb 的集成。以下是一些关键点,解释了这一过程:

  1. 自动集成:当你使用 Seq2SeqTrainer 时,它会自动检测 wandb 的安装并在初始化时配置相关设置。这意味着你无需手动设置 wandb。

  2. 回调功能Trainer 类提供了回调功能,可以在训练过程中记录各种指标(如损失、准确率等)。这些指标会被自动发送到 wandb。

  3. 配置管理training_args 中的参数可以指定 wandb 的项目名称、运行名称等,从而更好地组织和管理实验。

  4. 训练循环:在每个训练和评估周期结束时,Trainer 会调用相应的回调函数,将重要的训练信息(如损失、学习率等)记录到 wandb。

  5. 可视化:通过 wandb,你可以实时监控训练过程,包括损失曲线、模型性能等,帮助你更好地理解模型的训练动态。

多次试验还可以比较训练性能

训练结束, 损失收敛到0.05410315271151268

{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00,  2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs

测试下预测能力

import os
from transformers import T5Tokenizer, T5ForConditionalGeneration# 设置 NCCL 环境变量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"# 加载分词器
tokenizer = T5Tokenizer.from_pretrained("t5-base")model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")def generate_sql(question, db_id):input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"input_ids = tokenizer.encode(input_text, return_tensors="pt")  # 使▒~T▒ PyTorch ▒~Z~D▒| ▒~G~O▒| ▒▒~Ooutput = model.generate(input_ids,max_length=512,num_beams=5,  # 或者尝试其他解码策略early_stopping=True)print('output', output)generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)return generated_sqlquestion = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)

输出结果

evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]

相关文章:

Spider 数据集上实现nlp2sql训练任务

NLP2SQL&#xff08;自然语言处理到 SQL 查询的转换&#xff09;是一个重要的自然语言处理&#xff08;NLP&#xff09;任务&#xff0c;其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用&#xff0c;尤其是在与数据库交互的场景中&…...

数据结构——【树模板】

#思路 1、 结点类&#xff1a; 属性&#xff1a;数据&#xff0c;孩子结点列表 功能1&#xff1a;认孩子&#xff1a; 前提&#xff1a;在父子都是结点的情况下 2. 树类&#xff1a; 属性&#xff1a;根节点&#xff0c;生成初始化的总结点 功能1&#xff1a;获取初始化…...

R 数组:高效数据处理的基础

R 数组&#xff1a;高效数据处理的基础 引言 在数据科学和统计分析领域&#xff0c;R 语言以其强大的数据处理和分析能力而备受推崇。R 数组是 R 语言中用于存储和操作数据的基本数据结构。本文将详细介绍 R 数组的创建、操作和优化&#xff0c;帮助读者掌握 R 数组的使用技巧…...

【DeepSeek】DeepSeek概述 | 本地部署deepseek

目录 1 -> 概述 1.1 -> 技术特点 1.2 -> 模型发布 1.3 -> 应用领域 1.4 -> 优势与影响 2 -> 本地部署 2.1 -> 安装ollama 2.2 -> 部署deepseek-r1模型 1 -> 概述 DeepSeek是由中国的深度求索公司开发的一系列人工智能模型&#xff0c;以其…...

npm link,lerna,pnmp workspace区别

npm link、Lerna 和 pnpm workspace 是三种不同的工具/功能&#xff0c;用于处理 JavaScript 项目的依赖管理和 Monorepo 场景。它们的核心区别如下&#xff1a; 1. npm link 用途 本地调试依赖&#xff1a;将本地开发的包&#xff08;Package A&#xff09;临时链接到另一个…...

ASP.NET Core 使用 WebClient 从 URL 下载

本文使用 ASP .NET Core 3.1&#xff0c;但它在.NET 5、 .NET 6和.NET 8上也同样适用。如果使用较旧的.NET Framework&#xff0c;请参阅本文&#xff0c;不过&#xff0c;变化不大。 如果想要从 URL 下载任何数据类型&#xff0c;请参阅本文&#xff1a;HttpClient 使用WebC…...

【CubeMX-HAL库】STM32F407—无刷电机学习笔记

目录 简介&#xff1a; 学习资料&#xff1a; 跳转目录&#xff1a; 一、工程创建 二、板载LED 三、用户按键 四、蜂鸣器 1.完整IO控制代码 五、TFT彩屏驱动 六、ADC多通道 1.通道确认 2.CubeMX配置 ①开启对应的ADC通道 ②选择规则组通道 ③开启DMA ④开启ADC…...

vue3 点击图标从相册选择二维码图片,并使用jsqr解析二维码(含crypto-js加密解密过程)

vue3 点击图标从相册选择二维码图片&#xff0c;并使用jsqr解析二维码&#xff08;含crypto-js加密解密过程&#xff09; 1.安装 jsqr 和 crypto-js npm install -d jsqr npm install crypto-js2.在util目录下新建encryptionHelper.js文件&#xff0c;写加密解密方法。 // e…...

kafka 3.5.0 raft协议安装

前言 最近做项目&#xff0c;需要使用kafka进行通信&#xff0c;且只能使用kafka&#xff0c;笔者没有测试集群&#xff0c;就自己搭建了kafka集群&#xff0c;实际上笔者在很早之前就搭建了&#xff0c;因为当时还是zookeeper&#xff08;简称ZK&#xff09;注册元数据&#…...

用Kibana实现Elasticsearch索引的增删改查:实战指南

在大数据时代&#xff0c;Elasticsearch&#xff08;简称 ES&#xff09;和 Kibana 作为强大的数据搜索与可视化工具&#xff0c;受到了众多开发者的青睐。Kibana 提供了一个直观的界面&#xff0c;可以方便地对 Elasticsearch 中的数据进行操作。本文将详细介绍如何使用 Kiban…...

Redis基础笔记

一、基础知识 连接方式 CLI (Command Line Interface)API (Application Programming Interface)GUI (Graphical User Interface) 启动 redis-server连接到Redis&#xff08;Redis CLI Client&#xff09; redis redis-cli telnet 127.0.0.1 6379退出 quit/exit查看过期时…...

前后端服务配置

1、安装虚拟机&#xff08;VirtualBox或者vmware&#xff09;&#xff0c;在虚拟机上配置centos(选择你需要的Linux版本)&#xff0c;配置如nginx服务器等 1.1 VMware 下载路径Sign In注册下载 1.2 VirtualBox 下载路径https://www.virtualbox.org/wiki/Downloads 2、配置服…...

springboot 事务管理

在Spring Boot中&#xff0c;事务管理是通过Spring框架的事务管理模块来实现的。Spring提供了声明式事务管理和编程式事务管理两种方式。通常&#xff0c;我们使用声明式事务管理&#xff0c;因为它更简洁且易于维护。 1. 声明式事务管理 声明式事务管理是通过注解来实现的。…...

基于Typescript,使用Vite构建融合Vue.js的Babylon.js开发环境

一、创建Vite项目 使用Vite初始化一个VueTypeScript项目&#xff1a; npm create vitelatest my-babylon-app -- --template vue-ts cd my-babylon-app npm create vitelatest my-babylon-app -- --template vue-ts 命令用于快速创建一个基于 Vite 的 Vue TypeScript 项目。…...

在阿里云ECS上一键部署DeepSeek-R1

DeepSeek-R1 是一款开源模型&#xff0c;也提供了 API(接口)调用方式。据 DeepSeek介绍&#xff0c;DeepSeek-R1 后训练阶段大规模使用了强化学习技术&#xff0c;在只有极少标注数据的情况下提升了模型推理能力&#xff0c;该模型性能对标 OpenAl o1 正式版。DeepSeek-R1 推出…...

git SourceTree 使用

Source Tree 使用原理 文件的状态 创建仓库和提交 验证 再克隆的时候发发现一个问题&#xff0c;就是有一个 这个验证&#xff0c;起始很简单 就是 gitee 的账号和密码&#xff0c;但是要搞清楚的是账号不是名称&#xff0c;我之前一直再使用名称登录老是出问题 这个很简单的…...

游戏引擎学习第94天

仓库:https://gitee.com/mrxiao_com/2d_game_2 回顾上周的渲染器工作 完成一款游戏的开发&#xff0c;完全不依赖任何库和引擎&#xff0c;这样我们能够全面掌握游戏的开发过程&#xff0c;确保没有任何细节被隐藏。我们将深入探索每一个环节&#xff0c;犹如拿着手电筒翻看床…...

win32汇编环境,结构体的使用示例二

;运行效果 ;win32汇编环境,结构体的使用示例二 ;举例说明结构体的定义&#xff0c;如何访问其中的成员&#xff0c;使用assume指令指向某个结构体&#xff0c;计算结构数组所需的偏移量得到某个成员值等 ;直接抄进RadAsm可编译运行。重要部分加备注。 ;下面为asm文件 ;>>…...

DeepSeek从入门到精通教程PDF清华大学出版

DeepSeek爆火以来&#xff0c;各种应用方式层出不穷&#xff0c;对于很多人来说&#xff0c;还是特别模糊&#xff0c;有种雾里看花水中望月的感觉。 最近&#xff0c;清华大学新闻与传播学院新媒体研究中心&#xff0c;推出了一篇DeepSeek的使用教程&#xff0c;从最基础的是…...

【PDF提取内容】如何批量提取PDF里面的文字内容,把内容到处表格或者批量给PDF文件改名,基于C++的实现方案和步骤

以下分别介绍基于 C 批量提取 PDF 里文字内容并导出到表格&#xff0c;以及批量给 PDF 文件改名的实现方案、步骤和应用场景。 批量提取 PDF 文字内容并导出到表格 应用场景 文档数据整理&#xff1a;在处理大量学术论文、报告等 PDF 文档时&#xff0c;需要提取其中的关键信…...

SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来Matlab实现

SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来Matlab实现 目录 SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来Matlab实现预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现SSA-TCN麻雀算法优化时间卷积神经网络时间序列预测未来&#xff08;优…...

大模型推理——MLA实现方案

1.整体流程 先上一张图来整体理解下MLA的计算过程 2.实现代码 import math import torch import torch.nn as nn# rms归一化 class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps1e-6):super().__init__()self.weight nn.Pa…...

深度学习-神经机器翻译模型

以下为你介绍使用Python和深度学习框架Keras&#xff08;基于TensorFlow后端&#xff09;实现一个简单的神经机器翻译模型的详细步骤和代码示例&#xff0c;该示例主要处理英 - 法翻译任务。 1. 安装必要的库 首先&#xff0c;确保你已经安装了以下库&#xff1a; pip insta…...

Android Camera API 介绍

一 StreamConfigurationMap 1. StreamConfigurationMap 的作用 StreamConfigurationMap 是 Android Camera2 API 中的一个核心类&#xff0c;用于描述相机设备支持的输出流配置&#xff0c;包含以下信息&#xff1a; 支持的格式与分辨率&#xff1a;例如 YUV_420_888、JPEG、…...

大数据项目2:基于hadoop的电影推荐和分析系统设计和实现

前言 大数据项目源码资料说明&#xff1a; 大数据项目资料来自我多年工作中的开发积累与沉淀。 我分享的每个项目都有完整代码、数据、文档、效果图、部署文档及讲解视频。 可用于毕设、课设、学习、工作或者二次开发等&#xff0c;极大提升效率&#xff01; 1、项目目标 本…...

Windows逆向工程入门之汇编环境搭建

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 Visual Studio逆向工程配置 基础环境搭建 Visual Studio 官方下载地址安装配置选项(后期可随时通过VS调整) 使用C的桌面开发 拓展可选选项 MASM汇编框架 配置MASM汇编项目 创建新项目 选择空…...

gc buffer busy acquire导致的重大数据库性能故障

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 作者&#xff1a;IT邦德 中国DBA联盟(ACDU)成员&#xff0c;10余年DBA工作经验 Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主&#xff0c;全网粉丝10万 擅长主流Oracle、MySQL、PG、高斯…...

前端学习-页面加载事件和页面滚动事件(三十二)

目录 前言 页面加载事件和页面滚动事件 页面加载事件 load事件 语法 注意 DOMContentLoaded事件 语法 总结 页面加载事件有哪两个?如何添加? load 事件 DOMContentLoaded事件 页面滚动事件 存在原因 scroll监听整个页面滚动 页面滚动事件-获取位置 scrollLef…...

C++:将函数参数定义为const T的意义

C++很多函数的参数都会定义为const T&,那么这么做的意义是什么呢? 避免拷贝:通过引用传递参数而不是值传递,可以避免对象的拷贝,从而提高性能,特别是当对象较大时。 保护数据:使用const关键字可以防止函数修改传入的参数,确保数据的安全性和一致性。 对于保护数据这…...

Formily 如何进行表单验证

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…...