当前位置: 首页 > news >正文

深度学习_11_softmax_图片识别代码原理解析

完整代码:

import torch
from d2l import torch as d2l"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类"
num_inputs = 784
num_outputs = 10"模型全局"
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)"softmax全局"
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True) # 计算行的和return X_exp / partition  # 这里应用了广播机制"输出,即传入图片输出"
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"交叉熵损失"
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量"利用优化后的模型计算精度"
def evaluate_accuracy(net, data_iter):  #@saveif isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) # 下标相同数量 / 总下标return metric[0] / metric[1]"加法器全局"
class Accumulator:  #@savedef __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数# print(y)y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]lr = 0.1"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)if __name__ == '__main__':num_epochs = 10cnt = 1print(W)print(b)print(W.shape)for i in range(num_epochs):X, Y = train_epoch_ch3(net, train_iter, cross_entropy, updater)print("训练次数: " + str(cnt))cnt += 1print("训练损失: {:.4f}".format(X)) #训练损失是用训练数据集测得print("训练精度: {:.4f}".format(evaluate_accuracy(net, test_iter))) # 训练精度是用测试数据集测得print(".................................")print(W)print(b)

代码解析如下:

Fashion MNIST数据集中,每个样本都属于10个不同的类别,代表不同类型的服装(如鞋子、衬衫、裤子等)虽然只有10种类别但是每种类别有很多个样例,足够我们训练模型以及监测模型,所以这里输出为10,每次获取256个样例(获取的256个样例中有多种类别,鞋子,衬衫等,并不是单一类别)且返回的除了样例之外,还有该样例的标号(这个标号是为了区分该样例到底是什么)后续会讲到

还有一点就是获取的图片的像素是28*28大小,而且是二维,每个像素范围是0~255表示不同的颜色,这个数据迭代器load_data_fashion_mnist会将数据预处理成784,也就是一个维度

这种处理方式的优点就是更方便计算,但是缺点就是损失了一些空间信息,所以训练出来的模型算是低级模型吧,更高级的模型还是要保留图片二维空间训练出来

"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类"
num_inputs = 784
num_outputs = 10

这里sgd是库函数提供的更新梯度函数,由于梯度是由python自动帮你计算,他会在updater(X.shape[0])后自动获取计算的梯度

"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)

这个函数就是查看模型的识别情况,看识别对的数量有多少

"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量

这个函数就是对前面生成的数据以及函数的调用,并检测生成的模型的情况

"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数# print(y)y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

至于损失函数以及softmax就不再赘述,前面的文章已经剖析得很清楚了

深度学习_8_对Softmax回归的理解


以下主要讲述对求y_hat以及梯度计算得理解:

由上述可知,获取的数据集的形状train_iter, test_iter它们的形状为(256, 784)

模型W的形状为(784, 10)

模型b的形状为(10, )

为了方便不考虑波动值b

为了模拟计算过程我们将数据进行相对缩小

784 ----> 4
256 ----> 3
10  ----> 2

那么获取的数据集合的shape就是(3, 4)设获取得数据集为X假设其具体值为

[[1, 2, 3, 4],[5, 6, 7, 8,],[9, 10, 11, 12]
]

那么模型W的形状就是(4, 2)其为

[[w1,w2],[w3,w4],[w5,w6],[w7,w8]
]

X * W是可行的其结果为:

[[(1*w1+2*w3+3*w5+4*w7), (1*w2+2*w4+3*w6+4*w8)],[(5*w1+6*w3+7*w5+8*w7), (5*w2+6*w4+7*w6+8*w8)],[(9*w1+10*w3+11*w5+12*w7), (9*w2+10*w4+11*w6+12*w8)],
]

其中w1,w2…有对应数值,并不是单纯未知数

这里X*W的每个元素值对应Oi,要将Oi转成y_hat,进行一次softmax操作即可

展示这个的原因只是想证实y_hat由W中的元素w1,w2,w3…构成,那么后面求得的损失函数也是由w1,w2,w3…构成

而求梯度的操作就是对损失函数求每个wi偏导的操作

W的shape是(4,2)其中的每个元素都是wi,所以最后的求得得梯度也是(4, 2),最后有W和梯度以及学习率就可以更新新的W了

最后理解一下偏导:

偏导将偏导得变量当作未知数,其他变量当作常数,是因为本次求导最关心得是本变量对整个函数的数值影响,而忽略其他变量对函数的影响

相关文章:

深度学习_11_softmax_图片识别代码原理解析

完整代码: import torch from d2l import torch as d2l"创建训练集&创建检测集合" batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类" num_inputs 784 num_output…...

Java Web——前端HTML入门

目录 HTML&CSS3&JavaScript简述 1. HTML概念 2. 超文本 3. 标记语言 4. HTML基础结构 5. HTML基础词汇 6. HTML语法规则 7. VS Code 推荐使用的插件 8. 在线帮助文档 HTML&CSS3&JavaScript简述 HTML 主要用于网页主体结构的搭建,像一个毛坯…...

