当前位置: 首页 > news >正文

(3)深度学习学习笔记-简单线性模型

文章目录

  • 一、线性模型
  • 二、实例
    • 1.pytorch求导功能
    • 2.简单线性模型(人工数据集)
  • 来源


一、线性模型

一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。
那么房价y就可以认为y=w1x1+w2x2+w3x3+b,w为权重,b为偏差。
第一步
在这里插入图片描述
线性模型可以看做是单层(带权重的层是1层)神经网络。
在这里插入图片描述
第二步:
定义loss,衡量预估质量:真实值和预测值的差距
在这里插入图片描述
这里带1/2是方便求导的时候把2消去。
训练数据:收集数据来决定权重和偏差
训练损失:loss=1/n∑[(真实值-预测值(xi和权重的内积-偏差))平方]。目标是找到最小的loss
在这里插入图片描述
第三步:优化
优化方法:梯度下降。先挑选一个初值w0,之后不断更新w0使他接近最优解。更新方法是wt=wt-1 - 学习速率
梯度。
在这里插入图片描述
Learning rate不能太小(到达一个点要走很多步),也不能太大(一直震荡没有真的下降)
在这里插入图片描述
在这里插入图片描述
在整个训练集上梯度下降太贵,跑一次模型可能要数分钟/小时。所以采用小批量随机梯度下降,随机采样b个样本用这b个样本来近似损失。b不能太大也不能太小
在这里插入图片描述

二、实例

1.pytorch求导功能

代码如下:

# 自动求导
import torch
# 假设对函数y=2xT x关于列向量求导
x = torch.arange(4.0)
# 算y关于x的梯度之前,需要一个地方来存储梯度
x.requires_grad_(True)  # 等价于x=torch.arange(4.0,requires_grad=True)
print(x.grad)  # 默认值是None y关于x的导数存在这里y=2*torch.dot(x,x)
y.backward() # 求导
print(x.grad)
print(x.grad==4*x)# 在默认情况下,PyTorch会累积梯度,需要清除之前的值
x.grad.zero_()y = x.sum()
y.backward()
print(x.grad)x.grad.zero_()
y=x*x
u=y.detach()# 把y当成常数而不是x的函数
z=u*x
z.sum().backward()
print(x.grad==u)

2.简单线性模型(人工数据集)

代码如下:

# 构建人工数据集(好处是知道w和b)
# 根据w=[2,-3.4] b=4.2 和噪声生成数据集和标签 y=Xw+b+噪声
import numpy as np
import torch
from torch import nn
from torch.utils import data# 生成数据
def synthetic_data(w, b, num_examples):"""生成y=Xw+b+噪声"""X = np.random.normal(0, 1, (num_examples, len(w)))  # 均值为0,方差为1,num_ex个样本,列数=w的个数y = np.dot(X, w) + b  # y=Xw+by += np.random.normal(0, 0.01, y.shape)  # 加上随机噪音x1 = torch.tensor(X, dtype=torch.float32)  # 把np转化为torchy1 = torch.tensor(y, dtype=torch.float32)return x1, y1.reshape((-1, 1))  # 列向量反馈# 读取数据
def load_array(data_arrays, batch_size, is_train=True):"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)  # shuffle:是否需要随机打乱true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)batch_size = 10
data_iter = load_array((features, labels), batch_size)
# 使用iter构造Python迭代器,并使用next从迭代器中获取第一项。
print(next(iter(data_iter)))# 定义模型
# nn是神经网络的缩写
net = nn.Sequential(nn.Linear(2, 1))  # 输入维度是2 输出是1 sequential相当于一个list of layer# 初始化参数 net[0]访问这个layer
net[0].weight.data.normal_(0, 0.01)  # normal:使用正态分布替换weight的值,均值为0,方差为0.01
net[0].bias.data.fill_(0)  # bias直接设为0# 定义loss:mseloss类
loss = nn.MSELoss()# 定义优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)  # net.patameters():所有参数  ,lr:learning rate# 训练
num_epochs = 3
for epoch in range(num_epochs):  # 对所有数据扫一遍for X, y in data_iter:  # 拿出一个批量大小的x和yl = loss(net(X), y)  # x和y的小批量损失trainer.zero_grad()  # 梯度清零l.backward()  # 计算梯度trainer.step()  # 模型更新l = loss(net(features), labels)  # 计算损失print(f'epoch {epoch + 1}, loss {l:f}')

来源

b站 跟李沐学AI 动手学深度学习v2 08

相关文章:

(3)深度学习学习笔记-简单线性模型

文章目录 一、线性模型二、实例1.pytorch求导功能2.简单线性模型(人工数据集) 来源 一、线性模型 一个简单模型:假设一个房子的价格由卧室、卫生间、居住面积决定,用x1,x2,x3表示。 那么房价y就可以认为yw…...

pytorch3d 安装报错 RuntimeError: Not compiled with GPU support pytorch3d

