当前位置: 首页 > news >正文

深度学习_11_softmax_图片识别代码原理解析

完整代码:

import torch
from d2l import torch as d2l"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类"
num_inputs = 784
num_outputs = 10"模型全局"
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)"softmax全局"
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True) # 计算行的和return X_exp / partition  # 这里应用了广播机制"输出,即传入图片输出"
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"交叉熵损失"
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量"利用优化后的模型计算精度"
def evaluate_accuracy(net, data_iter):  #@saveif isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) # 下标相同数量 / 总下标return metric[0] / metric[1]"加法器全局"
class Accumulator:  #@savedef __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数# print(y)y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]lr = 0.1"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)if __name__ == '__main__':num_epochs = 10cnt = 1print(W)print(b)print(W.shape)for i in range(num_epochs):X, Y = train_epoch_ch3(net, train_iter, cross_entropy, updater)print("训练次数: " + str(cnt))cnt += 1print("训练损失: {:.4f}".format(X)) #训练损失是用训练数据集测得print("训练精度: {:.4f}".format(evaluate_accuracy(net, test_iter))) # 训练精度是用测试数据集测得print(".................................")print(W)print(b)

代码解析如下:

Fashion MNIST数据集中,每个样本都属于10个不同的类别,代表不同类型的服装(如鞋子、衬衫、裤子等)虽然只有10种类别但是每种类别有很多个样例,足够我们训练模型以及监测模型,所以这里输出为10,每次获取256个样例(获取的256个样例中有多种类别,鞋子,衬衫等,并不是单一类别)且返回的除了样例之外,还有该样例的标号(这个标号是为了区分该样例到底是什么)后续会讲到

还有一点就是获取的图片的像素是28*28大小,而且是二维,每个像素范围是0~255表示不同的颜色,这个数据迭代器load_data_fashion_mnist会将数据预处理成784,也就是一个维度

这种处理方式的优点就是更方便计算,但是缺点就是损失了一些空间信息,所以训练出来的模型算是低级模型吧,更高级的模型还是要保留图片二维空间训练出来

"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类"
num_inputs = 784
num_outputs = 10

这里sgd是库函数提供的更新梯度函数,由于梯度是由python自动帮你计算,他会在updater(X.shape[0])后自动获取计算的梯度

"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)

这个函数就是查看模型的识别情况,看识别对的数量有多少

"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量

这个函数就是对前面生成的数据以及函数的调用,并检测生成的模型的情况

"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数# print(y)y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

至于损失函数以及softmax就不再赘述,前面的文章已经剖析得很清楚了

深度学习_8_对Softmax回归的理解


以下主要讲述对求y_hat以及梯度计算得理解:

由上述可知,获取的数据集的形状train_iter, test_iter它们的形状为(256, 784)

模型W的形状为(784, 10)

模型b的形状为(10, )

为了方便不考虑波动值b

为了模拟计算过程我们将数据进行相对缩小

784 ----> 4
256 ----> 3
10  ----> 2

那么获取的数据集合的shape就是(3, 4)设获取得数据集为X假设其具体值为

[[1, 2, 3, 4],[5, 6, 7, 8,],[9, 10, 11, 12]
]

那么模型W的形状就是(4, 2)其为

[[w1,w2],[w3,w4],[w5,w6],[w7,w8]
]

X * W是可行的其结果为:

[[(1*w1+2*w3+3*w5+4*w7), (1*w2+2*w4+3*w6+4*w8)],[(5*w1+6*w3+7*w5+8*w7), (5*w2+6*w4+7*w6+8*w8)],[(9*w1+10*w3+11*w5+12*w7), (9*w2+10*w4+11*w6+12*w8)],
]

其中w1,w2…有对应数值,并不是单纯未知数

这里X*W的每个元素值对应Oi,要将Oi转成y_hat,进行一次softmax操作即可

展示这个的原因只是想证实y_hat由W中的元素w1,w2,w3…构成,那么后面求得的损失函数也是由w1,w2,w3…构成

而求梯度的操作就是对损失函数求每个wi偏导的操作

W的shape是(4,2)其中的每个元素都是wi,所以最后的求得得梯度也是(4, 2),最后有W和梯度以及学习率就可以更新新的W了

最后理解一下偏导:

偏导将偏导得变量当作未知数,其他变量当作常数,是因为本次求导最关心得是本变量对整个函数的数值影响,而忽略其他变量对函数的影响

相关文章:

深度学习_11_softmax_图片识别代码原理解析

完整代码: import torch from d2l import torch as d2l"创建训练集&创建检测集合" batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类" num_inputs 784 num_output…...

Java Web——前端HTML入门