华为ensp:为vlan配置ip

配置对应vlan的ip vlan1 interface Vlanif 1 进入vlan1 ip address 192.168.1.254 24配置IP为192.168.1.254 子网掩码为24位 这样就配置上ip了 vlan2 interface Vlanif 2 ip address 192.168.2.254 24 vlan3 interface Vlanif 3 ip address 192.168.3.254 24 查看结果 …...

laravel8-rabbitmq消息队列-实时监听跨服务器消息

使用场景介绍: 1)用于实时监听远程服务器发出的消息(json格式消息),接受并更新消息状态,存储到本地服务器 2)环境:lNMP(laravel8) 3)服务器需要开…...

git清除历史提交记录保持本地文件不变

https://www.cnblogs.com/langkyeSir/p/14528857.html git删除历史版本,保留当前状态。 有时候,我们误提交了某些隐私文件,使用git rm xxx删除后,其实版本库中是有历史记录的,想要删除这些记录,但是又不想…...

SOME/IP学习笔记2

1. SOME/IP 协议 SOME/IP目前支持UDP(用户传输协议)和TCP(传输控制协议), PS:UDP和TCP区别如下 TCP面向连接的,可靠的数据传输服务;UDP面向无连接的,尽最大努力的数据传输服务&…...

python实现FINS协议的TCP服务端(篇一)

python实现FINS协议的TCP服务端是一件稍微麻烦点的事情。它不像modbusTCP那样,可以使用现成的pymodbus模块去实现。但是,我们可以根据协议帧进行组包,自己去实现帧的格式,而这一切可以基于socket模块。本文为第一篇。 一、了解FI…...

利用uni-app 开发的iOS app 发布到App Store全流程

1.0.3 20200927 更新官方对应用审核流程的状态。 注:最新审核后续将同步社区另一篇记录 AppStore 审核被拒原因记录及解决措施 :苹果开发上架常见问题 | appuploader使用教程 1.0.2 20200925 新增首次驳回拒绝邮件解决措施。 1.0.1 20200922 首次…...

5个高质量的实用办公软件,每一款都是良心推荐

在现代办公环境中,高效的办公软件可以极大地提升工作效率,简化工作流程,帮助我们更好地完成工作。今天就给大家分享5个高质量的实用办公软件,每一款都是良心推荐。 01、FastStone Capture(截图工具) FastSt…...

基于GPTs个性化定制SCI论文专业翻译器

1. 什么是GPTs GPTs是OpenAI在2023年11月6日开发者大会上发布的重要功能更新,允许用户根据特定需求定制自己的ChatGPT模型。 Introducing GPTs 官方介绍页面https://openai.com/blog/introducing-gpts 在原有自定义ChatGPT的流程中,首先需要自己编制p…...

Final Cut Pro X for Mac:打造专业级视频剪辑的终极利器

随着数字媒体技术的不断发展,视频剪辑已经成为各行各业不可或缺的一部分。Final Cut Pro X for Mac作为一款专业的视频剪辑软件,凭借其强大的功能和易用性,已经成为Mac用户的首选。本文将向您详细介绍Final Cut Pro X for Mac的优势、功能以及…...

c++分割路径的字符串,得到 目录 文件名 扩展名

简单的做一个c小代码片的记录 c分割了图片的 路径字符串&#xff0c;得到 目录 文件名 扩展名 #include <iostream> using namespace std;int main() {std::string path "E:\\set1_seg\\32.jpg";//index:"\\"在字符串中的位置int index path.find…...

ABAP OpenSQL 分页处理

功能实现 在 ABAP 中&#xff0c;可以使用 OpenSQL 来实现分页功能。下面是一种实现分页的示例方法&#xff1a; 首先&#xff0c;定义一个内部表来存储查询结果数据&#xff1a; DATA lt_data TYPE TABLE OF your_data_type.然后&#xff0c;使用 SELECT 语句将数据查询到内…...

kubeasz一键部署k8s集群

下载程序 部署说明 部署文档 rootiZj6cd9joygowsf7am5hryZ:~# apt-get update rootiZj6cd9joygowsf7am5hryZ:~# apt-get upgrade rootiZj6cd9joygowsf7am5hryZ:~# export release3.6.2 rootiZj6cd9joygowsf7am5hryZ:~# wget https://github.com/easzlab/kubeasz/releases/…...

高性能图表库LightningChart JS v5.0 - 轻松实现图表自定义布局

LightningChart JS是Web上性能最高的图表库具有出色的执行性能 - 使用高数据速率同时监控数十个数据源。 GPU加速和WebGL渲染确保您的设备的图形处理器得到有效利用&#xff0c;从而实现高刷新率和流畅的动画。 点击获取LightningChart JS v5.0正式版下载 LightningChart JS …...

深度学习的集体智慧:最新发展综述

