当前位置: 首页 > news >正文

使用无标注的数据训练Bert

文章目录

  • 1、准备用于训练的数据集
  • 2、处理数据集
  • 3、克隆代码
  • 4、运行代码
  • 5、将ckpt模型转为bin模型使其可在pytorch中运用

Bert官方仓库:https://github.com/google-research/bert

1、准备用于训练的数据集

此处准备的是BBC news的数据集,下载链接:https://www.kaggle.com/datasets/gpreda/bbc-news
原数据集格式(.csv):
在这里插入图片描述

2、处理数据集

训练Bert时需要预处理数据,将数据处理成https://github.com/google-research/bert/blob/master/sample_text.txt中所示格式,如下所示:
在这里插入图片描述
数据预处理代码参考:

import pandas as pd# 读取BBC-news数据集
df = pd.read_csv("../../bbc_news.csv")
# print(df['title'])
l1 = []
l2 = []
cnt = 0
for line in df['title']:l1.append(line)for line in df['description']:l2.append(line)
# cnt=0
f = open("test1.txt", 'w+', encoding='utf8')
for i in range(len(l1)):s = l1[i] + " " + l2[i] + '\n'f.write(s)# cnt+=1# if cnt>10: break
f.close()
# print(l1)

处理完后的BBC news数据集格式如下所示:
在这里插入图片描述

3、克隆代码

使用git克隆仓库代码
http:

git clone https://github.com/google-research/bert.git

或ssh:

git clone git@github.com:google-research/bert.git

4、运行代码

先下载Bert模型:BERT-Base, Uncased
该文件中有以下文件:
在这里插入图片描述
运行代码:
在Teminal中运行:

python create_pretraining_data.py \--input_file=./sample_text.txt(数据集地址) \--output_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \--vocab_file=$BERT_BASE_DIR/vocab.txt(vocab.txt文件位置) \--do_lower_case=True \--max_seq_length=128 \--max_predictions_per_seq=20 \--masked_lm_prob=0.15 \--random_seed=12345 \--dupe_factor=5

训练模型:

python run_pretraining.py \--input_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \--output_dir=/tmp/pretraining_output(训练后模型保存位置) \--do_train=True \--do_eval=True \--bert_config_file=$BERT_BASE_DIR/bert_config.json(bert_config.json文件位置) \--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt(如果要从头开始的预训练,则去掉这行) \--train_batch_size=32 \--max_seq_length=128 \--max_predictions_per_seq=20 \--num_train_steps=20 \--num_warmup_steps=10 \--learning_rate=2e-5

训练完成后模型输出示例:

***** Eval results *****global_step = 20loss = 0.0979674masked_lm_accuracy = 0.985479masked_lm_loss = 0.0979328next_sentence_accuracy = 1.0next_sentence_loss = 3.45724e-05

要注意应该能够在至少具有 12GB RAM 的 GPU 上运行,不然会报错显存不足。
使用未标注数据训练BERT

5、将ckpt模型转为bin模型使其可在pytorch中运用

上一步训练好后准备好训练出来的model.ckpt-20.index文件和Bert模型中的bert_config.json文件

创建python文件convert_bert_original_tf_checkpoint_to_pytorch.py:

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""import argparseimport torchfrom transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logginglogging.set_verbosity_info()def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):# Initialise PyTorch modelconfig = BertConfig.from_json_file(bert_config_file)print("Building PyTorch model from configuration: {}".format(str(config)))model = BertForPreTraining(config)# Load weights from tf checkpointload_tf_weights_in_bert(model, config, tf_checkpoint_path)# Save pytorch-modelprint("Save PyTorch model to {}".format(pytorch_dump_path))torch.save(model.state_dict(), pytorch_dump_path)if __name__ == "__main__":parser = argparse.ArgumentParser()# Required parametersparser.add_argument("--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path.")parser.add_argument("--bert_config_file",default=None,type=str,required=True,help="The config json file corresponding to the pre-trained BERT model. \n""This specifies the model architecture.",)parser.add_argument("--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model.")args = parser.parse_args()convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

在Terminal中运行以下命令:

python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index(.ckpt.index文件位置) \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json(bert_config.json文件位置)  \
--pytorch_dump_path  Models/chinese_L-12_H-768_A-12/pytorch_model.bin(输出的.bin模型文件位置)

以上命令最好在一行中运行:

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json  --pytorch_dump_path  pytorch_model.bin

然后就可以得到bin文件了
在这里插入图片描述

【BERT for Tensorflow】本地ckpt文件的BERT使用

相关文章:

使用无标注的数据训练Bert

