当前位置: 首页 > news >正文

深度学习笔记1:神经网络与模型训练过程

参考博客:PyTorch深度学习实战(1)——神经网络与模型训练过程详解_pytorch 实战-CSDN博客

人工神经网络

ANN:张量及数学运算的集合,排列方式近似于松散的人脑神经元排列

组成

1)输入层

2)隐藏层(中间层):连接输入层和输出层,在输入数据上执行转换,此外隐藏层利用神经元将输入值修改为更高、更低维的值,通过修改中间结点的激活函数可以实现复杂表示函数

3)输出层

训练

本质就是通过重复前向传播和后向传播两个关键步骤来调整神经网络的参数

前向传播--》输入经过隐藏层的到输出结果,第一次正向传播 权重初始化 计算预测值

后向传播--》根据误差相应调整权重来减小误差 修正参数

神经网络重复正向传播与反向传播以预测输出,指导获得令误差较小的权重为止

前向传播

1)输入值乘以权重计算隐藏层值

2)计算激活值

3)在每个神经元上重复前两个步骤,直到输出层

4)计算loss

封装为函数

def forward(inputs,outputs,weights):#如果是首次迭代,随机初始化pre_hidden = np.dot(inputs,weights[0]) + weights[1]#向量点积hidden = 1/(1+np.exp(-pre_hidden))#sigmoid激活函数pred_out = np.dot(hidden,weights[2]) + weights[3]#计算输出mse = np.mean(np.square(pred_out - outputs))#lossreturn mse

反向传播

与前向传播相反,利用从前向传播中计算的损失值,以最小化损失值为目标更新网络权重

1)每次对神经网络中的每个权重进行少量修改

2)测量权重变化时的损失变化 ---》损失值关于权重的梯度

3)计算-α δL/δW 更新权重 α为学习率

如果改变权重损失变化大那么就大幅更新权重

否则 小幅跟新权重

在整个数据集上执行n次前向传播以及后向传播,表示模型进行了n个epoch的训练 ,执行一次算一个epoch

学习率

有助于构建更稳定的算法

梯度下降

更新权重以减小误差值的整个过程

SGD是将误差最小化的一种方法:随计算则数据中的训练数据样本,并根据该样本做出决策

实现梯度下降算法

1)定义前馈神经网络并计算均方误差值

2)为每个权重和偏执项增加一个非常小的量0.001,并针对每个权重和偏差的更新计算一个mse

from copy import deepcopy 
import numpy as np
#创建update_weights函数,执行梯度下降来更新权重
def update_weights(inputs, outputs, weights, lr):#使用deepcopy可以确保处理多个权重副本,不会影响实际权重original_weights = deepcopy(weights)temp_weights  = deepcopy(weights)updated_weights = deepcopy(weights)original_loss = forward(inputs, outputs, original_weigts)#遍历网络的所有层for i, layer in enumerate(orignial_weights):#循环遍历每个参数列表的所有参数 共四个参数列表,前两个表示输入连接到隐藏层的权重和偏置项参数#另外两个表示链接隐藏层和输出层的偏置参数for index, weight in np.ndenumerate(layer):temp_weights = deepcopy(weights)#原始权重集temp_weights[i][index] += 0.0001#增加很小的值权重更新_loss_plus = forward(inputs, outputs, temp_weights)#更新损失grad = (_loss_plus - original_loss)/0.0001updated_weights[i][index] -= grad * lr #利用损失变化更新权重,使用学习率领权重变化更稳定return updated_weights, original_loss#返回更新后的权重

合并前向传播和反向传播

构建一个带有隐藏层的简单神经网络,

1)输入连接到具有三个神经元的隐藏层

2)隐藏层连接到具有一个神经元的输出层

