当前位置: 首页 > news >正文

从0开始深度学习(9)——softmax回归的逐步实现

文章使用Fashion-MNIST数据集,做一次分类识别任务
Fashion-MNIST中包含的10个类别,分别为:
t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)
sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)、ankle boot(短靴)

0 图像数据

0.1 读取展示数据

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt# 下载数据集 ,60,000 个训练样本和 10,000 个测试样本,每个样本包含一张28*28的灰度图和一个标签trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="D/DL_Data/Fashion-MNIST", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="D/DL_Data/Fashion-MNIST", train=False, transform=trans, download=True)print("test:",len(mnist_test))
print("train:",len(mnist_train))# 获取第一个样本的图像和标签
image, label = mnist_train[0]
print("图像的形状:", image.shape)
print("标签:", label)

在这里插入图片描述

0.2 可视化图像

# 可视化
def show_img():class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# 可视化前5张图片fig, axes = plt.subplots(1, 10, figsize=(15, 3))for i in range(10):# 获取第 i 个样本的图像和标签image, label = mnist_train[i]# 将图像从 Tensor 转换回 numpy 数组,并移除通道维度image_np = image.squeeze().numpy()# 在子图中显示图像axes[i].imshow(image_np, cmap='gray')axes[i].set_title(f'Label: {class_names[label]}')axes[i].axis('off')  # 关闭坐标轴plt.tight_layout()plt.show()show_img()

在这里插入图片描述

0.3 整合为数据加载模块