文章目录 1、准备用于训练的数据集2、处理数据集3、克隆代码4、运行代码5、将ckpt模型转为bin模型使其可在pytorch中运用 Bert官方仓库:https://github.com/google-research/bert 1、准备用于训练的数据集 此处准备的是BBC news的数据集,下载链接&…...

《Netty》从零开始学netty源码(五十二)之PoolThreadCache

PoolThreadCache Netty有一个大的公共内存容器PoolArena,用来管理从操作系统中获得的内存,在高并发下如果所有线程都去这个大容器获取内存它的压力是非常大的,所以Netty为每个线程建立了一个本地缓存,即PoolThreadCache&#xff…...

放弃40k月薪的程序员工作,选择公务员,我来分享一下看法

我有一个朋友,拒绝了我为他提供的4万薪水的工作,去了一个体制内的银行,做程序员,即使薪水减半。他之前在北京一家大公司做程序员,一个月30k。当我开始创业时,我拉他来和我一起干,但那时我们太小…...

【MybatisPlus】高级版可视化、可配置 自动生成代码

今天看别人使用了一个更加智能的生成代码工具,可视化、可配置策略,非常方便,配置一次,在哪都可以使用,也不会跟项目藕合下面简单说一下使用方式。 1、介绍mybatis-plus-generator-ui 主要是封装了mybatis-plus-gener…...

【图像分割】【深度学习】Windows10下f-BRS官方代码Pytorch实现

【图像分割】【深度学习】Windows10下f-BRS官方代码Pytorch实现 提示:最近开始在【图像分割】方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法。 文章目录 【图像分割】【深度学习】Windows10下f-BRS官方代码Pytorch实现前言f-BRS模型运行环境安装1.下载源码并…...

2023/5/4总结

刷题&#xff1a; 第二周任务 - Virtual Judge (vjudge.net) 这一题用到了素筛,然后穷举即可 #include<stdio.h> #define Maxsize 500000 int a[Maxsize]; long long b[Maxsize]; long long max0; int sushu() {a[0]a[1]0;int i,j,k;for(i2,k0;i<Maxsize;i){if(a[i…...

electron+vue3全家桶+vite项目搭建【17】pinia状态持久化

文章目录 引入问题演示实现效果展示、实现步骤1.封装状态初始化函数2.封装状态更新同步函数3.完整代码 引入 上一篇文章我们已经实现了electron多窗口中&#xff0c;pinia的状态同步&#xff0c;但你会发现&#xff0c;如果我们在一个窗口里面修改了状态&#xff0c;然后再打开…...

java基础入门-05-【面向对象进阶(static继承)】

Java基础入门-05-【面向对象进阶&#xff08;static&继承&#xff09;】 13、面向对象进阶&#xff08;static&继承&#xff09;1.1 如何定义类1.2 如何通过类创建对象1.3 封装1.3.1 封装的步骤1.3.2 封装的步骤实现 1.4 构造方法1.4.1 构造方法的作用1.4.2 构造方法的…...

day12 IP协议与ethernet协议

目录 IP包头 IP网的意义 IP数据报的格式 IP数据报分片 以太网包头&#xff08;链路层协议&#xff09; IP包头 IP网的意义 当互联网上的主机进行通信时&#xff0c;就好像在一个网络上通信一样&#xff0c;看不见互联的各具体的网络异构细节&#xff1b; 如果在这种覆盖…...

蓝牙耳机哪款性价比高?2023蓝牙耳机性价比排行

随着蓝牙耳机的使用愈发频繁&#xff0c;蓝牙耳机产品也越来越多&#xff0c;蓝牙耳机的功能、价格、外观设计等都不尽相同。接下来&#xff0c;我来给大家推荐几款性价比高的蓝牙耳机&#xff0c;感兴趣的朋友一起来看看吧。 一、南卡小音舱Lite2蓝牙耳机 参考价&#xff1a…...

关于C语言的一些笔记

文章目录 May4,2023常量问题基本数据类型补码printf的字符格式控制关于异或、异或的理解赋值运算i和i的区别关系运算符 &#xff2d;ay5,2023逻辑运算中‘非’的理解逗号运算运算符的优先级问题三目运算 摘自加工于C技能树 May4,2023 常量问题 //定义常量 const float PI; PI…...

【Python入门知识】NumPy数组迭代及连接

前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 数组迭代 迭代意味着逐一遍历元素&#xff0c;当我们在 numpy 中处理多维数组时&#xff0c; 可以使用 python 的基本 for 循环来完成此操作。 如果我们对 1-D 数组进行迭代&#xff0c;它将逐一遍历每个元素。 实例 迭…...

我们公司的面试,有点不一样!

