当前位置: 首页 > news >正文

Pytorch深度学习教程_5_编写第一个神经网络

欢迎来到《pytorch深度学习教程》系列的第五篇!在前面的四篇中,我们已经介绍了Python、numpy及pytorch的基本使用,并在上一个教程中介绍了梯度。今天,我们将探索神经网络,对于神经网络进行概述并进行简单的实践学习

欢迎订阅专栏:

深度学习保姆教程_tRNA做科研的博客-CSDN博客 

神经网络是受人类大脑启发的计算模型,旨在识别数据中的模式。它们是深度学习的核心,推动了各个领域的突破。


目录

1.生物学启示

(1)大脑作为模型

(2)深度学习中的生物学原理

2.生物启发式AI的未来

(1)应用生物学启示:从理论到实践

3.人工神经元

(1)人工神经元的结构

(2)神经元的计算

(3)数学表达式总结 

(4)常见的激活函数:

(5)神经元在神经网络中的作用

(6)常见神经网络 

4.建立第一个神经网络:前馈神经网络

(1)前馈神经网络的结构

(2)前馈神经网络的工作原理

(3)训练前馈神经网络

(4)你的第一个神经网络训练

训练过程

5.结语


1.生物学启示

几个世纪以来,自然界一直是科学家和工程师宝贵的灵感来源。人工智能领域,特别是深度学习,也不例外。通过研究生物系统,研究人员开发了模仿大脑处理信息方式的创新算法和架构。

(1)大脑作为模型

人脑是一个复杂的、相互连接的神经元网络。这种复杂结构启发了人工神经网络的发展。

  • 神经元:大脑的基本构建块,负责处理和传输信息。
  • 突触:神经元之间的连接,促进通信。
  • 神经网络:受大脑结构启发,人工神经网络由相互连接的节点(神经元)组成,用于处理信息。

(2)深度学习中的生物学原理

  • 学习:大脑通过经验学习,调整突触连接。深度学习模型通过反向传播从数据中学习。
  • 层次性:大脑以层次化的方式处理信息。深度神经网络也采用层次化的表示方法。
  • 特征提取:大脑从感官输入中提取特征。卷积神经网络在图像特征提取方面表现出色。
  • 适应性:大脑适应新信息和环境。深度学习模型可以针对特定任务进行微调

2.生物启发式AI的未来

神经科学和人工智能的交叉领域具有巨大的潜力。通过继续探索大脑的机制,我们可以开发出更强大和更智能的AI系统。

(1)应用生物学启示:从理论到实践

了解神经网络的生物学基础是重要的,但其真正的力量在于实际应用。让我们探讨如何利用这些概念来构建现实世界的系统。

卷积神经网络(CNNs):视觉皮层的对应物

  • 图像识别:训练一个CNN将图像分类为不同的类别(例如,猫与狗)。
  • 目标检测:检测并定位图像中的对象。
  • 图像分割:图像的像素级分类。

循环神经网络(RNNs):处理顺序数据

  • 自然语言处理:构建用于文本分类、情感分析和语言翻译的模型。
  • 时间序列分析:基于过去的数据预测未来的值。
  • 语音识别:将音频信号转换为文本。

长短期记忆(LSTM)网络:捕捉长期依赖关系

  • 自然语言处理:处理复杂的语言模式和长期依赖关系。
  • 时间序列预测:基于长期模式预测未来的值。
  • 机器翻译:将一种语言的文本翻译成另一种语言。

 

3.人工神经元

人工神经元是神经网络的基本计算单元。受其生物对应物的启发,这些数学函数处理输入数据,应用变换并产生输出

(1)人工神经元的结构

一个人工神经元由几个组件组成:

  • 输入:输入神经元的数据。
  • 权重:分配给每个输入的数值,表示该输入的重要性。
  • 偏置:加到输入加权和上的常数值。
  • 激活函数:应用于神经元输出的非线性函数。

(2)神经元的计算

神经元的输出分两步计算:

  • 加权和:将每个输入乘以其对应的权重,对结果求和,并加上偏置。
  • 激活:将激活函数应用于加权和的结果。

我们给一个通俗的解释:

