如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets# 定义MLP模型
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()# 创建一个顺序的层序列:包括一个扁平化层、两个全连接层和ReLU激活self.layers = nn.Sequential(nn.Flatten(), # 将28x28的图像扁平化为784维向量nn.Linear(28 * 28, 512), # 第一个全连接层,784->512nn.ReLU(), # ReLU激活函数nn.Linear(512, 256), # 第二个全连接层,512->256nn.ReLU(), # ReLU激活函数nn.Linear(256, 10) # 第三个全连接层,256->10 (输出10个类别))def forward(self, x):return self.layers(x) # 定义前向传播# 加载FashionMNIST数据集
# 定义图像的预处理:转换为Tensor并标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 下载FashionMNIST数据并应用转换
dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)# 划分数据集为训练集和验证集
train_len = int(0.8 * len(dataset)) # 计算80%的长度作为训练数据
val_len = len(dataset) - train_len # 剩下的20%作为验证数据
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # 训练数据加载器,批量大小64,打乱数据
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # 验证数据加载器,批量大小64,不打乱# 初始化模型、损失函数和优化器
model = MLP() # 创建MLP模型实例
criterion = nn.CrossEntropyLoss() # 定义交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # 使用Adam优化器# 训练模型
epochs = 5 # 定义训练5个epochs
for epoch in range(epochs):model.train() # 将模型设置为训练模式for inputs, labels in train_loader: # 从训练加载器中获取批次数据outputs = model(inputs) # 前向传播loss = criterion(outputs, labels) # 计算损失optimizer.zero_grad() # 清除之前的梯度loss.backward() # 反向传播,计算梯度optimizer.step() # 更新权重# 在每个epoch结束时验证模型性能model.eval() # 将模型设置为评估模式total_correct = 0with torch.no_grad(): # 不计算梯度,节省内存和计算量for inputs, labels in val_loader: # 从验证加载器中获取批次数据outputs = model(inputs) # 前向传播_, predicted = outputs.max(1) # 获取预测的类别total_correct += (predicted == labels).sum().item() # 统计正确的预测数量accuracy = total_correct / val_len # 计算验证准确性print(f"Epoch {epoch + 1}/{epochs} - Validation accuracy: {accuracy:.4f}") # 打印验证准确性
nn.Flatten() 是一个特殊的层,它将多维的输入数据“展平”为一维数据。这在处理图像数据时尤为常见,因为图像通常是多维的(例如,一个大小为28x28的灰度图像在PyTorch中会有一个形状为[28, 28]的张量)。
在神经网络的某些层,特别是全连接层(如nn.Linear)之前,通常需要对数据进行扁平化处理。因为全连接层期望其输入是一维的(或者更准确地说,它期望输入的最后一个维度对应于特征,其他维度对应于数据的批次)。
为了更具体,让我们看一个例子:
考虑一个大小为[batch_size, 28, 28]的张量,这可以看作是一个batch_size数量的28x28图像的批次。当我们传递这个批次的图像到一个nn.Linear(28*28, 512)层时,我们需要先将图像展平。也就是说,每个28x28的图像需要转换为长度为784的一维向量。因此,输入数据的形状会从[batch_size, 28, 28]变为[batch_size, 784]。
nn.Flatten()就是做这个转换的。在这个特定的例子中,它会将[batch_size, 28, 28]的形状转换为[batch_size, 784]。
总结一下:nn.Flatten()用于将多维输入数据转换为一维,从而使其可以作为全连接层(如nn.Linear)的输入。
-
transforms.Compose:
这是一个简单的方式来链接(组合)多个图像转换操作。它会按照提供的顺序执行列表中的每个转换。 -
transforms.ToTensor():
这个转换将PIL图像或NumPy的ndarray转换为FloatTensor。并且它将图像的像素值范围从0-255变为0-1。简言之,它为我们完成了数据类型和值范围的转换。 -
transforms.Normalize((0.5,), (0.5,)):
这个转换标准化张量图像。给定的参数是均值和标准差。在这里,均值和标准差都是0.5。
使用给定的均值和标准差,这会将值范围从[0,1]转换为[-1,1]。
整个transform的目的是:
- 将图像数据从PIL格式转换为PyTorch张量格式。
- 将像素值从[0,255]范围转换为[0,1]范围。
- 使用给定的均值和标准差进一步标准化像素值,使其范围为[-1,1]。
初始化模型、损失函数和优化器
-
model = MLP():
- 这里我们实例化了我们之前定义的MLP类,从而创建了一个多层感知器(MLP)模型。
-
criterion = nn.CrossEntropyLoss():
- 在分类任务中,交叉熵损失函数 (CrossEntropyLoss) 是最常用的损失函数之一。它衡量真实标签和预测之间的差异。
- 注意:CrossEntropyLoss在内部执行softmax操作,因此模型输出应该是未经softmax处理的原始分数(logits)。
-
optimizer = optim.Adam(model.parameters(), lr=0.001):
- 优化器负责更新模型的权重,基于计算的梯度来减少损失。
- Adam是一种流行的优化器,它结合了两种扩展的随机梯度下降:Adaptive Gradients 和 Momentum。
- model.parameters()是传递给优化器的,它告诉优化器应该优化/更新哪些权重。
- lr=0.001定义了学习率,这是一个超参数,表示每次权重更新的步长大小。
常见的相关资料解答
- 模型 (在torch.nn中):
除了基本的MLP外,PyTorch提供了很多预定义的层和模型,常见的包括:
Convolutional Neural Networks (CNNs):nn.Conv2d: 2D卷积层,常用于图像处理。nn.Conv3d: 3D卷积层,常用于视频处理或医学图像。nn.MaxPool2d: 最大池化层。Recurrent Neural Networks (RNNs):nn.RNN: 基本的RNN层。nn.LSTM: 长短时记忆网络。nn.GRU: 门控循环单元。Transformer Architecture:nn.Transformer: 用于自然语言处理任务的Transformer模型。Batch Normalization, Dropout等:nn.BatchNorm2d: 批量归一化。nn.Dropout: 防止过拟合的正则化方法。
- 损失函数 (在torch.nn中):
常见的损失函数有:
Classification:nn.CrossEntropyLoss: 用于分类任务的交叉熵损失。nn.BCEWithLogitsLoss: 用于二分类任务的二元交叉熵损失,包括内部的sigmoid操作。nn.MultiLabelSoftMarginLoss: 用于多标签分类任务。Regression:nn.MSELoss: 均方误差,用于回归任务。nn.L1Loss: L1误差。Generative models:nn.KLDivLoss: Kullback-Leibler散度,常用于生成模型。
- 优化器 (在torch.optim中):
常见的优化器有:
optim.SGD: 随机梯度下降。
optim.Adam: 一个非常受欢迎的优化器,结合了AdaGrad和RMSProp的特点。
optim.RMSprop: 常用于深度学习任务。
optim.Adagrad: 自适应学习率优化器。
optim.Adadelta: 类似于Adagrad,但试图解决其快速降低学习率的问题。
optim.AdamW: Adam的变种,加入了权重衰减。

