当前位置: 首页 > article >正文

深度学习之神经网络框架搭建及模型优化

神经网络框架搭建及模型优化

目录

  • 神经网络框架搭建及模型优化
    • 1 数据及配置
      • 1.1 配置
      • 1.2 数据
      • 1.3 函数导入
      • 1.4 数据函数
      • 1.5 数据打包
    • 2 神经网络框架搭建
      • 2.1 框架确认
      • 2.2 函数搭建
      • 2.3 框架上传
    • 3 模型优化
      • 3.1 函数理解
      • 3.2 训练模型和测试模型代码
    • 4 最终代码测试
      • 4.1 SGD优化算法
      • 4.2 Adam优化算法
      • 4.3 多次迭代

1 数据及配置


1.1 配置

需要安装PyTorch,下载安装torch、torchvision、torchaudio,GPU需下载cuda版本,CPU可直接下载

cuda版本较大,最后通过控制面板pip install +存储地址离线下载,
CPU版本需再下载安装VC_redist.x64.exe,可下载上述三个后运行,通过报错网址直接下载安装

1.2 数据

使用的是 torchvision.datasets.MNIST的手写数据,包括特征数据和结果类别

1.3 函数导入

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

1.4 数据函数

train_data = datasets.MNIST(root='data',        # 数据集存储的根目录train=True,         # 加载训练集download=True,      # 如果数据集不存在,自动下载transform=ToTensor() # 将图像转换为张量
)
  • root 指定数据集存储的根目录。如果数据集不存在,会自动下载到这个目录。
  • train 决定加载训练集还是测试集。True 表示加载训练集,False 表示加载测试集。
  • download 如果数据集不在 root 指定的目录中,是否自动下载数据集。True 表示自动下载。
  • transform 对加载的数据进行预处理或转换。通常用于将数据转换为模型所需的格式,如将图像转换为张量。

1.5 数据打包

train_dataloader = DataLoader(train_data, batch_size=64)

  • train_data, 打包数据
  • batch_size=64,打包个数

代码展示:

import torch
print(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = train_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a = img.squeeze()
plt.show()train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)

运行结果:
在这里插入图片描述

在这里插入图片描述

调试查看:

在这里插入图片描述
:

2 神经网络框架搭建


2.1 框架确认

在搭建神经网络框架前,需先确认建立怎样的框架,目前并没有理论的指导,凭经验建立框架如下:

输入层:输入的图像数据(28*28)个神经元。
中间层1:全连接层,128个神经元,
中间层2:全连接层,256个神经元,
输出层:全连接层,10个神经元,对应10个类别。
需注意,中间层需使用激励函数激活,对累加数进行非线性的映射,以及forward前向传播过程的函数名不可更改

2.2 函数搭建

  • nn.Flatten() , 将输入展平为一维向量
  • nn.Linear(28*28, 128) ,全连接层,需注意每个连接层的输入输出需前后对应
  • torch.sigmoid(x),对中间层的输出应用Sigmoid激活函数
# 定义一个神经网络类,继承自 nn.Module
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()  # 调用父类 nn.Module 的构造函数# 定义网络层self.flatten = nn.Flatten()  # 将输入展平为一维向量,适用于将图像数据(如28x28)展平为784维self.hidden1 = nn.Linear(28*28, 128)  # 第一个全连接层,输入维度为784(28*28),输出维度为128self.hidden2 = nn.Linear(128, 256)    # 第二个全连接层,输入维度为128,输出维度为256self.out = nn.Linear(256, 10)         # 输出层,输入维度为256,输出维度为10(对应10个类别)# 定义前向传播过程def forward(self, x):x = self.flatten(x)       # 将输入数据展平x = self.hidden1(x)       # 通过第一个全连接层x = torch.sigmoid(x)      # 对第一个全连接层的输出应用Sigmoid激活函数x = self.hidden2(x)       # 通过第二个全连接层x = torch.sigmoid(x)      # 对第二个全连接层的输出应用Sigmoid激活函数x = self.out(x)           # 通过输出层return x                  # 返回最终的输出

2.3 框架上传

  • device = ‘cuda’ if torch.cuda.is_available() else ‘mps’ if torch.backends.mps.is_available() else ‘cpu’,确认设备, 检查是否有可用的GPU设备,如果有则使用GPU,否则使用CPU
  • model = NeuralNetwork().to(device),框架上传到GPU/CPU

模型输出展示:

在这里插入图片描述

3 模型优化


