当前位置: 首页 > article >正文

跟着StatQuest学知识07-张量与PyTorch

一、张量tensor

张量重新命名一些数据概念,存储数据以及权重和偏置。

张量还允许与数据相关的数学计算能够相对快速的完成。

通常,张量及其进行的数学计算会通过成为图形处理单元(GPUs)的特殊芯片来加速。但还有张量处理单元(TPUs)专门处理张量,使得神经网络运行相当更快。

另外,张量通过自动微分处理反向传播。

二、PyTorch

以下部分参考 【深度学习基础】用PyTorch从零开始搭建DNN深度神经网络

 图中的这个神经网络的参数都是训练优化好的,下面我们简便起见,假设最后一个参数b_final没有优化过,初始化为0,我们尝试用Pytorch实现一下对这个参数的优化,将final_bias初始化为0,看看最终这个-16可否被优化出来的。首先引入一些相关的库:

import torch
import torch.nn as nn
import torch.nn.functional as Fimport matplotlib.pyplot as plt
import seaborn as sns

其中torch就是PyTorch框架,matplotlib和seaborn都是用来绘图的库。然后我们定义对照着图中的各个参数,搭建神经网络如下: 

class BasicNN_train(nn.Module):  # 继承父类nn.Moduledef __init__(self):super().__init__()  # 对父类的成员进行初始化self.w00 = nn.Parameter(torch.tensor(1.7), requires_grad=False)self.b00 = nn.Parameter(torch.tensor(-0.85), requires_grad=False)self.w01 = nn.Parameter(torch.tensor(-40.8), requires_grad=False)self.w10 = nn.Parameter(torch.tensor(12.6), requires_grad=False)self.b10 = nn.Parameter(torch.tensor(0.0), requires_grad=False)self.w11 = nn.Parameter(torch.tensor(2.7), requires_grad=False)self.final_bias = nn.Parameter(torch.tensor(0.0), requires_grad=True)# requires_grad=True 表示需要优化def forward(self, input):  # 前向传播input_to_top_relu = input * self.w00 + self.b00top_relu_output = F.relu(input_to_top_relu)scaled_top_relu_output = top_relu_output * self.w01input_to_bottom_relu = input * self.w10 + self.b10bottom_relu_output = F.relu(input_to_bottom_relu)scaled_bottom_relu_output = bottom_relu_output * self.w11input_to_final_relu = scaled_top_relu_output + scaled_bottom_relu_output + self.final_biasoutput = F.relu(input_to_final_relu)return output

然后我们实例化这个网路,设定epoch=100,即最多进行100次前向和反向传播,定义损失函数就是预测值和实际值的平方误差,当损失函数之和低于0.0001时,我们就停止训练(最多训练100轮次),代码如下:

if __name__ == '__main__':model = BasicNN_train()  # 实例化神经网络模型inputs = torch.tensor([0., 0.5, 1.])  # 输入张量labels = torch.tensor([0., 1., 0.])  # 输出张量# 定义一个优化器 optimizer,使用随机梯度下降(SGD)算法来更新模型的参数optimizer = torch.optim.SGD(model.parameters(), lr=0.1)  # 学习率为0.1print("优化前的final_bias是:" + str(model.final_bias.data) + '\n')# 开始训练,最多100轮次for epoch in range(100):total_loss = 0  # 累积当前 epoch 中所有样本的损失值for iteration in range(len(inputs)): # len(inputs) 表示数据集中样本的数量input_i = inputs[iteration]label_i = labels[iteration]output_i = model(input_i) # 前向传播loss = (output_i - label_i) ** 2loss.backward() # 反向传播# 通过反向传播,PyTorch 会自动计算每个参数的梯度,并存储在参数的 .grad 属性中total_loss += float(loss)# 将每个样本的loss加和
  •  backward() 的功能:

        backward() 使用链式法则计算损失函数 loss 对模型参数的梯度。

        loss.backward() 是从 loss 开始,沿着计算图反向传播梯度,最终得到每个参数的梯度值。这些梯度值(数据)会被存储在模型参数的 .grad 属性中,用于后续的参数更新。

  • 正向传播是怎么实现的?

        model(input_i) 会自动调用 model 中定义的 forward 方法。

        在 Python 中,当一个类的实例被“调用”时(例如 model(input_i)),Python 会尝试调用该实例的 __call__ 方法。

        PyTorch 的 nn.Module 类实现了 __call__ 方法。当你调用 model(input_i) 时,实际上是调用了 model.__call__(input_i)。

        if total_loss < 0.0001:print(f"当前是第{epoch}轮次,已经满足total_loss < 0.0001,结束程序。")breakoptimizer.step()  # 使用优化器(如 SGD)更新模型的权重和偏置,以最小化损失函数。optimizer.zero_grad()  # 清除模型参数的梯度。print(f"当前是第{epoch}轮次,此时的final_bias值为{model.final_bias.data},total_loss为{total_loss}")# 画图如下input_doses = torch.linspace(start=0, end=1, steps=11)output_values = model(input_doses)sns.set(style="whitegrid")sns.lineplot(x=input_doses,y=output_values.detach(),color='green',linewidth=2.5)plt.ylabel('Effectiveness')plt.xlabel('Dose')plt.show()print(f"优化后的final_bias值为:{model.final_bias.data}")

