如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets# 定义MLP模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()# 创建一个顺序的层序列:包括一个扁平化层、两个全连接层和ReLU激活self.layers = nn.Sequential(nn.Flatten(), # 将28x28的图像扁平化为784维向量nn.Linear(28 * 28, 512), # 第一个全连接层,784->512nn.ReLU(), # ReLU激活函数nn.Linear(512, 256), # 第二个全连接层,512->256nn.ReLU(), # ReLU激活函数nn.Linear(256, 10) # 第三个全连接层,256->10 (输出10个类别))def forward(self, x):return self.layers(x) # 定义前向传播# 加载FashionMNIST数据集
# 定义图像的预处理:转换为Tensor并标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 下载FashionMNIST数据并应用转换
dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)# 划分数据集为训练集和验证集
train_len = int(0.8 * len(dataset)) # 计算80%的长度作为训练数据
val_len = len(dataset) - train_len # 剩下的20%作为验证数据
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 训练数据加载器,批量大小64,打乱数据
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # 验证数据加载器,批量大小64,不打乱# 初始化模型、损失函数和优化器
model = MLP() # 创建MLP模型实例
criterion = nn.CrossEntropyLoss() # 定义交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器# 训练模型
epochs = 5 # 定义训练5个epochs
for epoch in range(epochs):model.train() # 将模型设置为训练模式for inputs, labels in train_loader: # 从训练加载器中获取批次数据outputs = model(inputs) # 前向传播loss = criterion(outputs, labels) # 计算损失optimizer.zero_grad() # 清除之前的梯度loss.backward() # 反向传播,计算梯度optimizer.step() # 更新权重# 在每个epoch结束时验证模型性能model.eval() # 将模型设置为评估模式total_correct = 0with torch.no_grad(): # 不计算梯度,节省内存和计算量for inputs, labels in val_loader: # 从验证加载器中获取批次数据outputs = model(inputs) # 前向传播_, predicted = outputs.max(1) # 获取预测的类别total_correct += (predicted == labels).sum().item() # 统计正确的预测数量accuracy = total_correct / val_len # 计算验证准确性print(f"Epoch {epoch + 1}/{epochs} - Validation accuracy: {accuracy:.4f}") # 打印验证准确性
nn.Flatten() 是一个特殊的层,它将多维的输入数据“展平”为一维数据。这在处理图像数据时尤为常见,因为图像通常是多维的(例如,一个大小为28x28的灰度图像在PyTorch中会有一个形状为[28, 28]的张量)。
在神经网络的某些层,特别是全连接层(如nn.Linear)之前,通常需要对数据进行扁平化处理。因为全连接层期望其输入是一维的(或者更准确地说,它期望输入的最后一个维度对应于特征,其他维度对应于数据的批次)。
为了更具体,让我们看一个例子:
考虑一个大小为[batch_size, 28, 28]的张量,这可以看作是一个batch_size数量的28x28图像的批次。当我们传递这个批次的图像到一个nn.Linear(28*28, 512)层时,我们需要先将图像展平。也就是说,每个28x28的图像需要转换为长度为784的一维向量。因此,输入数据的形状会从[batch_size, 28, 28]变为[batch_size, 784]。
nn.Flatten()就是做这个转换的。在这个特定的例子中,它会将[batch_size, 28, 28]的形状转换为[batch_size, 784]。
总结一下:nn.Flatten()用于将多维输入数据转换为一维,从而使其可以作为全连接层(如nn.Linear)的输入。
-
transforms.Compose:
这是一个简单的方式来链接(组合)多个图像转换操作。它会按照提供的顺序执行列表中的每个转换。 -
transforms.ToTensor():
这个转换将PIL图像或NumPy的ndarray转换为FloatTensor。并且它将图像的像素值范围从0-255变为0-1。简言之,它为我们完成了数据类型和值范围的转换。 -
transforms.Normalize((0.5,), (0.5,)):
这个转换标准化张量图像。给定的参数是均值和标准差。在这里,均值和标准差都是0.5。
使用给定的均值和标准差,这会将值范围从[0,1]转换为[-1,1]。
整个transform的目的是:
- 将图像数据从PIL格式转换为PyTorch张量格式。
- 将像素值从[0,255]范围转换为[0,1]范围。
- 使用给定的均值和标准差进一步标准化像素值,使其范围为[-1,1]。
初始化模型、损失函数和优化器
-
model = MLP():
- 这里我们实例化了我们之前定义的MLP类,从而创建了一个多层感知器(MLP)模型。
-
criterion = nn.CrossEntropyLoss():
- 在分类任务中,交叉熵损失函数 (CrossEntropyLoss) 是最常用的损失函数之一。它衡量真实标签和预测之间的差异。
- 注意:CrossEntropyLoss在内部执行softmax操作,因此模型输出应该是未经softmax处理的原始分数(logits)。
-
optimizer = optim.Adam(model.parameters(), lr=0.001):
- 优化器负责更新模型的权重,基于计算的梯度来减少损失。
- Adam是一种流行的优化器,它结合了两种扩展的随机梯度下降:Adaptive Gradients 和 Momentum。
- model.parameters()是传递给优化器的,它告诉优化器应该优化/更新哪些权重。
- lr=0.001定义了学习率,这是一个超参数,表示每次权重更新的步长大小。
常见的相关资料解答
- 模型 (在torch.nn中):
除了基本的MLP外,PyTorch提供了很多预定义的层和模型,常见的包括:
Convolutional Neural Networks (CNNs):nn.Conv2d: 2D卷积层,常用于图像处理。nn.Conv3d: 3D卷积层,常用于视频处理或医学图像。nn.MaxPool2d: 最大池化层。Recurrent Neural Networks (RNNs):nn.RNN: 基本的RNN层。nn.LSTM: 长短时记忆网络。nn.GRU: 门控循环单元。Transformer Architecture:nn.Transformer: 用于自然语言处理任务的Transformer模型。Batch Normalization, Dropout等:nn.BatchNorm2d: 批量归一化。nn.Dropout: 防止过拟合的正则化方法。
- 损失函数 (在torch.nn中):
常见的损失函数有:
Classification:nn.CrossEntropyLoss: 用于分类任务的交叉熵损失。nn.BCEWithLogitsLoss: 用于二分类任务的二元交叉熵损失,包括内部的sigmoid操作。nn.MultiLabelSoftMarginLoss: 用于多标签分类任务。Regression:nn.MSELoss: 均方误差,用于回归任务。nn.L1Loss: L1误差。Generative models:nn.KLDivLoss: Kullback-Leibler散度,常用于生成模型。
- 优化器 (在torch.optim中):
常见的优化器有:
optim.SGD: 随机梯度下降。
optim.Adam: 一个非常受欢迎的优化器,结合了AdaGrad和RMSProp的特点。
optim.RMSprop: 常用于深度学习任务。
optim.Adagrad: 自适应学习率优化器。
optim.Adadelta: 类似于Adagrad,但试图解决其快速降低学习率的问题。
optim.AdamW: Adam的变种,加入了权重衰减。

