pytorch 实现线性回归(深度学习)
一 查看原始函数
初始化
%matplotlib inline
import random
import torch
from d2l import torch as d2l
1.1 生成原始数据
def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape) # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重
随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1。
w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b
![]()
二 执行训练
查看训练过程中的 参数变化:
print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y) # 计算总损失print('w:', w, 'b:', b) # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)


三 测试梯度更新
初始化数据
%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape) # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b
3.1 测试更新
print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
# param -= lr * param.grad / batch_size
# param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y) # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b) # l:', l, '\nl.sum().backward() # 计算更新梯度sgd([w, b], lr, batch_size)
使用 l.sum().backward() # 计算更新梯度:

不使用更新时:
print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
# param -= lr * param.grad / batch_size
# param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y) # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b) # l:', l, '\n# l.sum().backward() # 计算更新梯度sgd([w, b], lr, batch_size)# break

相关文章:
pytorch 实现线性回归(深度学习)
一 查看原始函数 初始化 %matplotlib inline import random import torch from d2l import torch as d2l 1.1 生成原始数据 def synthetic_data(w, b, num_examples):x torch.normal(0, 1, (num_examples, len(w)))y torch.matmul(x, w) bprint(x:, x)print(y:, y)y tor…...
[Doris] Doris的安装和部署 (二)
文章目录 1.安装要求1.1 Linux操作系统要求1.2 软件需求1.3 注意事项1.4 内部端口 2.集群部署2.1 操作系统安装要求2.2 下载安装包2.3 解压2.4 配置FE2.5 配置BE2.6 添加BE2.7 FE 扩容和缩容2.8 Doris 集群群起脚本 3.图形化 1.安装要求 1.1 Linux操作系统要求 1.2 软件需求 1…...
【QT+QGIS跨平台编译】之三十五:【cairo+Qt跨平台编译】(一套代码、一套框架,跨平台编译)
文章目录 一、cairo介绍二、文件下载三、文件分析四、pro文件五、编译实践一、cairo介绍 Cairo是一个功能强大的开源2D图形库,它提供了一套跨平台的API,用于绘制矢量图形和文本。Cairo支持多种输出目标,包括屏幕、图像文件、PDF、SVG等。 Cairo的设计目标是简单易用、高效…...
MySQL(基础)
第01章_数据库概述 1. 为什么要使用数据库 持久化(persistence):把数据保存到可掉电式存储设备中以供之后使用。大多数情况下,特别是企业级应用,数据持久化意味着将内存中的数据保存到硬盘上加以”固化”,而持久化的实现过程大多…...
STM32F1 - 中断系统
Interrupt 1> 硬件框图2> NVIC 中断管理3> EXTI 中断管理3.1> EXTI与NVIC3.2> EXTI内部框图 4> 外部中断实验4.1> 实验概述4.2> 程序设计 5> 中断向量表6> 总结 1> 硬件框图 NVIC:Nested Vectored Interrupt Controller【嵌套向量…...
【Linux系统化学习】缓冲区
目录 缓冲区 一个样例 现象解释 缓冲区存在的位置 缓冲区 在刚开始学习C语言的时候我们就听过缓冲区这个名词,很是晦涩难懂;在Linux下进程退出时也包含缓冲区,因此缓冲区到底是什么?有什么作用? 让我们先从一个小…...
基于BP算法的SAR成像matlab仿真
目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 BP算法的基本原理 4.2 BP算法的优点与局限性 5.完整工程文件 1.课题概述 基于BP算法的SAR成像。合成孔径雷达(SAR)是一种高分辨率的雷达系统,能够在各种天气和光…...
【C++ STL】你真的了解string吗?浅谈string的底层实现
文章目录 底层结构概述扩容机制浅拷贝与深拷贝插入和删除的效率浅谈VS和g的优化总结 底层结构概述 string可以帮助我们很好地管理字符串,但是你真的了解她吗?事实上,string的设计是非常复杂的,拥有上百个接口,但最常用…...
17.3.1.3 灰度
版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 灰度的算法主要有以下三种: 1、最大值法: 原图像:颜色值color(R,G,B&a…...
基于CAS操作的atomic原子类型
在上一节的卖票程序中,我们讲解了如何在多线程中保证临界资源的正确访问——使用互斥锁,即 lock_guard<mutex> lock(mtx); count;lock_guard<mutex> lock(mtx); count--; 从汇编角度解释线程间互斥-mutex互斥锁与lock_guard的使用-CSDN博客…...
Rust HashMap详解及单词统计示例
在Rust中,HashMap是一种非常有用的数据结构,用于存储键值对。本文将深入介绍HashMap的特性,以及通过一个单词统计的例子展示其用法。 HashMap简介 HashMap是Rust标准库提供的用于存储键值对的数据结构。它允许通过键快速查找对应的值&#…...
命令执行讲解和函数
命令执行漏洞简介 命令执行漏洞产生原因 应用未对用户输入做严格得检查过滤,导致用户输入得参数被当成命令来执行 命令执行漏洞的危害 1.继承Web服务程序的权限去执行系统命会或读写文件 2.反弹shell,获得目标服务器的权限 3.进一步内网渗透 远程代…...
外包实在是太坑了,划水三年,感觉人都废了
先说一下自己的情况,专科生,19年通过校招进入杭州某个外包软件公司,干了接近3年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了3年的功…...
代码随想录算法训练营第19天
77. 组合 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 class Solution:def combine(self, n: int, k: int) -> List[List[int]]:path []res []def dfs(n,k,index):if len(path) k:res.append(path[:])returnfor i in range(index,n1):…...
树莓派5 EEPROM引导加载程序恢复镜像
树莓派5不能正常启动,可以通过电源led灯的闪码来判断错误发生的大致情形。 LED警告闪码 如果树莓派由于某种原因无法启动,或者不得不关闭,在许多情况下,LED会闪烁特定的次数来指示发生了什么。LED会闪烁几次长闪烁,然…...
循序渐进-讲解Markdown进阶(Mermaid绘图)-附使用案例
Markdown 进阶操作 查看更多学习笔记:GitHub:LoveEmiliaForever Mermaid官网 由于CSDN对某些Mermaid或Markdown语法不支持,因此我的某些效果展示使用图片进行 下面的笔记内容全部是我根据Mermaid官方文档学习的,因为是初学者所以…...
寒假作业2月6号
第五章 静态成员与友元 一、填空题 1、一个类的头文件如下所示,num初始化值为5,程序产生对象T,且修改num为10,并使用show()函数输出num的值10。 #include <iostream.h> class Test { private: static int num; publi…...
ChatGPT绘图指南:DALL.E3玩法大全(一)
一、 DALLE.3 模型介绍 1、什么是 DALLE.3 模型? DALLE-3模型,是一种由OpenAI研发的技术,它是一种先进的生成模型,可以将文字描述转化为清晰的图片。这种模型的名称"DALLE"实际上是"Deep Auto-regressive Latent …...
计算机服务器中了_locked勒索病毒怎么办?Encrypted勒索病毒解密数据恢复
随着网络技术的不断发展,数字化办公已经成为企业生产运营的根本,对于企业来说,数据至关重要,但网络威胁无处不在,近期,云天数据恢复中心接到很多企业的求助,企业的计算机服务器遭到了_locked勒索…...
VueCLI核心知识3:全局事件总线、消息订阅与发布
这两种方式都可以实现任意两个组件之间的通信 1 全局事件总线 1.安装全局事件总线 import Vue from vue import App from ./App.vueVue.config.productionTip false/* 1.第一种写法 */ // const Demo Vue.extend({}) // const d new Demo()// Vue.prototype.x d // 把Dem…...
多种风格导航菜单 HTML 实现(附源码)
下面我将为您展示 6 种不同风格的导航菜单实现,每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...
Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...
全志A40i android7.1 调试信息打印串口由uart0改为uart3
一,概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本:2014.07; Kernel版本:Linux-3.10; 二,Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01),并让boo…...
AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...
scikit-learn机器学习
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...
宇树科技,改名了!
提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...
Qemu arm操作系统开发环境
使用qemu虚拟arm硬件比较合适。 步骤如下: 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载,下载地址:https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...
人工智能--安全大模型训练计划:基于Fine-tuning + LLM Agent
安全大模型训练计划:基于Fine-tuning LLM Agent 1. 构建高质量安全数据集 目标:为安全大模型创建高质量、去偏、符合伦理的训练数据集,涵盖安全相关任务(如有害内容检测、隐私保护、道德推理等)。 1.1 数据收集 描…...
Linux部署私有文件管理系统MinIO
最近需要用到一个文件管理服务,但是又不想花钱,所以就想着自己搭建一个,刚好我们用的一个开源框架已经集成了MinIO,所以就选了这个 我这边对文件服务性能要求不是太高,单机版就可以 安装非常简单,几个命令就…...
绕过 Xcode?使用 Appuploader和主流工具实现 iOS 上架自动化
iOS 应用的发布流程一直是开发链路中最“苹果味”的环节:强依赖 Xcode、必须使用 macOS、各种证书和描述文件配置……对很多跨平台开发者来说,这一套流程并不友好。 特别是当你的项目主要在 Windows 或 Linux 下开发(例如 Flutter、React Na…...