import numpy as np
import matplotlib.pyplot as pltdef sigmoid(x):return 1 / (1 + np.exp(-x))# 示例神经元
inputs = [1, 2, 3]
weights = [0.2, 0.3, 0.4]
bias = 0.1# 计算加权和
weighted_sum = np.dot(inputs, weights) + bias# 应用激活函数
output = sigmoid(weighted_sum)# 绘制 Sigmoid 函数
x = np.linspace(-10, 10, 100)
y = sigmoid(x)plt.figure(figsize=(8, 4))
plt.plot(x, y, label='Sigmoid Function')
plt.scatter([weighted_sum], [output], color='red', label=f'Output ({output:.2f})')
plt.axvline(x=weighted_sum, color='gray', linestyle='--', label=f'Weighted Sum ({weighted_sum:.2f})')
plt.xlabel('x')
plt.ylabel('Sigmoid(x)')
plt.title('Sigmoid')
plt.legend()
plt.grid(True)
plt.show()

然后我们详细看看发生了什么,这是什么:

-1输入与权重

  • 输入:inputs = [1, 2, 3] 表示神经元接收到三个输入值。
  • 权重:weights = [0.2, 0.3, 0.4] 表示每个输入对应的权重。

 -2计算加权和

  • 点积运算:np.dot(inputs, weights) 计算输入向量与权重向量的点积,即 1×0.2+2×0.3+3×0.41×0.2+2×0.3+3×0.4。具体计算:1×0.2=0.2, 2×0.3=0.6, 3×0.4=1.2
  • 点积结果:0.2+0.6+1.2=2.0
  • 加上偏置:bias = 0.1,所以加权和为 2.0+0.1=2.1。

 -3应用激活函数

 

(3)数学表达式总结 

整个神经元的计算过程可以用以下数学表达式表示: 

 

(4)常见的激活函数:

  • Sigmoid:输出介于0和1之间的值。
  • ReLU(修正线性单元):输出输入与0的最大值。
  • Tanh:输出介于-1和1之间的值。
  • Softmax:用于分类任务,输出每个类别的概率。

(5)神经元在神经网络中的作用

神经元在神经网络中按层组织:

  • 输入层:接收输入数据。
  • 隐藏层:通过多层神经元处理信息。
  • 输出层:产生最终输出。

注意事项

  • 梯度消失问题:可能出现在深层网络中,使训练变得困难。
  • 过拟合:模型可能变得过于复杂,在新数据上的表现不佳。
  • 计算成本:训练大型神经网络可能需要大量的计算资源。

(6)常见神经网络 

1.前馈神经网络

在前馈神经网络中,信息单向流动,从输入层流向输出层,没有循环或周期。

  • 全连接层:一层中的每个神经元都连接到下一层中的每个神经元。
  • 激活函数:应用于每个神经元的输出以引入非线性。

2.深度神经网络

深度神经网络具有多个隐藏层,允许它们学习复杂的模式。

  • 深度:隐藏层的数量。
  • 宽度:每层中神经元的数量。

3.循环神经网络(RNNs)

RNNs在相同层的神经元之间引入连接,形成周期。这使得它们能够处理序列数据。

  • 梯度消失问题:RNNs在处理长期依赖时可能会遇到困难。
  • LSTM和GRU:解决梯度消失问题的RNN变体。

4.卷积神经网络(CNNs)

CNNs专为处理网格状数据(如图像)而设计。

  • 卷积层:应用滤波器以从输入数据中提取特征。
  • 池化层:降低维度同时保留重要信息。

挑战和注意事项

  • 过拟合:神经网络容易过拟合,需要正则化技术
  • 超参数调整:找到最佳超参数对于性能至关重要。

4.建立第一个神经网络:前馈神经网络

前馈神经网络(FNNs)是最简单的人工神经网络类型。信息单向流动,从输入到输出,不形成循环。它们是理解更复杂架构的基础。

(1)前馈神经网络的结构

一个典型的FNN由以下部分组成:

  • 输入层:从外部世界接收数据。
  • 隐藏层:通过多层神经元处理信息。
  • 输出层:产生最终结果。

一层中的神经元与下一层中的每个神经元相连,形成全连接网络。

(2)前馈神经网络的工作原理

  • 输入:数据被输入到输入层。
  • 传播:信息通过隐藏层传递,每个神经元应用其激活函数
  • 输出:最终层产生输出,可以是分类、回归或其他所需的结果。

