深度学习_11_softmax_图片识别代码原理解析
完整代码:
import torch
from d2l import torch as d2l"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类"
num_inputs = 784
num_outputs = 10"模型全局"
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)"softmax全局"
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True) # 计算行的和return X_exp / partition # 这里应用了广播机制"输出,即传入图片输出"
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)"交叉熵损失"
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])"显示预测与估计相对应下标数量"
def accuracy(y_hat, y): #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量"利用优化后的模型计算精度"
def evaluate_accuracy(net, data_iter): #@saveif isinstance(net, torch.nn.Module):net.eval() # 将模型设置为评估模式metric = Accumulator(2) # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel()) # 下标相同数量 / 总下标return metric[0] / metric[1]"加法器全局"
class Accumulator: #@savedef __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater): #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数# print(y)y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]lr = 0.1"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)if __name__ == '__main__':num_epochs = 10cnt = 1print(W)print(b)print(W.shape)for i in range(num_epochs):X, Y = train_epoch_ch3(net, train_iter, cross_entropy, updater)print("训练次数: " + str(cnt))cnt += 1print("训练损失: {:.4f}".format(X)) #训练损失是用训练数据集测得print("训练精度: {:.4f}".format(evaluate_accuracy(net, test_iter))) # 训练精度是用测试数据集测得print(".................................")print(W)print(b)
代码解析如下:
Fashion MNIST数据集中,每个样本都属于10个不同的类别,代表不同类型的服装(如鞋子、衬衫、裤子等)虽然只有10种类别但是每种类别有很多个样例,足够我们训练模型以及监测模型,所以这里输出为10,每次获取256个样例(获取的256个样例中有多种类别,鞋子,衬衫等,并不是单一类别)且返回的除了样例之外,还有该样例的标号(这个标号是为了区分该样例到底是什么)后续会讲到
还有一点就是获取的图片的像素是28*28大小,而且是二维,每个像素范围是0~255表示不同的颜色,这个数据迭代器load_data_fashion_mnist会将数据预处理成784,也就是一个维度
这种处理方式的优点就是更方便计算,但是缺点就是损失了一些空间信息,所以训练出来的模型算是低级模型吧,更高级的模型还是要保留图片二维空间训练出来
"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类"
num_inputs = 784
num_outputs = 10
这里sgd是库函数提供的更新梯度函数,由于梯度是由python自动帮你计算,他会在updater(X.shape[0])后自动获取计算的梯度
"更新模型"
def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)
这个函数就是查看模型的识别情况,看识别对的数量有多少
"显示预测与估计相对应下标数量"
def accuracy(y_hat, y): #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1y_hat = y_hat.argmax(axis=1) # 取出每行中最大值cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum()) # 返回对应下标数量
这个函数就是对前面生成的数据以及函数的调用,并检测生成的模型的情况
"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater): #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数# print(y)y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]
至于损失函数以及softmax就不再赘述,前面的文章已经剖析得很清楚了
深度学习_8_对Softmax回归的理解
以下主要讲述对求y_hat以及梯度计算得理解:
由上述可知,获取的数据集的形状train_iter, test_iter它们的形状为(256, 784)
模型W的形状为(784, 10)
模型b的形状为(10, )
为了方便不考虑波动值b
为了模拟计算过程我们将数据进行相对缩小
784 ----> 4
256 ----> 3
10 ----> 2
那么获取的数据集合的shape就是(3, 4)设获取得数据集为X假设其具体值为
[[1, 2, 3, 4],[5, 6, 7, 8,],[9, 10, 11, 12]
]
那么模型W的形状就是(4, 2)其为
[[w1,w2],[w3,w4],[w5,w6],[w7,w8]
]
X * W是可行的其结果为:
[[(1*w1+2*w3+3*w5+4*w7), (1*w2+2*w4+3*w6+4*w8)],[(5*w1+6*w3+7*w5+8*w7), (5*w2+6*w4+7*w6+8*w8)],[(9*w1+10*w3+11*w5+12*w7), (9*w2+10*w4+11*w6+12*w8)],
]
其中w1,w2…有对应数值,并不是单纯未知数
这里X*W的每个元素值对应Oi,要将Oi转成y_hat,进行一次softmax操作即可
展示这个的原因只是想证实y_hat由W中的元素w1,w2,w3…构成,那么后面求得的损失函数也是由w1,w2,w3…构成
而求梯度的操作就是对损失函数求每个wi偏导的操作
W的shape是(4,2)其中的每个元素都是wi,所以最后的求得得梯度也是(4, 2),最后有W和梯度以及学习率就可以更新新的W了
最后理解一下偏导:
偏导将偏导得变量当作未知数,其他变量当作常数,是因为本次求导最关心得是本变量对整个函数的数值影响,而忽略其他变量对函数的影响
相关文章:
深度学习_11_softmax_图片识别代码原理解析
完整代码: import torch from d2l import torch as d2l"创建训练集&创建检测集合" batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)"每个图片长度,以及图片种类" num_inputs 784 num_output…...

