当前位置: 首页 > news >正文

深度学习之神经网络框架搭建及模型优化

神经网络框架搭建及模型优化

目录

  • 神经网络框架搭建及模型优化
    • 1 数据及配置
      • 1.1 配置
      • 1.2 数据
      • 1.3 函数导入
      • 1.4 数据函数
      • 1.5 数据打包
    • 2 神经网络框架搭建
      • 2.1 框架确认
      • 2.2 函数搭建
      • 2.3 框架上传
    • 3 模型优化
      • 3.1 函数理解
      • 3.2 训练模型和测试模型代码
    • 4 最终代码测试
      • 4.1 SGD优化算法
      • 4.2 Adam优化算法
      • 4.3 多次迭代

1 数据及配置


1.1 配置

需要安装PyTorch,下载安装torch、torchvision、torchaudio,GPU需下载cuda版本,CPU可直接下载

cuda版本较大,最后通过控制面板pip install +存储地址离线下载,
CPU版本需再下载安装VC_redist.x64.exe,可下载上述三个后运行,通过报错网址直接下载安装

1.2 数据

使用的是 torchvision.datasets.MNIST的手写数据,包括特征数据和结果类别

1.3 函数导入

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

1.4 数据函数

train_data = datasets.MNIST(root='data',        # 数据集存储的根目录train=True,         # 加载训练集download=True,      # 如果数据集不存在,自动下载transform=ToTensor() # 将图像转换为张量
)
  • root 指定数据集存储的根目录。如果数据集不存在,会自动下载到这个目录。
  • train 决定加载训练集还是测试集。True 表示加载训练集,False 表示加载测试集。
  • download 如果数据集不在 root 指定的目录中,是否自动下载数据集。True 表示自动下载。
  • transform 对加载的数据进行预处理或转换。通常用于将数据转换为模型所需的格式,如将图像转换为张量。

1.5 数据打包

train_dataloader = DataLoader(train_data, batch_size=64)

  • train_data, 打包数据
  • batch_size=64,打包个数

代码展示:

import torch
print(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = train_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a = img.squeeze()
plt.show()train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)

运行结果:
在这里插入图片描述

在这里插入图片描述

调试查看:

在这里插入图片描述
:

2 神经网络框架搭建


2.1 框架确认

在搭建神经网络框架前,需先确认建立怎样的框架,目前并没有理论的指导,凭经验建立框架如下:

输入层:输入的图像数据(28*28)个神经元。
中间层1:全连接层,128个神经元,
中间层2:全连接层,256个神经元,
输出层:全连接层,10个神经元,对应10个类别。
需注意,中间层需使用激励函数激活,对累加数进行非线性的映射,以及forward前向传播过程的函数名不可更改

2.2 函数搭建

  • nn.Flatten() , 将输入展平为一维向量
  • nn.Linear(28*28, 128) ,全连接层,需注意每个连接层的输入输出需前后对应
  • torch.sigmoid(x),对中间层的输出应用Sigmoid激活函数
# 定义一个神经网络类,继承自 nn.Module
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()  # 调用父类 nn.Module 的构造函数# 定义网络层self.flatten = nn.Flatten()  # 将输入展平为一维向量,适用于将图像数据(如28x28)展平为784维self.hidden1 = nn.Linear(28*28, 128)  # 第一个全连接层,输入维度为784(28*28),输出维度为128self.hidden2 = nn.Linear(128, 256)    # 第二个全连接层,输入维度为128,输出维度为256self.out = nn.Linear(256, 10)         # 输出层,输入维度为256,输出维度为10(对应10个类别)# 定义前向传播过程def forward(self, x):x = self.flatten(x)       # 将输入数据展平x = self.hidden1(x)       # 通过第一个全连接层x = torch.sigmoid(x)      # 对第一个全连接层的输出应用Sigmoid激活函数x = self.hidden2(x)       # 通过第二个全连接层x = torch.sigmoid(x)      # 对第二个全连接层的输出应用Sigmoid激活函数x = self.out(x)           # 通过输出层return x                  # 返回最终的输出

2.3 框架上传

  • device = ‘cuda’ if torch.cuda.is_available() else ‘mps’ if torch.backends.mps.is_available() else ‘cpu’,确认设备, 检查是否有可用的GPU设备,如果有则使用GPU,否则使用CPU
  • model = NeuralNetwork().to(device),框架上传到GPU/CPU

模型输出展示:

在这里插入图片描述

3 模型优化


