当前位置: 首页 > news >正文

如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识

# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets# 定义MLP模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()# 创建一个顺序的层序列:包括一个扁平化层、两个全连接层和ReLU激活self.layers = nn.Sequential(nn.Flatten(),                       # 将28x28的图像扁平化为784维向量nn.Linear(28 * 28, 512),            # 第一个全连接层,784->512nn.ReLU(),                          # ReLU激活函数nn.Linear(512, 256),                # 第二个全连接层,512->256nn.ReLU(),                          # ReLU激活函数nn.Linear(256, 10)                  # 第三个全连接层,256->10 (输出10个类别))def forward(self, x):return self.layers(x)                   # 定义前向传播# 加载FashionMNIST数据集
# 定义图像的预处理:转换为Tensor并标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 下载FashionMNIST数据并应用转换
dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)# 划分数据集为训练集和验证集
train_len = int(0.8 * len(dataset))           # 计算80%的长度作为训练数据
val_len = len(dataset) - train_len            # 剩下的20%作为验证数据
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  # 训练数据加载器,批量大小64,打乱数据
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)     # 验证数据加载器,批量大小64,不打乱# 初始化模型、损失函数和优化器
model = MLP()                                 # 创建MLP模型实例
criterion = nn.CrossEntropyLoss()             # 定义交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 使用Adam优化器# 训练模型
epochs = 5                                    # 定义训练5个epochs
for epoch in range(epochs):model.train()                             # 将模型设置为训练模式for inputs, labels in train_loader:       # 从训练加载器中获取批次数据outputs = model(inputs)               # 前向传播loss = criterion(outputs, labels)     # 计算损失optimizer.zero_grad()                 # 清除之前的梯度loss.backward()                       # 反向传播,计算梯度optimizer.step()                      # 更新权重# 在每个epoch结束时验证模型性能model.eval()                              # 将模型设置为评估模式total_correct = 0with torch.no_grad():                     # 不计算梯度,节省内存和计算量for inputs, labels in val_loader:     # 从验证加载器中获取批次数据outputs = model(inputs)           # 前向传播_, predicted = outputs.max(1)     # 获取预测的类别total_correct += (predicted == labels).sum().item()  # 统计正确的预测数量accuracy = total_correct / val_len        # 计算验证准确性print(f"Epoch {epoch + 1}/{epochs} - Validation accuracy: {accuracy:.4f}")  # 打印验证准确性

nn.Flatten() 是一个特殊的层,它将多维的输入数据“展平”为一维数据。这在处理图像数据时尤为常见,因为图像通常是多维的(例如,一个大小为28x28的灰度图像在PyTorch中会有一个形状为[28, 28]的张量)。

在神经网络的某些层,特别是全连接层(如nn.Linear)之前,通常需要对数据进行扁平化处理。因为全连接层期望其输入是一维的(或者更准确地说,它期望输入的最后一个维度对应于特征,其他维度对应于数据的批次)。

为了更具体,让我们看一个例子:

考虑一个大小为[batch_size, 28, 28]的张量,这可以看作是一个batch_size数量的28x28图像的批次。当我们传递这个批次的图像到一个nn.Linear(28*28, 512)层时,我们需要先将图像展平。也就是说,每个28x28的图像需要转换为长度为784的一维向量。因此,输入数据的形状会从[batch_size, 28, 28]变为[batch_size, 784]。

nn.Flatten()就是做这个转换的。在这个特定的例子中,它会将[batch_size, 28, 28]的形状转换为[batch_size, 784]。

总结一下:nn.Flatten()用于将多维输入数据转换为一维,从而使其可以作为全连接层(如nn.Linear)的输入。

  • transforms.Compose:
    这是一个简单的方式来链接(组合)多个图像转换操作。它会按照提供的顺序执行列表中的每个转换。

  • transforms.ToTensor():
    这个转换将PIL图像或NumPy的ndarray转换为FloatTensor。并且它将图像的像素值范围从0-255变为0-1。简言之,它为我们完成了数据类型和值范围的转换。

  • transforms.Normalize((0.5,), (0.5,)):
    这个转换标准化张量图像。给定的参数是均值和标准差。在这里,均值和标准差都是0.5。
    使用给定的均值和标准差,这会将值范围从[0,1]转换为[-1,1]。

