当前位置: 首页 > news >正文

用大模型学大模型05-线性回归

deepseek.com:多元线性回归的目标函数,损失函数,梯度下降 标量和矩阵形式的数学推导,pytorch真实能跑的代码案例以及模型,数据,预测结果的可视化展示, 模型应用场景和优缺点,及如何改进解决及改进方法数据推导。

一、数学推导

1. 模型定义
  • 输入
    • 样本数 n n n,特征数 m m m
    • 特征矩阵 X ∈ R n × ( m + 1 ) X \in \mathbb{R}^{n \times (m+1)} XRn×(m+1)(含截距项全1列)。
    • 参数向量 β = [ β 0 , β 1 , … , β m ] T ∈ R ( m + 1 ) × 1 \beta = [\beta_0, \beta_1, \dots, \beta_m]^T \in \mathbb{R}^{(m+1) \times 1} β=[β0,β1,,βm]TR(m+1)×1
  • 预测值
    y ^ = X β 或标量形式 y ^ i = β 0 + ∑ j = 1 m β j x i j \hat{y} = X \beta \quad \text{或标量形式} \quad \hat{y}_i = \beta_0 + \sum_{j=1}^m \beta_j x_{ij} y^=或标量形式y^i=β0+j=1mβjxij

2. 目标函数与损失函数
  • 目标:最小化预测值与真实值的平方误差。
  • 损失函数(MSE)
    L ( β ) = 1 2 n ∑ i = 1 n ( y ^ i − y i ) 2 = 1 2 n ∥ X β − y ∥ 2 2 L(\beta) = \frac{1}{2n} \sum_{i=1}^n (\hat{y}_i - y_i)^2 = \frac{1}{2n} \| X \beta - y \|_2^2 L(β)=2n1i=1n(y^iyi)2=2n1y22
    • 系数 1 2 n \frac{1}{2n} 2n1:简化梯度计算,避免平方项导数的系数干扰。

3. 梯度下降推导
标量形式

对每个参数 β j \beta_j βj 求偏导:

  1. 截距项 β 0 \beta_0 β0
    ∂ L ∂ β 0 = 1 n ∑ i = 1 n ( y ^ i − y i ) \frac{\partial L}{\partial \beta_0} = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i) β0L=n1i=1n(y^iyi)
  2. 特征权重 β j \beta_j βj j ≥ 1 j \geq 1 j1)
    ∂ L ∂ β j = 1 n ∑ i = 1 n ( y ^ i − y i ) x i j \frac{\partial L}{\partial \beta_j} = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i) x_{ij} βjL=n1i=1n(y^iyi)xij
矩阵形式

利用矩阵微分法则:
∇ β L = 1 n X T ( X β − y ) \nabla_\beta L = \frac{1}{n} X^T (X \beta - y) βL=n1XT(y)

  • 推导过程
    L ( β ) = 1 2 n ( X β − y ) T ( X β − y ) ⟹ ∂ L ∂ β = 1 n X T ( X β − y ) L(\beta) = \frac{1}{2n} (X \beta - y)^T (X \beta - y) \implies \frac{\partial L}{\partial \beta} = \frac{1}{n} X^T (X \beta - y) L(β)=2n1(y)T(y)βL=n1XT(y)
梯度下降更新公式

β ( t + 1 ) = β ( t ) − η ∇ β L = β ( t ) − η n X T ( X β ( t ) − y ) \beta^{(t+1)} = \beta^{(t)} - \eta \nabla_\beta L = \beta^{(t)} - \frac{\eta}{n} X^T (X \beta^{(t)} - y) β(t+1)=β(t)ηβL=β(t)nηXT(Xβ(t)y)

  • 学习率 η \eta η:控制参数更新步长。

二、应用场景

  1. 连续值预测
    • 房价预测、销售额预测、股票价格趋势分析。
  2. 因果关系分析
    • 研究广告投入与销量的量化关系。
  3. 基线模型
    • 作为复杂模型(如神经网络)的性能对比基准。

三、优缺点及解决方法