(3)训练前馈神经网络

  • 反向传播:一种基于预测输出与实际输出之间的误差来调整权重和偏置的算法。
  • 损失函数:衡量预测值与实际值之间的差异。
  • 优化器:更新网络的参数以最小化损失函数。

(4)你的第一个神经网络训练

我们用一个简单的二分类神经网络的例子进行学习:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np# 设置随机种子以确保结果可复现
torch.manual_seed(42)# 生成一些简单的二维数据用于分类
def generate_data(num_samples=100):# 生成两个类别的数据,每个类别50个样本data = []labels = []for i in range(num_samples // 2):# 第一类:左下角data.append([np.random.rand() * 0.6 - 0.3, np.random.rand() * 0.6 - 0.3])labels.append(0)# 第二类:右上角data.append([np.random.rand() * 0.6 + 0.2, np.random.rand() * 0.6 + 0.2])labels.append(1)return torch.tensor(data, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)# 定义一个简单的前馈神经网络
class SimpleFFNN(nn.Module):def __init__(self):super(SimpleFFNN, self).__init__()# 输入层到隐藏层,输入特征为2,隐藏单元为10self.fc1 = nn.Linear(2, 10)# 隐藏层到输出层,输出类别为2self.fc2 = nn.Linear(10, 2)def forward(self, x):# 使用ReLU激活函数x = torch.relu(self.fc1(x))# 输出层不使用激活函数,因为后续会使用交叉熵损失函数x = self.fc2(x)return x# 初始化模型、损失函数和优化器
model = SimpleFFNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)# 生成数据
X, y = generate_data()# 训练模型
num_epochs = 100
losses = []for epoch in range(num_epochs):optimizer.zero_grad()  # 清零梯度outputs = model(X)      # 前向传播loss = criterion(outputs, y)  # 计算损失loss.backward()         # 反向传播optimizer.step()        # 更新参数losses.append(loss.item())if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 绘制损失曲线
plt.figure(figsize=(8, 4))
plt.plot(range(1, num_epochs + 1), losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.show()# 可视化分类结果
def plot_decision_boundary(model, X, y):# 创建网格h = 0.02x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h),np.arange(y_min, y_max, h))grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)with torch.no_grad():Z = model(grid)Z = torch.argmax(Z, dim=1).reshape(xx.shape)plt.contourf(xx, yy, Z, alpha=0.3, cmap=plt.cm.Paired)plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired, edgecolors='k')# 绘制决策边界
plt.figure(figsize=(8, 6))
plot_decision_boundary(model, X.numpy(), y.numpy())
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Decision Boundary')
plt.show()

那么这个神经网络做了什么事情呢?我们来仔细剖析

网络结构

  1. 输入层:接收二维特征数据(每个样本有两个特征)。

  2. 隐藏层:包含10个神经元。使用ReLU激活函数,为网络引入非线性特性,使其能够学习更复杂的模式。

  3. 输出层:包含2个神经元,对应两个类别。不使用激活函数,因为后续使用了交叉熵损失函数,它内部已经包含了softmax操作。

训练过程

  1. 数据生成:生成了两类二维数据,每类50个样本,分别位于二维空间的左下角和右上角。

  2. 损失函数与优化器

    使用交叉熵损失函数(CrossEntropyLoss)来衡量预测值与真实标签之间的差异。使用随机梯度下降(SGD)优化器,学习率为0.1,用于更新网络参数以最小化损失。
  3. 训练步骤

    进行100个训练周期(epochs)。每个周期执行前向传播、计算损失、反向传播和参数更新。每10个周期打印一次当前的损失值,以便监控训练进度。

5.结语

我们对于神经网络的基本结构和神经元都有了基本的认识,也编写了我们第一个神经网络!这是一个非常值得纪念的里程碑

相关文章:

Pytorch深度学习教程_5_编写第一个神经网络

欢迎来到《pytorch深度学习教程》系列的第五篇!在前面的四篇中,我们已经介绍了Python、numpy及pytorch的基本使用,并在上一个教程中介绍了梯度。今天,我们将探索神经网络,对于神经网络进行概述并进行简单的实践学习 欢…...

ImportError: cannot import name ‘FixtureDef‘ from ‘pytest‘

错误信息表明 pytest 在尝试导入 FixtureDef 时出现了问题。通常是由于 pytest 版本不兼容 或 插件版本冲突 引起的。以下是详细的排查步骤和解决方案: 1. 检查 pytest 版本 首先,确认当前安装的 pytest 版本。某些插件可能需要特定版本的 pytest 才能…...

改BUG:Mock测试的时候,when失效

问题再现: 这里我写了一测试用户注册接口的测试类,并通过when模拟下层的服务,但实际上when并没有奏效,还是走了真实的service层的逻辑。 package cn.ac.evo.review.test;import cn.ac.evo.review.user.UserMainApplication; imp…...

【自动化脚本工具】AutoHotkey (Windows)

目录 1. 介绍AutoHotkey2. 功能脚本集锦2.1 桌面键盘显示 1. 介绍AutoHotkey 支持Windows安装使用,下载地址为:https://www.autohotkey.com/ 2. 功能脚本集锦 2.1 桌面键盘显示 便于练习键盘盲打 脚本地址:https://blog.csdn.net/weixin_6…...

专题--Linux体系

Linux体系结构相关| ProcessOn免费在线作图,在线流程图,在线思维导图 ProcessOn是一个在线协作绘图平台,为用户提供强大、易用的作图工具!支持在线创作流程图、思维导图、组织结构图、网络拓扑图、BPMN、UML图、UI界面原型设计、iOS界面原型设计等。同时…...

【DeepSeek】Mac m1电脑部署DeepSeek

一、电脑配置 个人电脑配置 二、安装ollama 简介:Ollama 是一个强大的开源框架,是一个为本地运行大型语言模型而设计的工具,它帮助用户快速在本地运行大模型,通过简单的安装指令,可以让用户执行一条命令就在本地运…...

Spring AI + Ollama 实现调用DeepSeek-R1模型API

一、前言 随着人工智能技术的飞速发展,大语言模型(LLM)在各个领域的应用越来越广泛。DeepSeek 作为一款备受瞩目的国产大语言模型,凭借其强大的自然语言处理能力和丰富的知识储备,迅速成为业界关注的焦点。无论是文本生…...

如何在本地和服务器新建Redis用户和密码

文章目录 一. Redis安装二. 新建Redis用户,测试连接2.1 本地数据库2.2 线上数据库2.2.1 安装和配置2.2.2 测试连接 三. 配置四. 分布式 一. Redis安装 Redis安装 可以设置开机自动启动,也可以在去查看系统服务,按[win R],输入命…...

jmeter接口测试(一)

一、什么是接口测试?为什么要做接口测试? 接口测试:就是测试项目和项目之间,模块和模块之间,组件和组件之间的数据交互和权限鉴定(鉴权)。 前后端分离:前后端联调。mock模拟&#x…...

Java-11

淘天集团2025届春季校园招聘在线笔试-研发 1。设有一个顺序共享栈storageArray[70],其中栈X的栈顶指针top1的初值为-1,栈Y的栈顶指针top2的初值为70,通过不断进行入栈操作,直到storageArray数组已满,此时top1 top2 …...

js中常用方法整理

数据类型 typeOf()Number()parseInt()parseFloat()- * / %检测数据类型转换为数字转换为整数类型转换为浮点类型非加法的数字运算toString()Boolean()String()转换为字符串,不能转换undefined/null字符串拼接转换为布尔类型转换为字符串、所有…...

umi react+antd 判断渲染消息提示、input搜索、多选按钮组

记得map里返回的每层遍历结构都要带上key(图里没加,最近在接手react,熟悉中......

Day15-后端Web实战-登录认证——会话技术JWT令牌过滤器拦截器

目录 登录认证1. 登录功能1.1 需求1.2 接口文档1.3 思路分析1.4 功能开发1.5 测试 2. 登录校验2.1 问题分析2.2 会话技术2.2.1 会话技术介绍2.2.2 会话跟踪方案2.2.2.1 方案一 - Cookie2.2.2.2 方案二 - Session2.2.2.3 方案三 - 令牌技术 2.3 JWT令牌2.3.1 介绍2.3.2 生成和校…...

【嵌入式常用工具】Srecord使用

文件格式 -Intel 表示hex格式-Motorola 表示S19格式-BINary 表示bin格式 截取指定地址段 srec_cat input.s19 -Motorola -crop 0x80010000 0x80380000 -output output.s19 -Motorola -address-length4填充指定地址段 srec_cat input.s19 -Motorola -fill 0xFF 0x100 0x200 …...

SwiftUI基础组件之HStack、VStack、ZStack详解

文章目录 引言一、HStack(水平堆栈)1.1 基本概念1.2 基本创建1.3 常用属性1.3.1 spacing1.3.2 alignment 二、VStack(垂直堆栈)2.1 基本概念2.2 基本创建2.3 常用属性2.3.1 spacing2.3.2 alignment 三、ZStack(深度堆栈…...

第2章 深入理解Thread构造函数

Thread的构造函数。 2.1 线程的命名 在构造一个Thread时可以为其命名。 2.1.1 线程的默认命名 下面构造函数中,并没有为线程命名。 Thread() Thread(Runnable target) Thread(ThreadGroup group, Runnable target)打开源码会看到 public Thread(Runnable targe…...

PLC扫描周期和工作原理

可编程逻辑控制器(PLC)的运行原理和扫描周期是其实现工业自动化的核心机制。以下从运行原理、扫描周期组成、关键特性及优化方向等方面进行详细阐述: 一、PLC运行原理 PLC采用**循环扫描(Cyclic Scan)**的工作模式&am…...

玩转Docker | 使用Docker部署本地自托管reference速查表工具

玩转Docker | 使用Docker部署本地自托管reference速查表工具 前言一、Reference介绍Reference简介主要特点二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署reference服务下载镜像创建容器检查容器状态检查服务端口安全设置四、访问reference应用五、测试与…...

MySQL数据库入门到大蛇尚硅谷宋红康老师笔记 高级篇 part 2

第02章_MySQL的数据目录 1. MySQL8的主要目录结构 1.1 数据库文件的存放路径 MySQL数据库文件的存放路径:/var/lib/mysql/ MySQL服务器程序在启动时会到文件系统的某个目录下加载一些文件,之后在运行过程中产生的数据也都会存储到这个目录下的某些文件…...

跟着 Lua 5.1 官方参考文档学习 Lua (3)

文章目录 2.5 – Expressions2.5.1 – Arithmetic Operators2.5.2 – Relational Operators2.5.3 – Logical Operators2.5.4 – Concatenation2.5.5 – The Length Operator2.5.6 – Precedence2.5.7 – Table Constructors2.5.8 – Function Calls2.5.9 – Function Definiti…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

什么是EULA和DPA

文章目录 EULA&#xff08;End User License Agreement&#xff09;DPA&#xff08;Data Protection Agreement&#xff09;一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA&#xff08;End User License Agreement&#xff09; 定义&#xff1a; EULA即…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接&#xff1a;3403. 从盒子中找出字典序最大的字符串 I 代码如下&#xff1a; class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

QT: `long long` 类型转换为 `QString` 2025.6.5

在 Qt 中&#xff0c;将 long long 类型转换为 QString 可以通过以下两种常用方法实现&#xff1a; 方法 1&#xff1a;使用 QString::number() 直接调用 QString 的静态方法 number()&#xff0c;将数值转换为字符串&#xff1a; long long value 1234567890123456789LL; …...

LeetCode - 199. 二叉树的右视图

题目 199. 二叉树的右视图 - 力扣&#xff08;LeetCode&#xff09; 思路 右视图是指从树的右侧看&#xff0c;对于每一层&#xff0c;只能看到该层最右边的节点。实现思路是&#xff1a; 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程&#xff0c;代码下载&#xff1a;这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中&#xff0c;**知识蒸馏&#xff08;Knowledge Distillation&#xff09;**被广泛应用&#xff0c;作为提升模型…...

Caliper 配置文件解析:fisco-bcos.json

config.yaml 文件 config.yaml 是 Caliper 的主配置文件,通常包含以下内容: test:name: fisco-bcos-test # 测试名称description: Performance test of FISCO-BCOS # 测试描述workers:type: local # 工作进程类型number: 5 # 工作进程数量monitor:type: - docker- pro…...

uniapp 实现腾讯云IM群文件上传下载功能

UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中&#xff0c;群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS&#xff0c;在uniapp中实现&#xff1a; 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...

用鸿蒙HarmonyOS5实现中国象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...