如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets# 定义MLP模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()# 创建一个顺序的层序列:包括一个扁平化层、两个全连接层和ReLU激活self.layers = nn.Sequential(nn.Flatten(), # 将28x28的图像扁平化为784维向量nn.Linear(28 * 28, 512), # 第一个全连接层,784->512nn.ReLU(), # ReLU激活函数nn.Linear(512, 256), # 第二个全连接层,512->256nn.ReLU(), # ReLU激活函数nn.Linear(256, 10) # 第三个全连接层,256->10 (输出10个类别))def forward(self, x):return self.layers(x) # 定义前向传播# 加载FashionMNIST数据集
# 定义图像的预处理:转换为Tensor并标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 下载FashionMNIST数据并应用转换
dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)# 划分数据集为训练集和验证集
train_len = int(0.8 * len(dataset)) # 计算80%的长度作为训练数据
val_len = len(dataset) - train_len # 剩下的20%作为验证数据
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 训练数据加载器,批量大小64,打乱数据
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # 验证数据加载器,批量大小64,不打乱# 初始化模型、损失函数和优化器
model = MLP() # 创建MLP模型实例
criterion = nn.CrossEntropyLoss() # 定义交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器# 训练模型
epochs = 5 # 定义训练5个epochs
for epoch in range(epochs):model.train() # 将模型设置为训练模式for inputs, labels in train_loader: # 从训练加载器中获取批次数据outputs = model(inputs) # 前向传播loss = criterion(outputs, labels) # 计算损失optimizer.zero_grad() # 清除之前的梯度loss.backward() # 反向传播,计算梯度optimizer.step() # 更新权重# 在每个epoch结束时验证模型性能model.eval() # 将模型设置为评估模式total_correct = 0with torch.no_grad(): # 不计算梯度,节省内存和计算量for inputs, labels in val_loader: # 从验证加载器中获取批次数据outputs = model(inputs) # 前向传播_, predicted = outputs.max(1) # 获取预测的类别total_correct += (predicted == labels).sum().item() # 统计正确的预测数量accuracy = total_correct / val_len # 计算验证准确性print(f"Epoch {epoch + 1}/{epochs} - Validation accuracy: {accuracy:.4f}") # 打印验证准确性
nn.Flatten() 是一个特殊的层,它将多维的输入数据“展平”为一维数据。这在处理图像数据时尤为常见,因为图像通常是多维的(例如,一个大小为28x28的灰度图像在PyTorch中会有一个形状为[28, 28]的张量)。
在神经网络的某些层,特别是全连接层(如nn.Linear)之前,通常需要对数据进行扁平化处理。因为全连接层期望其输入是一维的(或者更准确地说,它期望输入的最后一个维度对应于特征,其他维度对应于数据的批次)。
为了更具体,让我们看一个例子:
考虑一个大小为[batch_size, 28, 28]的张量,这可以看作是一个batch_size数量的28x28图像的批次。当我们传递这个批次的图像到一个nn.Linear(28*28, 512)层时,我们需要先将图像展平。也就是说,每个28x28的图像需要转换为长度为784的一维向量。因此,输入数据的形状会从[batch_size, 28, 28]变为[batch_size, 784]。
nn.Flatten()就是做这个转换的。在这个特定的例子中,它会将[batch_size, 28, 28]的形状转换为[batch_size, 784]。
总结一下:nn.Flatten()用于将多维输入数据转换为一维,从而使其可以作为全连接层(如nn.Linear)的输入。
-
transforms.Compose:
这是一个简单的方式来链接(组合)多个图像转换操作。它会按照提供的顺序执行列表中的每个转换。 -
transforms.ToTensor():
这个转换将PIL图像或NumPy的ndarray转换为FloatTensor。并且它将图像的像素值范围从0-255变为0-1。简言之,它为我们完成了数据类型和值范围的转换。 -
transforms.Normalize((0.5,), (0.5,)):
这个转换标准化张量图像。给定的参数是均值和标准差。在这里,均值和标准差都是0.5。
使用给定的均值和标准差,这会将值范围从[0,1]转换为[-1,1]。
整个transform的目的是:
- 将图像数据从PIL格式转换为PyTorch张量格式。
- 将像素值从[0,255]范围转换为[0,1]范围。
- 使用给定的均值和标准差进一步标准化像素值,使其范围为[-1,1]。
初始化模型、损失函数和优化器
-
model = MLP():
- 这里我们实例化了我们之前定义的MLP类,从而创建了一个多层感知器(MLP)模型。
-
criterion = nn.CrossEntropyLoss():
- 在分类任务中,交叉熵损失函数 (CrossEntropyLoss) 是最常用的损失函数之一。它衡量真实标签和预测之间的差异。
- 注意:CrossEntropyLoss在内部执行softmax操作,因此模型输出应该是未经softmax处理的原始分数(logits)。
-
optimizer = optim.Adam(model.parameters(), lr=0.001):
- 优化器负责更新模型的权重,基于计算的梯度来减少损失。
- Adam是一种流行的优化器,它结合了两种扩展的随机梯度下降:Adaptive Gradients 和 Momentum。
- model.parameters()是传递给优化器的,它告诉优化器应该优化/更新哪些权重。
- lr=0.001定义了学习率,这是一个超参数,表示每次权重更新的步长大小。
常见的相关资料解答
- 模型 (在torch.nn中):
除了基本的MLP外,PyTorch提供了很多预定义的层和模型,常见的包括:
Convolutional Neural Networks (CNNs):nn.Conv2d: 2D卷积层,常用于图像处理。nn.Conv3d: 3D卷积层,常用于视频处理或医学图像。nn.MaxPool2d: 最大池化层。Recurrent Neural Networks (RNNs):nn.RNN: 基本的RNN层。nn.LSTM: 长短时记忆网络。nn.GRU: 门控循环单元。Transformer Architecture:nn.Transformer: 用于自然语言处理任务的Transformer模型。Batch Normalization, Dropout等:nn.BatchNorm2d: 批量归一化。nn.Dropout: 防止过拟合的正则化方法。
- 损失函数 (在torch.nn中):
常见的损失函数有:
Classification:nn.CrossEntropyLoss: 用于分类任务的交叉熵损失。nn.BCEWithLogitsLoss: 用于二分类任务的二元交叉熵损失,包括内部的sigmoid操作。nn.MultiLabelSoftMarginLoss: 用于多标签分类任务。Regression:nn.MSELoss: 均方误差,用于回归任务。nn.L1Loss: L1误差。Generative models:nn.KLDivLoss: Kullback-Leibler散度,常用于生成模型。
- 优化器 (在torch.optim中):
常见的优化器有:
optim.SGD: 随机梯度下降。
optim.Adam: 一个非常受欢迎的优化器,结合了AdaGrad和RMSProp的特点。
optim.RMSprop: 常用于深度学习任务。
optim.Adagrad: 自适应学习率优化器。
optim.Adadelta: 类似于Adagrad,但试图解决其快速降低学习率的问题。
optim.AdamW: Adam的变种,加入了权重衰减。

