【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别
【Pytorch】学习记录分享5——PyTorch经典网络 ResNet
- 1. ResNet (残差网络)基础知识
- 2. 感受野
- 3. 手写体数字识别
- 3. 0 数据集(训练与测试集)
- 3. 1 数据加载
- 3. 2 函数实现:
- 3. 3 训练及其测试:
1. ResNet (残差网络)基础知识
图1 56层error比20层error高,提出ResNet (残差网络)的方案
网络效果:
网络结构:
2. 感受野
3. 手写体数字识别
3. 0 数据集(训练与测试集)
mnist 用于手写体训练与测试,这里包含完整的链接
3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数
input_size = 28 #图像的总尺寸28*28
num_classes = 10 #标签的种类数
num_epochs = 3 #训练的总循环周期
batch_size = 64 #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential( # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1, # 灰度图out_channels=16, # 要得到几多少个特征图kernel_size=5, # 卷积核大小stride=1, # 步长padding=2, # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1), # 输出的特征图为 (16, 28, 28)nn.ReLU(), # relu层nn.MaxPool2d(kernel_size=2), # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential( # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2), # 输出 (32, 14, 14)nn.ReLU(), # relu层nn.MaxPool2d(2), # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10) # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1) # flatten操作,结果为:(batch_size, 32 * 7 * 7) output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels)
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN()
#损失函数
criterion = nn.CrossEntropyLoss()
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader): #针对容器中的每一个批进行循环net.train() # 将模型设置为训练模式output = net(data) # 使用模型进行前向传播loss = criterion(output, target) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward() # 反向传播计算梯度optimizer.step() # 更新参数right = accuracy(output, target) # 计算当前批次的准确率train_rights.append(right) # 将准确率保存起来if batch_idx % 500 == 0: # 每500个批次进行一次验证net.eval() # 将模型设置为评估模式val_rights = [] # 存储验证集的准确率for (data, target) in test_loader: # 在测试集上进行验证output = net(data) # 使用模型进行前向传播right = accuracy(output, target) # 计算验证集上的准确率val_rights.append(right) # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights])) # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights])) # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1])) # 打印当前进度和准确率信息
相关文章:

【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别
【Pytorch】学习记录分享5——PyTorch经典网络 ResNet 1. ResNet (残差网络)基础知识2. 感受野3. 手写体数字识别3. 0 数据集(训练与测试集)3. 1 数据加载3. 2 函数实现:3. 3 训练及其测试: 1. ResNet &…...

Flink1.17实战教程(第三篇:时间和窗口)
系列文章目录 Flink1.17实战教程(第一篇:概念、部署、架构) Flink1.17实战教程(第二篇:DataStream API) Flink1.17实战教程(第三篇:时间和窗口) Flink1.17实战教程&…...

CSS 纵向扩展动画
上干货 <template><!-- mouseenter"startAnimation" 表示在鼠标进入元素时触发 startAnimation 方法。mouseleave"stopAnimation" 表示在鼠标离开元素时触发 stopAnimation 方法。 --><!-- 容器元素 --><div class"container&q…...
Android 12 Token 机制
一、前言 在 android framework 框架中 activity 和 window 是相互关联的,而他们的管理者 AMS 和 WMS 是怎么来实现这种关联关系的,答案就是通过 token。 首先大家需要了解一下 LayoutParams,当然属性很多,简单了解即可…...
TCP与UDP是流式传输协议吗?
TCP(传输控制协议)和UDP(用户数据报协议)是两种主要的传输层协议,它们用于在网络中传输数据。它们不是流式传输协议,而是提供了不同的数据传输特性: 1. TCP(传输控制协议࿰…...
61 贪心算法解救生艇问题
问题描述:第i个人的体重为peaple[i],每个船可以承载的最大重量为limit。每艘船最多可以同时载两人,但条件是这些人的重量之和最多为limit,返回载到每一个人多虚的最小船数,(保证每个人被船载)。 贪心算法求解:先将数组…...

C#高级 01.Net多线程
一.基本概念 1.什么是线程? 线程是操作系统中能独立运行的最小单位,也是程序中能并发执行的一段指令序列线程是进程的一部分,一个进程可以包含多个线程,这些线程共享进程资源进程有线程入口,也可以创建更多的线程 2.…...

Java---泛型讲解
文章目录 1. 泛型类2. 泛型方法3. 泛型接口4. 类型通配符5. 可变参数6. 可变参数的使用 1. 泛型类 1. 格式:修饰符 class 类名 <类型>{ }。例如:public class Generic <T>{ }。 2. 代码块举例: public class Generic <T>{…...

【论文阅读笔记】SegVol: Universal and Interactive Volumetric Medical Image Segmentation
Du Y, Bai F, Huang T, et al. SegVol: Universal and Interactive Volumetric Medical Image Segmentation[J]. arXiv preprint arXiv:2311.13385, 2023.[代码开源] 【论文概述】 本文思路借鉴于自然图像分割领域的SAM,介绍了一种名为SegVol的先进医学图像分割模型…...
Unix/Linux操作系统介绍
1、Unix/Linux操作系统介绍 1.1、操作系统的作用 1)操作系统的目标 方便:使计算机系统易于使用有效:以更有效的方式使用计算机系统资源扩展:方便用户有效开发、测试、引进新功能 2)操作系统的地位 操作系统在计算…...

