动手学深度学习(Pytorch版)代码实践 -深度学习基础-02线性回归基础版
02线性回归基础版
主要内容
- 数据生成:使用线性模型 ( y = X*w + b ) 加上噪声生成人造数据集。
- 数据读取:通过小批量读取数据集来实现批量梯度下降,打乱数据顺序并逐批返回特征和标签。
- 模型参数初始化:随机初始化权重和偏置,并设置为可计算梯度。
- 模型定义:实现线性回归模型 ( y = X*w + b )。
- 损失函数:实现均方误差损失函数。
- 优化函数:实现小批量随机梯度下降用于更新模型参数。
- 模型训练:设定学习率和迭代次数,通过每个批量计算损失、反向传播和参数更新。
import random
import torch# 生成数据集
def synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声"""# torch.normal: 返回一个从均值为0,标准差为1的正态分布中提取的随机数的张量# 生成形状为(num_examples, len(w))的矩阵X = torch.normal(0, 1, (num_examples, len(w)))# torch.matmul: 矩阵乘法y = torch.matmul(X, w) + b# 添加噪声:torch.normal(0, 0.01, y.shape)y += torch.normal(0, 0.01, y.shape)# reshape: 只改变张量的视图,不改变数据,将y转换为列向量return X, y.reshape((-1, 1))# 定义真实的权重和偏置
true_w = torch.tensor([2, -3.4])
true_b = 4.2
# 生成特征和标签
features, labels = synthetic_data(true_w, true_b, 1000)# 读取数据集
def data_iter(batch_size, features, labels):num_examples = len(features)# 生成一个从0到num_examples-1的整数列表indices = list(range(num_examples))# 将列表的次序打乱random.shuffle(indices)# 每次迭代生成一个小批量数据for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]# 设置批量大小
batch_size = 10# 初始化模型参数
# 随机初始化权重,设置requires_grad=True以计算梯度
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 初始化偏置为0,设置requires_grad=True以计算梯度
b = torch.zeros(1, requires_grad=True) # 定义模型
def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b# 定义损失函数
def squared_loss(y_hat, y):"""均方损失函数"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 定义优化函数
def sgd(params, lr, batch_size):"""小批量随机梯度下降"""# 更新参数时不需要计算梯度with torch.no_grad():for param in params:param -= lr * param.grad / batch_size # 参数更新param.grad.zero_() # 梯度清零# 模型训练
lr = 0.03 # 学习率
num_epochs = 5 # 迭代周期数
net = linreg # 线性回归模型
loss = squared_loss # 损失函数# 开始训练
for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y) # 计算小批量数据的损失l.sum().backward() # 计算梯度sgd([w, b], lr, batch_size) # 更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels) # 计算整个数据集上的损失print(f'第{epoch + 1}轮,损失: {float(train_l.mean()):f}')# 打印权重和偏置的估计误差
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')# 示例输出:
# 第1轮,损失: 0.036624
# 第2轮,损失: 0.000131
# 第3轮,损失: 0.000052
# 第4轮,损失: 0.000052
# 第5轮,损失: 0.000052
# w的估计误差: tensor([-0.0003, -0.0008], grad_fn=<SubBackward0>)
# b的估计误差: tensor([0.0007], grad_fn=<RsubBackward1>)
相关文章:
动手学深度学习(Pytorch版)代码实践 -深度学习基础-02线性回归基础版
02线性回归基础版 主要内容 数据生成:使用线性模型 ( y X*w b ) 加上噪声生成人造数据集。数据读取:通过小批量读取数据集来实现批量梯度下降,打乱数据顺序并逐批返回特征和标签。模型参数初始化:随机初始化权重和偏置&#x…...

信息学奥赛初赛天天练-15-阅读程序-深入解析二进制原码、反码、补码,位运算技巧,以及lowbit的神奇应用
更多资源请关注纽扣编程微信公众号 1 2021 CSP-J 阅读程序1 阅读程序(程序输入不超过数组或字符串定义的范围;判断题正确填 √,错误填;除特 殊说明外,判断题 1.5 分,选择题 3 分) 源码 #in…...

期权具体怎么交易详细的操作流程?
期权就是股票,唯一区别标的物上证指数,会看大盘吧,交易两个方向认购做多,认沽做空,双向t0交易,期权具体交易流程可以理解选择方向多和空,选开仓的合约,买入开仓和平仓没了࿰…...

系统架构设计师【第3章】: 信息系统基础知识 (核心总结)
文章目录 3.1 信息系统概述3.1.1 信息系统的定义3.1.2 信息系统的发展3.1.3 信息系统的分类3.1.4 信息系统的生命周期3.1.5 信息系统建设原则3.1.6 信息系统开发方法 3.2 业务处理系统(TPS)3.2.1 业务处理系统的概念3.2.2 业务处理系统的功能 …...

Linux 驱动设备匹配过程
一、Linux 驱动-总线-设备模型 1、驱动分层 Linux内核需要兼容多个平台,不同平台的寄存器设计不同导致操作方法不同,故内核提出分层思想,抽象出与硬件无关的软件层作为核心层来管理下层驱动,各厂商根据自己的硬件编写驱动…...

游戏子弹类python设计与实现详解
新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言 二、子弹类设计思路 1. 属性定义 2. 方法设计 三、子弹类实现详解 1. 定义子弹…...
Python基础学习笔记(六)——列表
目录 一、一维列表的介绍和创建二、序列的基本操作1. 索引的查询与返回2. 切片3. 序列加 三、元素的增删改1. 添加元素2. 删除元素3. 更改元素 四、排序五、列表生成式 一、一维列表的介绍和创建 列表(list),也称数组,是一种有序、…...
帝国CMS跳过选择会员类型直接注册方法
国CMS因允许多用户组注册,所以在注册页面会有一个选择注册用户组的界面,即使网站只用了一个用户组也会出现。 如果想去掉这个页面,直接进入注册页面,那么可按以下办法修改 打开 e/class/user.php 文件 查找: $chan…...