#导入相关库
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
x = np.array([[1,1]])
y = np.array(([[0]]))
#随机初始化权重和偏置值
W = [np.array([[-0.05, 0.3793],[-0.5820, -0.5204],[-0.2723, 0.1896]], dtype=np.float32).T, np.array([-0.0140, 0.5607, -0.0628], dtype=np.float32), #隐藏层两个偏置值np.array([[ 0.1528, -0.1745, -0.1135]], dtype=np.float32).T, np.array([-0.5516], dtype=np.float32)#输出层一个偏置值
]#在一百个epoch内执行前向传播和反向传播,使用之前定义的forward和
def forward(inputs,outputs,weights):#如果是首次迭代,随机初始化pre_hidden = np.dot(inputs,weights[0]) + weights[1]#向量点积hidden = 1/(1+np.exp(-pre_hidden))#sigmoid激活函数pred_out = np.dot(hidden,weights[2]) + weights[3]#计算输出mse = np.mean(np.square(pred_out - outputs))#lossreturn msedef update_weights(inputs, outputs, weights, lr):#使用deepcopy可以确保处理多个权重副本,不会影响实际权重original_weights = deepcopy(weights)temp_weights  = deepcopy(weights)updated_weights = deepcopy(weights)original_loss = forward(inputs, outputs, original_weights)#遍历网络的所有层for i, layer in enumerate(original_weights):#循环遍历每个参数列表的所有参数 共四个参数列表,前两个表示输入连接到隐藏层的权重和偏置项参数#另外两个表示链接隐藏层和输出层的偏置参数for index, weight in np.ndenumerate(layer):temp_weights = deepcopy(weights)#原始权重集temp_weights[i][index] += 0.0001#增加很小的值权重更新_loss_plus = forward(inputs, outputs, temp_weights)#更新损失grad = (_loss_plus - original_loss)/0.0001updated_weights[i][index] -= grad * lr #利用损失变化更新权重,使用学习率领权重变化更稳定return updated_weights, original_loss#返回更新后的权重#绘制损失值
losses = []
for  epoch in range(100):W, loss = update_weights(x, y, W, 0.01)losses.append(loss)
plt.plot(losses)
plt.title('loss over increasing number of  epochs')
plt.xlabel('epochs')
plt.ylabel('loss value')
plt.show()
print(W)
#获取更新后的权值之后通过将输出传递给网络对输入进行预测计算输出值
pre_hidden = np.dot(x, W[0]) + W[1]
hidden = 1/(1+np.exp(-pre_hidden))
pre_out = np.dot(hidden, W[2]) + W[3]
print(pre_out)

注:在服务器上运行时im.show()无法展示图片

总结

训练神经网络主要是通过重复两个关键步骤,及用给定的学习率进行前向传播和反向传播,最终得到最佳权重

相关文章:

深度学习笔记1:神经网络与模型训练过程

参考博客:PyTorch深度学习实战(1)——神经网络与模型训练过程详解_pytorch 实战-CSDN博客 人工神经网络 ANN:张量及数学运算的集合,排列方式近似于松散的人脑神经元排列 组成 1)输入层 2)隐…...

什么是 DevOps 自动化?

DevOps 自动化是一种现代软件开发方法,它使用工具和流程来自动化任务并简化工作流程。它将开发人员、IT 运营和安全团队聚集在一起,帮助他们有效协作并交付可靠的软件。借助 DevOps 自动化,组织能够处理重复性任务、优化流程并更快地将应用程…...

使用 Python 操作 MySQL 数据库的实用工具类:MySQLHandler

操作数据库是非常常见的需求,使用 Python 和 pymysql 库封装一个通用的 MySQL 数据库操作工具类,并通过示例演示如何使用这个工具类高效地管理数据库。 工具类的核心代码解析 MySQLHandler 类简介 MySQLHandler 是一个 Python 类,用于简化…...

DB-GPT V0.6.3 版本更新:支持 SiliconCloud 模型、新增知识处理工作流等

DB-GPT V0.6.3版本现已上线,快速预览新特性: 新特性 1. 支持 SiliconCloud 模型,让用户体验多模型的管理能力 如何使用: 修改环境变量文件.env,配置SiliconCloud模型 # 使用 SiliconCloud 的代理模型 LLM_MODELsiliconflow_p…...

亚式期权定价模型Turnbull-Wakeman进行delta对冲

Turnbull-Wakeman Model是一种用于定价和对冲亚式期权的数学模型。该模型由David Turnbull和Keith Wakeman在1990年提出,用于解决亚式期权的定价问题。 亚式期权是一种路径依赖类型的期权,其期权价格与标的资产价格某个期间内的平均值有关,假…...

Java的list中状态属性相同返回true的实现方案

文章目录 项目背景方案一、for循环实现实现思路 方案二、stream实现实现思路 项目背景 在项目中会遇到list中多个状态判断,状态值相等时,总体返回为true。 方案一、for循环实现 实现思路 遍历list,当出现不一致时,直接跳出循环…...

在 React 项目中安装和配置 Three.js

React 与 Three.js 的结合 :通过 React 管理组件化结构和应用逻辑,利用 Three.js 实现 3D 图形的渲染与交互。使用这种方法,我们可以在保持代码清晰和结构化的同时,实现令人惊叹的 3D 效果。 在本文中,我们将以一个简…...

服务器压力测试怎么做

在部署任何Web应用程序或服务之前,进行服务器压力测试(也称为负载测试)是确保系统能够处理预期用户流量的关键步骤。通过模拟大量并发请求,可以评估服务器的性能、稳定性和响应时间,识别潜在瓶颈,并优化资源…...

TCN-Transformer+LSTM多变量回归预测(Matlab)添加气泡图、散点密度图

TCN-TransformerLSTM多变量回归预测(Matlab)添加气泡图、散点密度图 目录 TCN-TransformerLSTM多变量回归预测(Matlab)添加气泡图、散点密度图预测效果基本介绍程序设计参考资料 预测效果 基本介绍 基本介绍 1.双路创新&#xff…...

Mac 查询IP配置,网络代理