3.1 函数理解

  • optimizer = torch.optim.Adam(model.parameters(), lr=0.001),定义优化器:
    • Adam()使用Adam优化算法,也可为SGD等优化算法
    • model.parameters()为优化模型的参数
    • lr为学习率/梯度下降步长为0.001
  • loss_fn = nn.CrossEntropyLoss(pre,y),定义损失函数,使用交叉熵损失函数,适用于分类任务
    • pre,预测结果
    • y,真实结果
    • loss_fn.item(),当前损失值
  • model.train() ,将模型设置为训练模式,模型参数是可变
  • x, y = x.to(device), y.to(device),将数据移动到指定设备(GPU或CPU)
  • 反向传播:清零梯度,计算梯度,更新模型参数
    • optimizer.zero_grad()清零梯度缓存
      loss.backward(), 计算梯度
      optimizer.step()更新模型参数
  • model.eval(),将模型设置为评估模式模型参数是不可变
  • with torch.no_grad(),禁用梯度计算,在测试过程中不需要计算梯度

3.2 训练模型和测试模型代码

optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1train(train_dataloader,model,loss_fn,optimizer)def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')

4 最终代码测试


4.1 SGD优化算法

torch.optim.SGD(model.parameters(),lr=0.01)

代码展示:

import torchprint(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = train_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a = img.squeeze()
plt.show()train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28*28,128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256,10)def forward(self,x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device)
#
print(model)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')
#train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model,loss_fn)

运行结果:
在这里插入图片描述

4.2 Adam优化算法

自适应算法,torch.optim.Adam(model.parameters(),lr=0.01)

运行结果:
在这里插入图片描述

4.3 多次迭代

代码展示:

import torchprint(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = train_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a = img.squeeze()
plt.show()train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28*28,128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256,10)def forward(self,x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device)
#
print(model)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')
#train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model,loss_fn)
#
e = 30
for i in range(e):print(f'e: {i+1}\n------------------')train(train_dataloader, model, loss_fn, optimizer)
print('done')test(test_dataloader, model, loss_fn)

运行结果:
在这里插入图片描述

相关文章:

深度学习之神经网络框架搭建及模型优化

神经网络框架搭建及模型优化 目录 神经网络框架搭建及模型优化1 数据及配置1.1 配置1.2 数据1.3 函数导入1.4 数据函数1.5 数据打包 2 神经网络框架搭建2.1 框架确认2.2 函数搭建2.3 框架上传 3 模型优化3.1 函数理解3.2 训练模型和测试模型代码 4 最终代码测试4.1 SGD优化算法…...

excel 日期转换

需求如下: 在excel 里面输入一个4515,4表示年份,2024年,51表示该年的51周,5表示日,周日用1表示,周一用2表示,以此类推,需要转换为年份/月份/日期 若想用公式来实现这一转换&#x…...

Awtk 如何添加开机画面

场景 我们知道在工程中,Ui是一个线程,并且需要一直存在,当我们使用的开机画面在这个线程开启就直接展示的时候,因为awtk的界面是window_open入栈的,即首次打开的窗口会记录在top,往后的窗口会依次往后存放&…...

【设计模式】【行为型模式】命令模式(Command)

👋hi,我不是一名外包公司的员工,也不会偷吃茶水间的零食,我的梦想是能写高端CRUD 🔥 2025本人正在沉淀中… 博客更新速度 📫 欢迎V: flzjcsg2,我们共同讨论Java深渊的奥秘 &#x1f…...

C++模拟实现AVL树

目录 1.文章概括 2.AVL树概念 3.AVL树的性质 4.AVL树的插入 5.旋转控制 1.左单旋 2. 右单旋 3.左右双旋 4.右左双旋 6.全部代码 1.文章概括 本文适合理解平衡二叉树的读者阅读,因为AVL树是平衡二叉树的一种优化,其大部分实现逻辑与平衡二叉树是…...

推荐算法实践:movielens数据集

MovieLens 数据集介绍 MovieLens 数据集是由明尼苏达大学的GroupLens研究小组维护的一个广泛使用的电影评分数据集,主要用于推荐系统的研究。该数据集包含用户对电影的评分、标签以及其他相关信息,是电影推荐系统开发与研究的常用数据源。 数据集版本 …...

dynamic_cast和static_cast和const_cast

dynamic_cast 在 C 中的作用 dynamic_cast 是 C 运行时类型转换(RTTI, Run-Time Type Identification)的一部分,主要用于: 安全的多态类型转换检查类型的有效性向下转换(Downcasting)跨类层次的指针或引用…...

