当前位置: 首页 > news >正文

深度学习实战模拟——softmax回归(图像识别并分类)

目录

1、数据集:

2、完整代码


1、数据集:

1.1 Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

1.2 Fashion‐MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据 集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集 不会用于训练,只用于评估模型性能。

以下函数用于在数字标签索引及其文本名称之间进行转换。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

2、完整代码

import torch
import torchvision
import pylab
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as plt
from d2l import torch as d2l
import timebatch_size = 256
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
num_epochs = 5class Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def cross_entropy(y_hat, y):return -torch.log(y_hat[range(len(y_hat)), y])def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp/partitiondef net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)def get_dataloader_workers():"""使用一个进程来读取的数据"""return 1def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""#共10个类别text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""画一系列图片"""figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)for i, (img, label) in enumerate(zip(imgs, titles)):xloc, yloc = i//num_cols, i % num_colsif torch.is_tensor(img):# 图片张量axes[xloc, yloc].imshow(img.reshape((28, 28)).numpy())else:# PIL图片axes[xloc, yloc].imshow(img)# 设置标题并取消横纵坐标上的刻度axes[xloc, yloc].set_title(label)plt.xticks([], ())axes[xloc, yloc].set_axis_off()pylab.show()def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = transforms.ToTensor()if resize:trans.insert(0, transforms.Resize(resize))mnist_train = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))def evaluate_accuracy(net, data_iter):"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]def updater(batch_size):lr = 0.1return d2l.sgd([W, b], lr, batch_size)def train_epoch_ch3(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()metric = Accumulator(3)for X, y in train_iter:y_hat = net(X)lo = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()lo.backward()updater.step()metric.add(float(lo)*len(y), accuracy(y_hat, y), y.size().numel())else:lo.sum().backward()updater(X.shape[0])metric.add(float(lo.sum()), accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]class Animator:  #@save"""绘制数据"""def __init__(self, legend=None):self.legend = legendself.X = [[], [], []]self.Y = [[], [], []]def add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nfor i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)def show(self):plt.plot(self.X[0], self.Y[0], 'r--')plt.plot(self.X[1], self.Y[1], 'g--')plt.plot(self.X[2], self.Y[2], 'b--')plt.legend(self.legend)plt.xlabel('epoch')plt.ylabel('value')plt.title('Visual')plt.show()def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型"""animator = Animator(legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)train_loss, train_acc = train_metricstest_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))print(f'epoch: {epoch+1},train_loss:{train_loss:.4f}, train_acc:{train_acc:.4f}, test_acc:{test_acc:.4f}')animator.show()def predict_ch3(net, test_iter, n=12):"""预测标签"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]show_images(X[0:n].reshape((n, 28, 28)), 2, int(n/2), titles=titles[0:n])if __name__ == '__main__':train_iter, test_iter = load_data_fashion_mnist(batch_size)train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)predict_ch3(net, test_iter)

分类效果:

相关文章:

深度学习实战模拟——softmax回归(图像识别并分类)

目录 1、数据集: 2、完整代码 1、数据集: 1.1 Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫&#xf…...

vue实现element-UI中table表格背景颜色设置

目前在style中设置不了,那么就在前面组件给设置上 :header-cell-style"{ color: #ffffff, fontSize: 14px, backgroundColor: #0E2152 }" :cell-style"{ color: #ffffff, fontSize: 14px, backgroundColor: #0E2152 }"...

RabbitMQ学习总结-消息的可靠性

保证MQ消息的可靠性,主要从三个方面:发送者确认可靠性,MQ确认可靠性,消费者确认可靠性。 1.发送者可靠性:主要依赖于发送者重试机制,发送者确认机制; 发送者重试机制,其实就是配置…...

2024蓝桥杯每日一题(BFS)

备战2024年蓝桥杯 -- 每日一题 Python大学A组 试题一:母亲的奶牛 试题二:走迷宫 试题三:八数码1 试题四:全球变暖 试题五:八数码2 试题一:母亲的奶牛 【题目描述】 农夫约…...

