动手学深度学习(Pytorch版)代码实践 -深度学习基础-02线性回归基础版
02线性回归基础版
主要内容
- 数据生成:使用线性模型 ( y = X*w + b ) 加上噪声生成人造数据集。
- 数据读取:通过小批量读取数据集来实现批量梯度下降,打乱数据顺序并逐批返回特征和标签。
- 模型参数初始化:随机初始化权重和偏置,并设置为可计算梯度。
- 模型定义:实现线性回归模型 ( y = X*w + b )。
- 损失函数:实现均方误差损失函数。
- 优化函数:实现小批量随机梯度下降用于更新模型参数。
- 模型训练:设定学习率和迭代次数,通过每个批量计算损失、反向传播和参数更新。
import random
import torch# 生成数据集
def synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声"""# torch.normal: 返回一个从均值为0,标准差为1的正态分布中提取的随机数的张量# 生成形状为(num_examples, len(w))的矩阵X = torch.normal(0, 1, (num_examples, len(w)))# torch.matmul: 矩阵乘法y = torch.matmul(X, w) + b# 添加噪声:torch.normal(0, 0.01, y.shape)y += torch.normal(0, 0.01, y.shape)# reshape: 只改变张量的视图,不改变数据,将y转换为列向量return X, y.reshape((-1, 1))# 定义真实的权重和偏置
true_w = torch.tensor([2, -3.4])
true_b = 4.2
# 生成特征和标签
features, labels = synthetic_data(true_w, true_b, 1000)# 读取数据集
def data_iter(batch_size, features, labels):num_examples = len(features)# 生成一个从0到num_examples-1的整数列表indices = list(range(num_examples))# 将列表的次序打乱random.shuffle(indices)# 每次迭代生成一个小批量数据for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]# 设置批量大小
batch_size = 10# 初始化模型参数
# 随机初始化权重,设置requires_grad=True以计算梯度
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 初始化偏置为0,设置requires_grad=True以计算梯度
b = torch.zeros(1, requires_grad=True) # 定义模型
def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b# 定义损失函数
def squared_loss(y_hat, y):"""均方损失函数"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 定义优化函数
def sgd(params, lr, batch_size):"""小批量随机梯度下降"""# 更新参数时不需要计算梯度with torch.no_grad():for param in params:param -= lr * param.grad / batch_size # 参数更新param.grad.zero_() # 梯度清零# 模型训练
lr = 0.03 # 学习率
num_epochs = 5 # 迭代周期数
net = linreg # 线性回归模型
loss = squared_loss # 损失函数# 开始训练
for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y) # 计算小批量数据的损失l.sum().backward() # 计算梯度sgd([w, b], lr, batch_size) # 更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels) # 计算整个数据集上的损失print(f'第{epoch + 1}轮,损失: {float(train_l.mean()):f}')# 打印权重和偏置的估计误差
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')# 示例输出:
# 第1轮,损失: 0.036624
# 第2轮,损失: 0.000131
# 第3轮,损失: 0.000052
# 第4轮,损失: 0.000052
# 第5轮,损失: 0.000052
# w的估计误差: tensor([-0.0003, -0.0008], grad_fn=<SubBackward0>)
# b的估计误差: tensor([0.0007], grad_fn=<RsubBackward1>)
相关文章:
动手学深度学习(Pytorch版)代码实践 -深度学习基础-02线性回归基础版
02线性回归基础版 主要内容 数据生成:使用线性模型 ( y X*w b ) 加上噪声生成人造数据集。数据读取:通过小批量读取数据集来实现批量梯度下降,打乱数据顺序并逐批返回特征和标签。模型参数初始化:随机初始化权重和偏置&#x…...

信息学奥赛初赛天天练-15-阅读程序-深入解析二进制原码、反码、补码,位运算技巧,以及lowbit的神奇应用
更多资源请关注纽扣编程微信公众号 1 2021 CSP-J 阅读程序1 阅读程序(程序输入不超过数组或字符串定义的范围;判断题正确填 √,错误填;除特 殊说明外,判断题 1.5 分,选择题 3 分) 源码 #in…...