优点
  1. 简单高效:计算复杂度低(( O(nm) ) 每轮梯度下降)。
  2. 可解释性强:参数直接反映特征对目标的影响程度。
  3. 闭式解存在:当 X T X X^T X XTX可逆时,可直接求解 β = ( X T X ) − 1 X T y \beta = (X^T X)^{-1} X^T y β=(XTX)1XTy
缺点及解决方法
缺点解决方法
线性假设限制引入多项式特征或使用非线性模型(如决策树、神经网络)。
多重共线性正则化(岭回归、Lasso)、主成分分析(PCA)降维。
对异常值敏感使用鲁棒损失函数(Huber损失)、数据清洗或加权最小二乘法。
异方差性(方差不均)加权回归、Box-Cox变换稳定方差。
特征维度高时不稳定正则化、逐步回归、特征选择(如基于p值或AIC准则)。

改进方法与数学推导

1. 正则化(Ridge 回归)

目标函数
L = 1 2 m ∥ X w − y ∥ 2 + λ ∥ w ∥ 2 L = \frac{1}{2m} \|Xw - y\|^2 + \lambda \|w\|^2 L=2m1Xwy2+λw2

梯度更新
∇ w L = 1 m X T ( X w − y ) + 2 λ m w \nabla_w L = \frac{1}{m} X^T (Xw - y) + \frac{2\lambda}{m} w wL=m1XT(Xwy)+m2λw

PyTorch 实现

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1.0)  # weight_decay 对应 λ
2. 数据预处理
  • 标准化:使特征均值为 0,方差为 1,加速收敛。
  • 异常值处理:使用 IQR 或 Z-Score 过滤离群点。
3. 特征工程
  • 多项式扩展:将 x 1 , x 2 x_1, x_2 x1,x2 扩展为 x 1 2 , x 2 2 , x 1 x 2 x_1^2, x_2^2, x_1x_2 x12,x22,x1x2 等,再用线性回归。

数学形式
y ^ = w 1 x 1 + w 2 x 2 + w 3 x 1 2 + w 4 x 2 2 + w 5 x 1 x 2 + b \hat{y} = w_1 x_1 + w_2 x_2 + w_3 x_1^2 + w_4 x_2^2 + w_5 x_1x_2 + b y^=w1x1+w2x2+w3x12+w4x22+w5x1x2+b


四、关键公式总结

内容标量形式矩阵形式
预测值 y ^ i = β 0 + β 1 x i 1 + ⋯ + β m x i m \hat{y}_i = \beta_0 + \beta_1 x_{i1} + \dots + \beta_m x_{im} y^i=β0+β1xi1++βmxim y ^ = X β \hat{y} = X \beta y^=
损失函数 L = 1 2 n ∑ i = 1 n ( y ^ i − y i ) 2 L = \frac{1}{2n} \sum_{i=1}^n (\hat{y}_i - y_i)^2 L=2n1i=1n(y^iyi)2 L = 1 2 n ∣ X β − y ∣ 2 2 L = \frac{1}{2n} | X \beta - y |_2^2 L=2n1y22
梯度 ∂ L ∂ β j = 1 n ∑ i = 1 n ( y ^ i − y i ) x i j \frac{\partial L}{\partial \beta_j} = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i) x_{ij} βjL=n1i=1n(y^iyi)xij ∇ β L = 1 n X T ( X β − y ) \nabla_\beta L = \frac{1}{n} X^T (X \beta - y) βL=n1XT(y)

五、实际应用示例

  1. 房价预测
    • 特征:房屋面积、卧室数量、地理位置。
    • 输出:房价。
    • 方法:通过梯度下降拟合参数,预测新样本价格。
  2. 广告效果分析
    • 特征:电视、网络、报纸广告投入。
    • 输出:销售额增长。
    • 结论:参数正负性指示广告渠道的有效性。

六、扩展:正则化改进

  • 岭回归(L2正则化)
    L ( β ) = 1 2 n ∥ X β − y ∥ 2 2 + λ ∥ β ∥ 2 2 L(\beta) = \frac{1}{2n} \| X \beta - y \|_2^2 + \lambda \| \beta \|_2^2 L(β)=2n1y22+λβ22
    • 解决多重共线性,防止过拟合。
  • Lasso(L1正则化)
    L ( β ) = 1 2 n ∥ X β − y ∥ 2 2 + λ ∥ β ∥ 1 L(\beta) = \frac{1}{2n} \| X \beta - y \|_2^2 + \lambda \| \beta \|_1 L(β)=2n1y22+λβ1
    • 自动特征选择,稀疏解。