整个transform的目的是:

  • 将图像数据从PIL格式转换为PyTorch张量格式。
  • 将像素值从[0,255]范围转换为[0,1]范围。
  • 使用给定的均值和标准差进一步标准化像素值,使其范围为[-1,1]。

初始化模型、损失函数和优化器

  • model = MLP():

    • 这里我们实例化了我们之前定义的MLP类,从而创建了一个多层感知器(MLP)模型。
  • criterion = nn.CrossEntropyLoss():

    • 在分类任务中,交叉熵损失函数 (CrossEntropyLoss) 是最常用的损失函数之一。它衡量真实标签和预测之间的差异。
    • 注意:CrossEntropyLoss在内部执行softmax操作,因此模型输出应该是未经softmax处理的原始分数(logits)。
  • optimizer = optim.Adam(model.parameters(), lr=0.001):

    • 优化器负责更新模型的权重,基于计算的梯度来减少损失。
    • Adam是一种流行的优化器,它结合了两种扩展的随机梯度下降:Adaptive Gradients 和 Momentum。
    • model.parameters()是传递给优化器的,它告诉优化器应该优化/更新哪些权重。
    • lr=0.001定义了学习率,这是一个超参数,表示每次权重更新的步长大小。

常见的相关资料解答

  1. 模型 (在torch.nn中):

除了基本的MLP外,PyTorch提供了很多预定义的层和模型,常见的包括:

Convolutional Neural Networks (CNNs):nn.Conv2d: 2D卷积层,常用于图像处理。nn.Conv3d: 3D卷积层,常用于视频处理或医学图像。nn.MaxPool2d: 最大池化层。Recurrent Neural Networks (RNNs):nn.RNN: 基本的RNN层。nn.LSTM: 长短时记忆网络。nn.GRU: 门控循环单元。Transformer Architecture:nn.Transformer: 用于自然语言处理任务的Transformer模型。Batch Normalization, Dropout等:nn.BatchNorm2d: 批量归一化。nn.Dropout: 防止过拟合的正则化方法。
  1. 损失函数 (在torch.nn中):

常见的损失函数有:

Classification:nn.CrossEntropyLoss: 用于分类任务的交叉熵损失。nn.BCEWithLogitsLoss: 用于二分类任务的二元交叉熵损失,包括内部的sigmoid操作。nn.MultiLabelSoftMarginLoss: 用于多标签分类任务。Regression:nn.MSELoss: 均方误差,用于回归任务。nn.L1Loss: L1误差。Generative models:nn.KLDivLoss: Kullback-Leibler散度,常用于生成模型。
  1. 优化器 (在torch.optim中):

常见的优化器有:

optim.SGD: 随机梯度下降。
optim.Adam: 一个非常受欢迎的优化器,结合了AdaGrad和RMSProp的特点。
optim.RMSprop: 常用于深度学习任务。
optim.Adagrad: 自适应学习率优化器。
optim.Adadelta: 类似于Adagrad,但试图解决其快速降低学习率的问题。
optim.AdamW: Adam的变种,加入了权重衰减。

在这里插入图片描述

每文一语

学习是不断的发展的

相关文章:

如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识

# 导入必要的库 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split import torchvision.transforms as transforms import torchvision.datasets as datasets# 定义MLP模型 class MLP(nn.Module):def __…...

为什么引入SVG文件,给它定义属性不生效原理分析

背景&#xff1a; 我使用antd 的Icon组件引入SVG图片&#xff0c;但给svg图片定义styles样式时&#xff0c;不生效&#xff0c;为什么呢&#xff1f; 我们平时用antd组件库的 < ArrowRightOutlined style{{color: red }}>时为什么会生效呢&#xff0c;但我图一这样定义就…...

Integer包装类常用方法和属性

包装类 什么是包装类Integer包装类常用方法和属性 什么是包装类 Java 包装类是指为了方便处理基本数据类型而提供的对应的引用类型。Java 提供了八个基本数据类型&#xff08;boolean、byte、short、int、long、float、double、char&#xff09;&#xff0c;每个基本数据类型对…...

基于Spring boot轻松实现一个多数据源框架

