从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)
🚩🚩🚩Hugging Face 实战系列 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传
从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练
3 数据加载函数
def load_dataset(logger, args):"""加载训练集"""logger.info("loading training dataset")train_path = args.train_pathwith open(train_path, "rb") as f:train_list = pickle.load(f)# test# train_list = train_list[:24]train_dataset = CPMDataset(train_list, args.max_len)return train_dataset
- List item
4 训练函数
def train(model, logger, train_dataset, args):train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,drop_last=True)logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochsoptimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)# 设置warmuplogger.info('start training')train_losses = [] # 记录每个epoch的平均loss# ========== start training ========== #for epoch in range(args.epochs):train_loss = train_epoch(model=model, train_dataloader=train_dataloader,optimizer=optimizer, scheduler=scheduler,logger=logger, epoch=epoch, args=args)train_losses.append(round(train_loss, 4))logger.info("train loss list:{}".format(train_losses))logger.info('training finished')logger.info("train_losses:{}".format(train_losses))
5 迭代训练函数
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0 # 记录下整个epoch的loss的总和epoch_correct_num = 0 # 每个epoch中,预测正确的word的数量epoch_total_num = 0 # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss
从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练
相关文章:
从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)
🚩🚩🚩Hugging Face 实战系列 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在PyCharm中进行 本篇文章配套的代码资源已经上传 从零构建属于自己的GPT系列1:文本数据预处理 从零构建属于自己的GPT系列2:语…...
Python词频统计(数据整理)
请编写程序,对一段英文文本,统计其中所有不同单词的个数,以及词频最大的前10%的单词。 输入格式: 输入给出一段非空文本,最后以符号#结尾。输入保证存在至少10个不同的单词。 输出格式: 在第一行中输出文本中所有不同单词的个数…...
基本面选股的方法
基本面选股是一种投资策略,主要关注公司的财务状况、盈利能力、行业地位等因素,以判断公司的价值并做出投资决策。以下是基本面选股的具体分析方法和重点: 财务状况分析: 利润表分析:关注公司的净利润、毛利率、营业…...

应用密码学期末复习(3)
目录 第三章 现代密码学应用案例 3.1安全电子邮件方案 3.1.1 PGP产生的背景 3.2 PGP提供了一个安全电子邮件解决方案 3.2.1 PGP加密流程 3.2.2 PGP解密流程 3.2.3 PGP整合了对称加密和公钥加密的方案 3.3 PGP数字签名和Hash函数 3.4 公钥分发与认证——去中心化模型 …...

PAD平板签约投屏-高端活动的选择
传统的现场纸质签约仪式除了缺乏仪式感之外还缺少互动性,如果要将签约的过程投放到大屏幕上更是需要额外的硬件设备成本。相比于传统的纸质签约仪式,平板现场电子签约的形式更加的新颖、更富有科技感、更具有仪式感。 平板签约投屏是应用于会议签字仪式的…...

分布式架构demo
1、外层创建pom 版本管理器 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.15</version><relativePath/> <!-- lookup parent from repository…...