最终的输出结果如下:

  一共34轮训练后,就实现了总损失小于0.001的要求,也看到最终的优化结果final_bia大概是-16,与之前我们的结论一致。 损失函数变化曲线如下:

    最终迭代到第34轮次后,实现了最终的效果: 

相关文章:

跟着StatQuest学知识07-张量与PyTorch

一、张量tensor 张量重新命名一些数据概念&#xff0c;存储数据以及权重和偏置。 张量还允许与数据相关的数学计算能够相对快速的完成。 通常&#xff0c;张量及其进行的数学计算会通过成为图形处理单元&#xff08;GPUs&#xff09;的特殊芯片来加速。但还有张量处理单元&am…...

nginx配置https域名后,代理后端服务器流式接口变慢

目录 问题描述原因解决办法 问题描述 使用nginx配置域名和https的ssl证书后&#xff0c;代理后端web服务器&#xff0c;发现流式接口比原来直接用服务器外部ip后端web服务器端口变慢了很多。 原因 在于 HTTP 和 HTTPS 在 Nginx 代理中的处理方式不同。以下几点解释了为什么 …...

前端字段名和后端不一致?解锁 JSON 映射的“隐藏规则” !!!

&#x1f680; 前端字段名和后端不一致&#xff1f;解锁 JSON 映射的“隐藏规则” &#x1f31f; 嘿&#xff0c;技术冒险家们&#xff01;&#x1f44b; 今天我们要聊一个开发中常见的“坑”&#xff1a;前端传来的 JSON 参数字段名和后端对象字段名不一致&#xff0c;会发生…...

基于springboot的新闻推荐系统(045)

摘要 随着信息互联网购物的飞速发展&#xff0c;国内放开了自媒体的政策&#xff0c;一般企业都开始开发属于自己内容分发平台的网站。本文介绍了新闻推荐系统的开发全过程。通过分析企业对于新闻推荐系统的需求&#xff0c;创建了一个计算机管理新闻推荐系统的方案。文章介绍了…...

2024年数维杯数学建模C题天然气水合物资源量评价解题全过程论文及程序

2024年数维杯数学建模 C题 天然气水合物资源量评价 原题再现&#xff1a; 天然气水合物&#xff08;Natural Gas Hydrate/Gas Hydrate&#xff09;即可燃冰&#xff0c;是天然气与水在高压低温条件下形成的类冰状结晶物质&#xff0c;因其外观像冰&#xff0c;遇火即燃&#…...

Linux与HTTP中的Cookie和Session

HTTP中的Cookie和Session 本篇介绍 前面几篇已经基本介绍了HTTP协议的大部分内容&#xff0c;但是前面提到了一点「HTTP是无连接、无状态的协议」&#xff0c;那么到底有什么无连接以及什么是无状态。基于这两个问题&#xff0c;随后解释什么是Cookie和Session&#xff0c;以…...

linux 备份工具,常用的Linux备份工具及其备份数据的语法

在Linux系统中&#xff0c;备份数据是确保数据安全性和完整性的关键步骤。以下是一些常用的Linux备份工具及其备份数据的语法&#xff1a; 1. tar命令 tar命令是Linux系统中常用的打包和压缩工具&#xff0c;可以将多个文件或目录打包成一个文件&#xff0c;并可以选择添加压…...

C++核心语法快速整理

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文主要为学过多门语言玩家快速入门C 没有基础的就放弃吧。 全部都是精华&#xff0c;看完能直接上手改别人的项目。 输出内容 std::代表了这里的cout使用的标准库&#xff0c;避免不同库中的相同命名导致混乱 …...

