【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别
【Pytorch】学习记录分享5——PyTorch经典网络 ResNet
- 1. ResNet (残差网络)基础知识
- 2. 感受野
- 3. 手写体数字识别
- 3. 0 数据集(训练与测试集)
- 3. 1 数据加载
- 3. 2 函数实现:
- 3. 3 训练及其测试:
1. ResNet (残差网络)基础知识
图1 56层error比20层error高,提出ResNet (残差网络)的方案
网络效果:
网络结构:
2. 感受野
3. 手写体数字识别
3. 0 数据集(训练与测试集)
mnist 用于手写体训练与测试,这里包含完整的链接
3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数
input_size = 28 #图像的总尺寸28*28
num_classes = 10 #标签的种类数
num_epochs = 3 #训练的总循环周期
batch_size = 64 #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential( # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1, # 灰度图out_channels=16, # 要得到几多少个特征图kernel_size=5, # 卷积核大小stride=1, # 步长padding=2, # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1), # 输出的特征图为 (16, 28, 28)nn.ReLU(), # relu层nn.MaxPool2d(kernel_size=2), # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential( # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2), # 输出 (32, 14, 14)nn.ReLU(), # relu层nn.MaxPool2d(2), # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10) # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1) # flatten操作,结果为:(batch_size, 32 * 7 * 7) output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels)
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN()
#损失函数
criterion = nn.CrossEntropyLoss()
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader): #针对容器中的每一个批进行循环net.train() # 将模型设置为训练模式output = net(data) # 使用模型进行前向传播loss = criterion(output, target) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward() # 反向传播计算梯度optimizer.step() # 更新参数right = accuracy(output, target) # 计算当前批次的准确率train_rights.append(right) # 将准确率保存起来if batch_idx % 500 == 0: # 每500个批次进行一次验证net.eval() # 将模型设置为评估模式val_rights = [] # 存储验证集的准确率for (data, target) in test_loader: # 在测试集上进行验证output = net(data) # 使用模型进行前向传播right = accuracy(output, target) # 计算验证集上的准确率val_rights.append(right) # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights])) # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights])) # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1])) # 打印当前进度和准确率信息
相关文章:

【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别
【Pytorch】学习记录分享5——PyTorch经典网络 ResNet 1. ResNet (残差网络)基础知识2. 感受野3. 手写体数字识别3. 0 数据集(训练与测试集)3. 1 数据加载3. 2 函数实现:3. 3 训练及其测试: 1. ResNet &…...

Flink1.17实战教程(第三篇:时间和窗口)
系列文章目录 Flink1.17实战教程(第一篇:概念、部署、架构) Flink1.17实战教程(第二篇:DataStream API) Flink1.17实战教程(第三篇:时间和窗口) Flink1.17实战教程&…...

CSS 纵向扩展动画
上干货 <template><!-- mouseenter"startAnimation" 表示在鼠标进入元素时触发 startAnimation 方法。mouseleave"stopAnimation" 表示在鼠标离开元素时触发 stopAnimation 方法。 --><!-- 容器元素 --><div class"container&q…...
Android 12 Token 机制
一、前言 在 android framework 框架中 activity 和 window 是相互关联的,而他们的管理者 AMS 和 WMS 是怎么来实现这种关联关系的,答案就是通过 token。 首先大家需要了解一下 LayoutParams,当然属性很多,简单了解即可…...
TCP与UDP是流式传输协议吗?
TCP(传输控制协议)和UDP(用户数据报协议)是两种主要的传输层协议,它们用于在网络中传输数据。它们不是流式传输协议,而是提供了不同的数据传输特性: 1. TCP(传输控制协议࿰…...
61 贪心算法解救生艇问题
问题描述:第i个人的体重为peaple[i],每个船可以承载的最大重量为limit。每艘船最多可以同时载两人,但条件是这些人的重量之和最多为limit,返回载到每一个人多虚的最小船数,(保证每个人被船载)。 贪心算法求解:先将数组…...

C#高级 01.Net多线程
一.基本概念 1.什么是线程? 线程是操作系统中能独立运行的最小单位,也是程序中能并发执行的一段指令序列线程是进程的一部分,一个进程可以包含多个线程,这些线程共享进程资源进程有线程入口,也可以创建更多的线程 2.…...

Java---泛型讲解
文章目录 1. 泛型类2. 泛型方法3. 泛型接口4. 类型通配符5. 可变参数6. 可变参数的使用 1. 泛型类 1. 格式:修饰符 class 类名 <类型>{ }。例如:public class Generic <T>{ }。 2. 代码块举例: public class Generic <T>{…...

【论文阅读笔记】SegVol: Universal and Interactive Volumetric Medical Image Segmentation
Du Y, Bai F, Huang T, et al. SegVol: Universal and Interactive Volumetric Medical Image Segmentation[J]. arXiv preprint arXiv:2311.13385, 2023.[代码开源] 【论文概述】 本文思路借鉴于自然图像分割领域的SAM,介绍了一种名为SegVol的先进医学图像分割模型…...
Unix/Linux操作系统介绍
1、Unix/Linux操作系统介绍 1.1、操作系统的作用 1)操作系统的目标 方便:使计算机系统易于使用有效:以更有效的方式使用计算机系统资源扩展:方便用户有效开发、测试、引进新功能 2)操作系统的地位 操作系统在计算…...