力扣思路题:最长特殊序列1

int findLUSlength(char * a, char * b){int alenstrlen(a),blenstrlen(b);if (strcmp(a,b)0)return -1;return alen>blen?alen:blen; }...

c# 的ref 和out

在C#中,ref和out是用于方法参数的关键字,它们都允许在方法调用中对参数进行修改。 ref关键字用于传递参数的引用。当使用ref关键字声明一个参数时,实际上是在告诉编译器此参数在调用方法之前必须被赋值。ref参数传递的是参数的引用地址&…...

ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案

ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案 文章目录 ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案摘要📑引言 🌟正文📚一、ONLYOFFICE文档概述 📊二、ONLYOFFI…...

Mar 14 | Datawhale 01~04 打卡 | Leetcode面试下

第一阶段主要就是学习字符串的处理和二叉树的遍历。前一段时间学习了二叉树的遍历,记忆还比较深刻,这几个题还是比较轻松的做出来了;但是像Longest Palindromic Substring这样的题除了简单的字符串处理(回文判断)&…...

【计算机网络】什么是http?

​ 目录 前言 1. 什么是HTTP协议? 2. 为什么使用HTTP协议? 3. HTTP协议通信过程 4. 什么是url? 5. HTTP报文 5.1 请求报文 5.2 响应报文 6. HTTP请求方式 7. HTTP头部字段 8. HTTP状态码 9. 连接管理 长连接与短连接 管线化连接…...

【python开发】并发编程(上)

并发编程(上) 一、进程和线程(一)多线程(二)多进程(三)GIL锁 二、多线程开发(一)t.start()(二)t.join()(三)t.…...

用c语言实现扫雷游戏

文章目录 概要整体架构流程代码的实现小结 概要 学习了c语言后,我们可以通过c语言写一些小程序,然后我们这篇文章主要是,扫雷游戏的一步一步游戏。 整体架构流程 扫雷网页版 根据上面网页版的基础扫雷可以看出是一个99的格子,…...

LeetCode 2882.删去重复的行

DataFrame customers ------------------- | Column Name | Type | ------------------- | customer_id | int | | name | object | | email | object | ------------------- 在 DataFrame 中基于 email 列存在一些重复行。 编写一个解决方案,删除这些重复行&#…...

对OceanBase进行 sysbench 压测前,如何用 obdiag巡检

有一些用户想对 OceanBase 进行 sysbench 压测,并向我询问是否需要对数据库的各种参数进行调整。我想起有一个工具 obdiag ,具备对集群进行巡检的功能。因此,我正好借此机会试用一下这个工具。 obdiag 功能的比较丰富,详细情况可参…...

每天学习几道面试题|Kafka架构设计类

文章目录 1. Kafka 是如何保证高可用性和容错性的?2. Kafka 的存储机制是怎样的?它是如何处理大量数据的?3. Kafka 如何处理消费者的消费速率低于生产者的生产速率?4. Kafka 集群中的 Controller 是什么?它的作用是什么…...

.rmallox勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复

导言: 近年来,勒索病毒的威胁日益增加,其中一种名为.rmallox的勒索病毒备受关注。这种病毒通过加密文件并勒索赎金来威胁受害者。本文将介绍.rmallox勒索病毒的特点,以及如何恢复被其加密的数据文件,并提供预防措施&a…...

安卓性能优化面试题 11-15

11. 简述APK安装包瘦身方案 ?(1):剔 除掉冗余的代码与不必要的jar包;具体来讲的话,我们可以使用SDK集成的ProGuard混淆工具,它可以在编译时检查并删除未使用的类、字段、方法 和属性,它会遍历所有代码找到无用处的代码,所有那些不可达的代码都会在生成最终apk文件之前被…...

Python错题集-9PermissionError:[Errno 13] (权限错误)

1问题描述 Traceback (most recent call last): File "D:\pycharm\projects\5-《Python数学建模算法与应用》程序和数据\02第2章 Python使用入门\ex2_38_1.py", line 9, in <module> fpd.ExcelWriter(data2_38_3.xlsx) #创建文件对象 File "D:…...

