当前位置: 首页 > news >正文

《昇思25天学习打卡营第6天|ResNet50图像分类》

写在前面

从本次开始,接触一些上层应用。
本次通过经典的模型,开始本次任务。这里开始学习resnet50网络模型,应该也会有resnet18,估计18的模型速度会更快一些。

resnet

通过对论文的结论进行展示,说明了模型的功能,解决了卷积网络层数加大后模型的退化问题。20层和56层相比,层数越大,模型效果越差,因此resnet主要解决这种问题。hekaiming是真的强呀。

基本流程

  1. 整理模型数据
  2. 构建模型网络核心逻辑(ResidualBlockBase/ResidualBlock)
  3. 创建模型一层

构建网络的代码

from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out

创建模型一层

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)

创建模型

搭建一个4层的网络。

from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return x

接下来,连接数据和模型网络,开始构建容易使用的网络。在这里设置了,模型残差的方法和每个block。

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrained_ckpt, replace=True)param_dict = load_checkpoint(pretrained_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"""ResNet50模型"""resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

模型训练和评估

并没有完全训练,使用了预训练的方法,下载了预训练的模型。

# 定义ResNet50网络
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=10)
# 重置全连接层
network.fc = fc

有了模型网络,接下来需要进行模型训练。训练的过程要设置学习率、优化器和损失函数。

# 设置学习率
num_epochs = 1
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss

之后进行多个epoch的迭代,实现模型训练的目标。

import mindspore.ops as opsdef train(data_loader, epoch):"""模型训练"""losses = []network.set_train(True)for i, (images, labels) in enumerate(data_loader):loss = train_step(images, labels)if i % 100 == 0 or i == step_size_train - 1:print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %(epoch + 1, num_epochs, i + 1, step_size_train, loss))losses.append(loss)return sum(losses) / len(losses)def evaluate(data_loader):"""模型验证"""network.set_train(False)correct_num = 0.0  # 预测正确个数total_num = 0.0  # 预测总数for images, labels in data_loader:logits = network(images)pred = logits.argmax(axis=1)  # 预测结果correct = ops.equal(pred, labels).reshape((-1, ))correct_num += correct.sum().asnumpy()total_num += correct.shape[0]acc = correct_num / total_num  # 准确率return acc# 开始循环训练
print("Start Training Loop ...")for epoch in range(num_epochs):curr_loss = train(data_loader_train, epoch)curr_acc = evaluate(data_loader_val)print("-" * 50)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, curr_loss, curr_acc))print("-" * 50)# 保存当前预测准确率最高的模型if curr_acc > best_acc:best_acc = curr_accms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

进行多轮训练之后,达到训练的目的,模型开始进行收敛,并且能够获取到最终的结果。

最后进行评估,这个并不复杂。

打开

请添加图片描述

相关文章:

《昇思25天学习打卡营第6天|ResNet50图像分类》

写在前面 从本次开始,接触一些上层应用。 本次通过经典的模型,开始本次任务。这里开始学习resnet50网络模型,应该也会有resnet18,估计18的模型速度会更快一些。 resnet 通过对论文的结论进行展示,说明了模型的功能&…...

Activiti 6 兼容openGauss数据库bytes类型不匹配

当前有个项目需要做国产调研,需要适配高斯数据库,项目启动的时候,提示column "bytes_" is type bytea but expression is of type blob byte_字段是act_ge_bytearray表的,openGauss里的类型是bytea,类型是匹…...

缓存技术:提升性能与效率的利器

在当今数字化时代,软件应用的性能与响应速度成为了衡量其成功与否的重要标准之一。随着数据量的爆炸性增长和用户需求的日益多样化,如何高效地处理这些数据并快速响应用户请求成为了软件开发中亟待解决的问题。缓存技术,作为提升系统性能、优…...

LeetCode 637, 67, 399

文章目录 637. 二叉树的层平均值题目链接标签思路代码 67. 二进制求和题目链接标签思路代码 399. 除法求值题目链接标签思路导入value 属性find() 方法union() 方法query() 方法 代码 637. 二叉树的层平均值 题目链接 637. 二叉树的层平均值 标签 树 深度优先搜索 广度优先…...

如何压缩视频大小不改变画质?这5个视频压缩免费软件超好用!

如何压缩视频大小不改变画质?随着生活的水平逐步提高,视频流媒体服务越来越受欢迎。提供简短而引人注目的视频来展示您的产品或服务已成为一种出色的营销手段。然而,当您要准备导出最终视频时,可能会面临一个常见问题:…...

深入理解 Java 虚拟机第三版(周志明)

这次社招选的这本作为 JVM 资料查阅,记录一些重点 1. 虚拟机历史 Sun Classic VM :已退休 HotSpot VM:主流虚拟机,热点代码探测技术 Mobile / Embedded VM :移动端、嵌入式使用的虚拟机 2.2 运行时数据区域 程序计…...

算法 定长按组翻转链表