TA-Lib学习研究笔记(二)——Overlap Studies上
TA-Lib学习研究笔记(二)——Overlap Studies 1. Overlap Studies 指标 [BBANDS, DEMA, EMA, HT_TRENDLINE, KAMA, MA, MAMA, MAVP, MIDPOINT, MIDPRICE, SAR, SAREXT, SMA, T3, TEMA, TRIMA, WMA]2.数据准备 get_data函数参数(代码&#x…...
牛客java基础考点1 标识符和变量
牛客java基础考点1 标识符和变量 标识符 字母和数字: 标识符由字母、数字、下划线(_)和美元符号($)组成。其中,标识符必须以字母、下划线或美元符号开头。大小写敏感: Java 是大小写敏感的语言…...

Qt将打印信息输出到文件
将打印信息(qDebug、qInfo、qWarning、qCritial等)输出到指定文件来以实现简单的日志功能。 #include "mainwindow.h" #include <QApplication> #include <QLoggingCategory> #include <QMutex> #include <QDateTime>…...

【risc-v】易灵思efinix FPGA sapphire_soc IP配置参数分享
系列文章目录 分享一些fpga内使用riscv软核的经验,共大家参考。后续内容比较多,会做成一个系列。 本系列会覆盖以下FPGA厂商 易灵思 efinix 赛灵思 xilinx 阿尔特拉 Altera 本文内容隶属于【易灵思efinix】系列。 前言 在efinix fpga中使用riscv是一…...

直播的种类及类型
随着网络技术和移动设备的普及,直播已经成为人们娱乐、学习、商业交流等众多领域的重要工具。 直播的种类主要有以下几种: 1.视频直播:这是最常见的直播形式,包括电商直播、婚庆直播、培训直播、家居直播等。 2.图文直播:这种直播形式包括PPT互动直播…...

时间序列数据压缩算法简述
本文简单介绍了时间序列压缩任务的来源,压缩算法的分类,并对常见压缩算法的优缺点进行了简介,爱码士们快来一探究竟呀! 引言 时间序列数据是在许多应用程序和领域中生成的一种基本数据类型,例如金融、医疗保健、交通和…...

智能锁-SI522TORC522方案资料
南京中科微这款SI522目前完全PinTOPin兼容的NXP:RC522、CV520 复旦微:FM17520、FM17522/FM17550 瑞盟:MS520、MS522 国民技术:NZ3801、NZ3802 SI522 是应用于13.56MHz 非接触式通信中高集成度读写卡系列芯片中的一员。是NXP 公司针对&quo…...
redux(4) -RTK简单使用
简单使用 1、下载 npm i reduxjs/toolkit react-redux 2、创建 1、在redux/user.js中创建模块user。从reduxjs/toolkit中引入createSlice创建模块片段,我们需要传入name、初始数据initialState、改state的reducers等。最后需要导出reducer和action。 代码如下&a…...

开源运维监控系统-Nightingale(夜莺)应用实践(未完)
一、前言 某业务系统因OS改造,原先的Zabbix监控系统推倒后未重建,本来计划用外部企业内其他监控系统接入,后又通知需要自建才能对接,考虑之前zabbix的一些不便,本次计划采用一个类Prometheus的监控系统,镜调研后发现Nightingale兼容Prometheus,又有一些其他功能增强,又…...

深入理解GMP模型
1、GMP模型的设计思想 1)、GMP模型 GMP分别代表: G:goroutine,Go协程,是参与调度与执行的最小单位M:machine,系统级线程P:processor,包含了运行goroutine的资源&#…...

数学建模-基于集成学习的共享单车异常检测的研究
基于集成学习的共享单车异常检测的研究 整体求解过程概述(摘要) 近年来,共享单车的快速发展在方便了人们出行的同时,也对城市交通产生了一定的负面影响,其主要原因为单车资源配置的不合理。本文通过建立单车租赁数量的预测模型和异常检测模型…...

C语言-内存分配
内存分配 1. 引入 int nums[10] {0}; //对int len 10; int nums[len] {0}; //错是因为系统的内存分配原则导致的2. 概述 在程序运行时,系统为了 更好的管理进程中的内存,所以有了 内存分配机制。 分配原则: 2.1 静态分配 静态分配原…...
算法工程师-机器学习面试题总结(1)
目录 1-1 损失函数是什么,如何定义合理的损失函数? 1-2 回归模型和分类模型常用损失函数有哪些?各有什么优缺点 1-3 什么是结构误差和经验误差?训练模型的时候如何判断已经达到最优? 1-4 模型的“泛化”能力是指&a…...

【蓝桥杯选拔赛真题73】Scratch烟花特效 少儿编程scratch图形化编程 蓝桥杯创意编程选拔赛真题解析
目录 scratch烟花特效 一、题目要求 编程实现 二、案例分析 1、角色分析...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)
题目:3442. 奇偶频次间的最大差值 I 思路 :哈希,时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况,哈希表这里用数组即可实现。 C版本: class Solution { public:int maxDifference(string s) {int a[26]…...

国防科技大学计算机基础课程笔记02信息编码
1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...
Android Wi-Fi 连接失败日志分析
1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分: 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析: CTR…...
在Ubuntu中设置开机自动运行(sudo)指令的指南
在Ubuntu系统中,有时需要在系统启动时自动执行某些命令,特别是需要 sudo权限的指令。为了实现这一功能,可以使用多种方法,包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法,并提供…...
Linux云原生安全:零信任架构与机密计算
Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...

Linux --进程控制
本文从以下五个方面来初步认识进程控制: 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程,创建出来的进程就是子进程,原来的进程为父进程。…...

R语言速释制剂QBD解决方案之三
本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...
动态 Web 开发技术入门篇
一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...
Webpack性能优化:构建速度与体积优化策略
一、构建速度优化 1、升级Webpack和Node.js 优化效果:Webpack 4比Webpack 3构建时间降低60%-98%。原因: V8引擎优化(for of替代forEach、Map/Set替代Object)。默认使用更快的md4哈希算法。AST直接从Loa…...