期权具体怎么交易详细的操作流程?
期权就是股票,唯一区别标的物上证指数,会看大盘吧,交易两个方向认购做多,认沽做空,双向t0交易,期权具体交易流程可以理解选择方向多和空,选开仓的合约,买入开仓和平仓没了࿰…...

系统架构设计师【第3章】: 信息系统基础知识 (核心总结)
文章目录 3.1 信息系统概述3.1.1 信息系统的定义3.1.2 信息系统的发展3.1.3 信息系统的分类3.1.4 信息系统的生命周期3.1.5 信息系统建设原则3.1.6 信息系统开发方法 3.2 业务处理系统(TPS)3.2.1 业务处理系统的概念3.2.2 业务处理系统的功能 …...

Linux 驱动设备匹配过程
一、Linux 驱动-总线-设备模型 1、驱动分层 Linux内核需要兼容多个平台,不同平台的寄存器设计不同导致操作方法不同,故内核提出分层思想,抽象出与硬件无关的软件层作为核心层来管理下层驱动,各厂商根据自己的硬件编写驱动…...

游戏子弹类python设计与实现详解
新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、引言 二、子弹类设计思路 1. 属性定义 2. 方法设计 三、子弹类实现详解 1. 定义子弹…...
Python基础学习笔记(六)——列表
目录 一、一维列表的介绍和创建二、序列的基本操作1. 索引的查询与返回2. 切片3. 序列加 三、元素的增删改1. 添加元素2. 删除元素3. 更改元素 四、排序五、列表生成式 一、一维列表的介绍和创建 列表(list),也称数组,是一种有序、…...
帝国CMS跳过选择会员类型直接注册方法
国CMS因允许多用户组注册,所以在注册页面会有一个选择注册用户组的界面,即使网站只用了一个用户组也会出现。 如果想去掉这个页面,直接进入注册页面,那么可按以下办法修改 打开 e/class/user.php 文件 查找: $chan…...

【python】python tkinter 计算器GUI版本(模仿windows计算器 源码)【独一无二】
👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…...
黑马es数据同步mq解决方案
方式一:同步调用 优点:实现简单,粗暴 缺点:业务耦合度高 方式二:异步通知 优点:低耦含,实现难度一般 缺点:依赖mq的可靠性 方式三:监听binlog 优点:完全解除服务间耦合 缺点:开启binlog增加数据库负担、实现复杂度高 利用MQ实现mysql与elastics…...
通过LLM多轮对话生成单元测试用例
通过LLM多轮对话生成单元测试用例 代码 在采用 随机生成pytorch算子测试序列且保证算子参数合法 这种方法之前,曾通过本文的方法生成算子组合测试用例。目前所测LLM生成的代码均会出现BUG,且多次交互后仍不能解决.也许随着LLM的更新,这个问题会得到解决.记录备用。 代码 impo…...

[Redis]String类型
基本命令 set命令 将 string 类型的 value 设置到 key 中。如果 key 之前存在,则覆盖,无论原来的数据类型是什么。之前关于此 key 的 TTL 也全部失效。 set key value [expiration EX seconds|PX milliseconds] [NX|XX] 选项[EX|PX] EX seconds⸺使用…...

Ai速递5.29
全球AI新闻速递 1.摩尔线程与无问芯穹合作,实现国产 GPU 端到端 AI 大模型实训。 2.宝马工厂:机器狗上岗,可“嗅探”故障隐患。 3.ChatGPT:macOS 开始公测。 4.Stability AI:推出Stable Assistant,可用S…...

Android9.0 MTK平台如何增加一个系统应用
在安卓定制化开发过程中,难免遇到要把自己的app预置到系统中,作为系统应用使用,其实方法有很多,过程很简单,今天分享一下我是怎么做的,共总分两步: 第一步:要找到当前系统应用apk存…...