完整代码示例


import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 生成随机数据
n_samples = 100
n_features = 2
X = torch.randn(n_samples, n_features)
true_w = torch.tensor([[3.0], [4.0]])
true_b = torch.tensor([2.0])
y = X @ true_w + true_b + torch.randn(n_samples, 1) * 0.1# 定义模型
class LinearRegression(nn.Module):def __init__(self, input_dim, output_dim):super(LinearRegression, self).__init__()self.linear = nn.Linear(input_dim, output_dim)def forward(self, x):return self.linear(x)model = LinearRegression(n_features, 1)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
n_epochs = 1000
losses = []  # 初始化一个空列表来存储损失值for epoch in range(n_epochs):# 前向传播y_pred = model(X)# 计算损失loss = criterion(y_pred, y)losses.append(loss.item())  # 将损失值添加到列表中# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch+1) % 100 == 0:print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}')# 可视化损失函数
plt.plot(losses)  # 绘制损失函数随训练轮数的变化
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.savefig("lr.png")
plt.show()

七、总结

多元线性回归是机器学习的基石模型,优势在于简单性和可解释性,但受限于线性假设。实际应用中需结合数据预处理、正则化或非线性扩展方法以提升性能。矩阵形式计算高效,适合编程实现;标量形式便于理解梯度下降的微观机制

相关文章:

用大模型学大模型05-线性回归

deepseek.com:多元线性回归的目标函数,损失函数,梯度下降 标量和矩阵形式的数学推导,pytorch真实能跑的代码案例以及模型,数据,预测结果的可视化展示, 模型应用场景和优缺点,及如何改进解决及改进方法数据推…...

Python实现AWS Fargate自动化部署系统

