基于Pytorch实现图像分类——基于jupyter
分类任务
- 网络基本构建与训练方法,常用函数解
- torch.nn.functional模块
- nn.Module模块
MNIST数据集下载
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
解压数据集
import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
查阅数据
from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

网络模型搭建

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

常用函数介绍
import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs] # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True)
bs = 64
bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))
模型搭建
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
print(net)
Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)
for name, parameter in net.named_parameters():print(name, parameter,parameter.size())
dataset数据接口
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
from torch import optim
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
相关文章:
基于Pytorch实现图像分类——基于jupyter
分类任务 网络基本构建与训练方法,常用函数解torch.nn.functional模块nn.Module模块 MNIST数据集下载 from pathlib import Path import requestsDATA_PATH Path("data") PATH DATA_PATH / "mnist"PATH.mkdir(parentsTrue, exist_okTrue)U…...
如何将CSDN的文章以PDF文件形式保存到本地
1.F12 打开开发者工具窗口 2.console下输入命令 (function(){$("#side").remove();$("#comment_title, #comment_list, #comment_bar, #comment_form, .announce, #ad_cen, #ad_bot").remove();$(".nav_top_2011, #header, #navigator").remove…...
面试经典150题——删除有序数组中的重复项
面试经典150题 day3 题目来源我的题解方法一 双指针 题目来源 力扣每日一题;题序:26 我的题解 方法一 双指针 使用两个指针分别指向相同元素的左右边界,再利用一个count记录最终需要的数组长度。 时间复杂度:O(n) 空间复杂度&a…...
Unity3D知识点精华浓缩
一、细节 1、类与组件的关系 2、Time.deltaTime的含义 3、怎么表示一帧的移动距离 4、Update和LateUpdate的区别和适用场景 5、找游戏对象的方式(别的对象 / 当前对象的子对象) 6、组件1调用组件2中方法的方式 7、在面板中获取外部数据的方法 8、序列化属…...
HTML的文档说明
1.告诉浏览器当前网页的版本 2.写法: !以前的写法:要依据网页的HTML的版本去确定,紫萼发油很多很多。 具体的写法可以参考:W3C官网的文档说明 !新写法:W3C都推荐用h5的写法 <DOCTYPE ht…...
ubuntu 更新或更改GCC/G++
最近遇到一些问题,需要用到gcc-9/g-9,但是我自带的ubuntu18.04是gcc-7.5/g-7.5,所以升级一下,奈何文章太多而且很多无效,所以在此记录一下: 参考:https://stackoverflow.com/questions/19836858…...
Java --- Java语言基础
这个Java可是个好东西,是一门面对对象的程序设计语言,其语法很类似C,所以学过C的伙伴们就很好上手,另外Java对C进行了简化与提高,这个在后期学习会感受到,Java还有很多的类库API文档以及第三方开发包。 这…...
【C++算法竞赛 · 图论】图的存储
前言 图的存储 邻接矩阵 方法 复杂度 应用 例题 题解 邻接表 方法 复杂度 应用 前言 上一篇文章中(【C算法竞赛 图论】图论基础),介绍了图论相关的概念和一种图的存储的方法,这篇文章将会介绍剩下的两种方法ÿ…...
Spring AOP IOC
spring的优缺点 IOC集中管理对象,对象之间解耦,方便维护对象AOP在不修改原代码的情况下,实现一些拦截提供众多辅助类,方便开发方便集成各种优秀框架 紧耦合和松耦合 松耦合可以使用单一职责原则、接口分离原则、依赖倒置原则 …...
Linux ARM平台开发系列讲解(QEMU篇) 1.1 编译QEMU 构建RISC-V64架构 运行Linux kernel
1. 概述 QEMU可以模拟很多架构的CPU(ARM,RISC-V等),重点是免费,用来学Linux简直太适合不过了,所以,我打算开一章节来教QEMU的使用,这样也方便环境统一调试,本章节就讲解如何在Ubuntu搭建QEMU,我的环境是全新的Ubuntu22,QEMU下载的9.0,kernel下载的6.0. 2. 源码下载…...
Day19-【Java SE进阶】网络编程
一、网络编程 1.概述 可以让设备中的程序与网络上其他设备中的程序进行数据交互(实现网络通信的)。java.net,*包下提供了网络编程的解决方案! 基本的通信架构 基本的通信架构有2种形式:CS架构(Client客户端/Server服务端)、BS架构(Browser浏览器/Server服务端)。 网络通信的…...
pyqt写个星三角降压启动方式2
星三角降压启动用可以用类进行封装,就像博图FB块那样。把逻辑都在类里完成,和外界需要交互的暴露出接口。测试过程中,发现类中直接用定时器QTimer会出现问题。然后就把定时器放到外面了。然后测试功能正常。 from PySide6.QtWidgets import …...
js可视化爬取数据生成当前热点词汇图
功能 可以爬取到很多数据,并且生成当前的热点词汇图,词越大越热门(词云图) 这里以b站某个评论区的数据为例,爬取63448条数据生成这样的图片 让我们能够更加直观的看到当前的热点 git地址 可以直接使用,中文…...
研发岗-面临统信UOS系统配置总结
第一步 获取root权限 配置环境等都需要用到root权限,所以我们先获取到root权限,方便下面的操作 下载软件 在UOS应用商店下载的所需应用 版本都比较低 安装node 官网下载了【arm64】的包,解压到指定文件夹,设置链接࿰…...
【STL详解 —— list的介绍及使用】
STL详解 —— list的介绍及使用 list的介绍list的介绍使用list的构造list iterator的使用list capacitylist element accesslist modifiers 示例list的迭代器失效 list的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭…...
cocos creator开发中遇到的问题和解决方案
前言 总结一下使用cocos开发遇到的坑,不定期更新。 问题汇总 代码修改Position坐标不生效 首先要通过打log或者断点排除下是不是逻辑上的问题,还有是不是有动画相关把位置修改了。我遇到的问题是坐标修改被widget组件覆盖了。 纹理压缩包体变大 co…...
10分钟带你学会配置DNS服务正反向解析
正向解析 服务端IP客户端IP网址192.168.160.134192.168.160.135www.openlab.com 一、首先做准备工作: 关闭安全软件,关闭防火墙,下载bind软件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# y…...
【vim 学习系列文章 19 -- 映射快捷键调用两个函数 A 和B】
请阅读【嵌入式开发学习必备专栏 之 Vim】 文章目录 映射快捷键调用两个函数 映射快捷键调用两个函数 在 Vim 中,如果想通过按下 gcm 来调用两个函数,比如 FunctionA 和 FunctionB,需要先定义这两个函数,然后创建一个映射。这个映…...
Windows安装MongoDB结合内网穿透轻松实现公网访问本地数据库
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...
sgg大数据全套技术链接[plus]
写在开头:感谢尚硅谷,尚硅谷万岁,我爱尚硅谷 111个技术栈43个项目,兄弟们,冲! 最近小米又又又火了一把,致敬所有造福人民的企业和伟大的企业家,致敬雷军,小米ÿ…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
微信小程序云开发平台MySQL的连接方式
注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...
CSS设置元素的宽度根据其内容自动调整
width: fit-content 是 CSS 中的一个属性值,用于设置元素的宽度根据其内容自动调整,确保宽度刚好容纳内容而不会超出。 效果对比 默认情况(width: auto): 块级元素(如 <div>)会占满父容器…...
SQL Server 触发器调用存储过程实现发送 HTTP 请求
文章目录 需求分析解决第 1 步:前置条件,启用 OLE 自动化方式 1:使用 SQL 实现启用 OLE 自动化方式 2:Sql Server 2005启动OLE自动化方式 3:Sql Server 2008启动OLE自动化第 2 步:创建存储过程第 3 步:创建触发器扩展 - 如何调试?第 1 步:登录 SQL Server 2008第 2 步…...
沙箱虚拟化技术虚拟机容器之间的关系详解
问题 沙箱、虚拟化、容器三者分开一一介绍的话我知道他们各自都是什么东西,但是如果把三者放在一起,它们之间到底什么关系?又有什么联系呢?我不是很明白!!! 就比如说: 沙箱&#…...
Sklearn 机器学习 缺失值处理 获取填充失值的统计值
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 使用 Scikit-learn 处理缺失值并提取填充统计信息的完整指南 在机器学习项目中,数据清…...
