PyTorch搭建神经网络入门教程
PyTorch搭建神经网络入门教程
在机器学习和深度学习中,神经网络是最常用的模型之一,而 PyTorch 是一个强大的深度学习框架,适合快速开发与研究。在这篇文章中,我们将带你一步步搭建一个简单的神经网络,并介绍 PyTorch 的基本用法。
1. 环境准备
首先,确保你已经安装了 PyTorch。可以使用 pip 安装:
pip install torch torchvision
安装完成后,你可以使用以下代码检查安装是否成功:
import torch
print(torch.__version__)
如果没有报错,说明 PyTorch 安装成功。
2. 定义数据集
在深度学习中,数据是非常重要的。我们将使用 PyTorch 内置的 torchvision 库中的 MNIST 数据集。这个数据集包含手写数字,适合用来训练图像分类模型。
import torch
from torchvision import datasets, transforms# 定义数据的预处理步骤
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载并加载训练集和测试集
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
上面的代码中,我们将数据进行了标准化处理,并使用 DataLoader 将数据分批加载。batch_size=64 意味着每次会处理64张图片。
3. 定义神经网络模型
我们将使用 torch.nn.Module 来定义一个简单的全连接神经网络。这个网络包含输入层、隐藏层和输出层,每层使用线性激活和 ReLU 非线性激活函数。
import torch.nn as nn
import torch.nn.functional as Fclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 定义两层全连接层self.fc1 = nn.Linear(28 * 28, 128) # 输入层,将28x28的图像展平为784个输入节点self.fc2 = nn.Linear(128, 64) # 隐藏层,128个神经元self.fc3 = nn.Linear(64, 10) # 输出层,10个分类(对应0-9的手写数字)def forward(self, x):# 展平图像数据(batch_size, 28*28)x = x.view(x.shape[0], -1)# 通过隐藏层和激活函数x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))# 输出层,直接输出为logitsx = self.fc3(x)return x
在上面的代码中,我们定义了一个三层的全连接神经网络。第一层将 28x28 像素的图像展平为一维数组,接下来通过隐藏层并应用 ReLU 激活函数,最终通过输出层得到 10 个类别的预测。
4. 定义损失函数与优化器
在训练神经网络之前,需要定义损失函数和优化器。我们将使用交叉熵损失函数和 Adam 优化器。
import torch.optim as optim# 实例化神经网络模型
model = SimpleNN()# 定义损失函数(交叉熵损失)
criterion = nn.CrossEntropyLoss()# 定义优化器(Adam)
optimizer = optim.Adam(model.parameters(), lr=0.003)
这里,CrossEntropyLoss 是常用于分类问题的损失函数,而 Adam 是一种自适应学习率的优化方法,适用于大多数深度学习任务。
5. 训练神经网络
我们可以开始训练神经网络了。训练过程通常包含前向传播、计算损失、反向传播和参数更新四个步骤。
epochs = 5 # 定义训练轮次
for epoch in range(epochs):running_loss = 0for images, labels in trainloader:# 将梯度清零optimizer.zero_grad()# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播并更新权重loss.backward()optimizer.step()# 记录损失running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(trainloader)}")
这里我们遍历训练数据集,通过每个批次的数据进行前向传播和反向传播,并根据损失更新模型的参数。
6. 测试模型
在模型训练完成后,我们需要在测试集上评估模型的表现。
correct = 0
total = 0# 不需要计算梯度
with torch.no_grad():for images, labels in testloader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')
通过 torch.no_grad(),我们可以避免在测试时计算梯度,提升性能。通过比较模型的预测结果与真实标签,我们可以计算测试集的准确率。
7. 小结
本文介绍了使用 PyTorch 搭建神经网络的基本步骤,主要包括数据处理、模型定义、损失函数与优化器的设置、训练与测试。虽然这个神经网络比较简单,但已经可以作为深度学习任务的入门示例。
接下来,你可以尝试:
- 调整神经网络的结构,如增加更多的隐藏层或使用不同的激活函数。
- 使用不同的数据集进行训练和测试。
- 结合 GPU 加速训练过程,提升模型的训练速度。
通过 PyTorch,搭建神经网络和进行深度学习研究变得更加简单高效。希望这篇教程对你有所帮助!
相关文章:
PyTorch搭建神经网络入门教程
PyTorch搭建神经网络入门教程 在机器学习和深度学习中,神经网络是最常用的模型之一,而 PyTorch 是一个强大的深度学习框架,适合快速开发与研究。在这篇文章中,我们将带你一步步搭建一个简单的神经网络,并介绍 PyTorch…...
你的电脑能不能安装windows 11,用这个软件检测下就知道了
为了应对Windows 11的推出,一款名为WhyNotWin11的创新型诊断软件应运而生。这个强大的工具不仅仅是一个简单的兼容性检测器,它更像是一位细心的数字医生,全方位评估您的计算机是否准备好迎接微软最新操作系统的挑战。 WhyNotWin11的功能远超…...
BF 算法
目录 BF算法 算法思路 完整代码 时间复杂度 查找所有起始位置 BF算法 BF算法:即暴力(Brute Force)算法,是一种模式匹配算法,将目标串 S 的第一个字符与模式串 T 的第一个字符进行匹配,若相等,则继续比较 S 的第二…...
SHOW-O——一款结合多模态理解和生成的单一Transformer
1.前言 大型语言模型 (LLM) 的重大进步激发了多模态大型语言模型 (MLLM) 的发展。早期的 MLLM 工作,例如 LLaVA、MiniGPT-4 和 InstructBLIP,展示了卓越的多模态理解能力。为了将 LLM 集成到多模态领域,这些研究探索了将预训练的模态特定编码…...
缓存框架JetCache源码解析-缓存变更通知机制
为什么需要缓存变更通知机制?如果我们使用的是本地缓存或者多级缓存(本地缓存远程缓存),当其中一个节点的本地缓存变更之后,为了保证缓存尽量的一致性,此时其他节点的本地缓存也需要去变更,这时…...
Android 设置特定Activity内容顶部显示在状态栏底部,也就是状态栏的下层 以及封装一个方法修改状态栏颜色
推荐:https://github.com/gyf-dev/ImmersionBar 在 Android 中要实现特定 Activity 内容顶部显示在状态栏底部以及封装方法修改状态栏颜色,可以通过以下步骤来完成: 一、让 Activity 内容显示在状态栏底部 在 AndroidManifest.xml 文件中,为特…...
用自己的数据集复现YOLOv5
yolov5已经出了很多版本了,这里我以目前最新的版本为例,先在官网下载源码:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite 然后下载预训练模型,需要哪个就点击哪个模型就行&am…...
如何在博客中插入其他的博客链接(超简单)最新版
如何在博客中插入其他的博客链接 1.复制自己要添加的网址(组合键:Ctrlc)2. 点击超链接按钮3. 粘贴自己刚才复制的网址(组合键:Ctrlv)并点击确定即可4.让博客链接显示中文5.点击蓝字即可打开 1.复制自己要添…...
JS通过递归函数来剔除树结构特定节点
最近在处理权限类问题过程中,遇到多次需要过滤一下来列表的数据,针对不同用户看到的数据不同。记录一下 我的数据大致是这样的: class UserTree {constructor() {this.userTreeData [// 示例数据{ nodeid: "1", nodename: "R…...
javayufa
1.变量、运算符、表达式、输入输出 编写一个简单的Java程序–手速练习 public class Main { public static void main(String[] args) { System.out.println("Hello World"); } } 三、语法基础 变量 变量必须先定义,才可以使用。不能重名。 变量定义的方…...
软考-高级系统分析师知识点-补充篇
云计算 云计算的体系结构由5部分组成,分别为应用层,平台层,资源层,用户访问层和管理层,云计算的本质是通过网络提供服务,所以其体系结构以服务为核心。 系统的可靠性技术---容错技术---冗余技术 容错是指系…...
JavaScript全面指南(四)
🌈个人主页:前端青山 🔥系列专栏:JavaScript篇 🔖人终将被年少不可得之物困其一生 依旧青山,本期给大家带来JavaScript篇专栏内容:JavaScript全面指南 目录 61、如何防止XSRF攻击 62、如何判断一个对象是否为数组&…...
2024年诺贝尔物理学奖的创新之举
对于2024年诺贝尔物理学奖的这一创新之举,我的观点可以从以下几点展开: 跨学科融合的里程碑:将诺贝尔物理学奖颁发给机器学习与神经网络领域的研究者,标志着科学界对跨学科合作和融合的认可达到新高度。这不仅体现了理论物理与计算…...
FileLink内外网文件交换——致力企业高效安全文件共享
随着数字化转型的推进,企业之间的文件交流需求日益增加。然而,传统的文件传输方式往往无法满足速度和安全性的双重要求。FileLink作为一款专注于跨网文件交换的工具,致力于为企业提供高效、安全的文件共享解决方案。 应用场景一:项…...
使用Python在Jupyter Notebook中显示Markdown文本
使用Python在Jupyter Notebook中显示Markdown文本 引言1. 导入必要的模块2. 定义一个函数来显示Markdown文本3. 使用print_md函数显示Markdown文本4. 总结 引言 作为一名Python初级程序员,你可能已经熟悉了Jupyter Notebook这个强大的工具。Jupyter Notebook不仅支…...
G1 GAN生成MNIST手写数字图像
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 G1 GAN生成MNIST手写数字图像 1. 生成对抗网络 (GAN) 简介 生成对抗网络 (GAN) 是一种通过“对抗性”学习生成数据的深度学习模型,通常用于生成…...
WPFDeveloper正式版发布
WPFDeveloper WPFDeveloper一个基于WPF自定义高级控件的WPF开发人员UI库,它提供了众多的自定义控件。 该项目的创建者和主要维护者是现役微软MVP 闫驚鏵: https://github.com/yanjinhuagood 该项目还有众多的维护者,详情可以访问github上的README&…...
实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”)
要实现鼠标经过某个元素时弹出提示框(通常称为“工具提示”或“悬浮提示”),你可以使用 JavaScript 结合 CSS 来创建这个效果。以下是详细步骤,包括 HTML、CSS 和 JavaScript 的代码示例。 HTML 结构 首先,创建一个简…...
【GAMES101笔记速查——Lecture 17 Materials and Appearances】
目录 1 材质和外观 1.1 自然界中,外观是光线和材质共同作用的结果 1.2 图形学中,什么是材质? 1.2.1 渲染方程严格正确,其中BRDF项决定了物体的材质 1.2.2 漫反射材质 (1)如何定义漫反射系数࿱…...
对于从vscode ssh到virtualBox的timeout记录
如题,解决方式如下: 1.把虚拟机关机退出来,在这个界面进行网络设置:选桥接网卡 2.然后再进系统,使用命令 ip addr查看如今的ip地址,应该和在本机里面看到的是一个网段 3.打开vscode,该干啥干…...
Browsershot终极教程:从零开始掌握Chrome无头浏览器
Browsershot终极教程:从零开始掌握Chrome无头浏览器 【免费下载链接】browsershot Convert HTML to an image, PDF or string 项目地址: https://gitcode.com/gh_mirrors/br/browsershot Browsershot是一款强大的工具,能够轻松实现HTML到图片、PD…...
GLM-4.6V-Flash-WEB效果展示:智能识别华硕/戴尔/联想BIOS界面
GLM-4.6V-Flash-WEB效果展示:智能识别华硕/戴尔/联想BIOS界面 1. 引言:BIOS界面识别的技术挑战 面对不同品牌电脑的BIOS设置界面,即使是经验丰富的技术人员也常常感到头疼。华硕的UEFI界面、戴尔的BIOS配置、联想的设置菜单——每个厂商都有…...
电源环路分析仪不会用?2026年硬件工程师的必备技能该补上了
电源环路分析仪不会用?2026年硬件工程师的必备技能该补上了实验室里,Buck电源刚调通,输出纹波看着也不错,但一上动态负载,输出电压就开始剧烈振荡。换了几组补偿参数,还是没找到症结所在。这时候,旁边有经验的前辈说了一句:"你测过环路稳定性吗?"说实话,…...
想做市场品牌策划?这3大秘诀让你的品牌脱颖而出!
行业痛点分析当前品牌策划领域面临诸多技术挑战。许多企业有产品无品牌,产品品质过硬、技术领先,但缺乏清晰的品牌定位与价值表达,陷入 “酒香也怕巷子深” 的困境,只能靠低价竞争。数据表明,约 60%的企业因品牌定位不…...
开始你的「一人公司」
未来大部分的公司,都将是「一个人 N 个 AI」的模式。 这意味着你不再需要很多前置条件,就能开始交付真正的产品。 阻碍你行动的不再是资金、团队或资源,而更多是——你有没有意愿。一、AI 会让认知成本趋近于零这是最关键的判断。电的出现让…...
侧信道攻击防御指南:从智能家居到云服务器的7个关键防护措施
侧信道攻击防御指南:从智能家居到云服务器的7个关键防护措施 在数字化浪潮席卷全球的今天,数据安全已成为企业生存的命脉。然而,当大多数安全团队还在与传统的网络攻击周旋时,一种更为隐蔽的威胁正在悄然蔓延——侧信道攻击。这种…...
零成本上手:在魔塔社区用免费GPU微调InternLM2.5-7B-Chat实战
1. 为什么选择魔塔社区进行大模型微调 第一次接触大模型微调的朋友们可能都有这样的困惑:动辄几十GB的模型参数,没有高端显卡怎么玩得转?这里就要给大家安利一个宝藏平台——阿里魔塔社区。我去年刚开始研究大模型时,也是被硬件门…...
Jupyter Notebook安全配置全攻略:如何在Linux上设置密码保护与远程访问
Jupyter Notebook安全配置全攻略:如何在Linux上设置密码保护与远程访问 在数据科学和机器学习领域,Jupyter Notebook已经成为不可或缺的工具,它提供了交互式编程环境,让开发者能够轻松地进行数据探索、可视化和模型训练。然而&…...
企业网络架构设计:如何选择核心交换机、汇聚交换机和接入交换机(含真实案例)
企业网络架构设计实战:核心层、汇聚层与接入层交换机选型指南 当一家200人规模的制造企业决定升级网络基础设施时,IT负责人发现市场上交换机的型号多达上千种,价格从几百元到几十万元不等。核心交换机是否必须选用思科Catalyst 9500系列&…...
信息化基础设施层建设
4.1 基础设施层建设 4.1.4 基础软件环境 基础软件环境的理论定位 基础软件环境是企业信息化建设的“操作系统”,其理论任务是为上层应用系统提供统一的运行环境、开发框架、数据服务和协作工具,包括操作系统、数据库、中间件、开发框架、版本控制、协…...
