当前位置: 首页 > news >正文

PyTorch之完整的神经网络模型训练

简单的示例:

在PyTorch中,可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码:

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义神经网络的层和参数self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc1 = nn.Linear(32 * 14 * 14, 128)self.fc2 = nn.Linear(128, 10)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool(x)x = x.view(x.size(0), -1)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.softmax(x)return x

在上面的示例中,定义了一个名为MyModel的神经网络模型,继承自nn.Module类。在__init__方法中,我们定义了模型的层和参数。具体来说:

  • 代码定义了一个卷积层,输入通道数为1,输出通道数为32,卷积核大小为3x3,步长为1,填充为1。
  • 定义了一个ReLU激活函数,用于在卷积层之后引入非线性性质。
  • 定义了一个最大池化层,池化核大小为2x2,步长为2。
  • 定义了一个全连接层,输入大小为32x14x14(经过卷积和池化后的特征图大小),输出大小为128。
  • 定义了另一个全连接层,输入大小为128,输出大小为10。
  • 定义了一个softmax函数,用于将模型的输出转换为概率分布。

forward方法中,定义了模型的前向传播过程。具体来说:

  • x = self.conv1(x): 将输入张量传递给卷积层进行卷积操作。
  • x = self.relu(x): 将卷积层的输出通过ReLU激活函数进行非线性变换。
  • x = self.maxpool(x): 将ReLU激活后的特征图进行最大池化操作。
  • x = x.view(x.size(0), -1): 将池化后的特征图展平为一维,以适应全连接层的输入要求。
  • x = self.fc1(x): 将展平后的特征向量传递给第一个全连接层。
  • x = self.relu(x): 将第一个全连接层的输出通过ReLU激活函数进行非线性变换。
  • x = self.fc2(x): 将第一个全连接层的输出传递给第二个全连接层。
  • x = self.softmax(x): 将第二个全连接层的输出通过softmax函数进行归一化,得到每个类别的概率分布。

这个示例展示了一个简单的卷积神经网络模型,适用于处理单通道的图像数据,并输出10个类别的分类结果。可以根据自己的需求和数据特点来定义和修改神经网络模型。

接下来将用于实际的数据集进行训练:

以下是基于CIFAR10数据集的神经网络训练模型:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from nn_mode import *#准备数据集
train_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=False,transform=torchvision.transforms.ToTensor())
#输出数据集的长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print(train_data_size)
print(test_data_size)
#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=64)
test_loader=DataLoader(dataset=test_data,batch_size=64)
#创建神经网络
sjnet=Sjnet()#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learn_lr=0.01#便于修改
YHQ=torch.optim.SGD(sjnet.parameters(),lr=learn_lr)#设置训练网络的参数
train_step=0#训练次数
test_step=0#测试次数
epoch=10#训练轮数writer=SummaryWriter('wanzheng_logs')for i in range(epoch):print("第{}轮训练".format(i+1))#开始训练for data in train_loader:imgs,targets=dataoutputs=sjnet(imgs)loss=loss_fn(outputs,targets)#优化器YHQ.zero_grad()  # 将神经网络的梯度置零,以准备进行反向传播loss.backward()  # 执行反向传播,计算神经网络中各个参数的梯度YHQ.step()  # 调用优化器的step()方法,根据计算得到的梯度更新神经网络的参数,完成一次参数更新train_step =train_step+1if train_step%100==0:print('训练次数为:{},loss为:{}'.format(train_step,loss))writer.add_scalar('train_loss',loss,train_step)#开始测试total_loss=0with torch.no_grad():#上下文管理器,用于指示在接下来的代码块中不计算梯度。for data in test_loader:imgs,targets=dataoutputs = sjnet(imgs)loss = loss_fn(outputs, targets)#使用损失函数 loss_fn 计算预测输出与目标之间的损失。total_loss=total_loss+loss#将当前样本的损失加到总损失上,用于累积所有样本的损失。print('整体测试集上的loss:{}'.format(total_loss))writer.add_scalar('test_loss', total_loss, test_step)test_step = test_step+1torch.save(sjnet,'sjnet_{}.pth'.format(i))print("模型已保存!")writer.close()

 其神经网络训练以及测试时的损失值使用TensorBoard进行展示,如图所示:

相关文章:

PyTorch之完整的神经网络模型训练

简单的示例: 在PyTorch中,可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码: import torch import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义神经…...

基于神经网络的偏微分方程求解器再度取得突破,北大字节的研究成果入选Nature子刊

目录 一.引言:神经网络与偏微分方程 二.如何基于神经网络求解偏微分方程 1.简要概述 2.基于神经网络求解偏微分方程的三大方向 2.1数据驱动 基于CNN 基于其他网络 2.2物理约束 PINN 基于 PINN 可测量标签数据 2.3物理驱动(纯物理约束) 全连接神经网路(FC-NN) CN…...