3.1 函数理解

  • optimizer = torch.optim.Adam(model.parameters(), lr=0.001),定义优化器:
    • Adam()使用Adam优化算法,也可为SGD等优化算法
    • model.parameters()为优化模型的参数
    • lr为学习率/梯度下降步长为0.001
  • loss_fn = nn.CrossEntropyLoss(pre,y),定义损失函数,使用交叉熵损失函数,适用于分类任务
    • pre,预测结果
    • y,真实结果
    • loss_fn.item(),当前损失值
  • model.train() ,将模型设置为训练模式,模型参数是可变
  • x, y = x.to(device), y.to(device),将数据移动到指定设备(GPU或CPU)
  • 反向传播:清零梯度,计算梯度,更新模型参数
    • optimizer.zero_grad()清零梯度缓存
      loss.backward(), 计算梯度
      optimizer.step()更新模型参数
  • model.eval(),将模型设置为评估模式模型参数是不可变
  • with torch.no_grad(),禁用梯度计算,在测试过程中不需要计算梯度

3.2 训练模型和测试模型代码

optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1train(train_dataloader,model,loss_fn,optimizer)def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')

4 最终代码测试


4.1 SGD优化算法

torch.optim.SGD(model.parameters(),lr=0.01)

代码展示:

import torchprint(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = train_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a = img.squeeze()
plt.show()train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28*28,128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256,10)def forward(self,x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device)
#
print(model)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')
#train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model,loss_fn)

运行结果:
在这里插入图片描述

4.2 Adam优化算法

自适应算法,torch.optim.Adam(model.parameters(),lr=0.01)

运行结果:
在这里插入图片描述

4.3 多次迭代

代码展示:

import torchprint(torch.__version__)
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensortrain_data = datasets.MNIST(root = 'data',train = True,download = True,transform = ToTensor()
)
test_data = datasets.MNIST(root = 'data',train = False,download = True,transform = ToTensor()
)
print(len(train_data))
print(len(test_data))
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,label = train_data[i+59000]figure.add_subplot(3,3,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a = img.squeeze()
plt.show()train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader= DataLoader(test_data, batch_size=64)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28*28,128)self.hidden2 = nn.Linear(128, 256)self.out = nn.Linear(256,10)def forward(self,x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.out(x)return x
model = NeuralNetwork().to(device)
#
print(model)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
loss_fn = nn.CrossEntropyLoss()
def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num %100 ==0:print(f'loss: {loss_value:>7f}  [number: {batch_size_num}]')batch_size_num +=1def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct +=(pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1)==y)b = (pred.argmax(1)==y).type(torch.float)test_loss /=num_batchescorrect /= sizeprint(f'test result: \n Accuracy: {(100*correct)}%, Avg loss:{test_loss}')
#train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model,loss_fn)
#
e = 30
for i in range(e):print(f'e: {i+1}\n------------------')train(train_dataloader, model, loss_fn, optimizer)
print('done')test(test_dataloader, model, loss_fn)

运行结果:
在这里插入图片描述

相关文章:

深度学习之神经网络框架搭建及模型优化

神经网络框架搭建及模型优化 目录 神经网络框架搭建及模型优化1 数据及配置1.1 配置1.2 数据1.3 函数导入1.4 数据函数1.5 数据打包 2 神经网络框架搭建2.1 框架确认2.2 函数搭建2.3 框架上传 3 模型优化3.1 函数理解3.2 训练模型和测试模型代码 4 最终代码测试4.1 SGD优化算法…...

采用分步式无线控制架构实现水池液位自动化管理

以下是基于巨控GRM241Q-4D4I4QHE模块的完整技术方案,采用分步式无线控制架构实现水池液位自动化管理: 一、系统架构设计 硬件部署 山顶单元 GRM241Q模块(带4G功能) 液位计(4-20mA) 功能:实时采…...

OpenEuler学习笔记(二十三):在OpenEuler上部署开源MES系统

在OpenEuler上部署小企业开源MES(制造执行系统,Manufacturing Execution System)是一个非常有价值的项目,可以帮助企业实现生产过程的数字化管理。以下是基于开源MES系统(如 Odoo MES 或 OpenMES)的部署步骤…...

SpringSecurity:授权服务器与客户端应用(入门案例)

文章目录 一、需求概述二、基本授权登录功能实现1、授权服务器开发2、客户端开发3、功能测试 三、自定义授权服务器登录页1、授权服务器开发2、功能测试 四、自定义授权服务器授权页1、授权服务器开发2、功能测试 五、客户端信息保存数据库1、授权服务器开发2、功能测试 一、需…...

没用的文章又➕1

次次登陆GitHub都让我抓心挠肝,用了热度最高的法子也不抵事儿。谁说github上全是大神了,也要有我这样的小菜鸟。下面是我的失败记录… 查询目标网站的DNS 在whois上输入目标网站github.com,在查询结果当中选取任意一个DNS将地址和名称添加在…...

BiGRU双向门控循环单元多变量多步预测,光伏功率预测(Matlab完整源码和数据)