目录 HTML&CSS3&JavaScript简述 1. HTML概念 2. 超文本 3. 标记语言 4. HTML基础结构 5. HTML基础词汇 6. HTML语法规则 7. VS Code 推荐使用的插件 8. 在线帮助文档 HTML&CSS3&JavaScript简述 HTML 主要用于网页主体结构的搭建,像一个毛坯…...

华为ensp:为vlan配置ip

配置对应vlan的ip vlan1 interface Vlanif 1 进入vlan1 ip address 192.168.1.254 24配置IP为192.168.1.254 子网掩码为24位 这样就配置上ip了 vlan2 interface Vlanif 2 ip address 192.168.2.254 24 vlan3 interface Vlanif 3 ip address 192.168.3.254 24 查看结果 …...

laravel8-rabbitmq消息队列-实时监听跨服务器消息

使用场景介绍: 1)用于实时监听远程服务器发出的消息(json格式消息),接受并更新消息状态,存储到本地服务器 2)环境:lNMP(laravel8) 3)服务器需要开…...

git清除历史提交记录保持本地文件不变

https://www.cnblogs.com/langkyeSir/p/14528857.html git删除历史版本,保留当前状态。 有时候,我们误提交了某些隐私文件,使用git rm xxx删除后,其实版本库中是有历史记录的,想要删除这些记录,但是又不想…...

SOME/IP学习笔记2

1. SOME/IP 协议 SOME/IP目前支持UDP(用户传输协议)和TCP(传输控制协议), PS:UDP和TCP区别如下 TCP面向连接的,可靠的数据传输服务;UDP面向无连接的,尽最大努力的数据传输服务&…...

python实现FINS协议的TCP服务端(篇一)

python实现FINS协议的TCP服务端是一件稍微麻烦点的事情。它不像modbusTCP那样,可以使用现成的pymodbus模块去实现。但是,我们可以根据协议帧进行组包,自己去实现帧的格式,而这一切可以基于socket模块。本文为第一篇。 一、了解FI…...

利用uni-app 开发的iOS app 发布到App Store全流程

1.0.3 20200927 更新官方对应用审核流程的状态。 注:最新审核后续将同步社区另一篇记录 AppStore 审核被拒原因记录及解决措施 :苹果开发上架常见问题 | appuploader使用教程 1.0.2 20200925 新增首次驳回拒绝邮件解决措施。 1.0.1 20200922 首次…...

5个高质量的实用办公软件,每一款都是良心推荐

在现代办公环境中,高效的办公软件可以极大地提升工作效率,简化工作流程,帮助我们更好地完成工作。今天就给大家分享5个高质量的实用办公软件,每一款都是良心推荐。 01、FastStone Capture(截图工具) FastSt…...

基于GPTs个性化定制SCI论文专业翻译器

1. 什么是GPTs GPTs是OpenAI在2023年11月6日开发者大会上发布的重要功能更新,允许用户根据特定需求定制自己的ChatGPT模型。 Introducing GPTs 官方介绍页面https://openai.com/blog/introducing-gpts 在原有自定义ChatGPT的流程中,首先需要自己编制p…...

Final Cut Pro X for Mac:打造专业级视频剪辑的终极利器

随着数字媒体技术的不断发展,视频剪辑已经成为各行各业不可或缺的一部分。Final Cut Pro X for Mac作为一款专业的视频剪辑软件,凭借其强大的功能和易用性,已经成为Mac用户的首选。本文将向您详细介绍Final Cut Pro X for Mac的优势、功能以及…...

c++分割路径的字符串,得到 目录 文件名 扩展名

简单的做一个c小代码片的记录 c分割了图片的 路径字符串&#xff0c;得到 目录 文件名 扩展名 #include <iostream> using namespace std;int main() {std::string path "E:\\set1_seg\\32.jpg";//index:"\\"在字符串中的位置int index path.find…...

ABAP OpenSQL 分页处理

功能实现 在 ABAP 中&#xff0c;可以使用 OpenSQL 来实现分页功能。下面是一种实现分页的示例方法&#xff1a; 首先&#xff0c;定义一个内部表来存储查询结果数据&#xff1a; DATA lt_data TYPE TABLE OF your_data_type.然后&#xff0c;使用 SELECT 语句将数据查询到内…...

kubeasz一键部署k8s集群

下载程序 部署说明 部署文档 rootiZj6cd9joygowsf7am5hryZ:~# apt-get update rootiZj6cd9joygowsf7am5hryZ:~# apt-get upgrade rootiZj6cd9joygowsf7am5hryZ:~# export release3.6.2 rootiZj6cd9joygowsf7am5hryZ:~# wget https://github.com/easzlab/kubeasz/releases/…...

