当前位置: 首页 > news >正文

(3)深度学习学习笔记-简单线性模型

文章目录

  • 一、线性模型
  • 二、实例
    • 1.pytorch求导功能
    • 2.简单线性模型(人工数据集)
  • 来源


一、线性模型

一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。
那么房价y就可以认为y=w1x1+w2x2+w3x3+b,w为权重,b为偏差。
第一步
在这里插入图片描述
线性模型可以看做是单层(带权重的层是1层)神经网络。
在这里插入图片描述
第二步:
定义loss,衡量预估质量:真实值和预测值的差距
在这里插入图片描述
这里带1/2是方便求导的时候把2消去。
训练数据:收集数据来决定权重和偏差
训练损失:loss=1/n∑[(真实值-预测值(xi和权重的内积-偏差))平方]。目标是找到最小的loss
在这里插入图片描述
第三步:优化
优化方法:梯度下降。先挑选一个初值w0,之后不断更新w0使他接近最优解。更新方法是wt=wt-1 - 学习速率
梯度。
在这里插入图片描述
Learning rate不能太小(到达一个点要走很多步),也不能太大(一直震荡没有真的下降)
在这里插入图片描述
在这里插入图片描述
在整个训练集上梯度下降太贵,跑一次模型可能要数分钟/小时。所以采用小批量随机梯度下降,随机采样b个样本用这b个样本来近似损失。b不能太大也不能太小
在这里插入图片描述

二、实例

1.pytorch求导功能

代码如下:

# 自动求导
import torch
# 假设对函数y=2xT x关于列向量求导
x = torch.arange(4.0)
# 算y关于x的梯度之前,需要一个地方来存储梯度
x.requires_grad_(True)  # 等价于x=torch.arange(4.0,requires_grad=True)
print(x.grad)  # 默认值是None y关于x的导数存在这里y=2*torch.dot(x,x)
y.backward() # 求导
print(x.grad)
print(x.grad==4*x)# 在默认情况下,PyTorch会累积梯度,需要清除之前的值
x.grad.zero_()y = x.sum()
y.backward()
print(x.grad)x.grad.zero_()
y=x*x
u=y.detach()# 把y当成常数而不是x的函数
z=u*x
z.sum().backward()
print(x.grad==u)

2.简单线性模型(人工数据集)

代码如下:

# 构建人工数据集(好处是知道w和b)
# 根据w=[2,-3.4] b=4.2 和噪声生成数据集和标签 y=Xw+b+噪声
import numpy as np
import torch
from torch import nn
from torch.utils import data# 生成数据
def synthetic_data(w, b, num_examples):"""生成y=Xw+b+噪声"""X = np.random.normal(0, 1, (num_examples, len(w)))  # 均值为0,方差为1,num_ex个样本,列数=w的个数y = np.dot(X, w) + b  # y=Xw+by += np.random.normal(0, 0.01, y.shape)  # 加上随机噪音x1 = torch.tensor(X, dtype=torch.float32)  # 把np转化为torchy1 = torch.tensor(y, dtype=torch.float32)return x1, y1.reshape((-1, 1))  # 列向量反馈# 读取数据
def load_array(data_arrays, batch_size, is_train=True):"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)  # shuffle:是否需要随机打乱true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)batch_size = 10
data_iter = load_array((features, labels), batch_size)
# 使用iter构造Python迭代器,并使用next从迭代器中获取第一项。
print(next(iter(data_iter)))# 定义模型
# nn是神经网络的缩写
net = nn.Sequential(nn.Linear(2, 1))  # 输入维度是2 输出是1 sequential相当于一个list of layer# 初始化参数 net[0]访问这个layer
net[0].weight.data.normal_(0, 0.01)  # normal:使用正态分布替换weight的值,均值为0,方差为0.01
net[0].bias.data.fill_(0)  # bias直接设为0# 定义loss:mseloss类
loss = nn.MSELoss()# 定义优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)  # net.patameters():所有参数  ,lr:learning rate# 训练
num_epochs = 3
for epoch in range(num_epochs):  # 对所有数据扫一遍for X, y in data_iter:  # 拿出一个批量大小的x和yl = loss(net(X), y)  # x和y的小批量损失trainer.zero_grad()  # 梯度清零l.backward()  # 计算梯度trainer.step()  # 模型更新l = loss(net(features), labels)  # 计算损失print(f'epoch {epoch + 1}, loss {l:f}')

来源