常用命令 1.查询IP ifconfig | grep "inet" 2.ping查询 ping 172.18.54.19(自己IP) 3.取消代理,通过在终端执行以下命令,可以取消 Git 的代理设置 git config --global --unset http.proxy git config --global …...

Vue2五、商品分类:My-Tag表头组件,My-Table整个组件

准备: 安包 npm less less-loader。拆分:一共分成两个组件部分: 1:My-Tag 标签一个组件。2:My-Table 整体一个组件(表头不固定,内容不固定(插槽)) 一&…...

梯度下降法求六轴机械臂逆向解

梯度下降法求六轴机械臂逆向解 一、几何基础 对于上述六轴机械臂的数学建模来说,可以构建一个六轴机械臂的运动学正逆解的数学模型,在一个直角坐标系中有如下旋转矩阵: 绕x轴旋转 R x ( θ x ) [ 1 0 0 0 cos ⁡ θ x sin ⁡ θ x 0 − …...

【生成模型之九】Paint by Example: Exemplar-based Image Editing with Diffusion Models

论文:Paint by Example: Exemplar-based Image Editing with Diffusion Models 代码:https://github. com/Fantasy-Studio/Paint-by-Example 为了实现高质量的基于样本的图像编辑,我们引入了四项关键技术,即利用图像先验、强数据-mask增强、内容瓶颈CLIP class token和无…...

集成RabbitMQ+MQ常用操作

文章目录 1.环境搭建1.Docker安装RabbitMQ1.拉取镜像2.安装命令3.开启5672和15672端口4.登录控制台 2.整合Spring AMQP1.sun-common模块下创建新模块2.引入amqp依赖和fastjson 3.新建一个mq-demo的模块1.在sun-frame下创建mq-demo2.然后在mq-demo下创建生产者和消费者子模块3.查…...

PVE虚拟化平台之开启虚拟机IP显示方法

PVE虚拟化平台之开启虚拟机IP显示方法 一、PVE平台介绍1.1 PVE简介1.2 PVE特点1.3 PVE主要使用场景二、检查PVE环境2.1 环境介绍2.2 检查PVE和虚拟机状态三、虚拟机开启Qemu代理四、Linux虚拟机安装Guest-Agent4.1 进入虚拟机VNC控制台4.2 查看虚拟机IP五、Windows虚拟机安装Gu…...

子Shell及Shell嵌套模式

子Shell 概念 Shell子进程,Shell脚本是从上至下,从左至右依次执行每一行的命令及语句的,即执行完一个命令之后再执行下一个。如果在shell脚本中遇到子脚本(即脚本嵌套),就会先执行子脚本的内容,完成后再返回父脚本继…...

Onedrive精神分裂怎么办(有变更却不同步)

Onedrive有时候会分裂,你在本地删除文件,并没有同步到云端,但是本地却显示同步成功。 比如删掉了一个目录,在本地看已经删掉,onedrive显示已同步,但是别的电脑并不会同步到这个删除操作,在网页版…...

【gym】给定的强化学习环境简介(二)

文章目录 环境介绍一 box2dbipedal_walkercar_dynamicscar_racinglunar_lander 二、 classic_controlacrobotCartPolecontinuous_mountain_carmountain_carpendulum 三、toy_textblackjackcliffwalkingfrozentaxi 四、mujocoAnt:HalfCheetah:Hopper&…...

ctfhub disable_functions关卡

1.CTFHub Bypass disable_function —— LD_PRELOAD 2.CTFHub Bypass disable_function —— ShellShock 3.CTFHub Bypass disable_function —— Apache Mod CGI 4.CTFHub Bypass disable_function —— 攻击PHP-FPM 5.CTFHub Bypass disable_function —— GC UAF 6.CTFHub B…...

SpringAI人工智能开发框架006---SpringAI多模态接口_编程测试springai多模态接口支持

可以看到springai对多模态的支持. 同样去创建一个项目 也是跟之前的项目一样,修改版本1.0.0 这里 然后修改仓库地址,为springai的地址 然后开始写代码...

uniapp 对接腾讯云IM群组成员管理(增删改查)

UniApp 实战:腾讯云IM群组成员管理(增删改查) 一、前言 在社交类App开发中,群组成员管理是核心功能之一。本文将基于UniApp框架,结合腾讯云IM SDK,详细讲解如何实现群组成员的增删改查全流程。 权限校验…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理&#xff1a…...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

day52 ResNet18 CBAM

在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

vscode(仍待补充)

写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh? debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...

(二)TensorRT-LLM | 模型导出(v0.20.0rc3)

0. 概述 上一节 对安装和使用有个基本介绍。根据这个 issue 的描述,后续 TensorRT-LLM 团队可能更专注于更新和维护 pytorch backend。但 tensorrt backend 作为先前一直开发的工作,其中包含了大量可以学习的地方。本文主要看看它导出模型的部分&#x…...

五年级数学知识边界总结思考-下册

目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...