MindSpore框架学习项目-ResNet药物分类-模型优化
目录
5.模型优化
5.1模型优化
6.结语
本项目可以在华为云modelart上租一个实例进行,也可以在配置至少为单卡3060的设备上进行
https://console.huaweicloud.com/modelarts/
Ascend环境也适用,但是注意修改device_target参数
需要本地编译器的一些代码传输、修改等可以勾上ssh远程开发
说明:项目使用的数据集来自华为云的数据资源。项目以深度学习任务构建的一般流程展开(数据导入、处理 > 模型选择、构建 > 模型训练 > 模型评估 > 模型优化)。
主线为‘一般流程’,同时代码中会标注出一些要点(# 要点1-1-1:设置使用的设备
)作为支线,帮助学习mindspore框架在进行深度学习任务时一些与pytorch的差异。
可以只看目录中带数字标签的部分来快速查阅代码。
本系列
MindSpore框架学习项目-ResNet药物分类-数据增强-CSDN博客
MindSpore框架学习项目-ResNet药物分类-构建模型-CSDN博客
MindSpore框架学习项目-ResNet药物分类-模型训练-CSDN博客
MindSpore框架学习项目-ResNet药物分类-模型评估-CSDN博客
MindSpore框架学习项目-ResNet药物分类-模型优化-CSDN博客
5.模型优化
5.1模型优化
要求:
通过调整超参数,使得模型在测试集上评价指标acc高出超参调整之前(要点4-1-3输出结果)的5%及以上
此环节一般为深度学习任务在构建模型、探索可行性的最后阶段,用于尽可能地发掘模型适配任务的潜能,为落地部署做准备。需要往上复盘并结合从‘模型构建’、‘模型训练’到‘模型推理’等环节的代码过程,进行参数的调优(优先从超参数入手)。
# 超参数
num_epochs = 10 # up
patience = 5
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,
step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 要点3-1-1:定义优化器为Momentum优化器, 动量因子设置为0.9
# opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
opt = nn.Adam(params=network.trainable_params(),learning_rate=lr)
# 要点3-1-2:定义损失函数为SoftmaxCrossEntropyWithLogits损失函数,sparse=True, reduction='mean'
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = ms.Model(network, loss_fn, opt, metrics={'acc'})best_acc = 0
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"# train
def train_loop(model, dataset, loss_fn, optimizer):# 要点3-1-3:模型编译:利用函数式编程实现loss的计算,并返回loss和模型预测值logitsdef forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits,label)return loss, logits# 要点3-1-4:利用value_and_grad API定义反向传播函数
grad_fn = ms.ops.value_and_grad(forward_fn, None, opt.parameters, has_aux=True)def train_step(data, label):(loss, _), grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))return loss
size = dataset.get_dataset_size()
model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(data, label)if batch % 100 == 0 or batch == step_size_train - 1:
loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")# test
def test_loop(model, dataset, loss_fn):
num_batches = dataset.get_dataset_size()# 要点3-1-5:设置模型为预测模式
model.set_train(False)
total, test_loss, correct = 0, 0, 0
y_true = []
y_pred = []for data, label in dataset.create_tuple_iterator():
y_true.extend(label.asnumpy().tolist())
pred = model(data)
total += len(data)
test_loss += loss_fn(pred, label).asnumpy()
y_pred.extend(pred.argmax(1).asnumpy().tolist())
correct += (pred.argmax(1) == label).asnumpy().sum()
test_loss /= num_batches
correct /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")print(classification_report(y_true,y_pred,target_names= list(index_label_dict.values()),digits=3))return correct,test_loss# 重新训练
no_improvement_count = 0
acc_list = []
loss_list = []
stop_epoch = num_epochs
for t in range(num_epochs):print(f"Epoch {t+1}\n-------------------------------")
train_loop(network, dataset_train, loss_fn, opt)
acc,loss = test_loop(network, dataset_val, loss_fn)
acc_list.append(acc)
loss_list.append(loss)# 要点3-2-1:设置条件:利用计算的acc指标,得到训练中得到的最优模型权重if best_acc < acc:
best_acc = accif not os.path.exists(best_ckpt_dir):
os.mkdir(best_ckpt_dir)# 要点3-2-2:利用save_checkpoint API对模型进行保存, 保存的路径为best_ckpt_path
ms.save_checkpoint(network,best_ckpt_path)
no_improvement_count = 0else:
no_improvement_count += 1if no_improvement_count > patience:print('Early stopping triggered. Restoring best weights...')
stop_epoch = tbreak
print("Done!")
说明
对于模型调优,先从超参数入手,比如epoch、batch_size等,可以初步判断数据集的质量;再一定程度上acc有所提升后,如果遇到性能瓶颈(通过超参数已经不能让模型精度进一步提高,同时还达不到预期,那就考虑参数--网络结构、激活函数、损失函数等)
这里将epoch从3->10,新一轮训练后的第十轮结果:
模型在性能上得到一定提升
复用前面的推理代码
# 重新加载模型 ‘BestCheckpoint/resnet50-best.ckpt’
num_class = 12 #
# 题目4-1-1:实例化resnet50 预测模型
net = resnet50(num_classes=num_class)
best_ckpt_path = 'BestCheckpoint/resnet50-best.ckpt'
# 题目4-1-2:加载模型参数
# 将最优的一次检查点信息(模型-网络权重参数)加载到参数字典
param_dict = ms.load_checkpoint(best_ckpt_path)
# 将网络权重载入网络结构--模型网络结构里
ms.load_param_into_net(net,param_dict)
model = ms.Model(net)
image_size = 224
workers = 1
# acc
test_acc, _ = test_loop(net, dataset_test, loss_fn)
print(f'Test Accuracy:{test_acc*100:.2f}%')
本次:
较上次:
精度提升>5%
6.结语
通过这个用ResNet50进行对中药材的种类及品阶进行12分类的项目,学习mindspore AI框架的使用和深度学习任务的一般流程,熟悉如何通过深度学习的方式来拟合数据,处理生产生活中的问题,为AI赋能的时代贡献点滴实践。
相关文章:

