《昇思25天学习打卡营第6天|ResNet50图像分类》
写在前面
从本次开始,接触一些上层应用。
本次通过经典的模型,开始本次任务。这里开始学习resnet50网络模型,应该也会有resnet18,估计18的模型速度会更快一些。
resnet
通过对论文的结论进行展示,说明了模型的功能,解决了卷积网络层数加大后模型的退化问题。20层和56层相比,层数越大,模型效果越差,因此resnet主要解决这种问题。hekaiming是真的强呀。
基本流程
- 整理模型数据
- 构建模型网络核心逻辑(ResidualBlockBase/ResidualBlock)
- 创建模型一层
构建网络的代码
from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class ResidualBlockBase(nn.Cell):expansion: int = 1 # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x # shortcuts分支out = self.conv1(x) # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out) # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity # 输出为主分支与shortcuts之和out = self.relu(out)return out
创建模型一层
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)
创建模型
搭建一个4层的网络。
from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return x
接下来,连接数据和模型网络,开始构建容易使用的网络。在这里设置了,模型残差的方法和每个block。
def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrained_ckpt, replace=True)param_dict = load_checkpoint(pretrained_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"""ResNet50模型"""resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)
模型训练和评估
并没有完全训练,使用了预训练的方法,下载了预训练的模型。
# 定义ResNet50网络
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=10)
# 重置全连接层
network.fc = fc
有了模型网络,接下来需要进行模型训练。训练的过程要设置学习率、优化器和损失函数。
# 设置学习率
num_epochs = 1
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss
之后进行多个epoch的迭代,实现模型训练的目标。
import mindspore.ops as opsdef train(data_loader, epoch):"""模型训练"""losses = []network.set_train(True)for i, (images, labels) in enumerate(data_loader):loss = train_step(images, labels)if i % 100 == 0 or i == step_size_train - 1:print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %(epoch + 1, num_epochs, i + 1, step_size_train, loss))losses.append(loss)return sum(losses) / len(losses)def evaluate(data_loader):"""模型验证"""network.set_train(False)correct_num = 0.0 # 预测正确个数total_num = 0.0 # 预测总数for images, labels in data_loader:logits = network(images)pred = logits.argmax(axis=1) # 预测结果correct = ops.equal(pred, labels).reshape((-1, ))correct_num += correct.sum().asnumpy()total_num += correct.shape[0]acc = correct_num / total_num # 准确率return acc# 开始循环训练
print("Start Training Loop ...")for epoch in range(num_epochs):curr_loss = train(data_loader_train, epoch)curr_acc = evaluate(data_loader_val)print("-" * 50)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, curr_loss, curr_acc))print("-" * 50)# 保存当前预测准确率最高的模型if curr_acc > best_acc:best_acc = curr_accms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)
进行多轮训练之后,达到训练的目的,模型开始进行收敛,并且能够获取到最终的结果。
最后进行评估,这个并不复杂。
打开

