当前位置: 首页 > news >正文

深度学习笔记1:神经网络与模型训练过程

参考博客:PyTorch深度学习实战(1)——神经网络与模型训练过程详解_pytorch 实战-CSDN博客

人工神经网络

ANN:张量及数学运算的集合,排列方式近似于松散的人脑神经元排列

组成

1)输入层

2)隐藏层(中间层):连接输入层和输出层,在输入数据上执行转换,此外隐藏层利用神经元将输入值修改为更高、更低维的值,通过修改中间结点的激活函数可以实现复杂表示函数

3)输出层

训练

本质就是通过重复前向传播和后向传播两个关键步骤来调整神经网络的参数

前向传播--》输入经过隐藏层的到输出结果,第一次正向传播 权重初始化 计算预测值

后向传播--》根据误差相应调整权重来减小误差 修正参数

神经网络重复正向传播与反向传播以预测输出,指导获得令误差较小的权重为止

前向传播

1)输入值乘以权重计算隐藏层值

2)计算激活值

3)在每个神经元上重复前两个步骤,直到输出层

4)计算loss

封装为函数

def forward(inputs,outputs,weights):#如果是首次迭代,随机初始化pre_hidden = np.dot(inputs,weights[0]) + weights[1]#向量点积hidden = 1/(1+np.exp(-pre_hidden))#sigmoid激活函数pred_out = np.dot(hidden,weights[2]) + weights[3]#计算输出mse = np.mean(np.square(pred_out - outputs))#lossreturn mse

反向传播

与前向传播相反,利用从前向传播中计算的损失值,以最小化损失值为目标更新网络权重

1)每次对神经网络中的每个权重进行少量修改

2)测量权重变化时的损失变化 ---》损失值关于权重的梯度

3)计算-α δL/δW 更新权重 α为学习率

如果改变权重损失变化大那么就大幅更新权重

否则 小幅跟新权重

在整个数据集上执行n次前向传播以及后向传播,表示模型进行了n个epoch的训练 ,执行一次算一个epoch

学习率

有助于构建更稳定的算法

梯度下降

更新权重以减小误差值的整个过程

SGD是将误差最小化的一种方法:随计算则数据中的训练数据样本,并根据该样本做出决策

实现梯度下降算法

1)定义前馈神经网络并计算均方误差值

2)为每个权重和偏执项增加一个非常小的量0.001,并针对每个权重和偏差的更新计算一个mse

from copy import deepcopy 
import numpy as np
#创建update_weights函数,执行梯度下降来更新权重
def update_weights(inputs, outputs, weights, lr):#使用deepcopy可以确保处理多个权重副本,不会影响实际权重original_weights = deepcopy(weights)temp_weights  = deepcopy(weights)updated_weights = deepcopy(weights)original_loss = forward(inputs, outputs, original_weigts)#遍历网络的所有层for i, layer in enumerate(orignial_weights):#循环遍历每个参数列表的所有参数 共四个参数列表,前两个表示输入连接到隐藏层的权重和偏置项参数#另外两个表示链接隐藏层和输出层的偏置参数for index, weight in np.ndenumerate(layer):temp_weights = deepcopy(weights)#原始权重集temp_weights[i][index] += 0.0001#增加很小的值权重更新_loss_plus = forward(inputs, outputs, temp_weights)#更新损失grad = (_loss_plus - original_loss)/0.0001updated_weights[i][index] -= grad * lr #利用损失变化更新权重,使用学习率领权重变化更稳定return updated_weights, original_loss#返回更新后的权重

合并前向传播和反向传播

构建一个带有隐藏层的简单神经网络,

1)输入连接到具有三个神经元的隐藏层

2)隐藏层连接到具有一个神经元的输出层