代码地址:BiGRU双向门控循环单元多变量多步预测,光伏功率预测(Matlab完整源码和数据) BiGRU双向门控循环单元多变量多步预测,光伏功率预测 一、引言 1.1、研究背景和意义 随着全球对可再生能源需求的不断增长,光伏…...

谷歌浏览器多开指南:如何完成独立IP隔离?

对于跨境电商来说,在进行社交媒体营销、广告投放等业务活动时,往往需要同时登录多个账号来提高运营效率和提升营销效果。然而,如果这些账号共享相同的 IP 地址,很容易被平台检测为关联账号,进而触发安全验证甚至封禁。…...

Django开发入门 – 3.用Django创建一个Web项目

Django开发入门 – 3.用Django创建一个Web项目 Build A Web Based Project With Django By JacksonML 本文简要介绍如何利用最新版Python 3.13.2来搭建Django环境,以及创建第一个Django Web应用项目,并能够运行Django Web服务器。 创建该Django项目需…...

【Java】多线程和高并发编程(三):锁(下)深入ReentrantReadWriteLock

文章目录 4、深入ReentrantReadWriteLock4.1 为什么要出现读写锁4.2 读写锁的实现原理4.3 写锁分析4.3.1 写锁加锁流程概述4.3.2 写锁加锁源码分析4.3.3 写锁释放锁流程概述&释放锁源码 4.4 读锁分析4.4.1 读锁加锁流程概述4.4.1.1 基础读锁流程4.4.1.2 读锁重入流程4.4.1.…...

讲解ES6中的变量和对象的解构赋值

在 ES6 中,解构赋值是一种非常方便的语法,它使得从数组或对象中提取值变得更加简洁和直观。解构赋值支持变量赋值,可以通过单独提取数组或对象的元素来赋值给变量。 下面我将分别讲解 数组解构 和 对象解构 的基本用法和一些高级特性。 1. …...

DeepSeek Coder + IDEA 辅助开发工具

开发者工具 我之前用的是Codegeex4模型,现在写一款DeepSeek Coder 本地模型 DeepSeek为什么火,我在网上看到一个段子下棋DeepSeek用兵法赢了ChatGpt,而没有用技术赢,这就是AI的思维推理,深入理解孙子兵法&#xff0c…...

云计算——AWS Solutions Architect – Associate(saa)4.安全组和NACL

安全组一充当虚拟防火墙对于关联实例,在实例级别控制入站和出站流量。 网络访问控制列表(NACL)一充当防火墙关联子网,在子网级别控制入站和出站流量。 在专有网络中,安全组和网络ACL(NACL)一起帮助构建分层网络防御。 安全组在实例级别操作…...

动量+均线组合策略关键点

动量均线组合策略关键点: 趋势确认: MA系统判断主趋势方向动量指标判断趋势强度 入场条件: 价格站上重要均线(如20日线)动量指标向上并保持高位短期均线上穿长期均线 出场条件: 价格跌破均线系统动量指标见顶回落短期均线下…...

Blazor-<select>

今天我们来说说<select>标签的用法&#xff0c;我们还是从一个示例代码开始 page "/demoPage" rendermode InteractiveAuto inject ILogger<InjectPage> logger; <h3>demoPage</h3> <select multiple>foreach (var item in list){<…...

Synchronized使用

文章目录 synchronized使用基本概念使用方法实现原理锁的粒度并发编程注意事项与Lock锁对比比较线程安全性与性能 synchronized使用 当涉及到多线程编程时&#xff0c;保证数据的正确性和一致性是至关重要的。而synchronized关键字是Java语言中最基本的同步机制之一&#xff0…...

OpenStack四种创建虚拟机的方式

实例&#xff08;Instances&#xff09;是在云内部运行的虚拟机。您可以从以下来源启动实例&#xff1a; 一、上传到镜像服务的镜像&#xff08;Image&#xff09; 使用已上传到镜像服务的镜像来启动实例。 二、复制到持久化卷的镜像&#xff08;Volume&#xff09; 使用已…...

Expo运行模拟器失败错误解决(xcrun simctl )

根据你的描述&#xff0c;问题主要涉及两个方面&#xff1a;xcrun simctl 错误和 Expo 依赖版本不兼容。以下是针对这两个问题的解决方案&#xff1a; 解决 xcrun simctl 错误 错误代码 72 通常表明 simctl 工具未正确配置或路径未正确设置。以下是解决步骤&#xff1a; 确保 …...

Docker从入门到精通- 容器化技术全解析

第一章&#xff1a;Docker 入门 一、什么是 Docker&#xff1f; Docker 就像一个超级厉害的 “打包神器”。它能帮咱们把应用程序和它运行所需要的东东都整整齐齐地打包到一起&#xff0c;形成一个独立的小盒子&#xff0c;这个小盒子在 Docker 里叫容器。以前呢&#xff0c;…...

