当前位置: 首页 > news >正文

深度学习笔记1:神经网络与模型训练过程

参考博客:PyTorch深度学习实战(1)——神经网络与模型训练过程详解_pytorch 实战-CSDN博客

人工神经网络

ANN:张量及数学运算的集合,排列方式近似于松散的人脑神经元排列

组成

1)输入层

2)隐藏层(中间层):连接输入层和输出层,在输入数据上执行转换,此外隐藏层利用神经元将输入值修改为更高、更低维的值,通过修改中间结点的激活函数可以实现复杂表示函数

3)输出层

训练

本质就是通过重复前向传播和后向传播两个关键步骤来调整神经网络的参数

前向传播--》输入经过隐藏层的到输出结果,第一次正向传播 权重初始化 计算预测值

后向传播--》根据误差相应调整权重来减小误差 修正参数

神经网络重复正向传播与反向传播以预测输出,指导获得令误差较小的权重为止

前向传播

1)输入值乘以权重计算隐藏层值

2)计算激活值

3)在每个神经元上重复前两个步骤,直到输出层

4)计算loss

封装为函数

def forward(inputs,outputs,weights):#如果是首次迭代,随机初始化pre_hidden = np.dot(inputs,weights[0]) + weights[1]#向量点积hidden = 1/(1+np.exp(-pre_hidden))#sigmoid激活函数pred_out = np.dot(hidden,weights[2]) + weights[3]#计算输出mse = np.mean(np.square(pred_out - outputs))#lossreturn mse

反向传播

与前向传播相反,利用从前向传播中计算的损失值,以最小化损失值为目标更新网络权重

1)每次对神经网络中的每个权重进行少量修改

2)测量权重变化时的损失变化 ---》损失值关于权重的梯度

3)计算-α δL/δW 更新权重 α为学习率

如果改变权重损失变化大那么就大幅更新权重

否则 小幅跟新权重

在整个数据集上执行n次前向传播以及后向传播,表示模型进行了n个epoch的训练 ,执行一次算一个epoch

学习率

有助于构建更稳定的算法

梯度下降

更新权重以减小误差值的整个过程

SGD是将误差最小化的一种方法:随计算则数据中的训练数据样本,并根据该样本做出决策

实现梯度下降算法

1)定义前馈神经网络并计算均方误差值

2)为每个权重和偏执项增加一个非常小的量0.001,并针对每个权重和偏差的更新计算一个mse

from copy import deepcopy 
import numpy as np
#创建update_weights函数,执行梯度下降来更新权重
def update_weights(inputs, outputs, weights, lr):#使用deepcopy可以确保处理多个权重副本,不会影响实际权重original_weights = deepcopy(weights)temp_weights  = deepcopy(weights)updated_weights = deepcopy(weights)original_loss = forward(inputs, outputs, original_weigts)#遍历网络的所有层for i, layer in enumerate(orignial_weights):#循环遍历每个参数列表的所有参数 共四个参数列表,前两个表示输入连接到隐藏层的权重和偏置项参数#另外两个表示链接隐藏层和输出层的偏置参数for index, weight in np.ndenumerate(layer):temp_weights = deepcopy(weights)#原始权重集temp_weights[i][index] += 0.0001#增加很小的值权重更新_loss_plus = forward(inputs, outputs, temp_weights)#更新损失grad = (_loss_plus - original_loss)/0.0001updated_weights[i][index] -= grad * lr #利用损失变化更新权重,使用学习率领权重变化更稳定return updated_weights, original_loss#返回更新后的权重

合并前向传播和反向传播

构建一个带有隐藏层的简单神经网络,

1)输入连接到具有三个神经元的隐藏层

2)隐藏层连接到具有一个神经元的输出层

#导入相关库
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
x = np.array([[1,1]])
y = np.array(([[0]]))
#随机初始化权重和偏置值
W = [np.array([[-0.05, 0.3793],[-0.5820, -0.5204],[-0.2723, 0.1896]], dtype=np.float32).T, np.array([-0.0140, 0.5607, -0.0628], dtype=np.float32), #隐藏层两个偏置值np.array([[ 0.1528, -0.1745, -0.1135]], dtype=np.float32).T, np.array([-0.5516], dtype=np.float32)#输出层一个偏置值
]#在一百个epoch内执行前向传播和反向传播,使用之前定义的forward和
def forward(inputs,outputs,weights):#如果是首次迭代,随机初始化pre_hidden = np.dot(inputs,weights[0]) + weights[1]#向量点积hidden = 1/(1+np.exp(-pre_hidden))#sigmoid激活函数pred_out = np.dot(hidden,weights[2]) + weights[3]#计算输出mse = np.mean(np.square(pred_out - outputs))#lossreturn msedef update_weights(inputs, outputs, weights, lr):#使用deepcopy可以确保处理多个权重副本,不会影响实际权重original_weights = deepcopy(weights)temp_weights  = deepcopy(weights)updated_weights = deepcopy(weights)original_loss = forward(inputs, outputs, original_weights)#遍历网络的所有层for i, layer in enumerate(original_weights):#循环遍历每个参数列表的所有参数 共四个参数列表,前两个表示输入连接到隐藏层的权重和偏置项参数#另外两个表示链接隐藏层和输出层的偏置参数for index, weight in np.ndenumerate(layer):temp_weights = deepcopy(weights)#原始权重集temp_weights[i][index] += 0.0001#增加很小的值权重更新_loss_plus = forward(inputs, outputs, temp_weights)#更新损失grad = (_loss_plus - original_loss)/0.0001updated_weights[i][index] -= grad * lr #利用损失变化更新权重,使用学习率领权重变化更稳定return updated_weights, original_loss#返回更新后的权重#绘制损失值
losses = []
for  epoch in range(100):W, loss = update_weights(x, y, W, 0.01)losses.append(loss)
plt.plot(losses)
plt.title('loss over increasing number of  epochs')
plt.xlabel('epochs')
plt.ylabel('loss value')
plt.show()
print(W)
#获取更新后的权值之后通过将输出传递给网络对输入进行预测计算输出值
pre_hidden = np.dot(x, W[0]) + W[1]
hidden = 1/(1+np.exp(-pre_hidden))
pre_out = np.dot(hidden, W[2]) + W[3]
print(pre_out)