安装环境 NVIDIA GeForce RTX 3090 cuda 11.3 python 3.8.5 torch 1.11.0 torchvision 0.12.0 环境安装命令 conda install pytorch1.11.0 torchvision0.12.0 torchaudio0.11.0 cudatoolkit11.3 -c pytorch安装pytorch3d参考官网链接 https://github.com/facebookresearch/p…...

spring工程的启动流程?bean的生命周期?提供哪些扩展点?管理事务?解决循环依赖问题的?事务传播行为有哪些?

1.Spring工程的启动流程: Spring工程的启动流程主要包括以下几个步骤: 加载配置文件:Spring会读取配置文件(如XML配置文件或注解配置)来获取应用程序的配置信息。实例化并初始化IoC容器:Spring会创建并初…...

使用 Zabbix 监控 RocketMQ列举监控项和触发器

在使用 Zabbix 监控 RocketMQ 的过程中,以下是一些可能的监控项和触发器: 监控项 集群总体健康状况生产者和消费者的连接数量Broker 的状态消息的生产和消费速度队列深度(即队列中的消息数量)磁盘空间使用内存使用CPU使用网络流…...

uniApp:路由与页面跳转及传参

方式一:声明式导航 声明式导航,通过组件进行跳转。官方文档:详情 使用 navigator 组件进行页面跳转。 属性类型默认值说明urlString应用内的跳转链接,值为相对路径或绝对路径,如:“…/first/first”&#x…...

Java中操作文件(二)

目录 一、什么是数据流 二、InputStream概述 2.1、方法 2.2、说明 三、FileInputStream概述 3.1、构造方法 3.2、利用Scanner进行字符串读取,简化操作 四、OutputStream概述 4.1、方法 4.2、PrinterWriter简化写操作 五、小程序练习 示例1 示例…...