React进行路由跳转的方法汇总

在 React 中进行路由跳转有多种方法,具体取决于你使用的路由库和版本。以下是常见的路由跳转方法汇总,主要基于 react-router-dom 库。 1. 使用 useNavigate 钩子(适用于 react-router-dom v6) useNavigate 是 react-router-dom…...

python卷积神经网络人脸识别示例实现详解

目录 一、准备 1)使用pytorch 2)安装pytorch 3)准备训练和测试资源 二、卷积神经网络的基本结构 三、代码实现 1)导入库 2)数据预处理 3)加载数据 4)构建一个卷积神经网络 5&#xff0…...

以Unity6.0为例,如何在Unity中开启DLSS功能

DLSS DLSS(NVIDIA 深度学习超级采样):NVIDIA DLSS 是一套由 GeForce RTX™ Tensor Core 提供支持的神经渲染技术,可提高帧率,同时提供可与原生分辨率相媲美的清晰、高质量图像。目前最新突破DLSS 4 带来了新的多帧…...

CSDN 大模型 笔记

AI 3大范式:计算 发发 交互 L1 生成代码 复制到IDEA (22年12-23年6,7月份) L2 部分自动编程 定义class 设计interface 让其填充实现 (23年7,8月份) L3 通用任务 CRUD (24年) L4 高度自动编程 通用领域专有任务&#xf…...

Flink怎么保证Exactly - Once 语义

Exactly - Once 语义是消息处理领域中的一种严格数据处理语义,指每条数据都只会被精确消费和处理一次,既不会丢失,也不会重复。 以下从消息传递语义对比、实现方式、应用场景等方面详细介绍: 与其他消息传递语义对比 在消息传递…...

AOS安装及操作演示

文章目录 一、安装node1.1 在 macOS 上管理 Node版本1.1.1 安装 nvm1.1.2 验证 nvm 是否安装成功1.1.3 使用 nvm 安装/切换 Node.js 版本1.1.4 卸载 Node.js 版本 1.2 在 windows 上管理 Node版本1.2.1 安装 nvm-windows1.2.2 安装 Node.js 版本1.2.3 切换 Node.js 版本1.2.4 卸…...

Python 操作 MongoDB 教程

一、引言 在当今数字化时代,数据的存储和管理至关重要。传统的关系型数据库在处理一些复杂场景时可能会显得力不从心,而 NoSQL 数据库应运而生。MongoDB 作为一款开源的、面向文档的 NoSQL 数据库,凭借其高性能、高可扩展性和灵活的数据模型…...

Stability AI 联合 UIUC 提出单视图 3D 重建方法SPAR3D,可0.7秒完成重建并支持交互式用户编辑。

Stability AI 联合 UIUC 提出一种简单而有效的单视图 3D 重建方法 SPAR3D,这是一款最先进的 3D 重建器,可以从单视图图像重建高质量的 3D 网格。SPAR3D 的重建速度很快,只需 0.7 秒,并支持交互式用户编辑。 相关链接 论文&#xf…...

网易易盾接入DeepSeek,数字内容安全“智”理能力全面升级

今年农历新年期间,全球AI领域再度掀起了一波革命性浪潮,国产通用大模型DeepSeek凭借其强大的多场景理解与内容生成能力迅速“出圈”,彻底改写全球人工智能产业的格局。 作为国内领先的数字内容风控服务商,网易易盾一直致力于探索…...

自动驾驶---如何打造一款属于自己的自动驾驶系统

在笔者的专栏《自动驾驶Planning决策规划》中,主要讲解了行车的相关知识,从Routing,到Behavior Planning,再到Motion Planning,以及最后的Control,笔者都做了相关介绍,其中主要包括算法在量产上…...

局域网使用Ollama(Linux)

解决局域网无法连接Ollama服务的问题 在搭建和使用Ollama服务的过程中,可能会遇到局域网内无法连接的情况。经过排查发现,若开启了代理软件,尤其是Hiddify,会导致此问题。这一发现耗费了我数小时的排查时间,希望能给大…...

聚焦 AUTO TECH China 2025,共探汽车内外饰新未来Automotive Interiors

全球汽车产业蓬勃发展的大背景下,汽车内外饰作为汽车重要组成部分,其市场需求与技术创新不断推动着行业变革。2025年11月20日至22日,一场备受瞩目的行业盛会 ——AUTO TECH China 2025 广州国际汽车内外饰技术展览会将在广州保利世贸博览馆盛…...

Moretl 增量文件采集工具