一、说明 我们调查了来自复杂系统的想法&#xff0c;如群体智能、自组织和紧急行为&#xff0c;这些想法在机器学习中越来越受欢迎。人工神经网络正在影响我们的日常生活&#xff0c;从执行预测性任务&#xff08;如推荐、面部识别和对象分类&#xff09;到生成任务&#xff08…...

Java之“数字困境”:资产管理项目中的Bug追踪与启示

目录 1 前言2 问题的发现3 调试的开始4 深入调试5 调试心得与反思6 结语 1 前言 在程序员的日常工作中&#xff0c;我们时常面对各种令人头疼的问题&#xff0c;其中最令人崩溃的瞬间之一&#xff0c;就是当我们花费大量时间追踪一个看似复杂的bug&#xff0c;最终发现问题的根…...

小程序微信登录授权突然没反应的原因和解决方案

之前的小程序微信授权一直用的很好 今天突然点击没反应了 马上在开发工具试一试 返现点击授权返回错误信息 排除所有代码问题&#xff08;之前一直用的好好的&#xff09;和服务器承载问题&#xff08;就几个人点击&#xff09; 第一反应就是小程序有啥政策改变的问题&#x…...

文本提交时如何使用PHP替换回车为br

1、使用PHP内置的nl2br()函数 nl2br()函数是PHP内置的函数&#xff0c;可以将任何字符串中的回车符&#xff08;\n&#xff09;替换为HTML中的换行符&#xff08;br&#xff09;。具体使用方法如下&#xff1a; $string "这里有一个\n换行符"; $string nl2br($str…...

安全框架SpringSecurity-1(认证入门数据库授权)

一、Spring Security ①&#xff1a;什么是Spring Security Spring Security是一个能够为基于Spring的企业应用系统提供声明式&#xff08;注解&#xff09;的安全访问控制解决方案的安全框架。它提供了一组可以在Spring应用上下文中配置的Bean&#xff0c;充分利用了Spring …...

java_网络服务相关_gateway_nacos_feign区别联系

1. spring-cloud-starter-gateway 作用&#xff1a;作为微服务架构的网关&#xff0c;统一入口&#xff0c;处理所有外部请求。 核心能力&#xff1a; 路由转发&#xff08;基于路径、服务名等&#xff09;过滤器&#xff08;鉴权、限流、日志、Header 处理&#xff09;支持负…...

Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务

通过akshare库&#xff0c;获取股票数据&#xff0c;并生成TabPFN这个模型 可以识别、处理的格式&#xff0c;写一个完整的预处理示例&#xff0c;并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务&#xff0c;进行预测并输…...

初学 pytest 记录

安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列&#xff0c;以便知晓哪些列包含有价值的数据&#xff0c;…...

DBLP数据库是什么?

DBLP&#xff08;Digital Bibliography & Library Project&#xff09;Computer Science Bibliography是全球著名的计算机科学出版物的开放书目数据库。DBLP所收录的期刊和会议论文质量较高&#xff0c;数据库文献更新速度很快&#xff0c;很好地反映了国际计算机科学学术研…...

6️⃣Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙

Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙 一、前言:离区块链还有多远? 区块链听起来可能遥不可及,似乎是只有密码学专家和资深工程师才能涉足的领域。但事实上,构建一个区块链的核心并不复杂,尤其当你已经掌握了一门系统编程语言,比如 Go。 要真正理解区…...

Linux安全加固:从攻防视角构建系统免疫

Linux安全加固:从攻防视角构建系统免疫 构建坚不可摧的数字堡垒 引言:攻防对抗的新纪元 在日益复杂的网络威胁环境中,Linux系统安全已从被动防御转向主动免疫。2023年全球网络安全报告显示,高级持续性威胁(APT)攻击同比增长65%,平均入侵停留时间缩短至48小时。本章将从…...

何谓AI编程【02】AI编程官网以优雅草星云智控为例建设实践-完善顶部-建立各项子页-调整排版-优雅草卓伊凡

何谓AI编程【02】AI编程官网以优雅草星云智控为例建设实践-完善顶部-建立各项子页-调整排版-优雅草卓伊凡 背景 我们以建设星云智控官网来做AI编程实践&#xff0c;很多人以为AI已经强大到不需要程序员了&#xff0c;其实不是&#xff0c;AI更加需要程序员&#xff0c;普通人…...

VSCode 使用CMake 构建 Qt 5 窗口程序

首先,目录结构如下图: 运行效果: cmake -B build cmake --build build 运行: windeployqt.exe F:\testQt5\build\Debug\app.exe main.cpp #include "mainwindow.h"#include <QAppli...

基于Uniapp的HarmonyOS 5.0体育应用开发攻略

一、技术架构设计 1.混合开发框架选型 &#xff08;1&#xff09;使用Uniapp 3.8版本支持ArkTS编译 &#xff08;2&#xff09;通过uni-harmony插件调用原生能力 &#xff08;3&#xff09;分层架构设计&#xff1a; graph TDA[UI层] -->|Vue语法| B(Uniapp框架)B --&g…...