每文一语
学习是不断的发展的
相关文章:
如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识
# 导入必要的库 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split import torchvision.transforms as transforms import torchvision.datasets as datasets# 定义MLP模型 class MLP(nn.Module):def __…...
为什么引入SVG文件,给它定义属性不生效原理分析
背景: 我使用antd 的Icon组件引入SVG图片,但给svg图片定义styles样式时,不生效,为什么呢? 我们平时用antd组件库的 < ArrowRightOutlined style{{color: red }}>时为什么会生效呢,但我图一这样定义就…...
Integer包装类常用方法和属性
包装类 什么是包装类Integer包装类常用方法和属性 什么是包装类 Java 包装类是指为了方便处理基本数据类型而提供的对应的引用类型。Java 提供了八个基本数据类型(boolean、byte、short、int、long、float、double、char),每个基本数据类型对…...
基于Spring boot轻松实现一个多数据源框架
Spring Boot 提供了 Data JPA 的包,允许你使用类似 ORM 的接口连接到 RDMS。它很容易使用和实现,只需要在 pom.xml 中添加一个条目(如果使用的是 Maven,Gradle 则是在 build.gradle 文件中)。 <dependencies>&l…...
vue前端实现打印功能并约束纸张大小---调用浏览器打印功能打印页面部分元素并固定纸张大小
需求是打印指定div实现小票打印功能。调用浏览器的自带打印功能只能实现打印可视区域,所以这里采用截图新窗口打开打印去实现此需求。 1.安装html2canvas库实现截图功能 npm install html2canvas --save2.在需要进行截图和打印的组件中,引入html2canvas…...
音乐播放器蜂鸣器ROM存储歌曲verilog,代码/视频
名称:音乐播放器蜂鸣器ROM存储歌曲 软件:Quartus 语言:Verilog 代码功能: 设计音乐播放器,要求至少包含2首歌曲,使用按键切换歌曲,使用开发板的蜂鸣器播放音乐,使用Quartus内的RO…...
Arduino Nano 引脚复用分析
近期开发的项目为气体传感器采集仪,综合需求,选取NANO作为主控,附属设备有 oled、旋转编码器、H桥板、蠕动泵、开关、航插等,主要是用现有接口怎么合理配置实现功能。 不管stm32 还是 Arduino 都要看清引脚图 D2 D3 引脚是两个外…...
Go 函数多返回值错误处理与error 类型介绍
Go 函数多返回值错误处理与error 类型介绍 文章目录 Go 函数多返回值错误处理与error 类型介绍一、error 类型与错误值构造1.1 Error 接口介绍1.2 构造错误值的方法1.2.1 使用errors包1.2.2 自定义错误类型 二、error 类型的好处2.1 第一点:统一了错误类型2.2 第二点…...
数论分块
本质就是利用取整分数值的块状分布。 UVA11526 H(n) 题意: 求 ∑ i 1 n n i \sum_{i1}^{n} \frac {n}{i} ∑i1nin。 解析: ⌊ n i ⌋ \lfloor \frac{n}{i} \rfloor ⌊in⌋ 只有 O ( n ) O(\sqrt n) O(n ) 种取值,考虑将相同值同…...
宏任务与微任务,代码执行顺序
js引擎工作进程是同步的。事件循环机制,事件队列。 脚本代码执行顺序,是先执行同步代码,遇到微任务,就把它推进任务队列中。每个宏任务完成后,再执行下一个宏任务。 宏任务有哪些: i/o读写 定时器setTi…...
正方形(Squares, ACM/ICPC World Finals 1990, UVa201)rust解法
有n行n列(2≤n≤9)的小黑点,还有m条线段连接其中的一些黑点。统计这些线段连成了多少个正方形(每种边长分别统计)。 行从上到下编号为1~n,列从左到右编号为1~n。边用H i j和V i j表示…...
【算法设计与分析qwl】伪码——顺序检索,插入排序
伪代码: 例子: 改进的顺序检索 Search(L,x)输入:数组L[1...n],元素从小到大排序,数x输出:若x在L中,输出x位置下标 j ,否则输出0 j <- 1 while j<n and x>L[j] do j <- j1 if x<…...
Uniapp路由拦截-自定义路由白名单
步骤一:新建routerIntercept.js文件 步骤二:routerIntercept文件中写入:(根据自己需要修改whiteList白名单中的页面路径和自己的逻辑处理) import Vue from vue // 白名单 const whiteList = [/pages/public/login,/pages/public/privacyAgreement, ]export default asy…...
在中国可以使用 HubSpot 吗?
当谈到市场营销和客户关系管理工具时,HubSpot通常是一家企业的首选。然而,对于许多中国的企业来说,一个重要的问题是:在中国可以使用HubSpot吗?这个问题涉及到不同的方面,包括政策法规、社交媒体平台、语言…...
Java的基础应用
Java是一种广泛应用于软件开发的编程语言,基础应用涵盖了很多方面。以下是Java的一些基础应用方面的介绍: 1. 控制流语句:Java中的程序流程控制语句分为选择语句和循环语句。选择语句包括if-else语句和switch语句,循环语句包括fo…...
【excel】列转行
列转行 工作中有一些数据是列表,现在需要转行 选表格内容:在excel表格中选中表格数据区域。点击复制:在选中表格区域处右击点击复制。点击选择性粘贴:在表格中鼠标右击点击选择性粘贴。勾选转置:在选择性粘勾选转置选…...
用Bing绘制「V我50」漫画;GPT-5业内交流笔记;LLM大佬的跳槽建议;Stable Diffusion生态全盘点第一课 | ShowMeAI日报
👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🔥 美国升级AI芯片出口禁令,13家中国GPU企业被列入实体清单 nytimes.com/2023/10/05/technology/chip-makers-china-lobbying…...
Java身份证实名认证-阿里云API 【姓名、身份证号】
1. 阿里云API市场 https://market.aliyun.com/products/57126001/cmapi00053442.html?spm5176.2020520132.101.3.a6217218nxxEiy#skuyuncode47442000022 购买对应套餐 2. 复制AppCode https://market.console.aliyun.com/imageconsole/index.htm#/?_kl85e10 云市场-已购买服…...
ND协议——无状态地址自动配置 (SLAAC)
参考学习:计算机网络 | 思科网络 | 无状态地址自动配置 (SLAAC) | 什么是SLAAC_瘦弱的皮卡丘的博客-CSDN博客 与 IPv4 类似,可以手动或动态配置 IPv6 全局单播地址。但是,动态分配 IPv6 全局单播地址有两种方法: 如图所示&#…...
iOS开发UITableView的使用,区别Plain模式和Grouped模式
简单赘述一下 的创建步骤 // 创建UITableView self.tableView [[UITableView alloc] initWithFrame:self.view.bounds style:UITableViewStylePlain]; // 设置数据源和代理 self.tableView.dataSource self; self.tableView.delegate self; // 注册自定义UITableViewCe…...
3.3.1_1 检错编码(奇偶校验码)
从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
MMaDA: Multimodal Large Diffusion Language Models
CODE : https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA,它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构…...
Frozen-Flask :将 Flask 应用“冻结”为静态文件
Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是:将一个 Flask Web 应用生成成纯静态 HTML 文件,从而可以部署到静态网站托管服务上,如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...
HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!
简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求,并检查收到的响应。它以以下模式之一…...
基于IDIG-GAN的小样本电机轴承故障诊断
目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) 梯度归一化(Gradient Normalization) (2) 判别器梯度间隙正则化(Discriminator Gradient Gap Regularization) (3) 自注意力机制(Self-Attention) 3. 完整损失函数 二…...
站群服务器的应用场景都有哪些?
站群服务器主要是为了多个网站的托管和管理所设计的,可以通过集中管理和高效资源的分配,来支持多个独立的网站同时运行,让每一个网站都可以分配到独立的IP地址,避免出现IP关联的风险,用户还可以通过控制面板进行管理功…...