注:在服务器上运行时im.show()无法展示图片

总结

训练神经网络主要是通过重复两个关键步骤,及用给定的学习率进行前向传播和反向传播,最终得到最佳权重

相关文章:

深度学习笔记1:神经网络与模型训练过程

参考博客:PyTorch深度学习实战(1)——神经网络与模型训练过程详解_pytorch 实战-CSDN博客 人工神经网络 ANN:张量及数学运算的集合,排列方式近似于松散的人脑神经元排列 组成 1)输入层 2)隐…...

什么是 DevOps 自动化?

DevOps 自动化是一种现代软件开发方法,它使用工具和流程来自动化任务并简化工作流程。它将开发人员、IT 运营和安全团队聚集在一起,帮助他们有效协作并交付可靠的软件。借助 DevOps 自动化,组织能够处理重复性任务、优化流程并更快地将应用程…...

使用 Python 操作 MySQL 数据库的实用工具类:MySQLHandler

操作数据库是非常常见的需求,使用 Python 和 pymysql 库封装一个通用的 MySQL 数据库操作工具类,并通过示例演示如何使用这个工具类高效地管理数据库。 工具类的核心代码解析 MySQLHandler 类简介 MySQLHandler 是一个 Python 类,用于简化…...

DB-GPT V0.6.3 版本更新:支持 SiliconCloud 模型、新增知识处理工作流等

DB-GPT V0.6.3版本现已上线,快速预览新特性: 新特性 1. 支持 SiliconCloud 模型,让用户体验多模型的管理能力 如何使用: 修改环境变量文件.env,配置SiliconCloud模型 # 使用 SiliconCloud 的代理模型 LLM_MODELsiliconflow_p…...

亚式期权定价模型Turnbull-Wakeman进行delta对冲

Turnbull-Wakeman Model是一种用于定价和对冲亚式期权的数学模型。该模型由David Turnbull和Keith Wakeman在1990年提出,用于解决亚式期权的定价问题。 亚式期权是一种路径依赖类型的期权,其期权价格与标的资产价格某个期间内的平均值有关,假…...

Java的list中状态属性相同返回true的实现方案

文章目录 项目背景方案一、for循环实现实现思路 方案二、stream实现实现思路 项目背景 在项目中会遇到list中多个状态判断,状态值相等时,总体返回为true。 方案一、for循环实现 实现思路 遍历list,当出现不一致时,直接跳出循环…...

在 React 项目中安装和配置 Three.js

React 与 Three.js 的结合 :通过 React 管理组件化结构和应用逻辑,利用 Three.js 实现 3D 图形的渲染与交互。使用这种方法,我们可以在保持代码清晰和结构化的同时,实现令人惊叹的 3D 效果。 在本文中,我们将以一个简…...

服务器压力测试怎么做

在部署任何Web应用程序或服务之前,进行服务器压力测试(也称为负载测试)是确保系统能够处理预期用户流量的关键步骤。通过模拟大量并发请求,可以评估服务器的性能、稳定性和响应时间,识别潜在瓶颈,并优化资源…...

TCN-Transformer+LSTM多变量回归预测(Matlab)添加气泡图、散点密度图

TCN-TransformerLSTM多变量回归预测(Matlab)添加气泡图、散点密度图 目录 TCN-TransformerLSTM多变量回归预测(Matlab)添加气泡图、散点密度图预测效果基本介绍程序设计参考资料 预测效果 基本介绍 基本介绍 1.双路创新&#xff…...

Mac 查询IP配置,网络代理

常用命令 1.查询IP ifconfig | grep "inet" 2.ping查询 ping 172.18.54.19(自己IP) 3.取消代理,通过在终端执行以下命令,可以取消 Git 的代理设置 git config --global --unset http.proxy git config --global …...

