当前位置: 首页 > news >正文

动手学深度学习(Pytorch版)代码实践 -深度学习基础-02线性回归基础版

02线性回归基础版

主要内容

  1. 数据生成:使用线性模型 ( y = X*w + b ) 加上噪声生成人造数据集。
  2. 数据读取:通过小批量读取数据集来实现批量梯度下降,打乱数据顺序并逐批返回特征和标签。
  3. 模型参数初始化:随机初始化权重和偏置,并设置为可计算梯度。
  4. 模型定义:实现线性回归模型 ( y = X*w + b )。
  5. 损失函数:实现均方误差损失函数。
  6. 优化函数:实现小批量随机梯度下降用于更新模型参数。
  7. 模型训练:设定学习率和迭代次数,通过每个批量计算损失、反向传播和参数更新。
import random
import torch# 生成数据集
def synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声"""# torch.normal: 返回一个从均值为0,标准差为1的正态分布中提取的随机数的张量# 生成形状为(num_examples, len(w))的矩阵X = torch.normal(0, 1, (num_examples, len(w)))# torch.matmul: 矩阵乘法y = torch.matmul(X, w) + b# 添加噪声:torch.normal(0, 0.01, y.shape)y += torch.normal(0, 0.01, y.shape)# reshape: 只改变张量的视图,不改变数据,将y转换为列向量return X, y.reshape((-1, 1))# 定义真实的权重和偏置
true_w = torch.tensor([2, -3.4])
true_b = 4.2
# 生成特征和标签
features, labels = synthetic_data(true_w, true_b, 1000)# 读取数据集
def data_iter(batch_size, features, labels):num_examples = len(features)# 生成一个从0到num_examples-1的整数列表indices = list(range(num_examples))# 将列表的次序打乱random.shuffle(indices)# 每次迭代生成一个小批量数据for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]# 设置批量大小
batch_size = 10# 初始化模型参数 
# 随机初始化权重,设置requires_grad=True以计算梯度
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 初始化偏置为0,设置requires_grad=True以计算梯度
b = torch.zeros(1, requires_grad=True)  # 定义模型
def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b# 定义损失函数
def squared_loss(y_hat, y):"""均方损失函数"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 定义优化函数
def sgd(params, lr, batch_size):"""小批量随机梯度下降"""# 更新参数时不需要计算梯度with torch.no_grad():for param in params:param -= lr * param.grad / batch_size  # 参数更新param.grad.zero_()  # 梯度清零# 模型训练
lr = 0.03  # 学习率
num_epochs = 5  # 迭代周期数
net = linreg  # 线性回归模型
loss = squared_loss  # 损失函数# 开始训练
for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # 计算小批量数据的损失l.sum().backward()  # 计算梯度sgd([w, b], lr, batch_size)  # 更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)  # 计算整个数据集上的损失print(f'第{epoch + 1}轮,损失: {float(train_l.mean()):f}')# 打印权重和偏置的估计误差
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')# 示例输出:
# 第1轮,损失: 0.036624
# 第2轮,损失: 0.000131
# 第3轮,损失: 0.000052
# 第4轮,损失: 0.000052
# 第5轮,损失: 0.000052
# w的估计误差: tensor([-0.0003, -0.0008], grad_fn=<SubBackward0>)
# b的估计误差: tensor([0.0007], grad_fn=<RsubBackward1>)

相关文章:

动手学深度学习(Pytorch版)代码实践 -深度学习基础-02线性回归基础版

02线性回归基础版 主要内容 数据生成&#xff1a;使用线性模型 ( y X*w b ) 加上噪声生成人造数据集。数据读取&#xff1a;通过小批量读取数据集来实现批量梯度下降&#xff0c;打乱数据顺序并逐批返回特征和标签。模型参数初始化&#xff1a;随机初始化权重和偏置&#x…...

信息学奥赛初赛天天练-15-阅读程序-深入解析二进制原码、反码、补码,位运算技巧,以及lowbit的神奇应用

更多资源请关注纽扣编程微信公众号 1 2021 CSP-J 阅读程序1 阅读程序&#xff08;程序输入不超过数组或字符串定义的范围&#xff1b;判断题正确填 √&#xff0c;错误填&#xff1b;除特 殊说明外&#xff0c;判断题 1.5 分&#xff0c;选择题 3 分&#xff09; 源码 #in…...

期权具体怎么交易详细的操作流程?