QT TCP通信介绍

QT是一个跨平台的C应用程序开发框架&#xff0c;它提供了一套完整的工具和库&#xff0c;用于开发各种类型的应用程序&#xff0c;包括图形用户界面(GUI)应用程序、命令行工具、网络应用程序等。QT提供了丰富的功能和类来简化网络通信的开发&#xff0c;其中包括TCP通信。 TCP…...

保姆级教学!微信小程序设计全攻略!

微信小程序开启了互联网软件的新使用模式。在各种微信小程序争相抢占流量的同时&#xff0c;如何设计微信小程序&#xff1f;让用户感到舒适是设计师在产品设计初期应该考虑的问题。那么如何做好微信小程序的设计呢&#xff1f;即时设计总结了以下设计指南&#xff0c;希望对准…...

日期差值的计算

1、枚举所有数值进行日期判断 时间复杂度是o(n)的&#xff0c;比较慢&#xff0c;单实例能凑合用&#xff0c;多实例的话时间复杂度有点高。 核心代码就是判断某个八位数能否表示一个日期。 static int[] month {0,31,28,31,30,31,30,31,31,30,31,30,31};static String a, b…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下&#xff1a; 一、场景操作步骤 操作步…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

Razor编程中@Html的方法使用大全

文章目录 1. 基础HTML辅助方法1.1 Html.ActionLink()1.2 Html.RouteLink()1.3 Html.Display() / Html.DisplayFor()1.4 Html.Editor() / Html.EditorFor()1.5 Html.Label() / Html.LabelFor()1.6 Html.TextBox() / Html.TextBoxFor() 2. 表单相关辅助方法2.1 Html.BeginForm() …...

MySQL 部分重点知识篇

一、数据库对象 1. 主键 定义 &#xff1a;主键是用于唯一标识表中每一行记录的字段或字段组合。它具有唯一性和非空性特点。 作用 &#xff1a;确保数据的完整性&#xff0c;便于数据的查询和管理。 示例 &#xff1a;在学生信息表中&#xff0c;学号可以作为主键&#xff…...

Spring Security 认证流程——补充

一、认证流程概述 Spring Security 的认证流程基于 过滤器链&#xff08;Filter Chain&#xff09;&#xff0c;核心组件包括 UsernamePasswordAuthenticationFilter、AuthenticationManager、UserDetailsService 等。整个流程可分为以下步骤&#xff1a; 用户提交登录请求拦…...

Linux中《基础IO》详细介绍

目录 理解"文件"狭义理解广义理解文件操作的归类认知系统角度文件类别 回顾C文件接口打开文件写文件读文件稍作修改&#xff0c;实现简单cat命令 输出信息到显示器&#xff0c;你有哪些方法stdin & stdout & stderr打开文件的方式 系统⽂件I/O⼀种传递标志位…...

高考志愿填报管理系统---开发介绍

高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发&#xff0c;采用现代化的Web技术&#xff0c;为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## &#x1f4cb; 系统概述 ### &#x1f3af; 系统定…...

实战设计模式之模板方法模式

概述 模板方法模式定义了一个操作中的算法骨架&#xff0c;并将某些步骤延迟到子类中实现。模板方法使得子类可以在不改变算法结构的前提下&#xff0c;重新定义算法中的某些步骤。简单来说&#xff0c;就是在一个方法中定义了要执行的步骤顺序或算法框架&#xff0c;但允许子类…...

​​企业大模型服务合规指南:深度解析备案与登记制度​​

伴随AI技术的爆炸式发展&#xff0c;尤其是大模型&#xff08;LLM&#xff09;在各行各业的深度应用和整合&#xff0c;企业利用AI技术提升效率、创新服务的步伐不断加快。无论是像DeepSeek这样的前沿技术提供者&#xff0c;还是积极拥抱AI转型的传统企业&#xff0c;在面向公众…...