每文一语
学习是不断的发展的
相关文章:
如何使用pytorch定义一个多层感知神经网络模型——拓展到所有模型知识
# 导入必要的库 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split import torchvision.transforms as transforms import torchvision.datasets as datasets# 定义MLP模型 class MLP(nn.Module):def __…...
为什么引入SVG文件,给它定义属性不生效原理分析
背景: 我使用antd 的Icon组件引入SVG图片,但给svg图片定义styles样式时,不生效,为什么呢? 我们平时用antd组件库的 < ArrowRightOutlined style{{color: red }}>时为什么会生效呢,但我图一这样定义就…...
Integer包装类常用方法和属性
包装类 什么是包装类Integer包装类常用方法和属性 什么是包装类 Java 包装类是指为了方便处理基本数据类型而提供的对应的引用类型。Java 提供了八个基本数据类型(boolean、byte、short、int、long、float、double、char),每个基本数据类型对…...
基于Spring boot轻松实现一个多数据源框架
Spring Boot 提供了 Data JPA 的包,允许你使用类似 ORM 的接口连接到 RDMS。它很容易使用和实现,只需要在 pom.xml 中添加一个条目(如果使用的是 Maven,Gradle 则是在 build.gradle 文件中)。 <dependencies>&l…...
vue前端实现打印功能并约束纸张大小---调用浏览器打印功能打印页面部分元素并固定纸张大小
需求是打印指定div实现小票打印功能。调用浏览器的自带打印功能只能实现打印可视区域,所以这里采用截图新窗口打开打印去实现此需求。 1.安装html2canvas库实现截图功能 npm install html2canvas --save2.在需要进行截图和打印的组件中,引入html2canvas…...
音乐播放器蜂鸣器ROM存储歌曲verilog,代码/视频
名称:音乐播放器蜂鸣器ROM存储歌曲 软件:Quartus 语言:Verilog 代码功能: 设计音乐播放器,要求至少包含2首歌曲,使用按键切换歌曲,使用开发板的蜂鸣器播放音乐,使用Quartus内的RO…...
Arduino Nano 引脚复用分析
近期开发的项目为气体传感器采集仪,综合需求,选取NANO作为主控,附属设备有 oled、旋转编码器、H桥板、蠕动泵、开关、航插等,主要是用现有接口怎么合理配置实现功能。 不管stm32 还是 Arduino 都要看清引脚图 D2 D3 引脚是两个外…...
Go 函数多返回值错误处理与error 类型介绍
Go 函数多返回值错误处理与error 类型介绍 文章目录 Go 函数多返回值错误处理与error 类型介绍一、error 类型与错误值构造1.1 Error 接口介绍1.2 构造错误值的方法1.2.1 使用errors包1.2.2 自定义错误类型 二、error 类型的好处2.1 第一点:统一了错误类型2.2 第二点…...
数论分块
本质就是利用取整分数值的块状分布。 UVA11526 H(n) 题意: 求 ∑ i 1 n n i \sum_{i1}^{n} \frac {n}{i} ∑i1nin。 解析: ⌊ n i ⌋ \lfloor \frac{n}{i} \rfloor ⌊in⌋ 只有 O ( n ) O(\sqrt n) O(n ) 种取值,考虑将相同值同…...
宏任务与微任务,代码执行顺序
js引擎工作进程是同步的。事件循环机制,事件队列。 脚本代码执行顺序,是先执行同步代码,遇到微任务,就把它推进任务队列中。每个宏任务完成后,再执行下一个宏任务。 宏任务有哪些: i/o读写 定时器setTi…...
正方形(Squares, ACM/ICPC World Finals 1990, UVa201)rust解法
有n行n列(2≤n≤9)的小黑点,还有m条线段连接其中的一些黑点。统计这些线段连成了多少个正方形(每种边长分别统计)。 行从上到下编号为1~n,列从左到右编号为1~n。边用H i j和V i j表示…...
【算法设计与分析qwl】伪码——顺序检索,插入排序
伪代码: 例子: 改进的顺序检索 Search(L,x)输入:数组L[1...n],元素从小到大排序,数x输出:若x在L中,输出x位置下标 j ,否则输出0 j <- 1 while j<n and x>L[j] do j <- j1 if x<…...
Uniapp路由拦截-自定义路由白名单
步骤一:新建routerIntercept.js文件 步骤二:routerIntercept文件中写入:(根据自己需要修改whiteList白名单中的页面路径和自己的逻辑处理) import Vue from vue // 白名单 const whiteList = [/pages/public/login,/pages/public/privacyAgreement, ]export default asy…...
在中国可以使用 HubSpot 吗?
当谈到市场营销和客户关系管理工具时,HubSpot通常是一家企业的首选。然而,对于许多中国的企业来说,一个重要的问题是:在中国可以使用HubSpot吗?这个问题涉及到不同的方面,包括政策法规、社交媒体平台、语言…...
Java的基础应用
Java是一种广泛应用于软件开发的编程语言,基础应用涵盖了很多方面。以下是Java的一些基础应用方面的介绍: 1. 控制流语句:Java中的程序流程控制语句分为选择语句和循环语句。选择语句包括if-else语句和switch语句,循环语句包括fo…...
【excel】列转行
列转行 工作中有一些数据是列表,现在需要转行 选表格内容:在excel表格中选中表格数据区域。点击复制:在选中表格区域处右击点击复制。点击选择性粘贴:在表格中鼠标右击点击选择性粘贴。勾选转置:在选择性粘勾选转置选…...
用Bing绘制「V我50」漫画;GPT-5业内交流笔记;LLM大佬的跳槽建议;Stable Diffusion生态全盘点第一课 | ShowMeAI日报
👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🔥 美国升级AI芯片出口禁令,13家中国GPU企业被列入实体清单 nytimes.com/2023/10/05/technology/chip-makers-china-lobbying…...
Java身份证实名认证-阿里云API 【姓名、身份证号】
1. 阿里云API市场 https://market.aliyun.com/products/57126001/cmapi00053442.html?spm5176.2020520132.101.3.a6217218nxxEiy#skuyuncode47442000022 购买对应套餐 2. 复制AppCode https://market.console.aliyun.com/imageconsole/index.htm#/?_kl85e10 云市场-已购买服…...
ND协议——无状态地址自动配置 (SLAAC)
参考学习:计算机网络 | 思科网络 | 无状态地址自动配置 (SLAAC) | 什么是SLAAC_瘦弱的皮卡丘的博客-CSDN博客 与 IPv4 类似,可以手动或动态配置 IPv6 全局单播地址。但是,动态分配 IPv6 全局单播地址有两种方法: 如图所示&#…...
iOS开发UITableView的使用,区别Plain模式和Grouped模式
简单赘述一下 的创建步骤 // 创建UITableView self.tableView [[UITableView alloc] initWithFrame:self.view.bounds style:UITableViewStylePlain]; // 设置数据源和代理 self.tableView.dataSource self; self.tableView.delegate self; // 注册自定义UITableViewCe…...
铭豹扩展坞 USB转网口 突然无法识别解决方法
当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...
Debian系统简介
目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版ÿ…...
【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器
——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的一体化测试平台,覆盖应用全生命周期测试需求,主要提供五大核心能力: 测试类型检测目标关键指标功能体验基…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
AtCoder 第409场初级竞赛 A~E题解
A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...
VM虚拟机网络配置(ubuntu24桥接模式):配置静态IP
编辑-虚拟网络编辑器-更改设置 选择桥接模式,然后找到相应的网卡(可以查看自己本机的网络连接) windows连接的网络点击查看属性 编辑虚拟机设置更改网络配置,选择刚才配置的桥接模式 静态ip设置: 我用的ubuntu24桌…...
【JVM面试篇】高频八股汇总——类加载和类加载器
目录 1. 讲一下类加载过程? 2. Java创建对象的过程? 3. 对象的生命周期? 4. 类加载器有哪些? 5. 双亲委派模型的作用(好处)? 6. 讲一下类的加载和双亲委派原则? 7. 双亲委派模…...
沙箱虚拟化技术虚拟机容器之间的关系详解
问题 沙箱、虚拟化、容器三者分开一一介绍的话我知道他们各自都是什么东西,但是如果把三者放在一起,它们之间到底什么关系?又有什么联系呢?我不是很明白!!! 就比如说: 沙箱&#…...
企业大模型服务合规指南:深度解析备案与登记制度
伴随AI技术的爆炸式发展,尤其是大模型(LLM)在各行各业的深度应用和整合,企业利用AI技术提升效率、创新服务的步伐不断加快。无论是像DeepSeek这样的前沿技术提供者,还是积极拥抱AI转型的传统企业,在面向公众…...
二维FDTD算法仿真
二维FDTD算法仿真,并带完全匹配层,输入波形为高斯波、平面波 FDTD_二维/FDTD.zip , 6075 FDTD_二维/FDTD_31.m , 1029 FDTD_二维/FDTD_32.m , 2806 FDTD_二维/FDTD_33.m , 3782 FDTD_二维/FDTD_34.m , 4182 FDTD_二维/FDTD_35.m , 4793...