Spring Boot 提供了 Data JPA 的包&#xff0c;允许你使用类似 ORM 的接口连接到 RDMS。它很容易使用和实现&#xff0c;只需要在 pom.xml 中添加一个条目&#xff08;如果使用的是 Maven&#xff0c;Gradle 则是在 build.gradle 文件中&#xff09;。 <dependencies>&l…...

vue前端实现打印功能并约束纸张大小---调用浏览器打印功能打印页面部分元素并固定纸张大小

需求是打印指定div实现小票打印功能。调用浏览器的自带打印功能只能实现打印可视区域&#xff0c;所以这里采用截图新窗口打开打印去实现此需求。 1.安装html2canvas库实现截图功能 npm install html2canvas --save2.在需要进行截图和打印的组件中&#xff0c;引入html2canvas…...

音乐播放器蜂鸣器ROM存储歌曲verilog,代码/视频

名称&#xff1a;音乐播放器蜂鸣器ROM存储歌曲 软件&#xff1a;Quartus 语言&#xff1a;Verilog 代码功能&#xff1a; 设计音乐播放器&#xff0c;要求至少包含2首歌曲&#xff0c;使用按键切换歌曲&#xff0c;使用开发板的蜂鸣器播放音乐&#xff0c;使用Quartus内的RO…...

Arduino Nano 引脚复用分析

近期开发的项目为气体传感器采集仪&#xff0c;综合需求&#xff0c;选取NANO作为主控&#xff0c;附属设备有 oled、旋转编码器、H桥板、蠕动泵、开关、航插等&#xff0c;主要是用现有接口怎么合理配置实现功能。 不管stm32 还是 Arduino 都要看清引脚图 D2 D3 引脚是两个外…...

Go 函数多返回值错误处理与error 类型介绍

Go 函数多返回值错误处理与error 类型介绍 文章目录 Go 函数多返回值错误处理与error 类型介绍一、error 类型与错误值构造1.1 Error 接口介绍1.2 构造错误值的方法1.2.1 使用errors包1.2.2 自定义错误类型 二、error 类型的好处2.1 第一点&#xff1a;统一了错误类型2.2 第二点…...

数论分块

本质就是利用取整分数值的块状分布。 UVA11526 H(n) 题意&#xff1a; 求 ∑ i 1 n n i \sum_{i1}^{n} \frac {n}{i} ∑i1n​in​。 解析&#xff1a; ⌊ n i ⌋ \lfloor \frac{n}{i} \rfloor ⌊in​⌋ 只有 O ( n ) O(\sqrt n) O(n ​) 种取值&#xff0c;考虑将相同值同…...

宏任务与微任务,代码执行顺序

js引擎工作进程是同步的。事件循环机制&#xff0c;事件队列。 脚本代码执行顺序&#xff0c;是先执行同步代码&#xff0c;遇到微任务&#xff0c;就把它推进任务队列中。每个宏任务完成后&#xff0c;再执行下一个宏任务。 宏任务有哪些&#xff1a; i/o读写 定时器setTi…...

正方形(Squares, ACM/ICPC World Finals 1990, UVa201)rust解法

有n行n列&#xff08;2≤n≤9&#xff09;的小黑点&#xff0c;还有m条线段连接其中的一些黑点。统计这些线段连成了多少个正方形&#xff08;每种边长分别统计&#xff09;。 行从上到下编号为1&#xff5e;n&#xff0c;列从左到右编号为1&#xff5e;n。边用H i j和V i j表示…...

【算法设计与分析qwl】伪码——顺序检索,插入排序

伪代码&#xff1a; 例子&#xff1a; 改进的顺序检索 Search(L,x)输入&#xff1a;数组L[1...n]&#xff0c;元素从小到大排序&#xff0c;数x输出&#xff1a;若x在L中&#xff0c;输出x位置下标 j ,否则输出0 j <- 1 while j<n and x>L[j] do j <- j1 if x<…...

Uniapp路由拦截-自定义路由白名单

步骤一:新建routerIntercept.js文件 步骤二:routerIntercept文件中写入:(根据自己需要修改whiteList白名单中的页面路径和自己的逻辑处理) import Vue from vue // 白名单 const whiteList = [/pages/public/login,/pages/public/privacyAgreement, ]export default asy…...