LabVIEW中实现Trio控制器的以太网通讯
在LabVIEW中实现与Trio控制器的以太网通讯,可以通过使用TCP/IP协议来完成。这种方法包括配置Trio控制器的网络设置、使用LabVIEW中的TCP/IP函数库进行数据传输和接收,以及处理通讯中的错误和数据解析。本文将详细说明实现步骤,包括配置、编程…...
C/C++运行时库与 UCRT 通用运行时库:全面总结与问题实例剖析
推荐一个AI网站,免费使用豆包AI模型,快去白嫖👉海鲸AI 1. 概述 在开发C/C应用程序时,运行时库(Runtime Library)是不可或缺的一部分。它们提供了一系列函数和功能,使得开发者能够更方便地进行编…...

【Python001】python批量下载、插入与读取Oracle中图片数据(已更新)
1.熟悉、梳理、总结数据分析实战中的python、oracle研发知识体系 2.欢迎点赞、关注、批评、指正,互三走起来,小手动起来! 文章目录 1.背景说明2.环境搭建2.1 参考链接2.2 `oracle`查询测试代码3.数据请求与插入3.1 `Oracle`建表语句3.2 `Python`代码实现3.3 效果示例4.问题链…...
流形学习(Manifold Learning)
基本概念 Manifold Learning(流形学习)是一种机器学习和数据分析的方法,它专注于从高维数据中发现低维的非线性结构。流形学习的基本假设是,尽管数据可能在高维空间中呈现,但它们实际上分布在一个低维的流形上。这个流…...

区块链技术和应用
文章目录 前言 一、区块链是什么? 二、区块链核心数据结构 2.1 交易 2.2 区块 三、交易 3.1 交易的生命周期 3.2 节点类型 3.3 分布式系统 3.4 节点数据库 3.5 智能合约 3.6 多个记账节点-去中心化 3.7 双花问题 3.8 共识算法 3.8.1 POW工作量证明 总结 前言 学习长…...

Docker拉取镜像报错:x509: certificate has expired or is not yet v..
太久没有使用docker进行镜像拉取,今天使用docker-compose拉取mongo发现报错(如下图): 报错信息翻译:证书已过期或尚未有效。 解决办法: 1.一般都是证书问题或者系统时间问题导致,可以先执行 da…...

基于FPGA的PID算法学习———实现PID比例控制算法
基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...
如何为服务器生成TLS证书
TLS(Transport Layer Security)证书是确保网络通信安全的重要手段,它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书,可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...

tree 树组件大数据卡顿问题优化
问题背景 项目中有用到树组件用来做文件目录,但是由于这个树组件的节点越来越多,导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多,导致的浏览器卡顿,这里很明显就需要用到虚拟列表的技术&…...
Python 包管理器 uv 介绍
Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...
蓝桥杯 冶炼金属
原题目链接 🔧 冶炼金属转换率推测题解 📜 原题描述 小蓝有一个神奇的炉子用于将普通金属 O O O 冶炼成为一种特殊金属 X X X。这个炉子有一个属性叫转换率 V V V,是一个正整数,表示每 V V V 个普通金属 O O O 可以冶炼出 …...
动态 Web 开发技术入门篇
一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...

搭建DNS域名解析服务器(正向解析资源文件)
正向解析资源文件 1)准备工作 服务端及客户端都关闭安全软件 [rootlocalhost ~]# systemctl stop firewalld [rootlocalhost ~]# setenforce 0 2)服务端安装软件:bind 1.配置yum源 [rootlocalhost ~]# cat /etc/yum.repos.d/base.repo [Base…...

通过 Ansible 在 Windows 2022 上安装 IIS Web 服务器
拓扑结构 这是一个用于通过 Ansible 部署 IIS Web 服务器的实验室拓扑。 前提条件: 在被管理的节点上安装WinRm 准备一张自签名的证书 开放防火墙入站tcp 5985 5986端口 准备自签名证书 PS C:\Users\azureuser> $cert New-SelfSignedCertificate -DnsName &…...