b站 跟李沐学AI 动手学深度学习v2 08

相关文章:

(3)深度学习学习笔记-简单线性模型

文章目录 一、线性模型二、实例1.pytorch求导功能2.简单线性模型(人工数据集) 来源 一、线性模型 一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。 那么房价y就可以认为yw…...

pytorch3d 安装报错 RuntimeError: Not compiled with GPU support pytorch3d

安装环境 NVIDIA GeForce RTX 3090 cuda 11.3 python 3.8.5 torch 1.11.0 torchvision 0.12.0 环境安装命令 conda install pytorch1.11.0 torchvision0.12.0 torchaudio0.11.0 cudatoolkit11.3 -c pytorch安装pytorch3d参考官网链接 https://github.com/facebookresearch/p…...

spring工程的启动流程?bean的生命周期?提供哪些扩展点?管理事务?解决循环依赖问题的?事务传播行为有哪些?

1.Spring工程的启动流程: Spring工程的启动流程主要包括以下几个步骤: 加载配置文件:Spring会读取配置文件(如XML配置文件或注解配置)来获取应用程序的配置信息。实例化并初始化IoC容器:Spring会创建并初…...

使用 Zabbix 监控 RocketMQ列举监控项和触发器

在使用 Zabbix 监控 RocketMQ 的过程中,以下是一些可能的监控项和触发器: 监控项 集群总体健康状况生产者和消费者的连接数量Broker 的状态消息的生产和消费速度队列深度(即队列中的消息数量)磁盘空间使用内存使用CPU使用网络流…...

uniApp:路由与页面跳转及传参

方式一:声明式导航 声明式导航,通过组件进行跳转。官方文档:详情 使用 navigator 组件进行页面跳转。 属性类型默认值说明urlString应用内的跳转链接,值为相对路径或绝对路径,如:“…/first/first”&#x…...

Java中操作文件(二)

目录 一、什么是数据流 二、InputStream概述 2.1、方法 2.2、说明 三、FileInputStream概述 3.1、构造方法 3.2、利用Scanner进行字符串读取,简化操作 四、OutputStream概述 4.1、方法 4.2、PrinterWriter简化写操作 五、小程序练习 示例1 示例…...