#导入相关库
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
x = np.array([[1,1]])
y = np.array(([[0]]))
#随机初始化权重和偏置值
W = [np.array([[-0.05, 0.3793],[-0.5820, -0.5204],[-0.2723, 0.1896]], dtype=np.float32).T, np.array([-0.0140, 0.5607, -0.0628], dtype=np.float32), #隐藏层两个偏置值np.array([[ 0.1528, -0.1745, -0.1135]], dtype=np.float32).T, np.array([-0.5516], dtype=np.float32)#输出层一个偏置值
]#在一百个epoch内执行前向传播和反向传播,使用之前定义的forward和
def forward(inputs,outputs,weights):#如果是首次迭代,随机初始化pre_hidden = np.dot(inputs,weights[0]) + weights[1]#向量点积hidden = 1/(1+np.exp(-pre_hidden))#sigmoid激活函数pred_out = np.dot(hidden,weights[2]) + weights[3]#计算输出mse = np.mean(np.square(pred_out - outputs))#lossreturn msedef update_weights(inputs, outputs, weights, lr):#使用deepcopy可以确保处理多个权重副本,不会影响实际权重original_weights = deepcopy(weights)temp_weights  = deepcopy(weights)updated_weights = deepcopy(weights)original_loss = forward(inputs, outputs, original_weights)#遍历网络的所有层for i, layer in enumerate(original_weights):#循环遍历每个参数列表的所有参数 共四个参数列表,前两个表示输入连接到隐藏层的权重和偏置项参数#另外两个表示链接隐藏层和输出层的偏置参数for index, weight in np.ndenumerate(layer):temp_weights = deepcopy(weights)#原始权重集temp_weights[i][index] += 0.0001#增加很小的值权重更新_loss_plus = forward(inputs, outputs, temp_weights)#更新损失grad = (_loss_plus - original_loss)/0.0001updated_weights[i][index] -= grad * lr #利用损失变化更新权重,使用学习率领权重变化更稳定return updated_weights, original_loss#返回更新后的权重#绘制损失值
losses = []
for  epoch in range(100):W, loss = update_weights(x, y, W, 0.01)losses.append(loss)
plt.plot(losses)
plt.title('loss over increasing number of  epochs')
plt.xlabel('epochs')
plt.ylabel('loss value')
plt.show()
print(W)
#获取更新后的权值之后通过将输出传递给网络对输入进行预测计算输出值
pre_hidden = np.dot(x, W[0]) + W[1]
hidden = 1/(1+np.exp(-pre_hidden))
pre_out = np.dot(hidden, W[2]) + W[3]
print(pre_out)

注:在服务器上运行时im.show()无法展示图片

总结

训练神经网络主要是通过重复两个关键步骤,及用给定的学习率进行前向传播和反向传播,最终得到最佳权重

相关文章:

深度学习笔记1:神经网络与模型训练过程

参考博客:PyTorch深度学习实战(1)——神经网络与模型训练过程详解_pytorch 实战-CSDN博客 人工神经网络 ANN:张量及数学运算的集合,排列方式近似于松散的人脑神经元排列 组成 1)输入层 2)隐…...

什么是 DevOps 自动化?

DevOps 自动化是一种现代软件开发方法,它使用工具和流程来自动化任务并简化工作流程。它将开发人员、IT 运营和安全团队聚集在一起,帮助他们有效协作并交付可靠的软件。借助 DevOps 自动化,组织能够处理重复性任务、优化流程并更快地将应用程…...

使用 Python 操作 MySQL 数据库的实用工具类:MySQLHandler

操作数据库是非常常见的需求,使用 Python 和 pymysql 库封装一个通用的 MySQL 数据库操作工具类,并通过示例演示如何使用这个工具类高效地管理数据库。 工具类的核心代码解析 MySQLHandler 类简介 MySQLHandler 是一个 Python 类,用于简化…...

DB-GPT V0.6.3 版本更新:支持 SiliconCloud 模型、新增知识处理工作流等

DB-GPT V0.6.3版本现已上线,快速预览新特性: 新特性 1. 支持 SiliconCloud 模型,让用户体验多模型的管理能力 如何使用: 修改环境变量文件.env,配置SiliconCloud模型 # 使用 SiliconCloud 的代理模型 LLM_MODELsiliconflow_p…...

亚式期权定价模型Turnbull-Wakeman进行delta对冲

Turnbull-Wakeman Model是一种用于定价和对冲亚式期权的数学模型。该模型由David Turnbull和Keith Wakeman在1990年提出,用于解决亚式期权的定价问题。 亚式期权是一种路径依赖类型的期权,其期权价格与标的资产价格某个期间内的平均值有关,假…...

Java的list中状态属性相同返回true的实现方案

文章目录 项目背景方案一、for循环实现实现思路 方案二、stream实现实现思路 项目背景 在项目中会遇到list中多个状态判断,状态值相等时,总体返回为true。 方案一、for循环实现 实现思路 遍历list,当出现不一致时,直接跳出循环…...

在 React 项目中安装和配置 Three.js

React 与 Three.js 的结合 :通过 React 管理组件化结构和应用逻辑,利用 Three.js 实现 3D 图形的渲染与交互。使用这种方法,我们可以在保持代码清晰和结构化的同时,实现令人惊叹的 3D 效果。 在本文中,我们将以一个简…...

服务器压力测试怎么做

在部署任何Web应用程序或服务之前,进行服务器压力测试(也称为负载测试)是确保系统能够处理预期用户流量的关键步骤。通过模拟大量并发请求,可以评估服务器的性能、稳定性和响应时间,识别潜在瓶颈,并优化资源…...

TCN-Transformer+LSTM多变量回归预测(Matlab)添加气泡图、散点密度图

TCN-TransformerLSTM多变量回归预测(Matlab)添加气泡图、散点密度图 目录 TCN-TransformerLSTM多变量回归预测(Matlab)添加气泡图、散点密度图预测效果基本介绍程序设计参考资料 预测效果 基本介绍 基本介绍 1.双路创新&#xff…...

Mac 查询IP配置,网络代理