高性能图表库LightningChart JS v5.0 - 轻松实现图表自定义布局

LightningChart JS是Web上性能最高的图表库具有出色的执行性能 - 使用高数据速率同时监控数十个数据源。 GPU加速和WebGL渲染确保您的设备的图形处理器得到有效利用&#xff0c;从而实现高刷新率和流畅的动画。 点击获取LightningChart JS v5.0正式版下载 LightningChart JS …...

深度学习的集体智慧:最新发展综述

一、说明 我们调查了来自复杂系统的想法&#xff0c;如群体智能、自组织和紧急行为&#xff0c;这些想法在机器学习中越来越受欢迎。人工神经网络正在影响我们的日常生活&#xff0c;从执行预测性任务&#xff08;如推荐、面部识别和对象分类&#xff09;到生成任务&#xff08…...

Java之“数字困境”:资产管理项目中的Bug追踪与启示

目录 1 前言2 问题的发现3 调试的开始4 深入调试5 调试心得与反思6 结语 1 前言 在程序员的日常工作中&#xff0c;我们时常面对各种令人头疼的问题&#xff0c;其中最令人崩溃的瞬间之一&#xff0c;就是当我们花费大量时间追踪一个看似复杂的bug&#xff0c;最终发现问题的根…...

小程序微信登录授权突然没反应的原因和解决方案

之前的小程序微信授权一直用的很好 今天突然点击没反应了 马上在开发工具试一试 返现点击授权返回错误信息 排除所有代码问题&#xff08;之前一直用的好好的&#xff09;和服务器承载问题&#xff08;就几个人点击&#xff09; 第一反应就是小程序有啥政策改变的问题&#x…...

文本提交时如何使用PHP替换回车为br

1、使用PHP内置的nl2br()函数 nl2br()函数是PHP内置的函数&#xff0c;可以将任何字符串中的回车符&#xff08;\n&#xff09;替换为HTML中的换行符&#xff08;br&#xff09;。具体使用方法如下&#xff1a; $string "这里有一个\n换行符"; $string nl2br($str…...

安全框架SpringSecurity-1(认证入门数据库授权)

一、Spring Security ①&#xff1a;什么是Spring Security Spring Security是一个能够为基于Spring的企业应用系统提供声明式&#xff08;注解&#xff09;的安全访问控制解决方案的安全框架。它提供了一组可以在Spring应用上下文中配置的Bean&#xff0c;充分利用了Spring …...

MODBUS TCP转CANopen 技术赋能高效协同作业

在现代工业自动化领域&#xff0c;MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步&#xff0c;这两种通讯协议也正在被逐步融合&#xff0c;形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...

在Ubuntu中设置开机自动运行(sudo)指令的指南

在Ubuntu系统中&#xff0c;有时需要在系统启动时自动执行某些命令&#xff0c;特别是需要 sudo权限的指令。为了实现这一功能&#xff0c;可以使用多种方法&#xff0c;包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法&#xff0c;并提供…...

Spring Boot面试题精选汇总

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂度…...

使用Spring AI和MCP协议构建图片搜索服务

目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式&#xff08;本地调用&#xff09; SSE模式&#xff08;远程调用&#xff09; 4. 注册工具提…...

【 java 虚拟机知识 第一篇 】

目录 1.内存模型 1.1.JVM内存模型的介绍 1.2.堆和栈的区别 1.3.栈的存储细节 1.4.堆的部分 1.5.程序计数器的作用 1.6.方法区的内容 1.7.字符串池 1.8.引用类型 1.9.内存泄漏与内存溢出 1.10.会出现内存溢出的结构 1.内存模型 1.1.JVM内存模型的介绍 内存模型主要分…...

毫米波雷达基础理论(3D+4D)

3D、4D毫米波雷达基础知识及厂商选型 PreView : https://mp.weixin.qq.com/s/bQkju4r6med7I3TBGJI_bQ 1. FMCW毫米波雷达基础知识 主要参考博文&#xff1a; 一文入门汽车毫米波雷达基本原理 &#xff1a;https://mp.weixin.qq.com/s/_EN7A5lKcz2Eh8dLnjE19w 毫米波雷达基础…...

Vue ③-生命周期 || 脚手架

生命周期 思考&#xff1a;什么时候可以发送初始化渲染请求&#xff1f;&#xff08;越早越好&#xff09; 什么时候可以开始操作dom&#xff1f;&#xff08;至少dom得渲染出来&#xff09; Vue生命周期&#xff1a; 一个Vue实例从 创建 到 销毁 的整个过程。 生命周期四个…...