MindSpore框架学习项目-ResNet药物分类-模型优化
目录 5.模型优化 5.1模型优化 6.结语 参考内容: 昇思MindSpore | 全场景AI框架 | 昇思MindSpore社区官网 华为自研的国产AI框架,训推一体,支持动态图、静态图,全场景适用,有着不错的生态 本项目可以在华为云modelar…...
基于阿里云DataWorks的物流履约时效离线分析
基于阿里云DataWorks的物流履约时效离线分析2. 数仓模型构建 ORC和Parquet区别: 压缩率与查询性能 压缩率 ORC通常压缩率更高,文件体积更小,适合存储成本敏感的场景。 Parquet因支持更灵活的嵌套结构,压缩率略…...

Kubernetes(k8s)学习笔记(八)--KubeSphere定制化安装
1执行下面的命令修改上一篇中yaml文件来实现定制化安装devops kubectl edit cm -n kubesphere-system ks-installer 主要是将devops几个配置由False改为True 然后使用下面的命令查看安装日志 kubectl logs -n kubesphere-system $(kubectl get pod -n kubesphere-system -l …...

养生:为健康生活筑牢根基
养生并非遥不可及的目标,而是贯穿于日常生活的点滴之中。从饮食、运动到心态调节,每一个环节都对我们的健康有着重要意义。以下为你详细介绍养生的实用策略,助力你开启健康生活模式。 饮食养生:科学搭配,滋养生命 合…...

Linux510 ssh服务 ssh连接
arning: Permanently added ‘11.1.1.100’ (ECDSA) to the list of known hosts. rooot11.1.1.100’s password: Permission denied, please try again. rooot11.1.1.100’s password: Permission denied, please try again 还没生效 登不上了 失效了 sshcaozx26成功登录 …...

关键点检测--使用YOLOv8对Leeds Sports Pose(LSP)关键点检测
目录 1. Leeds Sports Pose数据集下载2. 数据集处理2.1 获取标签2.2 将图像文件和标签文件处理成YOLO能使用的格式 3. 用YOLOv8进行训练3.1 训练3.2 预测 1. Leeds Sports Pose数据集下载 从kaggle官网下载这个数据集,地址为link,下载好的数据集文件如下…...
Elasticsearch内存管理与JVM优化:原理剖析与最佳实践
#作者:孙德新 文章目录 一、Elasticsearch缓存分类1、Node Query Cache:2、Shard Request Cache:3、Fielddata Cache: 三、内存常见的问题案例一案例二案例三案例四 四、内参分配最佳实践1、jvm heap分配2、将机器上少于一半的内…...

独立按键控制LED
目录 1.独立按键介绍 2.原理图 3.C51数据运输 解释:<< >> 编辑 解释:& | 解释:^ ~ 编辑 4.C51基本语句 5.按键的跳动 6.独立按键控制LED亮灭代码 第一步: 第二步: 第三步࿱…...

计算机科技笔记: 容错计算机设计03 系统可信性的度量 偶发故障期 浴盆曲线 韦布尔分布
可靠性 简化表达式 偶发故障期,系统发生故障概率趋近于一个常数 浴盆曲线 MTTF和计算 韦布尔分布 马尔可夫链 可靠度...

爬虫准备前工作
1.Pycham的下载 网址:PyCharm: The only Python IDE you need 2.Python的下载 网址:python.org(python3.9版本之后都可以) 3.node.js的下载 网址:Node.js — 在任何地方运行 JavaScript(版本使用18就可…...

【PostgreSQL数据分析实战:从数据清洗到可视化全流程】7.1 主流可视化工具对比(Tableau/Matplotlib/Python库)
👉 点击关注不迷路 👉 点击关注不迷路 👉 点击关注不迷路 文章大纲 第七章 可视化工具集成:Tableau、Matplotlib与Python库深度对比7.1 主流可视化工具对比:技术选型的决策框架7.1.1 工具定位与核心能力矩阵7.1.2 数据…...

操作系统实验习题解析 上篇
孤村落日残霞,轻烟老树寒鸦,一点飞鸿影下。 青山绿水,白草红叶黄花。 ————《天净沙秋》 白朴 【元】 目录 实验一: 代码: 解析: 运行结果: 实验二: 代码解析 1. 类设计 …...
复习javascript
1.修改元素内的内容 <div>zsgh</div> <script> const box1document.querySelector("div") box1.innerText"ppp" box1.innerHtml<h1>修改</h1> </script> 2.随机点名练习 <!DOCTYPE html> <html lang…...

基于Arduino Nano的DIY示波器
基于Arduino Nano的DIY示波器:打造属于你的口袋实验室 前言 在电子爱好者的世界里,示波器是不可或缺的工具之一。它能够帮助我们观察和分析各种电子信号的波形,从而更好地理解和调试电路。然而,市面上的示波器价格往往较高&…...

渠道销售简历模板范文
模板信息 简历范文名称:渠道销售简历模板范文,所属行业:其他 | 职位,模板编号:KRZ3J3 专业的个人简历模板,逻辑清晰,排版简洁美观,让你的个人简历显得更专业,找到好工作…...

JAVA练习题(1) 卖飞机票
import java.util.Scanner; public class Main {public static void main(String[] args) {Scanner scnew Scanner(System.in);System.out.println("请输入飞机的票价:");int pricesc.nextInt();System.out.println("请输入月份:");…...

杆件的拉伸与压缩变形
杆件的拉伸与压缩 第一题 Q u e s t i o n \mathcal{Question} Question 图示拉杆沿斜截面 m − m m-m m−m由两部分胶合而成。设在胶合面上许用拉应力 [ σ ] 100 MPa [\sigma]100\text{MPa} [σ]100MPa,许用切应力 [ τ ] 50 MPa [\tau]50\text{MPa} [τ]50MP…...
深入解析Vue3中ref与reactive的区别及源码实现
深入解析Vue3中ref与reactive的区别及源码实现 前言 Vue3带来了全新的响应式系统,其中ref和reactive是最常用的两个API。本文将从基础使用、核心区别到源码实现,由浅入深地分析这两个API。 一、基础使用 1. ref import { ref } from vueconst count…...
Makefile中 链接库,同一个库的静态库与动态库都链接了,生效的是哪个库
Makefile中 链接库,同一个库的静态库与动态库都链接了,生效的是哪个库 在 Makefile 中同时链接同一个库的静态库(.a)和动态库(.so)时,具体哪个库生效取决于链接顺序和编译器行为。以下是详细分析…...

企业开发平台大变革:AI 代理 + 平台工程重构数字化转型路径
在企业数字化转型的浪潮中,开发平台正经历着前所未有的技术革命。从 AST(抽象语法树)到 AI 驱动的智能开发,从微服务架构到信创适配,这场变革不仅重塑了软件开发的底层逻辑,更催生了全新的生产力范式。本文…...

《汽车噪声控制》复习重点
题型 选择 填空 分析 计算 第一章 噪声定义 不需要的声音,妨碍正常工作、学习、生活,危害身体健康的声音,统称为噪声 噪声污染 与大气污染、水污染并称现代社会三大公害 声波基本概念 定义 媒质质点的机械振动由近及远传播&am…...

Linux——MySQL约束与查询
表的约束 真正约束字段的是数据类型,但是数据类型约束很单一,需要有一些额外的约束,更好的保证数据的合 法性,从业务逻辑角度保证数据的正确性。比如有一个字段是email,要求是唯一的。 表的约束是为了防止插入不合法的…...

Asp.Net Core IIS发布后PUT、DELETE请求错误405
一、方案1 1、IIS管理器,处理程序映射。 2、找到aspNetCore,双击。点击请求限制...按钮,并在谓词选项卡上,添加两者DELETE和PUT. 二、方案2 打开web.config文件,添加<remove name"WebDAVModule" />&…...

STL-to-ASCII-Generator 实用教程
参阅:STL-to-ASCII-Generator 使用教程 开源项目网址 下载 STL-to-ASCII-Generator-main.zip 解压到 D:\js\ index.html 如下 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta id"ascii&q…...
关于数据库查询速度优化
本人接手了一个关于项目没有任何文档信息的代码,代码也没有相关文档说明信息!所以在做数据库查询优化的时候不敢改动。 原因1: 老板需要我做一个首页的统计查询。明明才几十万条数据,而且我加了筛选条件为什么会这么慢ÿ…...
sql serve 多表联合查询,根据一个表字段值动态改变查询条件
在SQL Server中进行多表联合查询时,如果需要根据一个表的字段值动态改变查询条件,可以采用几种不同的方法来实现这一需求。这里介绍两种常用的方法:CASE表达式和动态SQL。 方法1: 使用 CASE 表达式 这种方法适合于查询条件可以在单个SQL语句…...

巡检机器人数据处理技术的创新与实践
摘要 随着科技的飞速发展,巡检机器人在各行业中逐渐取代人工巡检,展现出高效、精准、安全等显著优势。当前,巡检机器人已从单纯的数据采集阶段迈向对采集数据进行深度分析的新阶段。本文探讨了巡检机器人替代人工巡检的现状及优势,…...

国产linux系统(银河麒麟,统信uos)使用 PageOffice 在线打开Word文件,并用前端对话框实现填空填表
不管是政府机关、公司企业,还是金融行业、教育行业等单位,在办公过程中都经常需要填写各种文书和表格,比如通知、报告、登记表、计划表、申请表等。这些文书和表格往往是用Word文件制作的模板,比方说一个通知模板中经常会有“关于…...
Kubernetes应用发布方式完整流程指南
Kubernetes(K8s)作为容器编排领域的核心工具,其应用发布流程体现了自动化、弹性和可观测性的优势。本文将通过一个Tomcat应用的示例,详细讲解从配置编写到高级发布的完整流程,帮助开发者掌握Kubernetes应用部署的核心步…...
视频编解码学习8之视频历史
视频技术的发展历史可以追溯到19世纪,至今已跨越近200年。以下是视频技术发展的主要阶段和里程碑: 1. 早期探索阶段(19世纪-1920年代) 1832年:约瑟夫普拉托(Joseph Plateau)发明"费纳奇镜&…...