我们公司的面试&#xff0c;有点不一样&#xff01; 朋友们周末愉快&#xff0c;我是鱼皮。因为我很屑&#xff0c;所以大家也可以叫我屑老板。 自从我发了自己创业的文章和视频后&#xff0c;收到了很多小伙伴们的祝福&#xff0c;真心非常感谢&#xff01; 不得不说&#…...

C++之初识STL—vector

文章目录 STL基本概念使用STL的好处容器vector1.vector容器简介2.vector对象的默认构造函数3.vector对象的带参构造函数4.vector的赋值5.vector的大小6.vector容器的访问方式7.vector的插入 STL基本概念 STL(Standard Template Library,标准模板库)STL 从广义上分为: 容器(con…...

资讯汇总230503

230503 12:21 【放松身心亲近自然 自驾露营成旅游新风尚】今年“五一”假期&#xff0c;我国旅游业的快速恢复催生自驾露营休闲游、短途游、夜游等新型旅游产品提质升级。快速发展的新兴旅游业态&#xff0c;在促进旅游消费、培育绿色健康生活方式等方面发挥了积极作用&#xf…...

C++之编程规范

目录 谷歌C风格指南&#xff1a;https://zh-google-styleguide.readthedocs.io/en/latest/google-cpp-styleguide/contents/ 编码规则&#xff1a; • 开闭原则&#xff1a;软件对扩展是开放的&#xff0c;对修改是关闭的 • 防御式编程&#xff1a;简单的说就是程序不能崩溃 •…...

ChatGPT做PPT方案,10组提示词方案!

今天我们要搞定的PPT内容是&#xff1a; 活动类型&#xff1a;节日活动、会员活动、新品活动分析类型&#xff1a;用户分析、新品立项、项目汇报内容类型&#xff1a;内容规划、品牌策划 用到的工具&#xff1a; mindshow 邀请码 6509097ChatGPT传送门&#xff08;免费使用…...

分布式夺命12连问

分布式理论 1. 说说CAP原则&#xff1f; CAP原则又称CAP定理&#xff0c;指的是在一个分布式系统中&#xff0c;Consistency&#xff08;一致性&#xff09;、 Availability&#xff08;可用性&#xff09;、Partition tolerance&#xff08;分区容错性&#xff09;这3个基本…...

sourceTree离线环境部署

目录 1、下载sourceTree安装包&#xff0c;打开之后弹出注册界面&#xff08;需要去国外网站注册&#xff09;2、使用技术手段跳过注册步骤3、打开安装包进行安装 注&#xff1a;建议提前安装好git 1、下载sourceTree安装包&#xff0c;打开之后弹出注册界面&#xff08;需要去…...

6.1.1 图:基本概念

一&#xff0c;基本概念 1.基本定义 &#xff08;1&#xff09;图的定义 顶点集不可以是空集&#xff0c;但边集可以是空集。 &#xff08;2&#xff09; 有向图的表示&#xff1a; 圆括号 无向图的表示&#xff1a; 尖括号 简单图、多重图&#xff1a; 简单图&#xff1a;…...

web vue 项目 Docker化部署

Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段&#xff1a; 构建阶段&#xff08;Build Stage&#xff09;&#xff1a…...

[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解

突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 ​安全措施依赖问题​ GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接&#xff0c;私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院挂号小程序

一、开发准备 ​​环境搭建​​&#xff1a; 安装DevEco Studio 3.0或更高版本配置HarmonyOS SDK申请开发者账号 ​​项目创建​​&#xff1a; File > New > Create Project > Application (选择"Empty Ability") 二、核心功能实现 1. 医院科室展示 /…...

在四层代理中还原真实客户端ngx_stream_realip_module

一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡&#xff08;如 HAProxy、AWS NLB、阿里 SLB&#xff09;发起上游连接时&#xff0c;将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后&#xff0c;ngx_stream_realip_module 从中提取原始信息…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​&#xff1a; 下载安装DevEco Studio 4.0&#xff08;支持HarmonyOS 5&#xff09;配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​&#xff1a; ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案

一、TRS收益互换的本质与业务逻辑 &#xff08;一&#xff09;概念解析 TRS&#xff08;Total Return Swap&#xff09;收益互换是一种金融衍生工具&#xff0c;指交易双方约定在未来一定期限内&#xff0c;基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

SpringCloudGateway 自定义局部过滤器

场景&#xff1a; 将所有请求转化为同一路径请求&#xff08;方便穿网配置&#xff09;在请求头内标识原来路径&#xff0c;然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...

自然语言处理——循环神经网络

自然语言处理——循环神经网络 循环神经网络应用到基于机器学习的自然语言处理任务序列到类别同步的序列到序列模式异步的序列到序列模式 参数学习和长程依赖问题基于门控的循环神经网络门控循环单元&#xff08;GRU&#xff09;长短期记忆神经网络&#xff08;LSTM&#xff09…...