期权就是股票&#xff0c;唯一区别标的物上证指数&#xff0c;会看大盘吧&#xff0c;交易两个方向认购做多&#xff0c;认沽做空&#xff0c;双向t0交易&#xff0c;期权具体交易流程可以理解选择方向多和空&#xff0c;选开仓的合约&#xff0c;买入开仓和平仓没了&#xff0…...

系统架构设计师【第3章】: 信息系统基础知识 (核心总结)

文章目录 3.1 信息系统概述3.1.1 信息系统的定义3.1.2 信息系统的发展3.1.3 信息系统的分类3.1.4 信息系统的生命周期3.1.5 信息系统建设原则3.1.6 信息系统开发方法 3.2 业务处理系统&#xff08;TPS&#xff09;3.2.1 业务处理系统的概念3.2.2 业务处理系统的功能 …...

Linux 驱动设备匹配过程

一、Linux 驱动-总线-设备模型 1、驱动分层 Linux内核需要兼容多个平台&#xff0c;不同平台的寄存器设计不同导致操作方法不同&#xff0c;故内核提出分层思想&#xff0c;抽象出与硬件无关的软件层作为核心层来管理下层驱动&#xff0c;各厂商根据自己的硬件编写驱动…...

游戏子弹类python设计与实现详解

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、子弹类设计思路 1. 属性定义 2. 方法设计 三、子弹类实现详解 1. 定义子弹…...

Python基础学习笔记(六)——列表

目录 一、一维列表的介绍和创建二、序列的基本操作1. 索引的查询与返回2. 切片3. 序列加 三、元素的增删改1. 添加元素2. 删除元素3. 更改元素 四、排序五、列表生成式 一、一维列表的介绍和创建 列表&#xff08;list&#xff09;&#xff0c;也称数组&#xff0c;是一种有序、…...

帝国CMS跳过选择会员类型直接注册方法

国CMS因允许多用户组注册&#xff0c;所以在注册页面会有一个选择注册用户组的界面&#xff0c;即使网站只用了一个用户组也会出现。 如果想去掉这个页面&#xff0c;直接进入注册页面&#xff0c;那么可按以下办法修改 打开 e/class/user.php 文件 查找&#xff1a; $chan…...

【python】python tkinter 计算器GUI版本(模仿windows计算器 源码)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…...

黑马es数据同步mq解决方案

方式一:同步调用 优点:实现简单&#xff0c;粗暴 缺点:业务耦合度高 方式二:异步通知 优点:低耦含&#xff0c;实现难度一般 缺点:依赖mq的可靠性 方式三:监听binlog 优点:完全解除服务间耦合 缺点:开启binlog增加数据库负担、实现复杂度高 利用MQ实现mysql与elastics…...

通过LLM多轮对话生成单元测试用例

通过LLM多轮对话生成单元测试用例 代码 在采用 随机生成pytorch算子测试序列且保证算子参数合法 这种方法之前,曾通过本文的方法生成算子组合测试用例。目前所测LLM生成的代码均会出现BUG,且多次交互后仍不能解决.也许随着LLM的更新,这个问题会得到解决.记录备用。 代码 impo…...

[Redis]String类型

基本命令 set命令 将 string 类型的 value 设置到 key 中。如果 key 之前存在&#xff0c;则覆盖&#xff0c;无论原来的数据类型是什么。之前关于此 key 的 TTL 也全部失效。 set key value [expiration EX seconds|PX milliseconds] [NX|XX] 选项[EX|PX] EX seconds⸺使用…...

Ai速递5.29

全球AI新闻速递 1.摩尔线程与无问芯穹合作&#xff0c;实现国产 GPU 端到端 AI 大模型实训。 2.宝马工厂&#xff1a;机器狗上岗&#xff0c;可“嗅探”故障隐患。 3.ChatGPT&#xff1a;macOS 开始公测。 4.Stability AI&#xff1a;推出Stable Assistant&#xff0c;可用S…...

Android9.0 MTK平台如何增加一个系统应用

在安卓定制化开发过程中&#xff0c;难免遇到要把自己的app预置到系统中&#xff0c;作为系统应用使用&#xff0c;其实方法有很多&#xff0c;过程很简单&#xff0c;今天分享一下我是怎么做的&#xff0c;共总分两步&#xff1a; 第一步&#xff1a;要找到当前系统应用apk存…...

LabVIEW中实现Trio控制器的以太网通讯