STM32八股【3】------RAM和片上FLASH

1、RAM和FLASH构成 1.RAM ┌──────────────────────────┐ │ 栈区 (Stack) │ ← 从RAM顶端向下扩展&#xff08;存储局部变量、函数调用信息&#xff09; │--------------------------│ │ 堆区 (Heap) │ ← …...

使用HAI来打通DeepSeek的任督二脉

一、什么是HAI HAI是一款专注于AI与科学计算领域的云服务产品&#xff0c;旨在为开发者、企业及科研人员提供高效、易用的算力支持与全栈解决方案。主要使用场景为&#xff1a; AI作画&#xff0c;AI对话/写作、AI开发/测试。 二、开通HAI 选择CPU算力 16核32GB&#xff0c;这…...

深入理解Aider sends a repo map

你提到的这个链接&#xff08;https://aider.chat/2023/10/22/repomap.html&#xff09;是 Aider 的官方文档&#xff0c;介绍了一种叫做“Repo Map”&#xff08;仓库地图&#xff09;的功能。Aider 是一个 AI 编程辅助工具&#xff0c;主要通过与大语言模型&#xff08;如 GP…...

【day2】数据结构刷题 栈

一 有效的括号 给定一个只包括 (&#xff0c;)&#xff0c;{&#xff0c;}&#xff0c;[&#xff0c;] 的字符串 s &#xff0c;判断字符串是否有效。 有效字符串需满足&#xff1a; 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有一个对应的…...

第16章:基于CNN和Transformer对心脏左心室的实验分析及改进策略

目录 1. 项目需求 2. 网络选择 2.1 UNet模块 2.2 TransUnet 2.2.1 SE模块 2.2.2 CBAM 2.3 关键代码 3 对比试验 3.1 unet 3.2 transformerSE 3.3 transformerCBAM 4. 结果分析 5. 推理 6. 下载 1. 项目需求 本文需要做的工作是基于CNN和Transformer的心脏左心室…...

云上 Redis 迁移至本地机房

文章目录 摘要在 IDC 搭建读写分离 redis 集群一、环境准备二、部署主从架构1. 安装Redis2. 配置主节点3. 配置从节点4. 所有 Redis 节点设置开机自启动三、部署代理层(读写分离)1. 安装Twemproxy2. 配置Twemproxy3. 配置开机自启动四、高可用配置(哨兵模式)1. 配置哨兵节点…...

zabbix数据库溯源

0x00 背景 zabbix数据库如果密码泄露被登录并新增管理员如何快速发现&#xff1f;并进行溯源&#xff1f; 本文介绍数据库本身未开启access log的情况。 0x01 实践 Mysql 数据库查insert SELECT * FROM sys.host_summary_by_statement_type where statement like %insert% 查…...

ZYNQ的cache原理与一致性操作

在Xilinx Zynq SoC中&#xff0c;Cache管理是确保处理器与外部设备&#xff08;如FPGA逻辑、DMA控制器&#xff09;之间数据一致性的关键。Zynq的ARM Cortex-A9处理器包含L1 Cache&#xff08;指令/数据&#xff09;和L2 Cache&#xff0c;其刷新&#xff08;Flush/Invalidate&…...

React 中useMemo和useCallback Hook 的作用,在什么场景下使用它们?

大白话React 中useMemo和useCallback Hook 的作用&#xff0c;在什么场景下使用它们&#xff1f; 在 React 里&#xff0c;useMemo 和 useCallback 这两个 Hook 可有用啦&#xff0c;能帮咱优化组件性能&#xff0c;避免不必要的计算和渲染。下面咱就来详细聊聊它们的作用和使…...

Android笔记之项目引用第三方库(如:Github等)

前言&#xff1a;原生Android开发时引用github上的仓库内容&#xff0c;故出此文。 方式一&#xff1a;使用 JitPack&#xff08;推荐&#xff09; 步骤 1&#xff1a;在项目的 build.gradle 文件中添加 JitPack 仓库 打开项目根目录下的 build.gradle 文件&#xff0c;在 a…...

Linux 系统性能优化高级全流程指南

Linux 系统性能优化高级全流程指南 一、系统基础状态捕获 1. 系统信息建档 除了原有的硬件、内核和存储拓扑信息收集&#xff0c;还增加 CPU 缓存、网络设备详细信息等。 # 硬件信息 lscpu > /opt/tuning/lscpu.origin dmidecode -t memory > /opt/tuning/meminfo.or…...

SQL Server——表数据的插入、修改和删除

目录 一、引言 二、表数据的插入、修改和删除 &#xff08;一&#xff09;方法一&#xff1a;在SSMS控制台上进行操作 1.向表中添加数据 2.对表中的数据进行修改 3.对表中的数据进行删除 &#xff08;二&#xff09;方法二&#xff1a;使用 SQL 代码进行操作 1.向表中添…...

WPF 布局中的共性尺寸组(Shared Size Group)

1. 什么是共性尺寸组&#xff1f; 在 WPF 的 Grid 布局中&#xff0c;SharedSizeGroup 允许多个 Grid 共享同一列或行的尺寸&#xff0c;即使它们属于不同的 Grid 也能保持大小一致。这样可以保证界面元素的对齐性&#xff0c;提高布局的一致性。 SharedSizeGroup 主要用于需…...

deepSeek-SSE流式推送数据

1、背景 DeepSeek作为当前最火的AI大模型&#xff0c; 使用的时候用户在输入框输入问题&#xff0c;大模型进行思考回答你&#xff0c;然后会有一个逐步显示的过程效果&#xff0c;而不是一次性返回整个答案给前端页面进行展示&#xff0c;为了搞清楚其中的原理&#xff0c;我们…...

【北京迅为】iTOP-RK3568开发板OpenHarmony系统南向驱动开发UART接口运作机制

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…...

Leetcode 3495. Minimum Operations to Make Array Elements Zero

Leetcode 3495. Minimum Operations to Make Array Elements Zero 1. 解题思路2. 代码实现 题目链接&#xff1a;3495. Minimum Operations to Make Array Elements Zero 1. 解题思路 这一题的话核心就是统计对任意自然数 n n n&#xff0c;从 1 1 1到 n n n当中所有的数字对…...

C#实现自己的Json解析器(LALR(1)+miniDFA)

C#实现自己的Json解析器(LALR(1)miniDFA) Json是一个用处广泛、文法简单的数据格式。本文介绍如何用bitParser&#xff08;拥有自己的解析器&#xff08;C#实现LALR(1)语法解析器和miniDFA词法分析器的生成器&#xff09;迅速实现一个简单高效的Json解析器。 读者可在&#xf…...

机器学习——KNN数据均一化

在KNN&#xff08;K-近邻&#xff09;算法中&#xff0c;数据均一化&#xff08;归一化&#xff09;是预处理的关键步骤&#xff0c;用于消除不同特征量纲差异对距离计算的影响。以下是两种常用的归一化操作及其核心要点&#xff1a; 质押 一 、主要思想 1. 最值归一化&#…...

异步编程与流水线架构:从理论到高并发

目录 一、异步编程核心机制解析 1.1 同步与异步的本质区别 1.1.1 控制流模型 1.1.2 资源利用对比 1.2 阻塞与非阻塞的技术实现 1.2.1 阻塞I/O模型 1.2.2 非阻塞I/O模型 1.3 异步编程关键技术 1.3.1 事件循环机制 1.3.2 Future/Promise模式 1.3.3 协程&#xff08;Cor…...

哈尔滨工业大学DeepSeek公开课人工智能:大模型原理 技术与应用-从GPT到DeepSeek|附视频下载方法

导 读INTRODUCTION 今天继续哈尔滨工业大学车万翔教授带来了一场主题为“DeepSeek 技术前沿与应用”的报告。 本报告深入探讨了大语言模型在自然语言处理&#xff08;NLP&#xff09;领域的核心地位及其发展历程&#xff0c;从基础概念出发&#xff0c;延伸至语言模型在机器翻…...

制作Oracle11g Docker 镜像

基于Linux系统&#xff0c;宿主主机要设置如下环境变量&#xff0c;oracle为64位版本 dockerfile中需要的数据库安装包可从csdn下载内找到 #!/bin/bash # 在宿主机上运行以设置Oracle所需的内核参数 # 这些命令需要root权限cat > /etc/sysctl.d/99-oracle.conf << EO…...

Excel处理控件Spire.XLS系列教程:C# 在 Excel 中添加或删除单元格边框

单元格边框是指在单元格或单元格区域周围添加的线条。它们可用于不同的目的&#xff0c;如分隔工作表中的部分、吸引读者注意重要的单元格或使工作表看起来更美观。本文将介绍如何使用 Spire.XLS for .NET 在 C# 中添加或删除 Excel 单元格边框。 安装 Spire.XLS for .NET E-…...