Vue2五、商品分类:My-Tag表头组件,My-Table整个组件

准备: 安包 npm less less-loader。拆分:一共分成两个组件部分: 1:My-Tag 标签一个组件。2:My-Table 整体一个组件(表头不固定,内容不固定(插槽)) 一&…...

梯度下降法求六轴机械臂逆向解

梯度下降法求六轴机械臂逆向解 一、几何基础 对于上述六轴机械臂的数学建模来说,可以构建一个六轴机械臂的运动学正逆解的数学模型,在一个直角坐标系中有如下旋转矩阵: 绕x轴旋转 R x ( θ x ) [ 1 0 0 0 cos ⁡ θ x sin ⁡ θ x 0 − …...

【生成模型之九】Paint by Example: Exemplar-based Image Editing with Diffusion Models

论文:Paint by Example: Exemplar-based Image Editing with Diffusion Models 代码:https://github. com/Fantasy-Studio/Paint-by-Example 为了实现高质量的基于样本的图像编辑,我们引入了四项关键技术,即利用图像先验、强数据-mask增强、内容瓶颈CLIP class token和无…...

集成RabbitMQ+MQ常用操作

文章目录 1.环境搭建1.Docker安装RabbitMQ1.拉取镜像2.安装命令3.开启5672和15672端口4.登录控制台 2.整合Spring AMQP1.sun-common模块下创建新模块2.引入amqp依赖和fastjson 3.新建一个mq-demo的模块1.在sun-frame下创建mq-demo2.然后在mq-demo下创建生产者和消费者子模块3.查…...

PVE虚拟化平台之开启虚拟机IP显示方法

PVE虚拟化平台之开启虚拟机IP显示方法 一、PVE平台介绍1.1 PVE简介1.2 PVE特点1.3 PVE主要使用场景二、检查PVE环境2.1 环境介绍2.2 检查PVE和虚拟机状态三、虚拟机开启Qemu代理四、Linux虚拟机安装Guest-Agent4.1 进入虚拟机VNC控制台4.2 查看虚拟机IP五、Windows虚拟机安装Gu…...

子Shell及Shell嵌套模式

子Shell 概念 Shell子进程,Shell脚本是从上至下,从左至右依次执行每一行的命令及语句的,即执行完一个命令之后再执行下一个。如果在shell脚本中遇到子脚本(即脚本嵌套),就会先执行子脚本的内容,完成后再返回父脚本继…...

Onedrive精神分裂怎么办(有变更却不同步)

Onedrive有时候会分裂,你在本地删除文件,并没有同步到云端,但是本地却显示同步成功。 比如删掉了一个目录,在本地看已经删掉,onedrive显示已同步,但是别的电脑并不会同步到这个删除操作,在网页版…...

【gym】给定的强化学习环境简介(二)

文章目录 环境介绍一 box2dbipedal_walkercar_dynamicscar_racinglunar_lander 二、 classic_controlacrobotCartPolecontinuous_mountain_carmountain_carpendulum 三、toy_textblackjackcliffwalkingfrozentaxi 四、mujocoAnt:HalfCheetah:Hopper&…...

ctfhub disable_functions关卡

1.CTFHub Bypass disable_function —— LD_PRELOAD 2.CTFHub Bypass disable_function —— ShellShock 3.CTFHub Bypass disable_function —— Apache Mod CGI 4.CTFHub Bypass disable_function —— 攻击PHP-FPM 5.CTFHub Bypass disable_function —— GC UAF 6.CTFHub B…...

SpringAI人工智能开发框架006---SpringAI多模态接口_编程测试springai多模态接口支持

可以看到springai对多模态的支持. 同样去创建一个项目 也是跟之前的项目一样,修改版本1.0.0 这里 然后修改仓库地址,为springai的地址 然后开始写代码...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...

转转集团旗下首家二手多品类循环仓店“超级转转”开业

6月9日,国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解,“超级…...

2025盘古石杯决赛【手机取证】

前言 第三届盘古石杯国际电子数据取证大赛决赛 最后一题没有解出来,实在找不到,希望有大佬教一下我。 还有就会议时间,我感觉不是图片时间,因为在电脑看到是其他时间用老会议系统开的会。 手机取证 1、分析鸿蒙手机检材&#x…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O(n) 时间复杂度…...

力扣-35.搜索插入位置

题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

【JavaSE】多线程基础学习笔记

多线程基础 -线程相关概念 程序(Program) 是为完成特定任务、用某种语言编写的一组指令的集合简单的说:就是我们写的代码 进程 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程分配内存…...

R 语言科研绘图第 55 期 --- 网络图-聚类

在发表科研论文的过程中,科研绘图是必不可少的,一张好看的图形会是文章很大的加分项。 为了便于使用,本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中,获取方式: R 语言科研绘图模板 --- sciRplothttps://mp.…...

HubSpot推出与ChatGPT的深度集成引发兴奋与担忧

上周三,HubSpot宣布已构建与ChatGPT的深度集成,这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋,但同时也存在一些关于数据安全的担忧。 许多网络声音声称,这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...