当前位置: 首页 > news >正文

从0开始深度学习(12)——多层感知机的逐步实现

依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过

1 读取数据

import torch
from torchvision import transforms
import torchvision
from torch.utils import data# 读取数据
def load_data_fashion_mnist(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=True, transform=trans, download=False)mnist_test = torchvision.datasets.FashionMNIST(root="D:/DL_Data/", train=False, transform=trans, download=False)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=12),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=12))train_iter, test_iter = load_data_fashion_mnist(256, resize=28)

2 初始化模型参数

以单隐藏层的多层感知机为例,选择使用256个隐藏单元

from torch import nn# 初始化模型参数
num_inputs=784      # 28*28
num_outputs=10
num_hiddens=256     # 我们选择使用256个隐藏单元,注意,一般选择使用2的若干次幂,因为内存的特殊性,可以在计算上更高效w1 = nn.Parameter(torch.randn(num_inputs,num_hiddens,requires_grad=True)*0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens,requires_grad=True))w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [w1, b1, w2, b2]

3 激活函数、损失函数、建立模型

# 激活函数
def relu(x):a=torch.zeros_like(x) # 保证全零张量和x的形状一致,利于广播计算return torch.max(x,a)# 损失函数
loss = nn.CrossEntropyLoss(reduction='none')#建立模型
def net(x):x=x.reshape((-1,num_inputs))#展开H=relu(x@w1+b1)# @表示矩阵乘法return (H@w2+b2)

4 训练模型

优化器使用SGD

#训练,优化器使用sgd
num_epochs=5
lr=00.1
updater=torch.optim.SGD(params,lr=lr)def train_epoch(net, train_iter, loss, updater):if isinstance(net, torch.nn.Module):net.train()  # 将模型设置为训练模式metric = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()updater.step()else:l.backward()updater([w, b], lr, batch_size)metric.add(float(l) * y.numel(), compute_accuracy(y_hat, y), y.numel())return metric[0] / metric[2], metric[1] / metric[2]def train(net, train_iter, test_iter, loss, num_epochs, updater):for epoch in range(num_epochs):train_metrics = train_epoch(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f'Epoch {epoch + 1}: Train Loss {train_metrics[0]:.3f}, Train Acc {train_metrics[1]:.3f}, Test Acc {test_acc:.3f}')class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def compute_accuracy(y_hat, y):  # 预测值、真实值if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)  # 找到一个样本中,对应的最大概率的类别cmp = y_hat.type(y.dtype) == y  # 将预测值 y_hat 与真实标签 y 进行比较,生成一个布尔张量 cmpreturn float(cmp.type(y.dtype).sum())# 计算在指定数据集上模型的准确率
def evaluate_accuracy(net, data_iter):  if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 累加多个变量的总和。这里初始化了一个包含两个元素的累加器,分别用来存储正确预测的数量和总的预测数量。with torch.no_grad():for X, y in data_iter:metric.add(compute_accuracy(net(X), y), y.numel())return metric[0] / metric[1]train(net, train_iter, test_iter, loss, num_epochs, updater)

在这里插入图片描述

5 预测

import matplotlib.pyplot as plt
# 定义 Fashion-MNIST 标签的文本描述
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 预测并显示结果
def predict(net, test_iter, n=6):for X, y in test_iter:break  # 只取一个批次的数据trues = get_fashion_mnist_labels(y)preds = get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true + '\n' + pred for true, pred in zip(trues, preds)]n = min(n, X.shape[0])fig, axs = plt.subplots(1, n, figsize=(12, 3))for i in range(n):axs[i].imshow(X[i].permute(1, 2, 0).squeeze().numpy(), cmap='gray')axs[i].set_title(titles[i])axs[i].axis('off')plt.show()# 调用预测函数
predict(net, test_iter, n=6)

在这里插入图片描述

相关文章:

从0开始深度学习(12)——多层感知机的逐步实现

依然以Fashion-MNIST图像分类数据集为例,手动实现多层感知机和激活函数的编写,大部分代码均在从0开始深度学习(9)——softmax回归的逐步实现中实现过 1 读取数据 import torch from torchvision import transforms import torchv…...