Java Web——前端HTML入门
目录 HTML&CSS3&JavaScript简述 1. HTML概念 2. 超文本 3. 标记语言 4. HTML基础结构 5. HTML基础词汇 6. HTML语法规则 7. VS Code 推荐使用的插件 8. 在线帮助文档 HTML&CSS3&JavaScript简述 HTML 主要用于网页主体结构的搭建,像一个毛坯…...

华为ensp:为vlan配置ip
配置对应vlan的ip vlan1 interface Vlanif 1 进入vlan1 ip address 192.168.1.254 24配置IP为192.168.1.254 子网掩码为24位 这样就配置上ip了 vlan2 interface Vlanif 2 ip address 192.168.2.254 24 vlan3 interface Vlanif 3 ip address 192.168.3.254 24 查看结果 …...
laravel8-rabbitmq消息队列-实时监听跨服务器消息
使用场景介绍: 1)用于实时监听远程服务器发出的消息(json格式消息),接受并更新消息状态,存储到本地服务器 2)环境:lNMP(laravel8) 3)服务器需要开…...
git清除历史提交记录保持本地文件不变
https://www.cnblogs.com/langkyeSir/p/14528857.html git删除历史版本,保留当前状态。 有时候,我们误提交了某些隐私文件,使用git rm xxx删除后,其实版本库中是有历史记录的,想要删除这些记录,但是又不想…...

SOME/IP学习笔记2
1. SOME/IP 协议 SOME/IP目前支持UDP(用户传输协议)和TCP(传输控制协议), PS:UDP和TCP区别如下 TCP面向连接的,可靠的数据传输服务;UDP面向无连接的,尽最大努力的数据传输服务&…...

python实现FINS协议的TCP服务端(篇一)
python实现FINS协议的TCP服务端是一件稍微麻烦点的事情。它不像modbusTCP那样,可以使用现成的pymodbus模块去实现。但是,我们可以根据协议帧进行组包,自己去实现帧的格式,而这一切可以基于socket模块。本文为第一篇。 一、了解FI…...

利用uni-app 开发的iOS app 发布到App Store全流程
1.0.3 20200927 更新官方对应用审核流程的状态。 注:最新审核后续将同步社区另一篇记录 AppStore 审核被拒原因记录及解决措施 :苹果开发上架常见问题 | appuploader使用教程 1.0.2 20200925 新增首次驳回拒绝邮件解决措施。 1.0.1 20200922 首次…...

5个高质量的实用办公软件,每一款都是良心推荐
在现代办公环境中,高效的办公软件可以极大地提升工作效率,简化工作流程,帮助我们更好地完成工作。今天就给大家分享5个高质量的实用办公软件,每一款都是良心推荐。 01、FastStone Capture(截图工具) FastSt…...

基于GPTs个性化定制SCI论文专业翻译器
1. 什么是GPTs GPTs是OpenAI在2023年11月6日开发者大会上发布的重要功能更新,允许用户根据特定需求定制自己的ChatGPT模型。 Introducing GPTs 官方介绍页面https://openai.com/blog/introducing-gpts 在原有自定义ChatGPT的流程中,首先需要自己编制p…...

Final Cut Pro X for Mac:打造专业级视频剪辑的终极利器
随着数字媒体技术的不断发展,视频剪辑已经成为各行各业不可或缺的一部分。Final Cut Pro X for Mac作为一款专业的视频剪辑软件,凭借其强大的功能和易用性,已经成为Mac用户的首选。本文将向您详细介绍Final Cut Pro X for Mac的优势、功能以及…...

c++分割路径的字符串,得到 目录 文件名 扩展名
简单的做一个c小代码片的记录 c分割了图片的 路径字符串,得到 目录 文件名 扩展名 #include <iostream> using namespace std;int main() {std::string path "E:\\set1_seg\\32.jpg";//index:"\\"在字符串中的位置int index path.find…...
ABAP OpenSQL 分页处理
功能实现 在 ABAP 中,可以使用 OpenSQL 来实现分页功能。下面是一种实现分页的示例方法: 首先,定义一个内部表来存储查询结果数据: DATA lt_data TYPE TABLE OF your_data_type.然后,使用 SELECT 语句将数据查询到内…...
kubeasz一键部署k8s集群
下载程序 部署说明 部署文档 rootiZj6cd9joygowsf7am5hryZ:~# apt-get update rootiZj6cd9joygowsf7am5hryZ:~# apt-get upgrade rootiZj6cd9joygowsf7am5hryZ:~# export release3.6.2 rootiZj6cd9joygowsf7am5hryZ:~# wget https://github.com/easzlab/kubeasz/releases/…...