Linux的基本权限

一、对shell的浅显认识 shell是操作系统下的一个外壳程序,无论是Linux操作系统,还是Windows操作系统,用户都不会直接对操作系统本身直接进行操作,需要通过一个外壳程序去间接的进行各种操作 在Linux的shell外壳就是命令行&#…...

指纹加密U盘/指纹KEY方案——采用金融级安全芯片 ACH512

方案概述 指纹加密U盘解决方案可实现指纹算法处理、数据安全加密、数据高速存取(EMMC/TF卡/NandFlash),可有效保护用户数据安全。 方案特点 • 采用金融级安全芯片 ACH512 • 存储介质:EMMC、TF卡、NandFlash • 支持全系列国密…...

Cloud-Sleuth分布式链路追踪(服务跟踪)

简介 在微服务框架中,一个由客户端发起的请求在后端系统中会经过多个不同的服务节点调用来协同产生最后的请求结果,每一个前端请求都会形成一条复杂的分布式服务调用链路,链路中的任何一环出现高延时或错误都会引起整个请求最后的失败 GitHub - spring-cloud/spring-cloud-sl…...

flink重温笔记(十四): flink 高级特性和新特性(3)——数据类型及 Avro 序列化

Flink学习笔记 前言:今天是学习 flink 的第 14 天啦!学习了 flink 高级特性和新特性之数据类型及 avro 序列化,主要是解决大数据领域数据规范化写入和规范化读取的问题,avro 数据结构可以节约存储空间,本文中结合企业真…...

python75-Python的函数参数,关键字(keyword)参数

在定义Python函数时可定义形参(形式参数的意思)这些形参的值要等到调用时才能确定下来,由函数的调用者负责为形参传入参数值。简单来说,就是谁调用函数,谁负责传入参数值。 关键字(keyword)参数 Python函数的参数名不是无意义的&#xff0c…...

Java宝典-抽象类和接口

目录 1. 抽象类1.1 抽象类的概念1.2 抽象类的语法1.3 抽象类的特点 2. 接口2.1 接口的概念2.2 接口的语法2.3 接口的特点2.4 实现多个接口2.5 接口的继承 3. 接口使用案例 铁汁们好,今天我们学习抽象类和接口~ 1. 抽象类 1.1 抽象类的概念 什么是抽象类?在面向对象中,如果一…...

6. Gin集成redis

文章目录 一:连接Redis二:基本使用三:字符串四:列表五:哈希六:Set七:管道八、事务九:示例 代码地址:https://gitee.com/lymgoforIT/golang-trick/tree/master/14-go-redi…...

DxO PureRAW:赋予RAW图像生命,打造非凡视觉体验 mac/win版

DxO PureRAW 是一款专为RAW图像处理而设计的软件,旨在帮助摄影师充分利用RAW格式的优势,实现更加纯净、细腻的图像效果。该软件凭借其强大的功能和易于使用的界面,成为了RAW图像处理领域的佼佼者。 DxO PureRAW 软件获取 首先,Dx…...

【MySQL | 第四篇】区分SQL语句的书写和执行顺序

文章目录 4.区分SQL语句的书写和执行顺序4.1书写顺序4.2执行顺序4.3总结4.4扩充&#xff1a;辨别having与where的异同&#xff1f;4.5聚合查询 4.区分SQL语句的书写和执行顺序 注意&#xff1a;SQL 语句的书写顺序与执行顺序不是一致的 4.1书写顺序 SELECT <字段名> …...

服务器又被挖矿记录

写在前面 23年11月的时候我写过一篇记录服务器被挖矿的情况&#xff0c;点我查看。当时是在桌面看到了bash进程CPU占用异常发现了服务器被挖矿。 而过了几个月没想到又被攻击&#xff0c;这次比上次攻击手段要更高明点&#xff0c;在这记录下吧。 发现过程 服务器用的是4090…...

嵌入式学习day34 网络

TCP包头: 1.序号:发送端发送数据包的编号 2.确认号:已经确认接收到的数据的编号(只有当ACK为1时,确认号才有用) TCP为什么安全可靠: 1.在通信前建立三次握手连接 SYN SYNACK ACK 2.在通信过程中通过序列号和确认号保障数据传输的完整性 本次发送序列号:上次…...

欧科云链:角力Web3.0,香港如何为合规设线?

在香港拥抱Web3.0的过程中,以欧科云链为代表的合规科技企业将凸显更大重要性。 ——据香港商报网报道 据香港明报、商报等媒体报道&#xff0c;港区全国政协兼香港选委界立法会议员吴杰庄在日前召开的全国两会上提出在大湾区建设国际中小企业创新Web3融资平台等提案&#xff0…...

Android SDK2 (实操三个小目标)