每文一语
学习是不断的发展的
相关文章:
如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识
# 导入必要的库 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split import torchvision.transforms as transforms import torchvision.datasets as datasets# 定义MLP模型 class MLP(nn.Module):def __…...
为什么引入SVG文件,给它定义属性不生效原理分析
背景: 我使用antd 的Icon组件引入SVG图片,但给svg图片定义styles样式时,不生效,为什么呢? 我们平时用antd组件库的 < ArrowRightOutlined style{{color: red }}>时为什么会生效呢,但我图一这样定义就…...
Integer包装类常用方法和属性
包装类 什么是包装类Integer包装类常用方法和属性 什么是包装类 Java 包装类是指为了方便处理基本数据类型而提供的对应的引用类型。Java 提供了八个基本数据类型(boolean、byte、short、int、long、float、double、char),每个基本数据类型对…...
基于Spring boot轻松实现一个多数据源框架
Spring Boot 提供了 Data JPA 的包,允许你使用类似 ORM 的接口连接到 RDMS。它很容易使用和实现,只需要在 pom.xml 中添加一个条目(如果使用的是 Maven,Gradle 则是在 build.gradle 文件中)。 <dependencies>&l…...
vue前端实现打印功能并约束纸张大小---调用浏览器打印功能打印页面部分元素并固定纸张大小
需求是打印指定div实现小票打印功能。调用浏览器的自带打印功能只能实现打印可视区域,所以这里采用截图新窗口打开打印去实现此需求。 1.安装html2canvas库实现截图功能 npm install html2canvas --save2.在需要进行截图和打印的组件中,引入html2canvas…...
音乐播放器蜂鸣器ROM存储歌曲verilog,代码/视频
名称:音乐播放器蜂鸣器ROM存储歌曲 软件:Quartus 语言:Verilog 代码功能: 设计音乐播放器,要求至少包含2首歌曲,使用按键切换歌曲,使用开发板的蜂鸣器播放音乐,使用Quartus内的RO…...
Arduino Nano 引脚复用分析
近期开发的项目为气体传感器采集仪,综合需求,选取NANO作为主控,附属设备有 oled、旋转编码器、H桥板、蠕动泵、开关、航插等,主要是用现有接口怎么合理配置实现功能。 不管stm32 还是 Arduino 都要看清引脚图 D2 D3 引脚是两个外…...
Go 函数多返回值错误处理与error 类型介绍
Go 函数多返回值错误处理与error 类型介绍 文章目录 Go 函数多返回值错误处理与error 类型介绍一、error 类型与错误值构造1.1 Error 接口介绍1.2 构造错误值的方法1.2.1 使用errors包1.2.2 自定义错误类型 二、error 类型的好处2.1 第一点:统一了错误类型2.2 第二点…...
数论分块
本质就是利用取整分数值的块状分布。 UVA11526 H(n) 题意: 求 ∑ i 1 n n i \sum_{i1}^{n} \frac {n}{i} ∑i1nin。 解析: ⌊ n i ⌋ \lfloor \frac{n}{i} \rfloor ⌊in⌋ 只有 O ( n ) O(\sqrt n) O(n ) 种取值,考虑将相同值同…...
宏任务与微任务,代码执行顺序
js引擎工作进程是同步的。事件循环机制,事件队列。 脚本代码执行顺序,是先执行同步代码,遇到微任务,就把它推进任务队列中。每个宏任务完成后,再执行下一个宏任务。 宏任务有哪些: i/o读写 定时器setTi…...
正方形(Squares, ACM/ICPC World Finals 1990, UVa201)rust解法
有n行n列(2≤n≤9)的小黑点,还有m条线段连接其中的一些黑点。统计这些线段连成了多少个正方形(每种边长分别统计)。 行从上到下编号为1~n,列从左到右编号为1~n。边用H i j和V i j表示…...
【算法设计与分析qwl】伪码——顺序检索,插入排序
伪代码: 例子: 改进的顺序检索 Search(L,x)输入:数组L[1...n],元素从小到大排序,数x输出:若x在L中,输出x位置下标 j ,否则输出0 j <- 1 while j<n and x>L[j] do j <- j1 if x<…...
Uniapp路由拦截-自定义路由白名单
步骤一:新建routerIntercept.js文件 步骤二:routerIntercept文件中写入:(根据自己需要修改whiteList白名单中的页面路径和自己的逻辑处理) import Vue from vue // 白名单 const whiteList = [/pages/public/login,/pages/public/privacyAgreement, ]export default asy…...
在中国可以使用 HubSpot 吗?
当谈到市场营销和客户关系管理工具时,HubSpot通常是一家企业的首选。然而,对于许多中国的企业来说,一个重要的问题是:在中国可以使用HubSpot吗?这个问题涉及到不同的方面,包括政策法规、社交媒体平台、语言…...
Java的基础应用
Java是一种广泛应用于软件开发的编程语言,基础应用涵盖了很多方面。以下是Java的一些基础应用方面的介绍: 1. 控制流语句:Java中的程序流程控制语句分为选择语句和循环语句。选择语句包括if-else语句和switch语句,循环语句包括fo…...
【excel】列转行
列转行 工作中有一些数据是列表,现在需要转行 选表格内容:在excel表格中选中表格数据区域。点击复制:在选中表格区域处右击点击复制。点击选择性粘贴:在表格中鼠标右击点击选择性粘贴。勾选转置:在选择性粘勾选转置选…...
用Bing绘制「V我50」漫画;GPT-5业内交流笔记;LLM大佬的跳槽建议;Stable Diffusion生态全盘点第一课 | ShowMeAI日报
👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🔥 美国升级AI芯片出口禁令,13家中国GPU企业被列入实体清单 nytimes.com/2023/10/05/technology/chip-makers-china-lobbying…...
Java身份证实名认证-阿里云API 【姓名、身份证号】
1. 阿里云API市场 https://market.aliyun.com/products/57126001/cmapi00053442.html?spm5176.2020520132.101.3.a6217218nxxEiy#skuyuncode47442000022 购买对应套餐 2. 复制AppCode https://market.console.aliyun.com/imageconsole/index.htm#/?_kl85e10 云市场-已购买服…...
ND协议——无状态地址自动配置 (SLAAC)
参考学习:计算机网络 | 思科网络 | 无状态地址自动配置 (SLAAC) | 什么是SLAAC_瘦弱的皮卡丘的博客-CSDN博客 与 IPv4 类似,可以手动或动态配置 IPv6 全局单播地址。但是,动态分配 IPv6 全局单播地址有两种方法: 如图所示&#…...
iOS开发UITableView的使用,区别Plain模式和Grouped模式
简单赘述一下 的创建步骤 // 创建UITableView self.tableView [[UITableView alloc] initWithFrame:self.view.bounds style:UITableViewStylePlain]; // 设置数据源和代理 self.tableView.dataSource self; self.tableView.delegate self; // 注册自定义UITableViewCe…...
【Oracle APEX开发小技巧12】
有如下需求: 有一个问题反馈页面,要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据,方便管理员及时处理反馈。 我的方法:直接将逻辑写在SQL中,这样可以直接在页面展示 完整代码: SELECTSF.FE…...
23-Oracle 23 ai 区块链表(Blockchain Table)
小伙伴有没有在金融强合规的领域中遇见,必须要保持数据不可变,管理员都无法修改和留痕的要求。比如医疗的电子病历中,影像检查检验结果不可篡改行的,药品追溯过程中数据只可插入无法删除的特性需求;登录日志、修改日志…...
vscode(仍待补充)
写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh? debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...
解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八
现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...
使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...
第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
爬虫基础学习day2
# 爬虫设计领域 工商:企查查、天眼查短视频:抖音、快手、西瓜 ---> 飞瓜电商:京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空:抓取所有航空公司价格 ---> 去哪儿自媒体:采集自媒体数据进…...
GC1808高性能24位立体声音频ADC芯片解析
1. 芯片概述 GC1808是一款24位立体声音频模数转换器(ADC),支持8kHz~96kHz采样率,集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器,适用于高保真音频采集场景。 2. 核心特性 高精度:24位分辨率,…...
html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码
目录 一、👨🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨…...
动态 Web 开发技术入门篇
一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...