常用命令 1.查询IP ifconfig | grep "inet" 2.ping查询 ping 172.18.54.19(自己IP) 3.取消代理,通过在终端执行以下命令,可以取消 Git 的代理设置 git config --global --unset http.proxy git config --global …...

Vue2五、商品分类:My-Tag表头组件,My-Table整个组件

准备: 安包 npm less less-loader。拆分:一共分成两个组件部分: 1:My-Tag 标签一个组件。2:My-Table 整体一个组件(表头不固定,内容不固定(插槽)) 一&…...

梯度下降法求六轴机械臂逆向解

梯度下降法求六轴机械臂逆向解 一、几何基础 对于上述六轴机械臂的数学建模来说,可以构建一个六轴机械臂的运动学正逆解的数学模型,在一个直角坐标系中有如下旋转矩阵: 绕x轴旋转 R x ( θ x ) [ 1 0 0 0 cos ⁡ θ x sin ⁡ θ x 0 − …...

【生成模型之九】Paint by Example: Exemplar-based Image Editing with Diffusion Models

论文:Paint by Example: Exemplar-based Image Editing with Diffusion Models 代码:https://github. com/Fantasy-Studio/Paint-by-Example 为了实现高质量的基于样本的图像编辑,我们引入了四项关键技术,即利用图像先验、强数据-mask增强、内容瓶颈CLIP class token和无…...

集成RabbitMQ+MQ常用操作

文章目录 1.环境搭建1.Docker安装RabbitMQ1.拉取镜像2.安装命令3.开启5672和15672端口4.登录控制台 2.整合Spring AMQP1.sun-common模块下创建新模块2.引入amqp依赖和fastjson 3.新建一个mq-demo的模块1.在sun-frame下创建mq-demo2.然后在mq-demo下创建生产者和消费者子模块3.查…...

PVE虚拟化平台之开启虚拟机IP显示方法

PVE虚拟化平台之开启虚拟机IP显示方法 一、PVE平台介绍1.1 PVE简介1.2 PVE特点1.3 PVE主要使用场景二、检查PVE环境2.1 环境介绍2.2 检查PVE和虚拟机状态三、虚拟机开启Qemu代理四、Linux虚拟机安装Guest-Agent4.1 进入虚拟机VNC控制台4.2 查看虚拟机IP五、Windows虚拟机安装Gu…...

子Shell及Shell嵌套模式

子Shell 概念 Shell子进程,Shell脚本是从上至下,从左至右依次执行每一行的命令及语句的,即执行完一个命令之后再执行下一个。如果在shell脚本中遇到子脚本(即脚本嵌套),就会先执行子脚本的内容,完成后再返回父脚本继…...

Onedrive精神分裂怎么办(有变更却不同步)

Onedrive有时候会分裂,你在本地删除文件,并没有同步到云端,但是本地却显示同步成功。 比如删掉了一个目录,在本地看已经删掉,onedrive显示已同步,但是别的电脑并不会同步到这个删除操作,在网页版…...

【gym】给定的强化学习环境简介(二)

文章目录 环境介绍一 box2dbipedal_walkercar_dynamicscar_racinglunar_lander 二、 classic_controlacrobotCartPolecontinuous_mountain_carmountain_carpendulum 三、toy_textblackjackcliffwalkingfrozentaxi 四、mujocoAnt:HalfCheetah:Hopper&…...

ctfhub disable_functions关卡

1.CTFHub Bypass disable_function —— LD_PRELOAD 2.CTFHub Bypass disable_function —— ShellShock 3.CTFHub Bypass disable_function —— Apache Mod CGI 4.CTFHub Bypass disable_function —— 攻击PHP-FPM 5.CTFHub Bypass disable_function —— GC UAF 6.CTFHub B…...

SpringAI人工智能开发框架006---SpringAI多模态接口_编程测试springai多模态接口支持

可以看到springai对多模态的支持. 同样去创建一个项目 也是跟之前的项目一样,修改版本1.0.0 这里 然后修改仓库地址,为springai的地址 然后开始写代码...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑:陈萍萍的公主一点人工一点智能 未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战,在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)

HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...

el-switch文字内置

el-switch文字内置 效果 vue <div style"color:#ffffff;font-size:14px;float:left;margin-bottom:5px;margin-right:5px;">自动加载</div> <el-switch v-model"value" active-color"#3E99FB" inactive-color"#DCDFE6"…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

OpenLayers 分屏对比(地图联动)

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能&#xff0c;和卷帘图层不一样的是&#xff0c;分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

嵌入式常见 CPU 架构

架构类型架构厂商芯片厂商典型芯片特点与应用场景PICRISC (8/16 位)MicrochipMicrochipPIC16F877A、PIC18F4550简化指令集&#xff0c;单周期执行&#xff1b;低功耗、CIP 独立外设&#xff1b;用于家电、小电机控制、安防面板等嵌入式场景8051CISC (8 位)Intel&#xff08;原始…...