def load_data_fashion_mnist(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

1 初始化参数模型

我们选择把 28 ∗ 28 28*28 2828的图片展开成 1 ∗ 784 1*784 1784的向量,认为每个像素位置都是一个特征,所以输入是784维,输出是10个类别标签,所以输出是10维

因为softmaxhi回归类似于线性回归,所以权重 w w w应该是 784 ∗ 10 784*10 78410 的矩阵,偏置是 1 ∗ 10 1*10 110 的行向量,接下来如同线性回归中一样,使用正太分布初始化权重,偏置初始化为0:

num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2 定义softmax操作

回顾一下softmax的公式:
在这里插入图片描述
由三个步骤组成:

  1. 对每个项目求幂
  2. 将每一行求和(小批量样本中,每个样本是一行),得到每个样本的规范化常数。
  3. 将每一行除以其规范化常数,确保结果的和为1。
# 定义softmax操作
def softmax(x):x_exp=torch.exp(x)x_exp_sum=x_exp.sum(1,keepdim=True)return x_exp/x_exp_sum

3 定义模型

# 定义模型
def net(x):x = x.reshape(-1, w.shape[0])  # 将图片重塑为 [batch_size, 784]temp = torch.matmul(x, w)temp = temp + breturn softmax(temp)

4 定义损失函数

使用从0开始深度学习(8)——softmax回归提到的交叉熵损失函数

# 定义损失函数
def cross_entropy(y_hat, y): # 预测值、真实值return - torch.log(y_hat[range(len(y_hat)), y]) # 计算负对数似然cross_entropy(y_hat, y)

5 分类精度

分类精度即正确预测数量与总预测数量之比。

def compute_accuracy(y_hat, y):  # 预测值、真实值if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmpreturn float(cmp.type(y.dtype).sum())# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。with torch.no_grad():for X, y in data_iter:metric.add(compute_accuracy(net(X), y), y.numel())return metric[0] / metric[1]class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]# 评估模型
accuracy = evaluate_accuracy(net, test_iter)
print(f"Test Accuracy: {accuracy:.4f}")

在这里插入图片描述

6 定义优化器

# 定义优化器
def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

7 训练

# 训练模型
def train_epoch(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()  # 将模型设置为训练模式metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()else:l.backward()updater([w, b], lr, batch_size)metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]def train(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')# 训练模型
updater = lambda params, lr, batch_size: sgd(params, lr, batch_size)
train(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

在这里插入图片描述

8 预测

# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 预测并显示结果
def predict(net, test_iter, n=6):for X, y in test_iter:break  # 只取一个批次的数据trues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n' + pred for true, pred in zip(trues, preds)]n = min(n, X.shape[0])fig, axs = plt.subplots(1, n, figsize=(12, 3))for i in range(n):axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')axs[i].set_title(titles[i])axs[i].axis('off')plt.show()# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

相关文章:

从0开始深度学习(9)——softmax回归的逐步实现

文章使用Fashion-MNIST数据集,做一次分类识别任务 Fashion-MNIST中包含的10个类别,分别为: t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙&…...

Cannot inspect org.apache.hadoop.hive.serde2.io.HiveDecimalWritable 问题分析处理

报错; org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.UnsupportedOperationException: Cannot inspect org.apache.hadoop.hive.serde2.io.HiveDecimalWritable 该问题常见于parquet格式hive表查询时,一般原因为hive表对应数据文件元数据对应格式与…...

电子取证新视角:USB键盘流量提取密码方法研究与实现

0x01 引言 在当今数字化时代,USB设备的广泛使用使得信息安全和电子取证领域面临着新的挑战与机遇。特别是USB键盘,作为一种常见的输入设备,其流量中可能包含用户输入的敏感信息,如密码和其他私人数据。因此,研究USB键…...

Tongweb7049m4+THS6010-6012配置故障轉移+重試机制(by lqw)

使用场景 1.ths代理tongweb多套后端,假如有其中一套tongweb因为服务器重启或者宕机后没有及时启动,导致ths一直轮询在这个出故障的节点上。 2.即使在tongweb重启了,有的应用启动也需要一定的时间,这个时候只是启动了应用端口&…...

在线客服系统网站源码-网页聊天客服实现代码

源码简介 在线客服系统 – 网上客服系统,在线客服系统网站源码。 消息预知功能就是别人在聊天框打字你都能看到 1.新增客服坐席消息互动,客服之间可以互相接收消息,可以智能分配 2.新增消息预知功能,可提前预知访客已输入未发…...

JioNLP:一款实用的中文NLP预处理工具包

一、什么是 JioNLP? JioNLP是一个面向NLP开发者的工具包,提供了常见的中文文本预处理、解析等功能,使用简单、高效准确、无需配置,可极大加快NLP项目的开发进度。 主要特点包括: 代码开源,使用MIT协议功能丰富,涵盖多个NLP预处理需求使用简单,无需复杂配置即可调用准确高效…...

GR-ConvNet论文 学习笔记

GR-ConvNet 文章目录 GR-ConvNet前言一、引言二、相关研究三、问题阐述四、方法A.推理模块B.控制模块C.模型结构D.训练方法E.损失函数 五、评估A.数据集B.抓取评判标准 六、实验A.设置B.家庭测试物体C.对抗性测试物体D.混合物体 七、结果A.康奈尔数据集B.Jacquard数据集C.抓取新…...

windows环境批量删除指定目录下的全部指定文件

写在开头: 1. 涉及文件删除,先在小范围内测试(更改D:\扫描文件路径) 2. 命令会递归该目录下的所有文件 命令: forfiles /p D:\ /s /m _maven.repositories /c "cmd /c del path"解释: /p D:\ …...

水深探测仪的作用和使用方法

在水域救援的行动里,救援人员时刻面临着复杂多变、充满未知的水域状况。当接到救援任务奔赴现场,那片需要涉足的水域就像一个神秘莫测的异世界,挑战着所有人的认知与勇气。 水深探测仪作为一种专用于测量水域深度的设备,通过声波和…...

Leetcode 搜索插入位置

这段代码的核心思想是 二分查找,用于在一个已经排序的数组中查找目标值的位置。如果目标值存在于数组中,返回它的索引;如果目标值不存在,返回它按顺序应该插入的位置。 算法思想步骤: 定义左右边界: 我们使…...

jsp怎么实现点赞功能

在JSP中实现点赞功能通常涉及前端页面的设计、后端逻辑处理以及数据存储。为了实现点赞功能,你可以使用以下步骤: 前端(JSP页面)设计 前端部分包括显示点赞按钮,并通过Ajax发送点赞请求,以避免页面刷新。 …...

取消microsoft edge作为默认浏览器 ,修改方法,默认修改不了的原因

将Microsoft Edge或其它浏览器设置为默认浏览器,可以尝试以下方法来解决此问题: 一, 通过浏览器设置修改:打开Microsoft Edge浏览器,单击右上角的“更多”按钮,然后选择“设置”。在设置页面左侧找到“默认…...

C++面试速通宝典——17

283. Nginx负载均衡算法 ‌‌‌‌  Nginx支持多种负载均衡算法。 轮询(Round Robin):默认算法,按顺序逐个分配请求到后端服务器。加权轮询(Weighted Round Robin):与轮询类似,但…...

10、论文阅读:基于双阶对比损失解纠缠表示的无监督水下图像增强

Unsupervised Underwater Image Enhancement Based on Disentangled Representations via Double-Order Contrastive Loss 前言引言方法介绍解耦框架多尺度生成器双阶对比损失双阶对比损失总结损失函数实验前言 在水下环境中拍摄的图像通常会受到颜色失真、低对比度和视觉质量…...

Git配置token免密登录

配置token免密登录 如果不用ssh免密登录,还有其他基于Token那得免密登录方法吗? 2021年开始,github就不能使用密码登录git了,需要使用token作为密码登录,需要自己在setting中创建。 那么每次都需要我手动输入token密…...

活动预告|博睿数据将受邀出席GOPS全球运维大会上海站!

第二十四届 GOPS 全球运维大会暨研运数智化技术峰会上海站将于2024年10月18日-19日在上海中庚聚龙酒店召开。大会将为期2天,侧重大模型、DevOps、SRE、AIOps、BizDevOps、云原生及安全等热门技术领域。特设了如大模型 运维/研发测试、银行/证券数字化转型、平台工程…...

Flutter技术学习

以下内容更适用于 不拘泥于教程学习,而是从简单项目入手的初学者。 在开始第一个项目之前,我们先要了解 两个概念。 Widget 和 属性 Widget 是用户界面的基本构建块,可以是任何 UI 元素。属性 是 widget 类中定义的变量,用于配…...

Kubernetes网络通讯模式深度解析

Kubernetes的网络模型建立在所有Pod能够直接相互通讯的假设之上,这构建了一个扁平且互联的网络空间。在如GCE(Google Cloud Engine)等云环境中,这一网络模型已预先配置,但在自建的Kubernetes集群中,我们需要…...

SBTI科学碳目标是什么?有什么重要意义

SBTI(Science Based Targets initiative),即科学碳目标倡议,是一个由全球环境信息研究中心(CDP)、联合国全球契约组织(UNGC)、世界资源研究所(WRI)和世界自然…...

英特尔新旗舰 CPU 将运行更凉爽、更高效,适合 PC 游戏

英特尔终于解决了台式机 CPU 发热和耗电的问题。英特尔的新旗舰 Core Ultra 200S 系列处理器将于 10 月 24 日上市,该系列专注于每瓦性能,比之前的第 14 代芯片运行更凉爽、更高效。这些代号为 Arrow Lake S 的处理器也是英特尔首款内置 NPU(…...

龙虎榜——20250610

上证指数放量收阴线,个股多数下跌,盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型,指数短线有调整的需求,大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的:御银股份、雄帝科技 驱动…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止

<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet&#xff1a; https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...

LLM基础1_语言模型如何处理文本

基于GitHub项目&#xff1a;https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken&#xff1a;OpenAI开发的专业"分词器" torch&#xff1a;Facebook开发的强力计算引擎&#xff0c;相当于超级计算器 理解词嵌入&#xff1a;给词语画"…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术&#xff0c;它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton)&#xff1a;由层级结构的骨头组成&#xff0c;类似于人体骨骼蒙皮 (Mesh Skinning)&#xff1a;将模型网格顶点绑定到骨骼上&#xff0c;使骨骼移动…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

C++使用 new 来创建动态数组

问题&#xff1a; 不能使用变量定义数组大小 原因&#xff1a; 这是因为数组在内存中是连续存储的&#xff0c;编译器需要在编译阶段就确定数组的大小&#xff0c;以便正确地分配内存空间。如果允许使用变量来定义数组的大小&#xff0c;那么编译器就无法在编译时确定数组的大…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测

uniapp 中配置 配置manifest 文档&#xff1a;manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号&#xff1a;4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...

R 语言科研绘图第 55 期 --- 网络图-聚类

在发表科研论文的过程中&#xff0c;科研绘图是必不可少的&#xff0c;一张好看的图形会是文章很大的加分项。 为了便于使用&#xff0c;本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中&#xff0c;获取方式&#xff1a; R 语言科研绘图模板 --- sciRplothttps://mp.…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...