什么是https证书?
HTTPS证书,也称为SSL(Secure Sockets Layer)证书或TLS(Transport Layer Security)证书,是一种数字证书,用于在网络上建立安全的加密连接。它的主要目的是确保在互联网上进行的数据传输的安全性和…...

C++ DAY2作业
1.课堂struct练习,用class; #include <iostream>using namespace std;class Stu { private:int age;char sex;int high; public:double score;void set_values(int a,char b,int c,double d);int get_age();char get_sex();int get_high(); }; vo…...

RabbitMQ核心概念记录
本文来记录下RabbitMQ核心概念 文章目录 什么叫消息队列为何用消息队列RabbitMQ简介RabbitMQ基本概念RabbitMQ 特点具体特点包括 Rabbitmq的工作过程RabbitMQ集群RabbitMQ 的集群节点包括Rabbit 模式大概分为以下三种单一模式普通模式镜像模式 本文小结 什么叫消息队列 消息&am…...

算法时间空间复杂度计算—空间复杂度
算法时间空间复杂度计算—空间复杂度 空间复杂度定义影响空间复杂度的因素算法在运行过程中临时占用的存储空间讲解 计算方法例子1、空间算法的常数阶2、空间算法的线性阶(递归算法)3、二分查找分析方法一(迭代法)方法二ÿ…...
计算机专业校招常见面试题目总结
博主面试岗位包括:java开发、软件测试、测试开发等岗位,基于之前经历的面试总结出的一些常见题目。仅供参考,互相学习!! 八股:java开发、测试、测开岗位 Java技术栈:Java基础、JVM、数据结构、…...

网络编程『简易TCP网络程序』
🔭个人主页: 北 海 🛜所属专栏: Linux学习之旅、神奇的网络世界 💻操作环境: CentOS 7.6 阿里云远程服务器 文章目录 🌤️前言🌦️正文TCP网络程序1.字符串回响1.1.核心功能1.2.程序…...

java itext5 生成PDF并填充数据导出
java itext5 生成PDF并填充数据导出 依赖**文本勾选框****页眉**,**页脚****图片**实际图 主要功能有文本勾选框,页眉,页脚,图片等功能。肯定没有专业软件画的好看,只是一点儿方法。仅供参考。 依赖 <!--pdf-->&…...

如何配置TLSv1.2版本的ssl
1、tomcat配置TLSv1.2版本的ssl 如下图所示,打开tomcat\conf\server.xml文件,进行如下配置: 注意:需要将申请的tomcat版本的ssl认证文件,如server.jks存放到tomcat\conf\ssl_file\目录下。 <Connector port"1…...
在CentOS 7上使用普通用户`minio`安装和配置MinIO
指定控制台端口号6901 以下是在CentOS 7上使用普通用户minio安装和配置MinIO的完整步骤,包括设置密码、设置开机自启动,以及使用minio用户启动和关闭服务的过程: 创建MinIO用户: sudo useradd -m minio sudo passwd minio这将创建一个可以登录…...

Vue3-27-路由-路径参数的简单使用
什么是路径参数 在路由配置中,可以将【参数】放在【路由路径】中, 从而实现,同一个 路由,同一个组件,因路径参数不同,可以渲染出不同的内容。特点 : 1、当携带不同路径参数的路由相互跳转时&am…...
《Playwright:微软的自动化测试工具详解》
Playwright 简介:声明内容来自网络,将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具,支持 Chrome、Firefox、Safari 等主流浏览器,提供多语言 API(Python、JavaScript、Java、.NET)。它的特点包括&a…...
Linux简单的操作
ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...
大语言模型如何处理长文本?常用文本分割技术详解
为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...
python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)
更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...
Spring AI与Spring Modulith核心技术解析
Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)
前言: 在Java编程中,类的生命周期是指类从被加载到内存中开始,到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期,让读者对此有深刻印象。 目录 …...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...
uniapp 实现腾讯云IM群文件上传下载功能
UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中,群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS,在uniapp中实现: 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...

自然语言处理——文本分类
文本分类 传统机器学习方法文本表示向量空间模型 特征选择文档频率互信息信息增益(IG) 分类器设计贝叶斯理论:线性判别函数 文本分类性能评估P-R曲线ROC曲线 将文本文档或句子分类为预定义的类或类别, 有单标签多类别文本分类和多…...