相关文章:
《昇思25天学习打卡营第6天|ResNet50图像分类》
写在前面 从本次开始,接触一些上层应用。 本次通过经典的模型,开始本次任务。这里开始学习resnet50网络模型,应该也会有resnet18,估计18的模型速度会更快一些。 resnet 通过对论文的结论进行展示,说明了模型的功能&…...
Activiti 6 兼容openGauss数据库bytes类型不匹配
当前有个项目需要做国产调研,需要适配高斯数据库,项目启动的时候,提示column "bytes_" is type bytea but expression is of type blob byte_字段是act_ge_bytearray表的,openGauss里的类型是bytea,类型是匹…...
缓存技术:提升性能与效率的利器
在当今数字化时代,软件应用的性能与响应速度成为了衡量其成功与否的重要标准之一。随着数据量的爆炸性增长和用户需求的日益多样化,如何高效地处理这些数据并快速响应用户请求成为了软件开发中亟待解决的问题。缓存技术,作为提升系统性能、优…...
LeetCode 637, 67, 399
文章目录 637. 二叉树的层平均值题目链接标签思路代码 67. 二进制求和题目链接标签思路代码 399. 除法求值题目链接标签思路导入value 属性find() 方法union() 方法query() 方法 代码 637. 二叉树的层平均值 题目链接 637. 二叉树的层平均值 标签 树 深度优先搜索 广度优先…...
如何压缩视频大小不改变画质?这5个视频压缩免费软件超好用!
如何压缩视频大小不改变画质?随着生活的水平逐步提高,视频流媒体服务越来越受欢迎。提供简短而引人注目的视频来展示您的产品或服务已成为一种出色的营销手段。然而,当您要准备导出最终视频时,可能会面临一个常见问题:…...
深入理解 Java 虚拟机第三版(周志明)
这次社招选的这本作为 JVM 资料查阅,记录一些重点 1. 虚拟机历史 Sun Classic VM :已退休 HotSpot VM:主流虚拟机,热点代码探测技术 Mobile / Embedded VM :移动端、嵌入式使用的虚拟机 2.2 运行时数据区域 程序计…...
算法 定长按组翻转链表
一、题目 已知一个链表的头部head,每k个结点为一组,按组翻转。要求返回翻转后的头部 k是一个正整数,它的值小于等于链表长度。如果节点总数不是k的整数倍,则剩余的结点保留原来的顺序。示例如下: (要求不…...
安装nfs和rpcbind设置linux服务器共享磁盘
1、安装nfs和rpcbind 1.1 检查服务器是否安装nfs和rpcbind,执行下命令,检查服务器是否安装过。 rpm -qa|grep nfs rpm -qa|grep rpcbind 说明服务器以安装了,如果没有就需要自己安装 2、安装nfs和rpcbind 将rpm安装包: libtirpc-…...
物联网在电力行业的应用
作者主页: 知孤云出岫 这里写目录标题 作者主页:物联网在电力行业的应用简介主要应用领域代码案例分析1. 智能电表数据采集和分析2. 设备监控和预测性维护3. 能耗管理和优化4. 电力负载预测5. 分布式能源管理6. 电动汽车充电管理7. 电网安全与故障检测 物联网在电力行业的应用…...
Java 代码规范if嵌套
在Java编程中,过度的if嵌套会使代码难以阅读和维护。为了遵循良好的代码规范,我们应尽量减少嵌套的深度。这通常可以通过重新组织代码或使用其他结构(如switch语句,或者将逻辑封装到单独的方法中)来实现。 以下是一个…...
ASPICE如何确保汽车软件产品质量的稳固基石
ASPICE通过一系列的方法和原则来保障汽车软件产品的质量,以下是其保障产品质量的几个关键方面: 制定明确的质量方针和目标: ASPICE要求组织制定明确的质量方针和目标,这些方针和目标与客户需求和预期相一致。 开发团队需要定义软…...
【深度学习】yolov8-seg分割训练,拼接图的分割复原
文章目录 项目背景造数据训练 项目背景 在日常开发中,经常会遇到一些图片是由多个图片拼接来的,如下图就是三个图片横向拼接来的。是否可以利用yolov8-seg模型来识别出这张图片的三张子图区域呢,这是文本要做的事情。 造数据 假设拼接方式有…...
Python升级打怪—Django入门
目录 一、Django简介 二、安装Django 三、创建Dajngo项目 (一) 创建项目 (二) 项目结构介绍 (三) 运行项目 (四) 结果 一、Django简介 Django是一个高级Python web框架,鼓励快速开发和干净、实用的设计。由经验丰富的开发人员构建,它解决了web开…...
leetcode面试题17.最大子矩阵
sooooooo long没刷题了,汗颜 题目链接:leetcode面试题17 1.题目 给定一个正整数、负整数和 0 组成的 N M 矩阵,编写代码找出元素总和最大的子矩阵。 返回一个数组 [r1, c1, r2, c2],其中 r1, c1 分别代表子矩阵左上角的行号和…...
计算机网络:构建联结的基础
目录 1. 网络拓扑结构 1.1 星型拓扑 1.2 环型拓扑 1.3 总线型拓扑 1.4 网状拓扑 2. 传输介质 2.1 双绞线 2.2 同轴电缆 2.3 光纤 2.4 无线电波 3. 协议栈模型 3.1 OSI模型 3.2 TCP/IP模型 4. 网络设备 4.1 交换机 4.2 路由器 4.3 网关 4.4 防火墙 5. IP地址…...
node和npm安装;electron、 electron-builder安装
1、node和npm安装 参考: https://blog.csdn.net/sw150811426/article/details/137147783 下载: https://nodejs.org/dist/v20.15.1/ 安装: 点击下载msi直接运行安装 安装完直接cmd打开可以,默认安装就已经添加了环境变量&…...
操作系统概念(黑皮书)阅读笔记
操作系统概念(黑皮书)阅读笔记 进程和内存管理部分章节 导论: 操作系统类似于政府,其本身不能实现任何有用功能,而是提供一个方便其他程序执行有用工作的环境 个人理解:os是government的作用࿰…...
matlab gui下的tcp client客户端编程框架
GUI界面 函数外定义全局变量 %全局变量 global TcpClient; %matlab作为tcpip客户端 建立连接 在“连接”按钮的回调函数下添加以下代码: global TcpClient;%全局变量 TcpClient tcpip(‘192.168.1.10’, 7, ‘NetworkRole’,‘client’); %连接到服务器地址和端…...
Matplotlib : Python 的绘图库
Matplotlib 是一个 Python 的绘图库,广泛用于生成各种静态、动态、交互式的图表。它基于 NumPy,一个用于科学计算的 Python 库。Matplotlib 可以用于生成出版质量级别的图表,并且提供了丰富的定制选项,以适应不同用户的需求。以下…...
数据编织 VS 数据仓库 VS 数据湖
目录 1. 什么是数据编织?2. 数据编织的工作原理3. 代码示例4. 数据编织的优势5. 应用场景6. 数据编织 vs 数据仓库6.1 数据存储方式6.2 数据更新和实时性6.3 灵活性和可扩展性6.4 查询性能6.5 数据治理和一致性6.6 适用场景6.7 代码示例比较 7. 数据编织 vs 数据湖7.1 数据存储…...
硬件工程师的‘第一板’:从最小系统设计到PCB Layout的STM32实战指南
STM32最小系统设计实战:从原理到PCB的工程化思维 作为一名硬件工程师,第一次独立完成PCB设计时的忐忑至今记忆犹新。那块承载着STM32最小系统的绿色电路板,不仅是我职业生涯的"第一板",更是一次从理论到实践的完整跨越。…...
FPGA新手避坑指南:用Vivado IP核搞定AXI总线,从看懂波形开始
FPGA新手避坑指南:用Vivado IP核搞定AXI总线,从看懂波形开始 第一次在Vivado中看到AXI总线波形时,我盯着屏幕上跳动的信号线完全摸不着头脑。VALID和READY信号像在玩捉迷藏,突发传输的时序如同天书——这大概是每个FPGA初学者都会…...
从HelloWorld到第一个APK:用Android Studio 2022.3.1完整走一遍Android应用发布流程
从HelloWorld到第一个APK:Android Studio 2022.3.1全流程实战指南 当你第一次打开Android Studio,看到那只呆萌的长颈鹿图标时,可能既兴奋又迷茫。兴奋的是终于要开始Android开发之旅了,迷茫的是安装完成后该从哪里入手。本文将带…...
DB-GPT-Hub:基于大模型微调构建专属文本到SQL数据集的实践指南
1. 项目概述:当大模型遇见数据库,一场效率革命正在发生如果你是一名数据工程师、数据分析师,或者任何需要频繁与数据库打交道的开发者,那么你一定对这样的场景不陌生:面对一个陌生的数据库,你需要花大量时间…...
通过curl命令直接测试Taotoken聊天补全接口的简易方法
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 通过curl命令直接测试Taotoken聊天补全接口的简易方法 在开发或调试过程中,有时我们希望在无需引入完整SDK的轻量级环境…...
SystemVerilog中logic数据类型:统一reg与wire的设计实践
1. 项目概述:从“reg”到“logic”的思维跃迁如果你写过Verilog,那么对reg和wire这两个数据类型一定再熟悉不过了。在RTL设计的世界里,我们习惯了用reg来描述寄存器,用wire来描述连线,这几乎成了一种肌肉记忆。但当你开…...
餐饮排烟5大误区,避开少走弯路
做餐饮这些年,见过太多后厨排烟出问题的门店。每家厨房格局、业态不同,排烟遇到的麻烦也五花八门。结合实操经验,整理出餐饮排烟最容易踩的 5 个坑,附上实用解决办法,看完能避开不少问题。一、居民区门店:大…...
阿里云百炼 + OpenClaw 打造超强自动化 AI
前置准备 已安装并可正常打开 OpenClaw Windows 版本 OpenClaw 部署包获取:https://xiake.yun/api/download/package/14?promoCodeIVD643FDE29AOpenClaw 顶部 Gateway 状态显示为在线准备好可正常登录的阿里云账号可正常访问阿里云百炼控制台地址确认账号已开通百…...
Midjourney v7艺术风格跃迁路径:从基础写实到超现实叙事的5阶能力模型,含GPT-4o协同提示链模板
更多请点击: https://intelliparadigm.com 第一章:Midjourney v7艺术风格跃迁路径总览 Midjourney v7 并非简单迭代,而是以扩散模型架构重构与多模态风格理解为内核的范式跃迁。其核心突破在于引入「语义风格锚点(Semantic Style…...
从入门到精通:trtexec命令行工具在TensorRT模型部署中的实战指南
1. trtexec工具基础入门 第一次接触trtexec时,我也被这个命令行工具的参数数量吓到了。但实际用下来发现,它就像瑞士军刀一样,虽然功能多但每个都很实用。trtexec是TensorRT安装包自带的命令行工具,主要用来做三件事:…...
