【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接
1.读取图像分类数据集
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
#读取数据集
trans=transforms.ToTensor()
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
print('训练数据集:',len(mnist_train),'测试数据集:',len(mnist_test))
print('训练数据集图片大小:',mnist_train[0][0].shape)#两个可视化数据集的函数
def get_fashion_mnist_labels(labels): #返回fashion_mnist数据集的文本标签text_labels=['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels ]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):figsize=(num_rows*scale,num_cols*scale)_,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)axes=axes.flatten()for i ,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
#几个样本的图像及其相应的标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
d2l.plt.show()#读取一小批量数据,大小为batchsize
batch_size=256
def get_dataloader_workers(): #使用4个进程来读取数据return 4
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
timer=d2l.Timer()
for X,y in train_iter:continue
print(f'{timer.stop():.2f}sec')
# 便于重用函数
def load_data_fasion_mnist(batch_size,resize:None):trans = [transforms.ToTensor()]if resize:trans.insert(0,transforms.Resize(resize))trans=transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return(data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test,batch_size,shuffle=False,num_workers=get_dataloader_workers()))
运行结果
2.Softmax 回归从零开始实现
import torch
from IPython import display
from d2l import torch as d2l
import matplotlib.pyplot as plt
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as npbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
num_inputs=784 #展平图像为向量
num_outputs=10 # 有10个类所以模型输出为10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)#定义权重w
b=torch.zeros(num_outputs,requires_grad=True)# 定义softmax
def softmax(X):X_exp=torch.exp(X)#对每个元素做指数运算partition =X_exp.sum(1,keepdim=True)#按照行求和return X_exp/partition #矩阵中的各个元素/对应行元素之和
#验证一下是否是正确的
X=torch.normal(0,0.01,(2,5))# 创建均值为0方差为1的两行五列的X
X_prob=softmax(X)
print('1.验证softmax:',X_prob,X_prob.sum(1))
#实现softmax回归模型
def net(X):return softmax(torch.matmul(X.reshape((-1,w.shape[0])),w)+b) # -1,每次喂数据的量,就是batchsizey=torch.tensor([0,2])
y_hat=torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])
print('2.根据标号拿出预测值:',y_hat[[0,1],y])
# 实现交叉熵损失
def cross_entropy(y_hat,y): #给定预测和真实标号Yreturn -torch.log(y_hat[range(len(y_hat)),y])# 锁定y轴在x轴上根据labels收取预测值,交叉熵损失中除了真值=1,其他都是0,这里直接算针织对应的预测概率
print('3.交叉熵损失:',cross_entropy(y_hat,y))#将预测类别与真实元素y进行比较
def accuracy(y_hat,y):if len(y_hat.shape)>1 and y_hat.shape[1]>1: #shape和列数大于1的时候y_hat=y_hat.argmax(axis=1)#把每一行元素最大的下标存到y_hatcmp=y_hat.type(y.dtype)==y #y_hat和y的数据类型转换,作比较变成布尔return float(cmp.type(y.dtype).sum())#转换成和y一样的形状求和
print('4.预测正确的概率:',accuracy(y_hat,y)/len(y))# 预测正确的样本数除以y的长度就是预测正确的概率#计算模型在数据迭代器上的精度
def evaluate_accuracy(net,data_iter):if isinstance(net,torch.nn.Module):net.eval()#将模型设置为评估模式,输入后得出的结果用来评估模型的准确率,不做反向传播metric =Accumulator(2) # 累加器for X,y in data_iter:metric.add(accuracy(net(X),y),y.numel())return metric[0]/metric[1] #返回分类正确的样本数和总样本数# accumulator的实现
class Accumulator: #作用是累加def __init__(self,n):self.data=[0.0]*ndef add(self,*args):self.data=[a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data=[0.0]*len(self.data)def __getitem__(self, idx):return self.data[idx]
if __name__=='__main__':print(evaluate_accuracy(net,test_iter))# softmax回归的训练
def train_epoch_ch3(net,train_iter,loss,updater):if isinstance(net,torch.nn.Module):net.train()metric=Accumulator(3)for X,y in train_iter:y_hat=net(X)l=loss(y_hat,y)if isinstance(updater,torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()metric.add(float(l)*len(y),accuracy(y_hat,y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y),y.numel())return metric[0]/metric[2],metric[1]/metric[2]class Animator:def __init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(3.5,2.5)):if legend is None:legend=[]d2l.use_svg_display()self.fig,self.axes=d2l.plt.subplots(nrows,ncols,figsize=figsize)if nrows*ncols==1:self.axes=[self.axes, ]self.config_axes=lambda :d2l.set_axes(self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)self.X,self.Y,self.fmts=None,None,fmtsdef add(self,x,y):if not hasattr(y,"__len__"):y=[y]n=len(y)if not hasattr(x, "__len__"):x=[x]*nif not self.X:self.X=[[]for _ in range(n)]if not self.Y:self.Y=[[]for _ in range(n)]for i ,(a,b) in enumerate(zip(x,y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x,y,fmt in zip(self.X,self.Y,self.fmts):self.axes[0].plot(x,y,fmt)self.config_axes()plt.draw()plt.pause(0.001)display.display(self.fig)display.clear_output(wait=True)def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator=Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0.3,0.9],legend=['train loss','train acc','test acc'])for epoch in range(num_epochs):train_metrics=train_epoch_ch3(net,train_iter,loss,updater)test_acc=evaluate_accuracy(net,test_iter)animator.add(epoch+1, train_metrics+(test_acc,))train_loss,train_acc=train_metricslr = 0.1
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)if __name__ == '__main__':num_epochs=10train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)# 对图像进行分类的预测def predict_ch3(net,test_iter,n=6):for X,y in test_iter:breaktrues=d2l.get_fashion_mnist_labels(y)preds=d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles=[true+'\n'+pred for true,pred in zip(trues,preds)]d2l.show_images(X[0:n].reshape((n,28,28)),1,n,titles=titles[0:n])d2l.plt.show()
if __name__ == '__main__':predict_ch3(net,test_iter)
运行结果
3.Softmax 回归简洁实现
import torch
from torch import nn
from d2l import torch as d2lbatch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)# 初始化模型参数
net =nn.Sequential(nn.Flatten(),nn.Linear(784,10))def init_weights(m):if type(m)==nn.Linear:nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights);loss=nn.CrossEntropyLoss(reduction='none')
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)d2l.plt.show()
运行结果
相关文章:

【代码pycharm】动手学深度学习v2-09 Softmax 回归 + 损失函数 + 图片分类数据集
课程链接 1.读取图像分类数据集 import matplotlib.pyplot as plt import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display() #读取数据集 transtransforms.ToTensor() mnist_…...

设计模式:24、访问者模式
目录 0、定义 1、访问者模式的五种角色 2、访问者模式的UML类图 3、示例代码 0、定义 表示一个作用于某对象结构中的各个元素的操作。它可以在不改变各个元素的类的前提下,定义作用于这些元素的新操作。 1、访问者模式的五种角色 抽象元素(Element…...

基于JAVA的旅游网站系统设计
摘要 随着信息技术和网络技术的迅速发展,人们的生活质量和观念也在发生着改变,各地争相发展旅游业,传统的 旅游社已经无法满足人们的需求,旅游网站将突破传统在时间和地域的限制,成为方便、快捷、安全、可靠的旅游 方…...
网络安全产品之认识防火墙
防火墙是一种网络安全产品,它设置在不同网络(如可信任的企业内部网和不可信的公共网)或网络安全域之间,通过监测、限制、更改跨越防火墙的数据流,尽可能地对外部屏蔽网络内部的信息、结构和运行状况,以此来…...

nginx反向代理(负载均衡)和tomcat介绍
nginx的代理 负载均衡 负载均衡的算法 负载均衡的架构 基于ip的七层代理 upstream模块要写在http模块中 七层代理的调用要写在location模块中 轮询 加权轮询 最小连接数 ip_Hash URL_HASH 基于域名的七层代理 配置主机 给其余客户机配置域名 给所有机器做域名映射 四层代理…...
Microsoft Azure 在线技术公开课:生成式 AI 基础知识
课程介绍 参加我们的生成式 AI 基础知识公开课,了解如何将最新 AI 进展应用到你的工作中。你将了解有关语言模型和生成式 AI 应用程序的基础知识。此外,你还将了解 Azure OpenAI 服务如何通过文本、代码、图像生成、自然语言摘要和语义搜索助你实现成果…...

lnmp+discuz论坛 附实验:搭建discuz论坛
Inmpdiscuz论坛 Inmp: t: linux操作系统 nr: nginx前端页面 me: mysql数据库 账号密码,等等都是保存在这个数据库里面 p: php——nginx擅长处理的是静态页面,页面登录账户,需要请求到数据库,通过php把动态请求转发到数据库 n…...

谷粒商城—分布式高级①.md
1. ELASTICSEARCH 1、安装elastic search dokcer中安装elastic search (1)下载ealastic search和kibana docker pull elasticsearch:7.6.2 docker pull kibana:7.6.2(2)配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "h…...

Unity开发配置不足,卡顿崩溃怎么办?
在游戏开发和虚拟现实等领域,Unity 软件以其强大的功能和广泛的适用性成为了众多开发者的首选。然而,要充分发挥 Unity 的性能,一台高性能的电脑设备是必不可少的。今天,我要向大家介绍川翔云电脑,它为 Unity 开发者提…...

在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1
KubeSphere4.1安装文档 在 Kubernetes 上快速安装 KubeSphere 在 Linux 上以 All-in-One 模式安装 kubernetes v1.22.12 kubesphere v3.4.1 官方文档:在 Linux 上以 All-in-One 模式安装 KubeSphere 下载文件 KubeKey git地址Releases kubesphere/kubekey 或 …...
网络安全自学是一项需要耐心和恒心的任务
网络安全自学是一项需要耐心和恒心的任务,但只要你按照正确的学习路线图去努力,就能够逐步掌握这一领域的知识和技能。下面是一份详细的学习路线图,它将帮助你从零基础开始,逐步成为网络安全领域的专家。 第一阶段:基…...
Python+OpenCV系列:图像的几何变换
Python OpenCV 系列:图像的几何变换 引言 在图像处理领域,几何变换是一个非常重要的操作,它可以改变图像的位置、大小、方向或形状。在计算机视觉中,这些操作对于图像预处理、特征提取和图像增强至关重要。本文将介绍如何利用 …...

第P1周:Pytorch实现mnist手写数字识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 1. 实现pytorch环境配置 2. 实现mnist手写数字识别 3. 自己写几个数字识别试试具体实现 (一)环境 语言环境:Python…...

使用EventLog Analyzer进行Apache日志监控和日志分析
一、什么是Apache日志分析 Apache日志分析是网站管理和维护的重要部分,通过分析Apache服务器生成的日志文件,可以了解网站的访问情况、识别潜在的安全问题、优化网站性能等。 二、Apache日志类型 Apache日志主要有两种类型:访问日志&a…...

PaddleOCR模型ch_PP-OCRv3文本检测模型研究(二)颈部网络
上节研究了PaddleOCR文本检测v3模型的骨干网,本文接着研究其颈部网络。 文章目录 研究起点残注层颈部网络代码实验小结 研究起点 摘取开源yml配置文件,摘取网络架构Architecture中颈部网络的配置如下 Neck:name: RSEFPNout_channels: 96shortcut: True可…...

360极速浏览器不支持看PDF
360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…...

【深度学习】深刻理解ViT
ViT(Vision Transformer)是谷歌研究团队于2020年提出的一种新型图像识别模型,首次将Transformer架构成功应用于计算机视觉任务中。Transformer最初应用于自然语言处理(如BERT和GPT),而ViT展示了其在视觉任务…...
解决vue2中更新列表数据,页面dom没有重新渲染的问题
在 Vue 2 中,直接修改数组的某个项可能不会触发视图的更新。这是因为 Vue 不能检测到数组的索引变化或对象属性的直接赋值。为了确保 Vue 能够正确地响应数据变化,你可以使用以下几种方法: 1. 使用 Vue.set() 使用 Vue.set() 方法可以确保 …...

vscode通过ssh连接远程服务器(实习心得)
一、连接ssh服务器 1.打开Visual Studio Code,进入拓展市场(CtrlShiftX),下载拓展Remote - SSH 2. 点击远程资源管理器选项卡,并选择远程(隧道/SSH)类别 3. 点击ssh配置:输入你的账号主机ip地址 4.在弹出的选择配置文件中…...

知识图谱9:知识图谱的展示
1、知识图谱的展示有很多工具 Neo4j Browser - - - - 浏览器版本 Neo4j Desktop - - - - 桌面版本 graphX - - - - 可以集成到Neo4j Desktop Neo4j 提供的 Neo4j Bloom 是用户友好的可视化工具,适合非技术用户直观地浏览图数据。Cypher 是其核心查询语言&#x…...

XCTF-web-easyupload
试了试php,php7,pht,phtml等,都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接,得到flag...
零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?
一、核心优势:专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发,是一款收费低廉但功能全面的Windows NAS工具,主打“无学习成本部署” 。与其他NAS软件相比,其优势在于: 无需硬件改造:将任意W…...
条件运算符
C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...
服务器硬防的应用场景都有哪些?
服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...
【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表
1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...
GitHub 趋势日报 (2025年06月08日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...

前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...

NLP学习路线图(二十三):长短期记忆网络(LSTM)
在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...