springboot+vue在线考试系统(java项目源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的在线考试系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者:风歌&a…...

样式方案:在 Vite 中接入现代化的 CSS 工程化方案

上一小节,我们使用 Vite 初始化了一个 Web 项目,迈出了使用 Vite 的第一步。但在实际工作中,仅用 Vite 官方的脚手架项目是不够的,往往还需要考虑诸多的工程化因素,借助 Vite 本身的配置以及业界的各种生态&#xff0c…...

C#获取根目录实现方法汇总

以下是C#获取不同类型项目根目录的实现方法汇总,以及在 .NET Core 中获取项目根目录的方法: 控制台应用程序 string rootPath Environment.CurrentDirectory; string rootPath AppDomain.CurrentDomain.BaseDirectory; string rootPath Path.GetFul…...

vue获取当前坐标并通过天地图逆转码为省市区

因为需求需要获取用户当前的地理位置用于分析 通过原生的navigator.geolocation.getCurrentPosition获取经纬度 这个方法在谷歌浏览器会失效(原因未知),目前ie浏览器是可以获取的 getCurrentPosition() {if (navigator.geolocation) {var o…...

【MySQL】事务及其隔离性/隔离级别

目录 一、事务的概念 1、事务的四种特性 2、事务的作用 3、存储引擎对事务的支持 4、事务的提交方式 二、事务的启动、回滚与提交 1、准备工作:调整MySQL的默认隔离级别为最低/创建测试表 2、事务的启动、回滚与提交 3、启动事务后未commit,但是…...

计算机由于找不到d3dx9_35.dll,无法启动软件游戏的三个修复方法

在打开游戏的时候,计算机提示由于找不到d3dx9_35.dll,无法正常启动运行。这个是为什么呢?d3dx9_35.dll是DirectX 9.0里面的一个动态连结库文件,它包含了Direct3D、DirectPlay几个组件的二进制文件,为软件提供了多媒体图…...

第三章 模型篇:模型与模型的搭建

写在前面的话 这部分只解释代码,不对线性层(全连接层),卷积层等layer的原理进行解释。 尽量写的比较全了,但是自身水平有限,不太确定是否有遗漏重要的部分。 教程参考: https://pytorch.org/tutorials/ https://githu…...

深度学习一些简单概念的整理笔记

大概看了一点动手学深度学习,简单整理一些概念。 一些问题 测试结果 Precision-Recall曲线定性分析模型精度average precision(AP) 平均精度 Precision :检索出来的条目中有多大比例是我们需要的。 一些概念 损失函数(loss function&…...

Vue3中引入Element-plus

安装 npm install element-plus --save完整引入 打包后体积很大,适合学习,不适合生产。 此方法对于 vite 和 cli 脚手架创建的vue3均适用 // main.ts import { createApp } from vue //引入element-plus import ElementPlus from element-plus //引入…...

如何查看 Facebook 公共主页的广告数量上限?

作为Facebook的资深人员,了解如何查看公共主页的广告数量上限对于有效管理和优化广告策略至关重要。本文将详细介绍如何轻松查看Facebook公共主页的广告数量上限,以帮助您更好地掌握广告投放策略。 一、什么是Facebook公共主页的广告数量上限&#xff1f…...

U-Boot移植 (2)- LCD 驱动修改和网络驱动修改

文章目录 1. LCD 驱动修改1.1 修改c文件配置1.2 修改h文件配置1.3 编译测试 2. 网络驱动修改2.1 I.MX6U-ALPHA 开发板网络简介2.2 网络 PHY 地址修改2.3 删除 uboot 中 74LV595 的驱动代码2.4 添加开发板网络复位引脚驱动2.5 更新 PHY 的连接状态和速度2.6 烧写调试2.7 测试一下…...

Ubuntu 23.10 现在由Linux内核6.3提供支持

对于那些希望在Ubuntu上尝试最新的Linux 6.3内核系列的人来说,今天有一个好消息,因为即将发布的Ubuntu 23.10(Mantic Minotaur)已经重新基于Linux内核6.3。 Ubuntu 23.10的开发工作于4月底开始,基于目前的临时版本Ubu…...

Python 学习之NumPy(一)

文章目录 1.为什么要学习NumPy2.NumPy的数组变换以及索引访问3.NumPy筛选使用介绍筛选出上面nb数组中能被3整除的所有数筛选出数组中小于9的所有数提取出数组中所有的奇数数组中所有的奇数替换为-1二维数组交换2列生成数值5—10,shape 为(3,5)的二维随机浮点数 NumP…...

Nftables栈溢出漏洞(CVE-2022-1015)复现

背景介绍 Nftables Nftables 是一个基于内核的包过滤框架,用于 Linux 操作系统中的网络安全和防火墙功能。nftables 的设计目标是提供一种更简单、更灵活和更高效的方式来管理网络数据包的流量。 钩子点(Hook Point) 钩子点的作用是拦截数…...

【C++】 Qt-事件(上)(事件、重写事件、事件分发)

文章目录 事件重写事件事件分发 事件 事件(event)是由系统或Qt本身在不同的时刻发出的。比如,当用户按下鼠标,敲下键盘,或窗口需要重新绘制的时候,都会发出一个相应的事件。一些事件是在对用户操作做出响应…...

k8s部署springboot

前言 首先以SpringBoot应用为例介绍一下k8s的部署步骤。 1.从代码仓库下载代码,比如GitLab; 2.接着是进行打包,比如使用Maven; 3.编写Dockerfile文件,把步骤2产生的包制作成镜像; 4.上传步骤3的镜像到远程…...

备战秋招002(20230704)

文章目录 前言一、今天学习了什么?二、关于问题的答案1.线程池2.synchronized关键字3、volatile 总结 前言 提示:这里为每天自己的学习内容心情总结; Learn By Doing,Now or Never,Writing is organized thinking. …...

游泳买耳机买什么的比较好,列举几款实战性好的游泳耳机

对于运动用户来说,在运动时都会选择听一些节奏感比较强的音乐,让自己运动是更有活力。现在已经是三伏天中的前伏期间,不少人会选择在三伏天的日子里进行减肥瘦身,耳游泳已经成为很多人都首选运动,游泳是非常好的有氧运…...

【无线传感器】使用 MATLAB和 XBee连续监控温度传感器无线网络研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

Java基础-多线程JUC-生产者和消费者

1. 生产者与消费者 实现线程轮流交替执行的结果; 实现线程休眠和唤醒均要使用到锁对象; 修改标注位(foodFlag); 代码实现: public class demo11 {public static void main(String[] args) {/*** 需求&#…...

day2 QT按钮与容器

目录 按钮 1、QPushButton 2、QToolButton 3、QRadioButton 4、QCheckBox 示例 容器 ​编辑 1. QGroupBox(分组框) 2. QScrollArea(滚动区域) 3. QToolBox(工具箱) 4. QTabWidget(选…...

JPA 批量插入较大数据 解决性能慢问题

JPA 批量插入较大数据 解决性能慢问题 使用jpa saveAll接口的话需要了解原理&#xff1a; TransactionalOverridepublic <S extends T> List<S> saveAll(Iterable<S> entities) {Assert.notNull(entities, "Entities must not be null!");List<…...

为啥离不了 linux

Linux与Windows都是十分常见的电脑操作系统&#xff0c;相信你对它们二者都有所了解&#xff01;在你的使用过程中&#xff0c;是否有什么事让你觉得在Linux上顺理成章&#xff0c;换到Windows上就令你费解&#xff1f;亦或者关于这二者你有任何想要分享的&#xff0c;都可以在…...

基于分形的置乱算法和基于混沌系统的置乱算法哪种更安全?

在信息安全领域中&#xff0c;置乱算法是一种重要的加密手段&#xff0c;它可以将明文进行混淆和打乱&#xff0c;从而实现保密性和安全性。常见的置乱算法包括基于分形的置乱算法和基于混沌系统的置乱算法。下面将从理论和实践两方面&#xff0c;对这两种置乱算法进行比较和分…...