在中国可以使用 HubSpot 吗?

当谈到市场营销和客户关系管理工具时&#xff0c;HubSpot通常是一家企业的首选。然而&#xff0c;对于许多中国的企业来说&#xff0c;一个重要的问题是&#xff1a;在中国可以使用HubSpot吗&#xff1f;这个问题涉及到不同的方面&#xff0c;包括政策法规、社交媒体平台、语言…...

Java的基础应用

Java是一种广泛应用于软件开发的编程语言&#xff0c;基础应用涵盖了很多方面。以下是Java的一些基础应用方面的介绍&#xff1a; 1. 控制流语句&#xff1a;Java中的程序流程控制语句分为选择语句和循环语句。选择语句包括if-else语句和switch语句&#xff0c;循环语句包括fo…...

【excel】列转行

列转行 工作中有一些数据是列表&#xff0c;现在需要转行 选表格内容&#xff1a;在excel表格中选中表格数据区域。点击复制&#xff1a;在选中表格区域处右击点击复制。点击选择性粘贴&#xff1a;在表格中鼠标右击点击选择性粘贴。勾选转置&#xff1a;在选择性粘勾选转置选…...

用Bing绘制「V我50」漫画;GPT-5业内交流笔记;LLM大佬的跳槽建议;Stable Diffusion生态全盘点第一课 | ShowMeAI日报

&#x1f440;日报&周刊合集 | &#x1f3a1;生产力工具与行业应用大全 | &#x1f9e1; 点赞关注评论拜托啦&#xff01; &#x1f525; 美国升级AI芯片出口禁令&#xff0c;13家中国GPU企业被列入实体清单 nytimes.com/2023/10/05/technology/chip-makers-china-lobbying…...

Java身份证实名认证-阿里云API 【姓名、身份证号】

1. 阿里云API市场 https://market.aliyun.com/products/57126001/cmapi00053442.html?spm5176.2020520132.101.3.a6217218nxxEiy#skuyuncode47442000022 购买对应套餐 2. 复制AppCode https://market.console.aliyun.com/imageconsole/index.htm#/?_kl85e10 云市场-已购买服…...

ND协议——无状态地址自动配置 (SLAAC)

参考学习&#xff1a;计算机网络 | 思科网络 | 无状态地址自动配置 (SLAAC) | 什么是SLAAC_瘦弱的皮卡丘的博客-CSDN博客 与 IPv4 类似&#xff0c;可以手动或动态配置 IPv6 全局单播地址。但是&#xff0c;动态分配 IPv6 全局单播地址有两种方法&#xff1a; 如图所示&#…...

iOS开发UITableView的使用,区别Plain模式和Grouped模式

简单赘述一下 的创建步骤 // 创建UITableView self.tableView [[UITableView alloc] initWithFrame:self.view.bounds style:UITableViewStylePlain]; // 设置数据源和代理 self.tableView.dataSource self; self.tableView.delegate self; // 注册自定义UITableViewCe…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

Debian系统简介

目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版&#xff…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出&#xff1a;JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中&#xff0c;随机数的生成看似简单&#xff0c;却隐藏着许多玄机。无论是生成密码、加密密钥&#xff0c;还是创建安全令牌&#xff0c;随机数的质量直接关系到系统的安全性。Jav…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?

在建筑行业&#xff0c;项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升&#xff0c;传统的管理模式已经难以满足现代工程的需求。过去&#xff0c;许多企业依赖手工记录、口头沟通和分散的信息管理&#xff0c;导致效率低下、成本失控、风险频发。例如&#…...

什么是EULA和DPA

文章目录 EULA&#xff08;End User License Agreement&#xff09;DPA&#xff08;Data Protection Agreement&#xff09;一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA&#xff08;End User License Agreement&#xff09; 定义&#xff1a; EULA即…...

c#开发AI模型对话

AI模型 前面已经介绍了一般AI模型本地部署&#xff0c;直接调用现成的模型数据。这里主要讲述讲接口集成到我们自己的程序中使用方式。 微软提供了ML.NET来开发和使用AI模型&#xff0c;但是目前国内可能使用不多&#xff0c;至少实践例子很少看见。开发训练模型就不介绍了&am…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...