从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)
🚩🚩🚩Hugging Face 实战系列 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传
从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练
3 数据加载函数
def load_dataset(logger, args):"""加载训练集"""logger.info("loading training dataset")train_path = args.train_pathwith open(train_path, "rb") as f:train_list = pickle.load(f)# test# train_list = train_list[:24]train_dataset = CPMDataset(train_list, args.max_len)return train_dataset
- List item
4 训练函数
def train(model, logger, train_dataset, args):train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,drop_last=True)logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochsoptimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)# 设置warmuplogger.info('start training')train_losses = [] # 记录每个epoch的平均loss# ========== start training ========== #for epoch in range(args.epochs):train_loss = train_epoch(model=model, train_dataloader=train_dataloader,optimizer=optimizer, scheduler=scheduler,logger=logger, epoch=epoch, args=args)train_losses.append(round(train_loss, 4))logger.info("train loss list:{}".format(train_losses))logger.info('training finished')logger.info("train_losses:{}".format(train_losses))
5 迭代训练函数
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0 # 记录下整个epoch的loss的总和epoch_correct_num = 0 # 每个epoch中,预测正确的word的数量epoch_total_num = 0 # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss
从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练
相关文章:
从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)
🚩🚩🚩Hugging Face 实战系列 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在PyCharm中进行 本篇文章配套的代码资源已经上传 从零构建属于自己的GPT系列1:文本数据预处理 从零构建属于自己的GPT系列2:语…...
Python词频统计(数据整理)
请编写程序,对一段英文文本,统计其中所有不同单词的个数,以及词频最大的前10%的单词。 输入格式: 输入给出一段非空文本,最后以符号#结尾。输入保证存在至少10个不同的单词。 输出格式: 在第一行中输出文本中所有不同单词的个数…...
基本面选股的方法
基本面选股是一种投资策略,主要关注公司的财务状况、盈利能力、行业地位等因素,以判断公司的价值并做出投资决策。以下是基本面选股的具体分析方法和重点: 财务状况分析: 利润表分析:关注公司的净利润、毛利率、营业…...
应用密码学期末复习(3)
目录 第三章 现代密码学应用案例 3.1安全电子邮件方案 3.1.1 PGP产生的背景 3.2 PGP提供了一个安全电子邮件解决方案 3.2.1 PGP加密流程 3.2.2 PGP解密流程 3.2.3 PGP整合了对称加密和公钥加密的方案 3.3 PGP数字签名和Hash函数 3.4 公钥分发与认证——去中心化模型 …...
PAD平板签约投屏-高端活动的选择
传统的现场纸质签约仪式除了缺乏仪式感之外还缺少互动性,如果要将签约的过程投放到大屏幕上更是需要额外的硬件设备成本。相比于传统的纸质签约仪式,平板现场电子签约的形式更加的新颖、更富有科技感、更具有仪式感。 平板签约投屏是应用于会议签字仪式的…...
分布式架构demo
1、外层创建pom 版本管理器 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.15</version><relativePath/> <!-- lookup parent from repository…...
TA-Lib学习研究笔记(二)——Overlap Studies上
TA-Lib学习研究笔记(二)——Overlap Studies 1. Overlap Studies 指标 [BBANDS, DEMA, EMA, HT_TRENDLINE, KAMA, MA, MAMA, MAVP, MIDPOINT, MIDPRICE, SAR, SAREXT, SMA, T3, TEMA, TRIMA, WMA]2.数据准备 get_data函数参数(代码&#x…...
牛客java基础考点1 标识符和变量
牛客java基础考点1 标识符和变量 标识符 字母和数字: 标识符由字母、数字、下划线(_)和美元符号($)组成。其中,标识符必须以字母、下划线或美元符号开头。大小写敏感: Java 是大小写敏感的语言…...
Qt将打印信息输出到文件
将打印信息(qDebug、qInfo、qWarning、qCritial等)输出到指定文件来以实现简单的日志功能。 #include "mainwindow.h" #include <QApplication> #include <QLoggingCategory> #include <QMutex> #include <QDateTime>…...
【risc-v】易灵思efinix FPGA sapphire_soc IP配置参数分享
系列文章目录 分享一些fpga内使用riscv软核的经验,共大家参考。后续内容比较多,会做成一个系列。 本系列会覆盖以下FPGA厂商 易灵思 efinix 赛灵思 xilinx 阿尔特拉 Altera 本文内容隶属于【易灵思efinix】系列。 前言 在efinix fpga中使用riscv是一…...
直播的种类及类型
随着网络技术和移动设备的普及,直播已经成为人们娱乐、学习、商业交流等众多领域的重要工具。 直播的种类主要有以下几种: 1.视频直播:这是最常见的直播形式,包括电商直播、婚庆直播、培训直播、家居直播等。 2.图文直播:这种直播形式包括PPT互动直播…...
时间序列数据压缩算法简述
本文简单介绍了时间序列压缩任务的来源,压缩算法的分类,并对常见压缩算法的优缺点进行了简介,爱码士们快来一探究竟呀! 引言 时间序列数据是在许多应用程序和领域中生成的一种基本数据类型,例如金融、医疗保健、交通和…...
智能锁-SI522TORC522方案资料
南京中科微这款SI522目前完全PinTOPin兼容的NXP:RC522、CV520 复旦微:FM17520、FM17522/FM17550 瑞盟:MS520、MS522 国民技术:NZ3801、NZ3802 SI522 是应用于13.56MHz 非接触式通信中高集成度读写卡系列芯片中的一员。是NXP 公司针对&quo…...
redux(4) -RTK简单使用
简单使用 1、下载 npm i reduxjs/toolkit react-redux 2、创建 1、在redux/user.js中创建模块user。从reduxjs/toolkit中引入createSlice创建模块片段,我们需要传入name、初始数据initialState、改state的reducers等。最后需要导出reducer和action。 代码如下&a…...
开源运维监控系统-Nightingale(夜莺)应用实践(未完)
一、前言 某业务系统因OS改造,原先的Zabbix监控系统推倒后未重建,本来计划用外部企业内其他监控系统接入,后又通知需要自建才能对接,考虑之前zabbix的一些不便,本次计划采用一个类Prometheus的监控系统,镜调研后发现Nightingale兼容Prometheus,又有一些其他功能增强,又…...
深入理解GMP模型
1、GMP模型的设计思想 1)、GMP模型 GMP分别代表: G:goroutine,Go协程,是参与调度与执行的最小单位M:machine,系统级线程P:processor,包含了运行goroutine的资源&#…...
数学建模-基于集成学习的共享单车异常检测的研究
基于集成学习的共享单车异常检测的研究 整体求解过程概述(摘要) 近年来,共享单车的快速发展在方便了人们出行的同时,也对城市交通产生了一定的负面影响,其主要原因为单车资源配置的不合理。本文通过建立单车租赁数量的预测模型和异常检测模型…...
C语言-内存分配
内存分配 1. 引入 int nums[10] {0}; //对int len 10; int nums[len] {0}; //错是因为系统的内存分配原则导致的2. 概述 在程序运行时,系统为了 更好的管理进程中的内存,所以有了 内存分配机制。 分配原则: 2.1 静态分配 静态分配原…...
算法工程师-机器学习面试题总结(1)
目录 1-1 损失函数是什么,如何定义合理的损失函数? 1-2 回归模型和分类模型常用损失函数有哪些?各有什么优缺点 1-3 什么是结构误差和经验误差?训练模型的时候如何判断已经达到最优? 1-4 模型的“泛化”能力是指&a…...
【蓝桥杯选拔赛真题73】Scratch烟花特效 少儿编程scratch图形化编程 蓝桥杯创意编程选拔赛真题解析
目录 scratch烟花特效 一、题目要求 编程实现 二、案例分析 1、角色分析...
R语言AI模型部署方案:精准离线运行详解
R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...
Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)
概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...
【Oracle】分区表
个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
群晖NAS如何在虚拟机创建飞牛NAS
套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...
【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)
前言: 双亲委派机制对于面试这块来说非常重要,在实际开发中也是经常遇见需要打破双亲委派的需求,今天我们一起来探索一下什么是双亲委派机制,在此之前我们先介绍一下类的加载器。 目录 编辑 前言: 类加载器 1. …...
《信号与系统》第 6 章 信号与系统的时域和频域特性
目录 6.0 引言 6.1 傅里叶变换的模和相位表示 6.2 线性时不变系统频率响应的模和相位表示 6.2.1 线性与非线性相位 6.2.2 群时延 6.2.3 对数模和相位图 6.3 理想频率选择性滤波器的时域特性 6.4 非理想滤波器的时域和频域特性讨论 6.5 一阶与二阶连续时间系统 6.5.1 …...
游戏开发中常见的战斗数值英文缩写对照表
游戏开发中常见的战斗数值英文缩写对照表 基础属性(Basic Attributes) 缩写英文全称中文释义常见使用场景HPHit Points / Health Points生命值角色生存状态MPMana Points / Magic Points魔法值技能释放资源SPStamina Points体力值动作消耗资源APAction…...