高性能图表库LightningChart JS v5.0 - 轻松实现图表自定义布局
LightningChart JS是Web上性能最高的图表库具有出色的执行性能 - 使用高数据速率同时监控数十个数据源。 GPU加速和WebGL渲染确保您的设备的图形处理器得到有效利用,从而实现高刷新率和流畅的动画。 点击获取LightningChart JS v5.0正式版下载 LightningChart JS …...

深度学习的集体智慧:最新发展综述
一、说明 我们调查了来自复杂系统的想法,如群体智能、自组织和紧急行为,这些想法在机器学习中越来越受欢迎。人工神经网络正在影响我们的日常生活,从执行预测性任务(如推荐、面部识别和对象分类)到生成任务(…...

Java之“数字困境”:资产管理项目中的Bug追踪与启示
目录 1 前言2 问题的发现3 调试的开始4 深入调试5 调试心得与反思6 结语 1 前言 在程序员的日常工作中,我们时常面对各种令人头疼的问题,其中最令人崩溃的瞬间之一,就是当我们花费大量时间追踪一个看似复杂的bug,最终发现问题的根…...
小程序微信登录授权突然没反应的原因和解决方案
之前的小程序微信授权一直用的很好 今天突然点击没反应了 马上在开发工具试一试 返现点击授权返回错误信息 排除所有代码问题(之前一直用的好好的)和服务器承载问题(就几个人点击) 第一反应就是小程序有啥政策改变的问题&#x…...
文本提交时如何使用PHP替换回车为br
1、使用PHP内置的nl2br()函数 nl2br()函数是PHP内置的函数,可以将任何字符串中的回车符(\n)替换为HTML中的换行符(br)。具体使用方法如下: $string "这里有一个\n换行符"; $string nl2br($str…...

安全框架SpringSecurity-1(认证入门数据库授权)
一、Spring Security ①:什么是Spring Security Spring Security是一个能够为基于Spring的企业应用系统提供声明式(注解)的安全访问控制解决方案的安全框架。它提供了一组可以在Spring应用上下文中配置的Bean,充分利用了Spring …...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...

初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列,以便知晓哪些列包含有价值的数据,…...

DBLP数据库是什么?
DBLP(Digital Bibliography & Library Project)Computer Science Bibliography是全球著名的计算机科学出版物的开放书目数据库。DBLP所收录的期刊和会议论文质量较高,数据库文献更新速度很快,很好地反映了国际计算机科学学术研…...
6️⃣Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙
Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙 一、前言:离区块链还有多远? 区块链听起来可能遥不可及,似乎是只有密码学专家和资深工程师才能涉足的领域。但事实上,构建一个区块链的核心并不复杂,尤其当你已经掌握了一门系统编程语言,比如 Go。 要真正理解区…...
Linux安全加固:从攻防视角构建系统免疫
Linux安全加固:从攻防视角构建系统免疫 构建坚不可摧的数字堡垒 引言:攻防对抗的新纪元 在日益复杂的网络威胁环境中,Linux系统安全已从被动防御转向主动免疫。2023年全球网络安全报告显示,高级持续性威胁(APT)攻击同比增长65%,平均入侵停留时间缩短至48小时。本章将从…...

何谓AI编程【02】AI编程官网以优雅草星云智控为例建设实践-完善顶部-建立各项子页-调整排版-优雅草卓伊凡
何谓AI编程【02】AI编程官网以优雅草星云智控为例建设实践-完善顶部-建立各项子页-调整排版-优雅草卓伊凡 背景 我们以建设星云智控官网来做AI编程实践,很多人以为AI已经强大到不需要程序员了,其实不是,AI更加需要程序员,普通人…...

VSCode 使用CMake 构建 Qt 5 窗口程序
首先,目录结构如下图: 运行效果: cmake -B build cmake --build build 运行: windeployqt.exe F:\testQt5\build\Debug\app.exe main.cpp #include "mainwindow.h"#include <QAppli...
基于Uniapp的HarmonyOS 5.0体育应用开发攻略
一、技术架构设计 1.混合开发框架选型 (1)使用Uniapp 3.8版本支持ArkTS编译 (2)通过uni-harmony插件调用原生能力 (3)分层架构设计: graph TDA[UI层] -->|Vue语法| B(Uniapp框架)B --&g…...