在LabVIEW中实现与Trio控制器的以太网通讯&#xff0c;可以通过使用TCP/IP协议来完成。这种方法包括配置Trio控制器的网络设置、使用LabVIEW中的TCP/IP函数库进行数据传输和接收&#xff0c;以及处理通讯中的错误和数据解析。本文将详细说明实现步骤&#xff0c;包括配置、编程…...

C/C++运行时库与 UCRT 通用运行时库:全面总结与问题实例剖析

推荐一个AI网站&#xff0c;免费使用豆包AI模型&#xff0c;快去白嫖&#x1f449;海鲸AI 1. 概述 在开发C/C应用程序时&#xff0c;运行时库&#xff08;Runtime Library&#xff09;是不可或缺的一部分。它们提供了一系列函数和功能&#xff0c;使得开发者能够更方便地进行编…...

【Python001】python批量下载、插入与读取Oracle中图片数据(已更新)

1.熟悉、梳理、总结数据分析实战中的python、oracle研发知识体系 2.欢迎点赞、关注、批评、指正,互三走起来,小手动起来! 文章目录 1.背景说明2.环境搭建2.1 参考链接2.2 `oracle`查询测试代码3.数据请求与插入3.1 `Oracle`建表语句3.2 `Python`代码实现3.3 效果示例4.问题链…...

流形学习(Manifold Learning)

基本概念 Manifold Learning&#xff08;流形学习&#xff09;是一种机器学习和数据分析的方法&#xff0c;它专注于从高维数据中发现低维的非线性结构。流形学习的基本假设是&#xff0c;尽管数据可能在高维空间中呈现&#xff0c;但它们实际上分布在一个低维的流形上。这个流…...

区块链技术和应用

文章目录 前言 一、区块链是什么&#xff1f; 二、区块链核心数据结构 2.1 交易 2.2 区块 三、交易 3.1 交易的生命周期 3.2 节点类型 3.3 分布式系统 3.4 节点数据库 3.5 智能合约 3.6 多个记账节点-去中心化 3.7 双花问题 3.8 共识算法 3.8.1 POW工作量证明 总结 前言 学习长…...

Docker拉取镜像报错:x509: certificate has expired or is not yet v..

太久没有使用docker进行镜像拉取&#xff0c;今天使用docker-compose拉取mongo发现报错&#xff08;如下图&#xff09;&#xff1a; 报错信息翻译&#xff1a;证书已过期或尚未有效。 解决办法&#xff1a; 1.一般都是证书问题或者系统时间问题导致&#xff0c;可以先执行 da…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展&#xff1a;显示创建时间8. 功能扩展&#xff1a;记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案

问题描述&#xff1a;iview使用table 中type: "index",分页之后 &#xff0c;索引还是从1开始&#xff0c;试过绑定后台返回数据的id, 这种方法可行&#xff0c;就是后台返回数据的每个页面id都不完全是按照从1开始的升序&#xff0c;因此百度了下&#xff0c;找到了…...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码&#xff0c;因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存&#xff0c;无论是github还是gittee&#xff0c;都是一种基于git去保存代码的形式&#xff0c;这样保存代码…...

离线语音识别方案分析

随着人工智能技术的不断发展&#xff0c;语音识别技术也得到了广泛的应用&#xff0c;从智能家居到车载系统&#xff0c;语音识别正在改变我们与设备的交互方式。尤其是离线语音识别&#xff0c;由于其在没有网络连接的情况下仍然能提供稳定、准确的语音处理能力&#xff0c;广…...

Unity中的transform.up

2025年6月8日&#xff0c;周日下午 在Unity中&#xff0c;transform.up是Transform组件的一个属性&#xff0c;表示游戏对象在世界空间中的“上”方向&#xff08;Y轴正方向&#xff09;&#xff0c;且会随对象旋转动态变化。以下是关键点解析&#xff1a; 基本定义 transfor…...

rknn toolkit2搭建和推理

安装Miniconda Miniconda - Anaconda Miniconda 选择一个 新的 版本 &#xff0c;不用和RKNN的python版本保持一致 使用 ./xxx.sh进行安装 下面配置一下载源 # 清华大学源&#xff08;最常用&#xff09; conda config --add channels https://mirrors.tuna.tsinghua.edu.cn…...

前端工具库lodash与lodash-es区别详解

lodash 和 lodash-es 是同一工具库的两个不同版本&#xff0c;核心功能完全一致&#xff0c;主要区别在于模块化格式和优化方式&#xff0c;适合不同的开发环境。以下是详细对比&#xff1a; 1. 模块化格式 lodash 使用 CommonJS 模块格式&#xff08;require/module.exports&a…...