基于Pytorch实现图像分类——基于jupyter
分类任务
- 网络基本构建与训练方法,常用函数解
- torch.nn.functional模块
- nn.Module模块
MNIST数据集下载
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
解压数据集
import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
查阅数据
from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
网络模型搭建
import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
常用函数介绍
import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs] # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True)
bs = 64
bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))
模型搭建
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
print(net)
Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)
for name, parameter in net.named_parameters():print(name, parameter,parameter.size())
dataset数据接口
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
from torch import optim
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
相关文章:

基于Pytorch实现图像分类——基于jupyter
分类任务 网络基本构建与训练方法,常用函数解torch.nn.functional模块nn.Module模块 MNIST数据集下载 from pathlib import Path import requestsDATA_PATH Path("data") PATH DATA_PATH / "mnist"PATH.mkdir(parentsTrue, exist_okTrue)U…...

如何将CSDN的文章以PDF文件形式保存到本地
1.F12 打开开发者工具窗口 2.console下输入命令 (function(){$("#side").remove();$("#comment_title, #comment_list, #comment_bar, #comment_form, .announce, #ad_cen, #ad_bot").remove();$(".nav_top_2011, #header, #navigator").remove…...
面试经典150题——删除有序数组中的重复项
面试经典150题 day3 题目来源我的题解方法一 双指针 题目来源 力扣每日一题;题序:26 我的题解 方法一 双指针 使用两个指针分别指向相同元素的左右边界,再利用一个count记录最终需要的数组长度。 时间复杂度:O(n) 空间复杂度&a…...
Unity3D知识点精华浓缩
一、细节 1、类与组件的关系 2、Time.deltaTime的含义 3、怎么表示一帧的移动距离 4、Update和LateUpdate的区别和适用场景 5、找游戏对象的方式(别的对象 / 当前对象的子对象) 6、组件1调用组件2中方法的方式 7、在面板中获取外部数据的方法 8、序列化属…...
HTML的文档说明
1.告诉浏览器当前网页的版本 2.写法: !以前的写法:要依据网页的HTML的版本去确定,紫萼发油很多很多。 具体的写法可以参考:W3C官网的文档说明 !新写法:W3C都推荐用h5的写法 <DOCTYPE ht…...
ubuntu 更新或更改GCC/G++
最近遇到一些问题,需要用到gcc-9/g-9,但是我自带的ubuntu18.04是gcc-7.5/g-7.5,所以升级一下,奈何文章太多而且很多无效,所以在此记录一下: 参考:https://stackoverflow.com/questions/19836858…...
Java --- Java语言基础
这个Java可是个好东西,是一门面对对象的程序设计语言,其语法很类似C,所以学过C的伙伴们就很好上手,另外Java对C进行了简化与提高,这个在后期学习会感受到,Java还有很多的类库API文档以及第三方开发包。 这…...
【C++算法竞赛 · 图论】图的存储
前言 图的存储 邻接矩阵 方法 复杂度 应用 例题 题解 邻接表 方法 复杂度 应用 前言 上一篇文章中(【C算法竞赛 图论】图论基础),介绍了图论相关的概念和一种图的存储的方法,这篇文章将会介绍剩下的两种方法ÿ…...
Spring AOP IOC
spring的优缺点 IOC集中管理对象,对象之间解耦,方便维护对象AOP在不修改原代码的情况下,实现一些拦截提供众多辅助类,方便开发方便集成各种优秀框架 紧耦合和松耦合 松耦合可以使用单一职责原则、接口分离原则、依赖倒置原则 …...
Linux ARM平台开发系列讲解(QEMU篇) 1.1 编译QEMU 构建RISC-V64架构 运行Linux kernel
1. 概述 QEMU可以模拟很多架构的CPU(ARM,RISC-V等),重点是免费,用来学Linux简直太适合不过了,所以,我打算开一章节来教QEMU的使用,这样也方便环境统一调试,本章节就讲解如何在Ubuntu搭建QEMU,我的环境是全新的Ubuntu22,QEMU下载的9.0,kernel下载的6.0. 2. 源码下载…...

Day19-【Java SE进阶】网络编程
一、网络编程 1.概述 可以让设备中的程序与网络上其他设备中的程序进行数据交互(实现网络通信的)。java.net,*包下提供了网络编程的解决方案! 基本的通信架构 基本的通信架构有2种形式:CS架构(Client客户端/Server服务端)、BS架构(Browser浏览器/Server服务端)。 网络通信的…...
pyqt写个星三角降压启动方式2
星三角降压启动用可以用类进行封装,就像博图FB块那样。把逻辑都在类里完成,和外界需要交互的暴露出接口。测试过程中,发现类中直接用定时器QTimer会出现问题。然后就把定时器放到外面了。然后测试功能正常。 from PySide6.QtWidgets import …...

js可视化爬取数据生成当前热点词汇图
功能 可以爬取到很多数据,并且生成当前的热点词汇图,词越大越热门(词云图) 这里以b站某个评论区的数据为例,爬取63448条数据生成这样的图片 让我们能够更加直观的看到当前的热点 git地址 可以直接使用,中文…...

研发岗-面临统信UOS系统配置总结
第一步 获取root权限 配置环境等都需要用到root权限,所以我们先获取到root权限,方便下面的操作 下载软件 在UOS应用商店下载的所需应用 版本都比较低 安装node 官网下载了【arm64】的包,解压到指定文件夹,设置链接࿰…...

【STL详解 —— list的介绍及使用】
STL详解 —— list的介绍及使用 list的介绍list的介绍使用list的构造list iterator的使用list capacitylist element accesslist modifiers 示例list的迭代器失效 list的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭…...
cocos creator开发中遇到的问题和解决方案
前言 总结一下使用cocos开发遇到的坑,不定期更新。 问题汇总 代码修改Position坐标不生效 首先要通过打log或者断点排除下是不是逻辑上的问题,还有是不是有动画相关把位置修改了。我遇到的问题是坐标修改被widget组件覆盖了。 纹理压缩包体变大 co…...

10分钟带你学会配置DNS服务正反向解析
正向解析 服务端IP客户端IP网址192.168.160.134192.168.160.135www.openlab.com 一、首先做准备工作: 关闭安全软件,关闭防火墙,下载bind软件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# y…...
【vim 学习系列文章 19 -- 映射快捷键调用两个函数 A 和B】
请阅读【嵌入式开发学习必备专栏 之 Vim】 文章目录 映射快捷键调用两个函数 映射快捷键调用两个函数 在 Vim 中,如果想通过按下 gcm 来调用两个函数,比如 FunctionA 和 FunctionB,需要先定义这两个函数,然后创建一个映射。这个映…...

Windows安装MongoDB结合内网穿透轻松实现公网访问本地数据库
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...
sgg大数据全套技术链接[plus]
写在开头:感谢尚硅谷,尚硅谷万岁,我爱尚硅谷 111个技术栈43个项目,兄弟们,冲! 最近小米又又又火了一把,致敬所有造福人民的企业和伟大的企业家,致敬雷军,小米ÿ…...
[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?
🧠 智能合约中的数据是如何在区块链中保持一致的? 为什么所有区块链节点都能得出相同结果?合约调用这么复杂,状态真能保持一致吗?本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里…...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...

srs linux
下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935,SRS管理页面端口是8080,可…...
在Ubuntu中设置开机自动运行(sudo)指令的指南
在Ubuntu系统中,有时需要在系统启动时自动执行某些命令,特别是需要 sudo权限的指令。为了实现这一功能,可以使用多种方法,包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法,并提供…...
Rapidio门铃消息FIFO溢出机制
关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...

安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖
在Vuzix M400 AR智能眼镜的助力下,卢森堡罗伯特舒曼医院(the Robert Schuman Hospitals, HRS)凭借在无菌制剂生产流程中引入增强现实技术(AR)创新项目,荣获了2024年6月7日由卢森堡医院药剂师协会࿰…...

【分享】推荐一些办公小工具
1、PDF 在线转换 https://smallpdf.com/cn/pdf-tools 推荐理由:大部分的转换软件需要收费,要么功能不齐全,而开会员又用不了几次浪费钱,借用别人的又不安全。 这个网站它不需要登录或下载安装。而且提供的免费功能就能满足日常…...

Unity中的transform.up
2025年6月8日,周日下午 在Unity中,transform.up是Transform组件的一个属性,表示游戏对象在世界空间中的“上”方向(Y轴正方向),且会随对象旋转动态变化。以下是关键点解析: 基本定义 transfor…...

高考志愿填报管理系统---开发介绍
高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...
跨平台商品数据接口的标准化与规范化发展路径:淘宝京东拼多多的最新实践
在电商行业蓬勃发展的当下,多平台运营已成为众多商家的必然选择。然而,不同电商平台在商品数据接口方面存在差异,导致商家在跨平台运营时面临诸多挑战,如数据对接困难、运营效率低下、用户体验不一致等。跨平台商品数据接口的标准…...