如何利用OpenCV和yolo实现人脸检测

在之前的blog里面,我们有介绍OpenCV和yolo的区别,本文就人脸检测为例,分别介绍下OpenCV和yolo的实现方式。 OpenCV实现人脸检测 一、安装 OpenCV 首先确保你已经安装了 OpenCV 库。可以通过以下方式安装: 使用包管理工具安装&…...

015集——c# 实现CAD excel交互(CAD—C#二次开发入门)

第一步:添加引用 程序集—>扩展 namespace WindowsFormsApp2 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){}private void 获取当前excel_Click(object sender, EventArgs e…...

【计网笔记】以太网

经典以太网 总线拓扑 物理层 Manchester编码 数据链路层 MAC子层 MAC帧 DIX格式与IEEE802.3格式 IEEE802.3格式兼容DIX格式 前导码(帧开始定界符SOF) 8字节 前7字节均为0xAA第8字节为0xAB前7字节的Manchester编码将产生稳定方波,用于…...

Java 入门基础篇14 - java面向对象思想以及特性

学习目标: 一、目标 面向对象思想类和对象对象的创建和使用属性和方法封装 开始学习: 二、编程思想 2.1 什么是编程思想 做人有做人的原则,编程也有编程的原则。这些编程的原则,就叫做编程思想。 2.2 面向过程和面向对象 二…...

第15篇:网络架构优化与综合案例分析

目录 引言 15.1 网络性能优化的方法与工具 15.1.1 带宽管理与流量控制 15.1.2 负载均衡 15.1.3 缓存优化 15.2 网络故障的排查与解决 15.2.1 常用的网络故障排查工具 15.2.2 网络故障排查案例 15.3 网络安全架构的综合设计案例 15.3.1 企业网络安全架构的要求 15.3.…...

UI自动化测试实战

补充:Selenium主要用于Web页面的自动化测试,它可以模拟用户的各种操作,如点击、输入、滚动等,来测试网页的功能。而Appium是一个开源的移动端自动化测试工具。 一、自动化测试实战章节 自动化测试流程测试用例编写项目自动化测试…...

东方智者颜廷利:以哲学思想促进世界和谐与无私奉献

【本社讯】在全球化的今天,东方智慧与哲学思想正逐渐成为促进世界和谐与理解的重要力量。近日,祖籍齐鲁大地山东济南的东方智者颜廷利以其深邃的哲学思想和对人类社会的深刻洞察,引起了国际社会的广泛关注。 颜廷利,一位致力于哲学研究与实践的智者,他的思想跨越古今,融合了东…...

基于 springboot vue停车场管理系统 设计与实现

博主介绍:专注于Java(springboot ssm 等开发框架) vue .net php phython node.js uniapp 微信小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不…...

如何验证ssl私钥和证书是否匹配?

从证书(CRT)文件提取公钥 openssl x509 -in server.crt -pubkey -noout | openssl sha256从证书签名请求(CSR)文件提取公钥 openssl req -in server.csr -pubkey -noout | openssl sha256从私钥(KEY)文件…...

MongoDB的基本操作

🌷数据库准备 🎈Mongoshell 1.在指定目录下创建mongodb文件夹、其子文件log和data以及mongodb.log cd /home/ubuntu mkdir -p mongodb/data mkdir -p mongodb/log touch mongodb/log/mongodb.log 执行mongodb命令启动mongdb服务 mongod --dbpath /h…...

spring mvc后端实现过程

文章目录 一、Spring mvc1、controller1.1、LoginController011.2、LoginController 2、service2.1、LoginService2.1、LoginInimplements 3、dao3.1、LoginMapper3.1、LoginMapper.xml 4、实体类 一、Spring mvc 1、controller 控制器层、处理用户的请求和响应, …...

102005

import os os.environ["CUDA_VISIBLE_DEVICES"] "0" # 设定使用的 GPUimport tensorflow as tf from dataset import generate_data import numpy as np from model import enhancednet# 检查 TensorFlow 是否可以识别 GPU gpus tf.config.list_physica…...

Cisco ACI环境给Leaf配置OOB带外管理IP方法

可以通过GUI 或CLI进行配置 通过CLI更简单,和配置传统交换机差不多, ACI中共有3大组件 APIC 控制器 SPINE 核心 LEAF 接入 下面我们将3种角色的带外IP配置方法都列出来 1 APIC配置带外IP This example shows how to configure out-of-band managemen…...

免费送源码:Java+B/S+MySQL springboot电影推荐系统 计算机毕业设计原创定制

摘 要 随着互联网与移动互联网迅速普及,网络上的电影娱乐信息数量相当庞大,人们对获取感兴趣的电影娱乐信息的需求越来越大,个性化的电影推荐系统成为一个热门。然而电影信息的表示相当复杂,己有的相似度计算方法与推荐算法都各有优势&#…...

数据清洗(脚本)

使用脚本清洗数据时,可以根据具体的数据问题选择编程语言,如Shell、Python、SQL等。这里我以 Python(Pandas库) 和 SQL 为例,演示如何通过脚本进行数据清洗。 1. 使用 Python(Pandas库) 进行数…...

jmeter中发送post请求遇到的问题

用jmeter发送post请求,把请求参数放在Body Data处,参数都写得正确,但没想到结果每次都报错,直接响应结果乱七八糟,改成用Parameters,反而不乱报错了。 上图 请求里如下 另外一些请求也是这样 这个响应结果也是错误的…...

Java中使用protobuf

一、简介 Protocal Buffers(简称protobuf)是谷歌的一项技术,用于结构化的数据序列化、反序列化。 Protocol Buffers 是一种语言无关、平台无关、可扩展的序列化结构数据的方法,它可用于(数据)通信协议、数据存储等。 Protocol B…...

2020款Macbook Pro A2251无法充电无法开机定位及修复

问题背景 up主有一台2020年的Macbook Pro,带Touch Bar,16G512G,四核I5,型号A2251 应该是一周没充电了,之前还用的好好的,后来有一天出差想带上 打开没电,手头上有个小米的66W快充头&#xff0c…...

Spring Cloud --- 引入Gateway网关

引入Gateway网关 介绍 Spring Cloud Gateway 组件的核心是一系列的过滤器,通过这些过滤器可以将客户端发送的请求转发(路由)到对应的微服务。 Spring Cloud Gateway 是加在整个微服务最前沿的防火墙和代理器,隐藏微服务结点 IP 端口信息,从…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

Element Plus 表单(el-form)中关于正整数输入的校验规则

目录 1 单个正整数输入1.1 模板1.2 校验规则 2 两个正整数输入&#xff08;联动&#xff09;2.1 模板2.2 校验规则2.3 CSS 1 单个正整数输入 1.1 模板 <el-formref"formRef":model"formData":rules"formRules"label-width"150px"…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

C#学习第29天:表达式树(Expression Trees)

目录 什么是表达式树&#xff1f; 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持&#xff1a; 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...

go 里面的指针

指针 在 Go 中&#xff0c;指针&#xff08;pointer&#xff09;是一个变量的内存地址&#xff0c;就像 C 语言那样&#xff1a; a : 10 p : &a // p 是一个指向 a 的指针 fmt.Println(*p) // 输出 10&#xff0c;通过指针解引用• &a 表示获取变量 a 的地址 p 表示…...

《Docker》架构

文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器&#xff0c;docker&#xff0c;镜像&#xff0c;k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...

门静脉高压——表现

一、门静脉高压表现 00:01 1. 门静脉构成 00:13 组成结构&#xff1a;由肠系膜上静脉和脾静脉汇合构成&#xff0c;是肝脏血液供应的主要来源。淤血后果&#xff1a;门静脉淤血会同时导致脾静脉和肠系膜上静脉淤血&#xff0c;引发后续系列症状。 2. 脾大和脾功能亢进 00:46 …...