开启对话式智能分析新纪元——Wyn商业智能 BI 携手Deepseek 驱动数据分析变革

2月18号&#xff0c;Wyn 商业智能 V8.0Update1 版本将重磅推出对话式智能分析&#xff0c;集成Deepseek R1大模型&#xff0c;通过AI技术的深度融合&#xff0c;致力于打造"会思考的BI系统"&#xff0c;让数据价值触手可及&#xff0c;助力企业实现从数据洞察到决策执…...

RabbitMQ 消息顺序性保证

方式一&#xff1a;Consumer设置exclusive 注意条件 作用于basic.consume不支持quorum queue 当同时有A、B两个消费者调用basic.consume方法消费&#xff0c;并将exclusive设置为true时&#xff0c;第二个消费者会抛出异常&#xff1a; com.rabbitmq.client.AlreadyClosedEx…...

防御保护作业二

拓扑图 需求 需求一&#xff1a; 需求二&#xff1a; 需求三&#xff1a; 需求四&#xff1a; 需求五&#xff1a; 需求六&#xff1a; 需求七&#xff1a; 需求分析 1.按照要求进行设备IP地址的配置 2.在FW上开启DHCP功能&#xff0c;并配置不同的全局地址池&#xff0c;为…...

Spring Boot中实现多租户架构

文章目录 Spring Boot中实现多租户架构多租户架构概述核心思想多租户的三种模式优势挑战租户识别机制1. 租户标识(Tenant Identifier)2. 常见的租户识别方式3. 实现租户识别的关键点4. 租户识别示例代码5. 租户识别机制的挑战数据库隔离的实现1. 数据库隔离的核心目标2. 数据…...

【AI-27】DPO和PPO的区别

DPO&#xff08;Direct Preference Optimization&#xff09;和 PPO&#xff08;Proximal Policy Optimization&#xff09;有以下区别&#xff1a; 核心原理 DPO&#xff1a;基于用户偏好或人类反馈直接优化&#xff0c;核心是对比学习或根据偏好数据调整策略&#xff0c;将…...

Git stash 暂存你的更改(隐藏存储)

一、Git Stash 概述 在开发的时候经常会遇到切换分支时需要你存储当前的更改&#xff0c;如果你暂时不想应用当前更改也不想放弃更改&#xff0c;那么你可以使用 git stash先将其隐藏存储&#xff0c;这样代码就会变成未修改的状态&#xff0c;等解决其他问题后&#xff0c;在…...

负载测试和压力测试的原理分别是什么

负载测试和压力测试是性能测试的两种主要类型&#xff0c;它们的原理和应用场景有所不同。 负载测试&#xff08;Load Testing&#xff09; 原理&#xff1a; 负载测试通过模拟实际用户行为&#xff0c;逐步增加系统负载&#xff0c;观察系统在不同负载下的表现。目的是评估系…...

shell脚本控制——定时运行作业

在使用脚本时&#xff0c;你也许希望脚本能在以后某个你无法亲临现场的时候运行。Linux系统提供了多个在预选时间运行脚本的方法&#xff1a;at命令、cron表以及anacron。每种方法都使用不同的技术来安排脚本的运行时间和频率。接下来将依次介绍这些方法。 1.使用at命令调度作…...

LeetCode 热题 100 回顾

目录 一、哈希部分 1.两数之和 &#xff08;简单&#xff09; 2.字母异位词分组 &#xff08;中等&#xff09; 3.最长连续序列 &#xff08;中等&#xff09; 二、双指针部分 4.移动零 &#xff08;简单&#xff09; 5.盛最多水的容器 &#xff08;中等&#xff09; 6…...

HTML5--网页前端编程(上)

HTML5–网页前端编程(上) 1.网页 (1)网站是根据一定的规则,使用html制作的相关的网页的集合。 网页是网站上的一页,通常是html格式的文件,他要通过浏览器来阅读。网页是网站的基本元素,由图片链接声音文字等元素造成,以.html或.htm后缀结尾的文件称为html文件。 (2…...

气体控制器联动风机,检测到环境出现异常时自动打开风机进行排风;

一、功能&#xff1a;检测到环境出现异常时自动打开风机进行排风&#xff1b; 二、设备&#xff1a; 1.气体控制器主机&#xff1a;温湿度&#xff0c;TVOC等探头的主机&#xff0c;可上报数据&#xff0c;探头监测到异常时&#xff0c;主机会监测到异常可联动风机或声光报警…...

示波器使用指南

耦合方式 在示波器中&#xff0c;耦合方式决定了信号源与示波器输入之间的信号传输方式。具体来说&#xff0c;直流耦合、交流耦合和接地耦合这三种方式有不同的工作原理和应用场景&#xff0c;下面是它们的差异&#xff1a; 1. 直流耦合&#xff08;DC Coupling&#xff09;…...