【python】python tkinter 计算器GUI版本(模仿windows计算器 源码)【独一无二】
👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…...
黑马es数据同步mq解决方案
方式一:同步调用 优点:实现简单,粗暴 缺点:业务耦合度高 方式二:异步通知 优点:低耦含,实现难度一般 缺点:依赖mq的可靠性 方式三:监听binlog 优点:完全解除服务间耦合 缺点:开启binlog增加数据库负担、实现复杂度高 利用MQ实现mysql与elastics…...
通过LLM多轮对话生成单元测试用例
通过LLM多轮对话生成单元测试用例 代码 在采用 随机生成pytorch算子测试序列且保证算子参数合法 这种方法之前,曾通过本文的方法生成算子组合测试用例。目前所测LLM生成的代码均会出现BUG,且多次交互后仍不能解决.也许随着LLM的更新,这个问题会得到解决.记录备用。 代码 impo…...

[Redis]String类型
基本命令 set命令 将 string 类型的 value 设置到 key 中。如果 key 之前存在,则覆盖,无论原来的数据类型是什么。之前关于此 key 的 TTL 也全部失效。 set key value [expiration EX seconds|PX milliseconds] [NX|XX] 选项[EX|PX] EX seconds⸺使用…...

Ai速递5.29
全球AI新闻速递 1.摩尔线程与无问芯穹合作,实现国产 GPU 端到端 AI 大模型实训。 2.宝马工厂:机器狗上岗,可“嗅探”故障隐患。 3.ChatGPT:macOS 开始公测。 4.Stability AI:推出Stable Assistant,可用S…...

Android9.0 MTK平台如何增加一个系统应用
在安卓定制化开发过程中,难免遇到要把自己的app预置到系统中,作为系统应用使用,其实方法有很多,过程很简单,今天分享一下我是怎么做的,共总分两步: 第一步:要找到当前系统应用apk存…...

LabVIEW中实现Trio控制器的以太网通讯
在LabVIEW中实现与Trio控制器的以太网通讯,可以通过使用TCP/IP协议来完成。这种方法包括配置Trio控制器的网络设置、使用LabVIEW中的TCP/IP函数库进行数据传输和接收,以及处理通讯中的错误和数据解析。本文将详细说明实现步骤,包括配置、编程…...
C/C++运行时库与 UCRT 通用运行时库:全面总结与问题实例剖析
推荐一个AI网站,免费使用豆包AI模型,快去白嫖👉海鲸AI 1. 概述 在开发C/C应用程序时,运行时库(Runtime Library)是不可或缺的一部分。它们提供了一系列函数和功能,使得开发者能够更方便地进行编…...

【Python001】python批量下载、插入与读取Oracle中图片数据(已更新)
1.熟悉、梳理、总结数据分析实战中的python、oracle研发知识体系 2.欢迎点赞、关注、批评、指正,互三走起来,小手动起来! 文章目录 1.背景说明2.环境搭建2.1 参考链接2.2 `oracle`查询测试代码3.数据请求与插入3.1 `Oracle`建表语句3.2 `Python`代码实现3.3 效果示例4.问题链…...
流形学习(Manifold Learning)
基本概念 Manifold Learning(流形学习)是一种机器学习和数据分析的方法,它专注于从高维数据中发现低维的非线性结构。流形学习的基本假设是,尽管数据可能在高维空间中呈现,但它们实际上分布在一个低维的流形上。这个流…...

区块链技术和应用
文章目录 前言 一、区块链是什么? 二、区块链核心数据结构 2.1 交易 2.2 区块 三、交易 3.1 交易的生命周期 3.2 节点类型 3.3 分布式系统 3.4 节点数据库 3.5 智能合约 3.6 多个记账节点-去中心化 3.7 双花问题 3.8 共识算法 3.8.1 POW工作量证明 总结 前言 学习长…...

Docker拉取镜像报错:x509: certificate has expired or is not yet v..
太久没有使用docker进行镜像拉取,今天使用docker-compose拉取mongo发现报错(如下图): 报错信息翻译:证书已过期或尚未有效。 解决办法: 1.一般都是证书问题或者系统时间问题导致,可以先执行 da…...

UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
Matlab | matlab常用命令总结
常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...
【python异步多线程】异步多线程爬虫代码示例
claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表
##鸿蒙核心技术##运动开发##Sensor Service Kit(传感器服务)# 前言 在运动类应用中,运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据,如配速、距离、卡路里消耗等,用户可以更清晰…...

C# 表达式和运算符(求值顺序)
求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如,已知表达式3*52,依照子表达式的求值顺序,有两种可能的结果,如图9-3所示。 如果乘法先执行,结果是17。如果5…...

【Linux手册】探秘系统世界:从用户交互到硬件底层的全链路工作之旅
目录 前言 操作系统与驱动程序 是什么,为什么 怎么做 system call 用户操作接口 总结 前言 日常生活中,我们在使用电子设备时,我们所输入执行的每一条指令最终大多都会作用到硬件上,比如下载一款软件最终会下载到硬盘上&am…...
Spring Security 认证流程——补充
一、认证流程概述 Spring Security 的认证流程基于 过滤器链(Filter Chain),核心组件包括 UsernamePasswordAuthenticationFilter、AuthenticationManager、UserDetailsService 等。整个流程可分为以下步骤: 用户提交登录请求拦…...
Python实现简单音频数据压缩与解压算法
Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中,压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言,提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...