springboot+vue在线考试系统(java项目源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的在线考试系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者:风歌&a…...

样式方案:在 Vite 中接入现代化的 CSS 工程化方案

上一小节,我们使用 Vite 初始化了一个 Web 项目,迈出了使用 Vite 的第一步。但在实际工作中,仅用 Vite 官方的脚手架项目是不够的,往往还需要考虑诸多的工程化因素,借助 Vite 本身的配置以及业界的各种生态&#xff0c…...

C#获取根目录实现方法汇总

以下是C#获取不同类型项目根目录的实现方法汇总,以及在 .NET Core 中获取项目根目录的方法: 控制台应用程序 string rootPath Environment.CurrentDirectory; string rootPath AppDomain.CurrentDomain.BaseDirectory; string rootPath Path.GetFul…...

vue获取当前坐标并通过天地图逆转码为省市区

因为需求需要获取用户当前的地理位置用于分析 通过原生的navigator.geolocation.getCurrentPosition获取经纬度 这个方法在谷歌浏览器会失效(原因未知),目前ie浏览器是可以获取的 getCurrentPosition() {if (navigator.geolocation) {var o…...

【MySQL】事务及其隔离性/隔离级别

目录 一、事务的概念 1、事务的四种特性 2、事务的作用 3、存储引擎对事务的支持 4、事务的提交方式 二、事务的启动、回滚与提交 1、准备工作:调整MySQL的默认隔离级别为最低/创建测试表 2、事务的启动、回滚与提交 3、启动事务后未commit,但是…...

计算机由于找不到d3dx9_35.dll,无法启动软件游戏的三个修复方法

在打开游戏的时候,计算机提示由于找不到d3dx9_35.dll,无法正常启动运行。这个是为什么呢?d3dx9_35.dll是DirectX 9.0里面的一个动态连结库文件,它包含了Direct3D、DirectPlay几个组件的二进制文件,为软件提供了多媒体图…...

第三章 模型篇:模型与模型的搭建

写在前面的话 这部分只解释代码,不对线性层(全连接层),卷积层等layer的原理进行解释。 尽量写的比较全了,但是自身水平有限,不太确定是否有遗漏重要的部分。 教程参考: https://pytorch.org/tutorials/ https://githu…...

深度学习一些简单概念的整理笔记

大概看了一点动手学深度学习,简单整理一些概念。 一些问题 测试结果 Precision-Recall曲线定性分析模型精度average precision(AP) 平均精度 Precision :检索出来的条目中有多大比例是我们需要的。 一些概念 损失函数(loss function&…...

Vue3中引入Element-plus

安装 npm install element-plus --save完整引入 打包后体积很大,适合学习,不适合生产。 此方法对于 vite 和 cli 脚手架创建的vue3均适用 // main.ts import { createApp } from vue //引入element-plus import ElementPlus from element-plus //引入…...

如何查看 Facebook 公共主页的广告数量上限?

作为Facebook的资深人员,了解如何查看公共主页的广告数量上限对于有效管理和优化广告策略至关重要。本文将详细介绍如何轻松查看Facebook公共主页的广告数量上限,以帮助您更好地掌握广告投放策略。 一、什么是Facebook公共主页的广告数量上限&#xff1f…...

U-Boot移植 (2)- LCD 驱动修改和网络驱动修改

文章目录 1. LCD 驱动修改1.1 修改c文件配置1.2 修改h文件配置1.3 编译测试 2. 网络驱动修改2.1 I.MX6U-ALPHA 开发板网络简介2.2 网络 PHY 地址修改2.3 删除 uboot 中 74LV595 的驱动代码2.4 添加开发板网络复位引脚驱动2.5 更新 PHY 的连接状态和速度2.6 烧写调试2.7 测试一下…...

Ubuntu 23.10 现在由Linux内核6.3提供支持

对于那些希望在Ubuntu上尝试最新的Linux 6.3内核系列的人来说,今天有一个好消息,因为即将发布的Ubuntu 23.10(Mantic Minotaur)已经重新基于Linux内核6.3。 Ubuntu 23.10的开发工作于4月底开始,基于目前的临时版本Ubu…...

Python 学习之NumPy(一)

文章目录 1.为什么要学习NumPy2.NumPy的数组变换以及索引访问3.NumPy筛选使用介绍筛选出上面nb数组中能被3整除的所有数筛选出数组中小于9的所有数提取出数组中所有的奇数数组中所有的奇数替换为-1二维数组交换2列生成数值5—10,shape 为(3,5)的二维随机浮点数 NumP…...

Nftables栈溢出漏洞(CVE-2022-1015)复现

背景介绍 Nftables Nftables 是一个基于内核的包过滤框架,用于 Linux 操作系统中的网络安全和防火墙功能。nftables 的设计目标是提供一种更简单、更灵活和更高效的方式来管理网络数据包的流量。 钩子点(Hook Point) 钩子点的作用是拦截数…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框,很难让人不联想到SQL注入,但提示都说了不是SQL注入,所以就不往这方面想了 ​ 先查看一下网页源码,发现一段JavaScript代码,有一个关键类ctfs…...

【Oracle】分区表

个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...

Kafka入门-生产者

生产者 生产者发送流程: 延迟时间为0ms时,也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于:异步发送不需要等待结果,同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...

热烈祝贺埃文科技正式加入可信数据空间发展联盟

2025年4月29日,在福州举办的第八届数字中国建设峰会“可信数据空间分论坛”上,可信数据空间发展联盟正式宣告成立。国家数据局党组书记、局长刘烈宏出席并致辞,强调该联盟是推进全国一体化数据市场建设的关键抓手。 郑州埃文科技有限公司&am…...

前端开发者常用网站

Can I use网站:一个查询网页技术兼容性的网站 一个查询网页技术兼容性的网站Can I use:Can I use... Support tables for HTML5, CSS3, etc (查询浏览器对HTML5的支持情况) 权威网站:MDN JavaScript权威网站:JavaScript | MDN...

数据结构第5章:树和二叉树完全指南(自整理详细图文笔记)

名人说:莫道桑榆晚,为霞尚满天。——刘禹锡(刘梦得,诗豪) 原创笔记:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 上一篇:《数据结构第4章 数组和广义表》…...

命令行关闭Windows防火墙

命令行关闭Windows防火墙 引言一、防火墙:被低估的"智能安检员"二、优先尝试!90%问题无需关闭防火墙方案1:程序白名单(解决软件误拦截)方案2:开放特定端口(解决网游/开发端口不通)三、命令行极速关闭方案方法一:PowerShell(推荐Win10/11)​方法二:CMD命令…...