书接上回&#xff1a;Android SDK 1&#xff08;概览&#xff09;-CSDN博客 今天讲讲三个实际练手内容&#xff0c;用的是瑞星微的sdk。 1 实操编译Android.bp 首先还是感叹下&#xff0c;现在的系统真的越搞越复杂&#xff0c;最早只有gcc&#xff0c;后面多了make&#xf…...

数字编码与字符编码:解锁编程世界的基石

在计算机的世界里&#xff0c;一切信息都是以数字的形式存在。但是&#xff0c;你有没有想过&#xff0c;我们是如何在这个由0和1构成的数字世界中表示复杂的信息&#xff0c;如文本、图像和声音的呢&#xff1f;本篇文章将带你深入探索数字编码与字符编码的奥秘&#xff0c;它…...

C语言-写一个简单的Web服务器(一)

基于TCP的web服务器 概述 C语言可以干大事&#xff0c;我们基于C语言可以完成一个简易的Web服务器。当你能够自行完成web服务器&#xff0c;你会对C语言有更深入的理解。对于网络编程&#xff0c;字符串的使用&#xff0c;文件使用等等都会有很大的提高。 关于网络的TCP协议在…...

MySQL底层原理

1. 请解释MySQL的逻辑架构和物理架构。 MySQL的逻辑架构和物理架构涉及到多个层面&#xff0c;包括网络连接、服务处理、存储引擎以及数据存储等部分。具体如下&#xff1a; 逻辑架构&#xff1a; 连接层&#xff08;Connection Layer&#xff09;&#xff1a;客户端通过TCP…...

复盘-word

word-大学生网络创业交流会 设置段落&#xff0c;段后行距才有分 word-选中左边几行字进行操作 按住alt键进行选中 word复制excel随excel改变&#xff08;选择性粘贴&#xff09; 页边距为普通页边距定义 ##### word 在内容控件里面填文字&#xff08;调属性&#xff09…...

Vue中的组件:构建现代Web应用的基石

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性&#xff1a; 隐藏字段的实现细节 提供对字段的受控访问 访问控制&#xff1a; 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性&#xff1a; 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑&#xff1a; 可以…...

android13 app的触摸问题定位分析流程

一、知识点 一般来说,触摸问题都是app层面出问题,我们可以在ViewRootImpl.java添加log的方式定位;如果是touchableRegion的计算问题,就会相对比较麻烦了,需要通过adb shell dumpsys input > input.log指令,且通过打印堆栈的方式,逐步定位问题,并找到修改方案。 问题…...

【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)

前言&#xff1a; 双亲委派机制对于面试这块来说非常重要&#xff0c;在实际开发中也是经常遇见需要打破双亲委派的需求&#xff0c;今天我们一起来探索一下什么是双亲委派机制&#xff0c;在此之前我们先介绍一下类的加载器。 目录 ​编辑 前言&#xff1a; 类加载器 1. …...

DBLP数据库是什么?

DBLP&#xff08;Digital Bibliography & Library Project&#xff09;Computer Science Bibliography是全球著名的计算机科学出版物的开放书目数据库。DBLP所收录的期刊和会议论文质量较高&#xff0c;数据库文献更新速度很快&#xff0c;很好地反映了国际计算机科学学术研…...

【WebSocket】SpringBoot项目中使用WebSocket

1. 导入坐标 如果springboot父工程没有加入websocket的起步依赖&#xff0c;添加它的坐标的时候需要带上版本号。 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId> </dep…...

Java详解LeetCode 热题 100(26):LeetCode 142. 环形链表 II(Linked List Cycle II)详解

文章目录 1. 题目描述1.1 链表节点定义 2. 理解题目2.1 问题可视化2.2 核心挑战 3. 解法一&#xff1a;HashSet 标记访问法3.1 算法思路3.2 Java代码实现3.3 详细执行过程演示3.4 执行结果示例3.5 复杂度分析3.6 优缺点分析 4. 解法二&#xff1a;Floyd 快慢指针法&#xff08;…...

2025年低延迟业务DDoS防护全攻略:高可用架构与实战方案

一、延迟敏感行业面临的DDoS攻击新挑战 2025年&#xff0c;金融交易、实时竞技游戏、工业物联网等低延迟业务成为DDoS攻击的首要目标。攻击呈现三大特征&#xff1a; AI驱动的自适应攻击&#xff1a;攻击流量模拟真实用户行为&#xff0c;差异率低至0.5%&#xff0c;传统规则引…...

嵌入式面试常问问题

以下内容面向嵌入式/系统方向的初学者与面试备考者,全面梳理了以下几大板块,并在每个板块末尾列出常见的面试问答思路,帮助你既能夯实基础,又能应对面试挑战。 一、TCP/IP 协议 1.1 TCP/IP 五层模型概述 链路层(Link Layer) 包括网卡驱动、以太网、Wi‑Fi、PPP 等。负责…...