什么是https证书?
HTTPS证书,也称为SSL(Secure Sockets Layer)证书或TLS(Transport Layer Security)证书,是一种数字证书,用于在网络上建立安全的加密连接。它的主要目的是确保在互联网上进行的数据传输的安全性和…...

C++ DAY2作业
1.课堂struct练习,用class; #include <iostream>using namespace std;class Stu { private:int age;char sex;int high; public:double score;void set_values(int a,char b,int c,double d);int get_age();char get_sex();int get_high(); }; vo…...

RabbitMQ核心概念记录
本文来记录下RabbitMQ核心概念 文章目录 什么叫消息队列为何用消息队列RabbitMQ简介RabbitMQ基本概念RabbitMQ 特点具体特点包括 Rabbitmq的工作过程RabbitMQ集群RabbitMQ 的集群节点包括Rabbit 模式大概分为以下三种单一模式普通模式镜像模式 本文小结 什么叫消息队列 消息&am…...

算法时间空间复杂度计算—空间复杂度
算法时间空间复杂度计算—空间复杂度 空间复杂度定义影响空间复杂度的因素算法在运行过程中临时占用的存储空间讲解 计算方法例子1、空间算法的常数阶2、空间算法的线性阶(递归算法)3、二分查找分析方法一(迭代法)方法二ÿ…...
计算机专业校招常见面试题目总结
博主面试岗位包括:java开发、软件测试、测试开发等岗位,基于之前经历的面试总结出的一些常见题目。仅供参考,互相学习!! 八股:java开发、测试、测开岗位 Java技术栈:Java基础、JVM、数据结构、…...

网络编程『简易TCP网络程序』
🔭个人主页: 北 海 🛜所属专栏: Linux学习之旅、神奇的网络世界 💻操作环境: CentOS 7.6 阿里云远程服务器 文章目录 🌤️前言🌦️正文TCP网络程序1.字符串回响1.1.核心功能1.2.程序…...

java itext5 生成PDF并填充数据导出
java itext5 生成PDF并填充数据导出 依赖**文本勾选框****页眉**,**页脚****图片**实际图 主要功能有文本勾选框,页眉,页脚,图片等功能。肯定没有专业软件画的好看,只是一点儿方法。仅供参考。 依赖 <!--pdf-->&…...

如何配置TLSv1.2版本的ssl
1、tomcat配置TLSv1.2版本的ssl 如下图所示,打开tomcat\conf\server.xml文件,进行如下配置: 注意:需要将申请的tomcat版本的ssl认证文件,如server.jks存放到tomcat\conf\ssl_file\目录下。 <Connector port"1…...
在CentOS 7上使用普通用户`minio`安装和配置MinIO
指定控制台端口号6901 以下是在CentOS 7上使用普通用户minio安装和配置MinIO的完整步骤,包括设置密码、设置开机自启动,以及使用minio用户启动和关闭服务的过程: 创建MinIO用户: sudo useradd -m minio sudo passwd minio这将创建一个可以登录…...

Vue3-27-路由-路径参数的简单使用
什么是路径参数 在路由配置中,可以将【参数】放在【路由路径】中, 从而实现,同一个 路由,同一个组件,因路径参数不同,可以渲染出不同的内容。特点 : 1、当携带不同路径参数的路由相互跳转时&am…...
RestClient
什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级ÿ…...
零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?
一、核心优势:专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发,是一款收费低廉但功能全面的Windows NAS工具,主打“无学习成本部署” 。与其他NAS软件相比,其优势在于: 无需硬件改造:将任意W…...

地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动
一、前言说明 在2011版本的gb28181协议中,拉取视频流只要求udp方式,从2016开始要求新增支持tcp被动和tcp主动两种方式,udp理论上会丢包的,所以实际使用过程可能会出现画面花屏的情况,而tcp肯定不丢包,起码…...
Admin.Net中的消息通信SignalR解释
定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...
在Ubuntu24上采用Wine打开SourceInsight
1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...

Ubuntu系统多网卡多相机IP设置方法
目录 1、硬件情况 2、如何设置网卡和相机IP 2.1 万兆网卡连接交换机,交换机再连相机 2.1.1 网卡设置 2.1.2 相机设置 2.3 万兆网卡直连相机 1、硬件情况 2个网卡n个相机 电脑系统信息,系统版本:Ubuntu22.04.5 LTS;内核版本…...
LLaMA-Factory 微调 Qwen2-VL 进行人脸情感识别(二)
在上一篇文章中,我们详细介绍了如何使用LLaMA-Factory框架对Qwen2-VL大模型进行微调,以实现人脸情感识别的功能。本篇文章将聚焦于微调完成后,如何调用这个模型进行人脸情感识别的具体代码实现,包括详细的步骤和注释。 模型调用步骤 环境准备:确保安装了必要的Python库。…...