基于Pytorch实现图像分类——基于jupyter
分类任务
- 网络基本构建与训练方法,常用函数解
- torch.nn.functional模块
- nn.Module模块
MNIST数据集下载
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
解压数据集
import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
查阅数据
from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

网络模型搭建

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

常用函数介绍
import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs] # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float, requires_grad = True)
bs = 64
bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))
模型搭建
from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
print(net)
Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)
for name, parameter in net.named_parameters():print(name, parameter,parameter.size())
dataset数据接口
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
from torch import optim
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
相关文章:
基于Pytorch实现图像分类——基于jupyter
分类任务 网络基本构建与训练方法,常用函数解torch.nn.functional模块nn.Module模块 MNIST数据集下载 from pathlib import Path import requestsDATA_PATH Path("data") PATH DATA_PATH / "mnist"PATH.mkdir(parentsTrue, exist_okTrue)U…...
如何将CSDN的文章以PDF文件形式保存到本地
1.F12 打开开发者工具窗口 2.console下输入命令 (function(){$("#side").remove();$("#comment_title, #comment_list, #comment_bar, #comment_form, .announce, #ad_cen, #ad_bot").remove();$(".nav_top_2011, #header, #navigator").remove…...
面试经典150题——删除有序数组中的重复项
面试经典150题 day3 题目来源我的题解方法一 双指针 题目来源 力扣每日一题;题序:26 我的题解 方法一 双指针 使用两个指针分别指向相同元素的左右边界,再利用一个count记录最终需要的数组长度。 时间复杂度:O(n) 空间复杂度&a…...
Unity3D知识点精华浓缩
一、细节 1、类与组件的关系 2、Time.deltaTime的含义 3、怎么表示一帧的移动距离 4、Update和LateUpdate的区别和适用场景 5、找游戏对象的方式(别的对象 / 当前对象的子对象) 6、组件1调用组件2中方法的方式 7、在面板中获取外部数据的方法 8、序列化属…...
HTML的文档说明
1.告诉浏览器当前网页的版本 2.写法: !以前的写法:要依据网页的HTML的版本去确定,紫萼发油很多很多。 具体的写法可以参考:W3C官网的文档说明 !新写法:W3C都推荐用h5的写法 <DOCTYPE ht…...
ubuntu 更新或更改GCC/G++
最近遇到一些问题,需要用到gcc-9/g-9,但是我自带的ubuntu18.04是gcc-7.5/g-7.5,所以升级一下,奈何文章太多而且很多无效,所以在此记录一下: 参考:https://stackoverflow.com/questions/19836858…...
Java --- Java语言基础
这个Java可是个好东西,是一门面对对象的程序设计语言,其语法很类似C,所以学过C的伙伴们就很好上手,另外Java对C进行了简化与提高,这个在后期学习会感受到,Java还有很多的类库API文档以及第三方开发包。 这…...
【C++算法竞赛 · 图论】图的存储
前言 图的存储 邻接矩阵 方法 复杂度 应用 例题 题解 邻接表 方法 复杂度 应用 前言 上一篇文章中(【C算法竞赛 图论】图论基础),介绍了图论相关的概念和一种图的存储的方法,这篇文章将会介绍剩下的两种方法ÿ…...
Spring AOP IOC
spring的优缺点 IOC集中管理对象,对象之间解耦,方便维护对象AOP在不修改原代码的情况下,实现一些拦截提供众多辅助类,方便开发方便集成各种优秀框架 紧耦合和松耦合 松耦合可以使用单一职责原则、接口分离原则、依赖倒置原则 …...
Linux ARM平台开发系列讲解(QEMU篇) 1.1 编译QEMU 构建RISC-V64架构 运行Linux kernel
1. 概述 QEMU可以模拟很多架构的CPU(ARM,RISC-V等),重点是免费,用来学Linux简直太适合不过了,所以,我打算开一章节来教QEMU的使用,这样也方便环境统一调试,本章节就讲解如何在Ubuntu搭建QEMU,我的环境是全新的Ubuntu22,QEMU下载的9.0,kernel下载的6.0. 2. 源码下载…...
Day19-【Java SE进阶】网络编程
一、网络编程 1.概述 可以让设备中的程序与网络上其他设备中的程序进行数据交互(实现网络通信的)。java.net,*包下提供了网络编程的解决方案! 基本的通信架构 基本的通信架构有2种形式:CS架构(Client客户端/Server服务端)、BS架构(Browser浏览器/Server服务端)。 网络通信的…...
pyqt写个星三角降压启动方式2
星三角降压启动用可以用类进行封装,就像博图FB块那样。把逻辑都在类里完成,和外界需要交互的暴露出接口。测试过程中,发现类中直接用定时器QTimer会出现问题。然后就把定时器放到外面了。然后测试功能正常。 from PySide6.QtWidgets import …...
js可视化爬取数据生成当前热点词汇图
功能 可以爬取到很多数据,并且生成当前的热点词汇图,词越大越热门(词云图) 这里以b站某个评论区的数据为例,爬取63448条数据生成这样的图片 让我们能够更加直观的看到当前的热点 git地址 可以直接使用,中文…...
研发岗-面临统信UOS系统配置总结
第一步 获取root权限 配置环境等都需要用到root权限,所以我们先获取到root权限,方便下面的操作 下载软件 在UOS应用商店下载的所需应用 版本都比较低 安装node 官网下载了【arm64】的包,解压到指定文件夹,设置链接࿰…...
【STL详解 —— list的介绍及使用】
STL详解 —— list的介绍及使用 list的介绍list的介绍使用list的构造list iterator的使用list capacitylist element accesslist modifiers 示例list的迭代器失效 list的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭…...
cocos creator开发中遇到的问题和解决方案
前言 总结一下使用cocos开发遇到的坑,不定期更新。 问题汇总 代码修改Position坐标不生效 首先要通过打log或者断点排除下是不是逻辑上的问题,还有是不是有动画相关把位置修改了。我遇到的问题是坐标修改被widget组件覆盖了。 纹理压缩包体变大 co…...
10分钟带你学会配置DNS服务正反向解析
正向解析 服务端IP客户端IP网址192.168.160.134192.168.160.135www.openlab.com 一、首先做准备工作: 关闭安全软件,关闭防火墙,下载bind软件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# y…...
【vim 学习系列文章 19 -- 映射快捷键调用两个函数 A 和B】
请阅读【嵌入式开发学习必备专栏 之 Vim】 文章目录 映射快捷键调用两个函数 映射快捷键调用两个函数 在 Vim 中,如果想通过按下 gcm 来调用两个函数,比如 FunctionA 和 FunctionB,需要先定义这两个函数,然后创建一个映射。这个映…...
Windows安装MongoDB结合内网穿透轻松实现公网访问本地数据库
💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...
sgg大数据全套技术链接[plus]
写在开头:感谢尚硅谷,尚硅谷万岁,我爱尚硅谷 111个技术栈43个项目,兄弟们,冲! 最近小米又又又火了一把,致敬所有造福人民的企业和伟大的企业家,致敬雷军,小米ÿ…...
SDXL-Turbo快速上手:AutoDL开箱即用,零配置体验实时AI绘画
SDXL-Turbo快速上手:AutoDL开箱即用,零配置体验实时AI绘画 1. 什么是SDXL-Turbo SDXL-Turbo是StabilityAI推出的新一代实时AI绘画模型,它彻底改变了传统AI绘画需要等待数秒甚至数十秒才能看到结果的工作方式。基于创新的对抗扩散蒸馏技术(A…...
2026降AI率工具红黑榜:降AI率网站怎么选?看完少走弯路
千笔AI、ThouPen、豆包位列红榜,精准适配国内高校AI率检测规范;黑榜需避开低质免费工具、无正规检测对接平台及改写痕迹明显的工具;选择时应优先匹配三维模型:降AI效果-学术合规性-使用成本。 一、红榜:10 款高分论文降…...
3大突破:重新定义Revit插件开发流程
3大突破:重新定义Revit插件开发流程 【免费下载链接】RevitAddInManager Revit AddinManager update .NET assemblies without restart Revit for developer. 项目地址: https://gitcode.com/gh_mirrors/re/RevitAddInManager 引言:Revit插件开发…...
XC6206-1.8V是什么?有哪些作用?
本文主要介绍XC6206-1.8V是什么?有哪些作用?XC6206-1.8V是一款超低功耗、高精度的固定输出低压差线性稳压器(LDO),核心作用是把较高电压转换成稳定的1.8V输出,专门为电池供电和低功耗设备设计。图文来源&am…...
aircrack-ng使用教程
aircrack-ng是一款用于无线网络安全评估的工具套件,主要用于破解WEP和WPA/WPA2-PSK加密的无线网络密码。它通过分析捕获的数据包,利用密码破解技术来获取网络密钥,是网络安全测试和渗透测试中常用的工具之一。该工具支持多种攻击模式和优化选…...
Wan2.2-I2V-A14B开源可部署:符合等保2.0要求,支持审计日志+访问控制
Wan2.2-I2V-A14B开源可部署:符合等保2.0要求,支持审计日志访问控制 1. 镜像概述与核心特性 Wan2.2-I2V-A14B是一款专为文生视频任务优化的私有部署镜像,基于RTX 4090D 24GB显存显卡和CUDA 12.4环境深度定制。本镜像不仅提供高性能的视频生成…...
统信系统下如何管理Mysql?
背景 看到标题很多朋友会打趣的问我:“你不是一直用麒麟操作系统做讲解吗?”,其实DBCS和DESK的兼容性太强了,什么操作系统都行,Windows上最容易了,所以我一般不用Windows,下次我用Ubuntu给大家…...
CasRel模型惊艳效果:同一实体对(马云-阿里巴巴)识别7种关系
CasRel模型惊艳效果:同一实体对(马云-阿里巴巴)识别7种关系 1. 关系抽取的神奇能力 你有没有遇到过这样的情况:阅读一篇关于企业家的报道时,想知道他和他的公司之间到底有哪些关系?是创始人?董…...
别再纠结在线辨识了!聊聊永磁同步电机(PMSM)离线参数自学习的完整流程与避坑指南
永磁同步电机离线参数辨识实战:从理论到工程落地的全流程解析 在电机控制领域,参数辨识一直是个让人又爱又恨的话题。尤其是当项目从实验室走向量产时,那些在仿真中运行良好的算法,往往会因为实际电机参数的偏差而表现失常。我曾亲…...
数智驱动 人才筑基——拔尖创新人才与卓越工程师培养论坛举行
3月22日,第二届高等院校新工科人才培养暨产教融合发展大会在北京举行。大会以“科技创新 智造未来”为主题,来自全国各地的本科院校、职业院校、行业企业以及媒体等1000余位嘉宾参会。22日下午,数智驱动 人才筑基——拔尖创新人才与卓越工程师…...