一、背景介绍 在现代云原生应用开发中,自动化部署是提高开发效率和保证部署质量的关键。AWS Fargate作为一项无服务器计算引擎,可以让我们专注于应用程序开发而无需管理底层基础设施。本文将详细介绍如何使用Python实现AWS Fargate的完整自动化部署流程。 © ivwdcwso (ID…...

国产编辑器EverEdit - 上下翻滚不迷路(历史编辑位置、历史光标位置回溯功能)

1 光标位置跳转 1.1 应用场景 某些场景下,用户从当前编辑位置跳转到别的位置查阅信息,如果要快速跳转回之前编辑位置,则可以使用光标跳转相关功能。 1.2 使用方法 1.2.1 上一个编辑位置 跳转到上一个编辑位置,即文本修改过的位…...

今日写题work05

题目:用队列实现栈 思路 队列的特点是先进先出,而栈的特点是后进先出。所以想要用队列实现模拟栈,我们可以使用两个队列,一个队列负责压栈,一个队列负责出栈。压栈很简单就是检空再调用队列的push就好,那出…...

[C++语法基础与基本概念] std::function与可调用对象

std::function与可调用对象 函数指针lambda表达式std::function与std::bind仿函数总结std::thread与可调用对象std::async与可调用对象回调函数 可调用对象是指那些像函数一样可以直接被调用的对象,他们广泛用于C的算法,回调,事件处理等机制。…...

两个实用且热门的 Python 爬虫案例,结合动态/静态网页抓取和反爬策略,附带详细代码和实现说明

在这个瞬息万变的世界里,保持一颗探索的心,永远怀揣梦想前行。即使有时会迷失方向,也不要忘记内心深处那盏指引你前进的明灯。它代表着你的希望、你的信念以及对未来的无限憧憬。每一个不曾起舞的日子,都是对生命的辜负&#xff1…...

华象新闻 | 2月20日前谨慎升级 PostgreSQL 版本

各位 PostgreSQL 用户,建议近期进行升级 PostgreSQL 版本。 2月20日计划进行非周期性版本发布 PostgreSQL全球开发团队计划于2025年2月20日进行一次非周期性发布,以解决2025年2月13日更新版本中引入的一个回归问题。 2月13日的更新版本包括了17.3、16.7、…...

跳跃游戏 II - 贪心算法解法

问题描述&#xff1a; 给定一个长度为 n 的 0 索引整数数组 nums&#xff0c;我们从数组的第一个元素 nums[0] 开始。每个元素 nums[i] 表示从索引 i 可以跳跃的最大长度&#xff0c;换句话说&#xff0c;从位置 i&#xff0c;你可以跳到位置 i j&#xff0c;其中 0 < j &…...

图像质量评价指标-UCIQE-UIQM

一、评价指标UCIQE 在文章《An underwater color image quality evaluation metric》中&#xff0c;提到的了评价指标UCIQE&#xff08;Underwater Colour Image Quality Evaluation&#xff09;&#xff0c;是一种无参考图像质量评价指标&#xff0c;主要用于评估水下图像的质…...

CentOS上安装WordPress

在CentOS上安装WordPress是一个相对直接的过程&#xff0c;可以通过多种方法完成&#xff0c;包括使用LAMP&#xff08;Linux, Apache, MySQL, PHP&#xff09;栈或使用更现代的LEMP&#xff08;Linux, Nginx, MySQL, PHP&#xff09;栈。 我选择的是&#xff08;Linux, Nginx…...

Spring Boot 原理分析

spring-boot.version&#xff1a;2.4.3.RELEASE Spring Boot 依赖管理 spring-boot-starter-parent 配置文件管理 <resources> <resource> <directory>${basedir}/src/main/resources</directory> <filtering>true&l…...

Git 本地项目上传 GitHub 全指南(SSH Token 两种上传方式详细讲解)

前言&#xff1a;Git 与 GitHub 的区别与联系 在学习如何将本地项目上传到 GitHub 之前&#xff0c;先来弄清楚 Git 和 GitHub 的区别以及它们之间的联系。 对比项GitGitHub定义分布式版本控制系统&#xff08;DVCS&#xff09;&#xff0c;用于本地和远程管理代码版本托管 G…...

jenkins服务启动-排错

服务状态为active (exited) 且进程不在 查看/etc/rc.d/init.d/jenkins配置 获取配置参数 [rootfy-jenkins-prod jenkins]# cat /etc/rc.d/init.d/jenkins | grep -v #JENKINS_WAR"/usr/lib/jenkins/jenkins.war" test -r "$JENKINS_WAR" || { echo "…...

CF 144A.Arrival of the General(Java实现)

题目分析 一个n个身高数据&#xff0c;问最高的到最前面&#xff0c;最矮的到最后面的最短交换次数 思路分析 首先&#xff0c;如果数据有重复项&#xff0c;例如示例二中&#xff0c;最矮的数据就是最后一个出现的数据位置&#xff0c;最高的数据就是最先出现的数据位置&…...

SAP-ABAP:SAP中REPORT程序和online程序的区别对比

在SAP中&#xff0c;REPORT程序和Online程序&#xff08;通常指Dialog程序&#xff09;是两种常见的ABAP程序类型&#xff0c;它们在用途、结构和用户交互方式上有显著区别。以下是它们的详细对比&#xff1a; 1. 用途 REPORT程序Online程序主要用于数据查询、报表生成和批量数…...

Java发展史

JavaEE的由来 语言的诞生 Java的前身是Oak语言&#xff0c;其目的是搞嵌入式开发开发智能面包机 叮~~~&#x1f35e;&#x1f35e;&#x1f35e; 产品以失败告终 巅峰 网景公司需要网景浏览器打开网页&#xff0c;Oak->Java&#xff0c;进行前端开发&#xff08;相关技…...

vue3--SVG图标的封装与使用

流程 终端输入- -安装下面这个包 npm install vite-plugin-svg-icons -Dvite.config.ts文件中引入 import {createSvgIconsPlugin} from vite-plugin-svg-iconsvite.config.ts文件中配置plugins选项 将下面代码 createSvgIconsPlugin({//用于指定包含 SVG 图标的文件夹路径…...

Datawhale Ollama教程笔记3

小白的看课思路&#xff1a; Ollama REST API 是什么&#xff1f; 想象一下&#xff0c;你有一个智能的“盒子”&#xff08;Ollama&#xff09;&#xff0c;里面装了很多聪明的“小助手”&#xff08;语言模型&#xff09;。如果你想让这些“小助手”帮你完成一些任务&#…...

学习数据结构(10)栈和队列下+二叉树(堆)上

1.关于栈和队列的算法题 &#xff08;1&#xff09;用队列实现栈 解法一&#xff1a;&#xff08;参考代码&#xff09; 题目要求实现六个函数&#xff0c;分别是栈初始化&#xff0c;入栈&#xff0c;移除并返回栈顶元素&#xff0c;返回栈顶元素&#xff0c;判空&#xff0…...

洛谷 P3660 USACO17FEB Why Did the Cow Cross the Road III 题解

题意 有一个圆&#xff0c;圆周上按顺时针方向给出 2 n 2n 2n个点。第 i i i个点的颜色是 c o l o r i color_i colori​&#xff0c;其中数据保证 1 ≤ c o l o r i ≤ n 1\le color_i\le n 1≤colori​≤n&#xff0c;而且每种不同的颜色有且只有两个点。不存在位置重叠的点…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接&#xff1a;A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串&#xff0c;只有在同时为 o 时输出 Yes 并结束程序&#xff0c;否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

【服务器压力测试】本地PC电脑作为服务器运行时出现卡顿和资源紧张(Windows/Linux)

要让本地PC电脑作为服务器运行时出现卡顿和资源紧张的情况&#xff0c;可以通过以下几种方式模拟或触发&#xff1a; 1. 增加CPU负载 运行大量计算密集型任务&#xff0c;例如&#xff1a; 使用多线程循环执行复杂计算&#xff08;如数学运算、加密解密等&#xff09;。运行图…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

AspectJ 在 Android 中的完整使用指南

一、环境配置&#xff08;Gradle 7.0 适配&#xff09; 1. 项目级 build.gradle // 注意&#xff1a;沪江插件已停更&#xff0c;推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...

接口自动化测试:HttpRunner基础

相关文档 HttpRunner V3.x中文文档 HttpRunner 用户指南 使用HttpRunner 3.x实现接口自动化测试 HttpRunner介绍 HttpRunner 是一个开源的 API 测试工具&#xff0c;支持 HTTP(S)/HTTP2/WebSocket/RPC 等网络协议&#xff0c;涵盖接口测试、性能测试、数字体验监测等测试类型…...

FFmpeg avformat_open_input函数分析

函数内部的总体流程如下&#xff1a; avformat_open_input 精简后的代码如下&#xff1a; int avformat_open_input(AVFormatContext **ps, const char *filename,ff_const59 AVInputFormat *fmt, AVDictionary **options) {AVFormatContext *s *ps;int i, ret 0;AVDictio…...

全面解析数据库:从基础概念到前沿应用​

在数字化时代&#xff0c;数据已成为企业和社会发展的核心资产&#xff0c;而数据库作为存储、管理和处理数据的关键工具&#xff0c;在各个领域发挥着举足轻重的作用。从电商平台的商品信息管理&#xff0c;到社交网络的用户数据存储&#xff0c;再到金融行业的交易记录处理&a…...

如何在Windows本机安装Python并确保与Python.NET兼容

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

书籍“之“字形打印矩阵(8)0609

题目 给定一个矩阵matrix&#xff0c;按照"之"字形的方式打印这个矩阵&#xff0c;例如&#xff1a; 1 2 3 4 5 6 7 8 9 10 11 12 ”之“字形打印的结果为&#xff1a;1&#xff0c;…...

统计学(第8版)——统计抽样学习笔记(考试用)

一、统计抽样的核心内容与问题 研究内容 从总体中科学抽取样本的方法利用样本数据推断总体特征&#xff08;均值、比率、总量&#xff09;控制抽样误差与非抽样误差 解决的核心问题 在成本约束下&#xff0c;用少量样本准确推断总体特征量化估计结果的可靠性&#xff08;置…...