永久免费: <下载> <使用说明> 用途 定时全量或增量采集工控机,电脑文件或日志. 优势 开箱即用: 解压直接运行.不需额外下载.管理设备: 后台统一管理客户端.无人值守: 客户端自启动,自更新.稳定安全: 架构简单,兼容性好,通过授权控制访问. 架构 技术架构: Asp…...

支持多种网络数据库格式的自动化转换工具——VisualXML

一、VisualXML软件介绍 对于DBC、ARXML……文件的编辑、修改等繁琐操作&#xff0c;WINDHILL风丘科技开发的总线设计工具——VisualXML&#xff0c;可轻松解决这一问题&#xff0c;提升工作效率。 VisualXML是一个强大且基于Excel表格生成多种网络数据库文件的转换工具&#…...

mysql8 用C++源码角度看客户端发起sql网络请求,并处理sql命令

MySQL 8 的 C 源码中&#xff0c;处理网络请求和 SQL 命令的流程涉及多个函数和类。以下是关键的函数和类&#xff0c;以及它们的作用&#xff1a; 1. do_command 函数 do_command 函数是 MySQL 服务器中处理客户端命令的核心函数。它从客户端读取一个命令并执行。这个函数在…...

四、OSG学习笔记-基础图元

前一章节&#xff1a; 三、OSG学习笔记-应用基础-CSDN博客https://blog.csdn.net/weixin_36323170/article/details/145514021 代码&#xff1a;CuiQingCheng/OsgStudy - Gitee.com 一、绘制盒子模型 下面一个简单的 demo #include<windows.h> #include<osg/Node&…...

使用vllm docker容器部署大语言模型

说明 最近deepseek比较火&#xff0c;我在一台4卡4090的服务器上尝试部署了一下&#xff0c;记录下部署步骤。 安装过程 安卓docker和nvidia-container-toolkit 安装19.03版本以上的docker-ce即可。安装步骤参考清华docker源上的安装步骤&#xff1a;Docker CE 软件仓库 为…...

window 安装GitLab服务器笔记

目录 视频&#xff1a; 资源&#xff1a; Linux CeneOS7&#xff1a; VMware&#xff1a; Linux无法安装 yum install vim -y 1.手动创建目录 2.下载repo PS 补充视频不可复制的代码 安装GitLab *修改root用户密码相关&#xff08;我卡在第一步就直接放弃了这个操作&…...

MySQL数据库入门到大蛇尚硅谷宋红康老师笔记 基础篇 part 10

第10章_创建和管理表 DDL&#xff1a;数据定义语言。CREATE \ALTER\ DROP \RENAME TRUNCATE DML&#xff1a;数据操作语言。INSERT \DELETE \UPDATE \SELECT&#xff08;重中之重&#xff09; DCL&#xff1a;数据控制语言。COMMIT \…...

react项目引入tailwindcss不生效解决方案

根据tailwindcss官网的操作步骤下来&#xff0c;样式未生效&#xff0c;且未报错&#xff0c;看了挺多的资料&#xff0c;还是并未解决。 后面在另一个项目尝试时&#xff0c;报了下面的问题&#xff1a; Error: PostCSS plugin tailwindcss requires PostCSS 8 根据这个链接…...

Expo运行模拟器失败错误解决(xcrun simctl )

根据你的描述&#xff0c;问题主要涉及两个方面&#xff1a;xcrun simctl 错误和 Expo 依赖版本不兼容。以下是针对这两个问题的解决方案&#xff1a; 解决 xcrun simctl 错误 错误代码 72 通常表明 simctl 工具未正确配置或路径未正确设置。以下是解决步骤&#xff1a; 确保 …...

【系统架构设计师】体系结构文档化

目录 1. 说明2. 重要性3. 主要内容4. 编写原则5. 实践建议6. 例题6.1 例题1 1. 说明 1.绝大多数的体系结构都是抽象的&#xff0c;由一些概念上的构建组成。2.层的概念在任何程序设计语言中都不存在。3.要让系统分析员和程序员去实现体系结构&#xff0c;还必须将体系结构进行…...

【0403】Postgres内核 检查(procArray )给定 db 是否有其他 backend process 正在运行

文章目录 1. 给定 db 是否有其他 backend 正在运行1.1 获取 allPgXact[] 索引1.1.1 MyProc 中 databaseId 初始化实现1.2 allProcs[] 中各 databaseId 判断1. 给定 db 是否有其他 backend 正在运行 CREATE DATABASE 语句创建用户指定 数据库名(database-name)时候, 会通过 …...