一、题目 已知一个链表的头部head,每k个结点为一组,按组翻转。要求返回翻转后的头部 k是一个正整数,它的值小于等于链表长度。如果节点总数不是k的整数倍,则剩余的结点保留原来的顺序。示例如下: (要求不…...

安装nfs和rpcbind设置linux服务器共享磁盘

1、安装nfs和rpcbind 1.1 检查服务器是否安装nfs和rpcbind,执行下命令,检查服务器是否安装过。 rpm -qa|grep nfs rpm -qa|grep rpcbind 说明服务器以安装了,如果没有就需要自己安装 2、安装nfs和rpcbind 将rpm安装包: libtirpc-…...

物联网在电力行业的应用

作者主页: 知孤云出岫 这里写目录标题 作者主页:物联网在电力行业的应用简介主要应用领域代码案例分析1. 智能电表数据采集和分析2. 设备监控和预测性维护3. 能耗管理和优化4. 电力负载预测5. 分布式能源管理6. 电动汽车充电管理7. 电网安全与故障检测 物联网在电力行业的应用…...

Java 代码规范if嵌套

在Java编程中,过度的if嵌套会使代码难以阅读和维护。为了遵循良好的代码规范,我们应尽量减少嵌套的深度。这通常可以通过重新组织代码或使用其他结构(如switch语句,或者将逻辑封装到单独的方法中)来实现。 以下是一个…...

ASPICE如何确保汽车软件产品质量的稳固基石

ASPICE通过一系列的方法和原则来保障汽车软件产品的质量,以下是其保障产品质量的几个关键方面: 制定明确的质量方针和目标: ASPICE要求组织制定明确的质量方针和目标,这些方针和目标与客户需求和预期相一致。 开发团队需要定义软…...

【深度学习】yolov8-seg分割训练,拼接图的分割复原

文章目录 项目背景造数据训练 项目背景 在日常开发中,经常会遇到一些图片是由多个图片拼接来的,如下图就是三个图片横向拼接来的。是否可以利用yolov8-seg模型来识别出这张图片的三张子图区域呢,这是文本要做的事情。 造数据 假设拼接方式有…...

Python升级打怪—Django入门

目录 一、Django简介 二、安装Django 三、创建Dajngo项目 (一) 创建项目 (二) 项目结构介绍 (三) 运行项目 (四) 结果 一、Django简介 Django是一个高级Python web框架,鼓励快速开发和干净、实用的设计。由经验丰富的开发人员构建,它解决了web开…...

leetcode面试题17.最大子矩阵

sooooooo long没刷题了,汗颜 题目链接:leetcode面试题17 1.题目 给定一个正整数、负整数和 0 组成的 N M 矩阵,编写代码找出元素总和最大的子矩阵。 返回一个数组 [r1, c1, r2, c2],其中 r1, c1 分别代表子矩阵左上角的行号和…...

计算机网络:构建联结的基础

目录 1. 网络拓扑结构 1.1 星型拓扑 1.2 环型拓扑 1.3 总线型拓扑 1.4 网状拓扑 2. 传输介质 2.1 双绞线 2.2 同轴电缆 2.3 光纤 2.4 无线电波 3. 协议栈模型 3.1 OSI模型 3.2 TCP/IP模型 4. 网络设备 4.1 交换机 4.2 路由器 4.3 网关 4.4 防火墙 5. IP地址…...

node和npm安装;electron、 electron-builder安装

1、node和npm安装 参考: https://blog.csdn.net/sw150811426/article/details/137147783 下载: https://nodejs.org/dist/v20.15.1/ 安装: 点击下载msi直接运行安装 安装完直接cmd打开可以,默认安装就已经添加了环境变量&…...

操作系统概念(黑皮书)阅读笔记

操作系统概念(黑皮书)阅读笔记 进程和内存管理部分章节 导论: 操作系统类似于政府,其本身不能实现任何有用功能,而是提供一个方便其他程序执行有用工作的环境 ​ 个人理解:os是government的作用&#xff0…...

matlab gui下的tcp client客户端编程框架

GUI界面 函数外定义全局变量 %全局变量 global TcpClient; %matlab作为tcpip客户端 建立连接 在“连接”按钮的回调函数下添加以下代码: global TcpClient;%全局变量 TcpClient tcpip(‘192.168.1.10’, 7, ‘NetworkRole’,‘client’); %连接到服务器地址和端…...

Matplotlib : Python 的绘图库

Matplotlib 是一个 Python 的绘图库,广泛用于生成各种静态、动态、交互式的图表。它基于 NumPy,一个用于科学计算的 Python 库。Matplotlib 可以用于生成出版质量级别的图表,并且提供了丰富的定制选项,以适应不同用户的需求。以下…...

数据编织 VS 数据仓库 VS 数据湖

目录 1. 什么是数据编织?2. 数据编织的工作原理3. 代码示例4. 数据编织的优势5. 应用场景6. 数据编织 vs 数据仓库6.1 数据存储方式6.2 数据更新和实时性6.3 灵活性和可扩展性6.4 查询性能6.5 数据治理和一致性6.6 适用场景6.7 代码示例比较 7. 数据编织 vs 数据湖7.1 数据存储…...

网络六边形受到攻击

大家读完觉得有帮助记得关注和点赞!!! 抽象 现代智能交通系统 (ITS) 的一个关键要求是能够以安全、可靠和匿名的方式从互联车辆和移动设备收集地理参考数据。Nexagon 协议建立在 IETF 定位器/ID 分离协议 (…...

[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解

突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 ​安全措施依赖问题​ GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

服务器硬防的应用场景都有哪些?

服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...

VTK如何让部分单位不可见

最近遇到一个需求&#xff0c;需要让一个vtkDataSet中的部分单元不可见&#xff0c;查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行&#xff0c;是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示&#xff0c;主要是最后一个参数&#xff0c;透明度…...

什么是Ansible Jinja2

理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具&#xff0c;可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板&#xff0c;允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板&#xff0c;并通…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同&#xff0c;结合所安装的tensorflow的目录结构修改from语句即可。 原语句&#xff1a; from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后&#xff1a; from tensorflow.python.keras.lay…...

在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案

这个问题我看其他博主也写了&#xff0c;要么要会员、要么写的乱七八糟。这里我整理一下&#xff0c;把问题说清楚并且给出代码&#xff0c;拿去用就行&#xff0c;照着葫芦画瓢。 问题 在继承QWebEngineView后&#xff0c;重写mousePressEvent或event函数无法捕获鼠标按下事…...