【人工智能】从零开始用Python实现逻辑回归模型:深入理解逻辑回归的原理与应用
解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门!
逻辑回归是一种经典的统计学习方法,用于分类问题尤其是二分类问题。它通过学习数据的特征和目标标签之间的关系,输出样本属于某个类别的概率。本文将从零开始用Python实现逻辑回归模型,深入探讨其数学原理,并逐步讲解如何编写代码实现。我们会涵盖从数据预处理、特征归一化、梯度下降算法到模型训练的每一个细节。通过丰富的代码示例和详细的中文注释,本文将帮助读者从头掌握逻辑回归模型的构建与优化。
正文
目录
- 逻辑回归简介
- 逻辑回归的数学原理
- 2.1 逻辑回归的假设函数
- 2.2 损失函数的推导
- 2.3 梯度下降优化
- 从零开始实现逻辑回归模型
- 3.1 数据预处理
- 3.2 逻辑回归模型类的定义
- 3.3 编写模型训练方法
- 3.4 编写模型预测方法
- 逻辑回归模型的性能评估
- 4.1 准确率与混淆矩阵
- 4.2 交叉验证
- 完整代码实现
- 实验与结果分析
- 总结
1. 逻辑回归简介
逻辑回归是一种广泛应用的分类模型,通常用于二分类问题。不同于线性回归直接输出一个数值,逻辑回归使用一个Sigmoid函数将预测值映射到[0, 1]区间,以此表示样本属于某个类别的概率。逻辑回归的目标是找到最佳的参数,使得模型能够正确地将输入映射到相应的类别。
2. 逻辑回归的数学原理
逻辑回归的数学基础是线性模型,通过线性组合的输入特征生成一个预测值,然后利用Sigmoid函数将其转换为概率。
2.1 逻辑回归的假设函数
逻辑回归的假设函数为:
h θ ( x ) = σ ( θ T x ) = 1 1 + e − θ T x h_{\theta}(x) = \sigma(\theta^T x) = \frac{1}{1 + e^{-\theta^T x}} hθ(x)=σ(θTx)=1+e−θTx1
其中:
- h θ ( x ) h_{\theta}(x) hθ(x) 是预测的概率值;
- θ \theta θ 是模型的参数向量;
- x x x 是特征向量;
- σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+e−z1是Sigmoid激活函数。
Sigmoid函数的输出在(0, 1)之间,表示样本属于正类(标签为1)的概率。
2.2 损失函数的推导
逻辑回归的目标是最小化损失函数(Cost Function),通常选择对数似然函数作为损失函数。对于单个样本,损失函数为:
L ( h θ ( x ) , y ) = − y log ( h θ ( x ) ) − ( 1 − y ) log ( 1 − h θ ( x ) ) L(h_{\theta}(x), y) = -y \log(h_{\theta}(x)) - (1 - y) \log(1 - h_{\theta}(x)) L(hθ(x),y)=−ylog(hθ(x))−(1−y)log(1−hθ(x))
其中 y 为样本的真实标签。当 y=1 时,我们最小化 − log ( h θ ( x ) ) -\log(h_{\theta}(x)) −log(hθ(x));当 y=0 时,我们最小化 − log ( 1 − h θ ( x ) ) -\log(1 - h_{\theta}(x)) −log(1−hθ(x))。
总体的损失函数,即所有样本的平均损失为:
J ( θ ) = 1 m ∑ i = 1 m ( − y ( i ) log ( h θ ( x ( i ) ) ) − ( 1 − y ( i ) ) log ( 1 − h θ ( x ( i ) ) ) ) J(\theta) = \frac{1}{m} \sum_{i=1}^m \left(-y^{(i)} \log(h_{\theta}(x^{(i)})) - (1 - y^{(i)}) \log(1 - h_{\theta}(x^{(i)}))\right) J(θ)=m1i=1∑m(−y(i)log(hθ(x(i)))−(1−y(i))log(1−hθ(x(i))))
2.3 梯度下降优化
为了优化损失函数,我们可以使用梯度下降法更新参数 ( \theta )。梯度下降的更新公式为:
θ = θ − α ∇ J ( θ ) \theta = \theta - \alpha \nabla J(\theta) θ=θ−α∇J(θ)
其中:
α \alpha α 是学习率;
- ∇ J ( θ ) \nabla J(\theta) ∇J(θ) 是损失函数 J ( θ ) J(\theta) J(θ)关于参数的梯度。
梯度的计算公式为:
∂ J ( θ ) ∂ θ j = 1 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) x j ( i ) \frac{\partial J(\theta)}{\partial \theta_j} = \frac{1}{m} \sum_{i=1}^m \left( h_{\theta}(x^{(i)}) - y^{(i)} \right) x_j^{(i)} ∂θj∂J(θ)=m1i=1∑m(hθ(x(i))−y(i))xj(i)
3. 从零开始实现逻辑回归模型
接下来,我们将逐步实现逻辑回归模型,包括数据预处理、模型定义、训练和预测。
3.1 数据预处理
首先,我们使用一个简单的二分类数据集。数据集需要进行标准化处理,以提高模型的训练效果。
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 生成简单的二分类数据集
np.random.seed(0)
X = np.random.randn(100, 2) # 100个样本,每个样本2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int) # 简单的线性可分数据# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
3.2 逻辑回归模型类的定义
接下来,我们定义逻辑回归模型类,包括初始化、Sigmoid函数和损失函数的实现。
class LogisticRegression:def __init__(self, learning_rate=0.01, n_iterations=1000):self.learning_rate = learning_rateself.n_iterations = n_iterationsself.weights = Noneself.bias = Nonedef sigmoid(self, z):# 定义Sigmoid激活函数return 1 / (1 + np.exp(-z))def compute_cost(self, y, y_pred):# 计算损失函数m = len(y)cost = (-1 / m) * np.sum(y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))return cost
3.3 编写模型训练方法
使用梯度下降算法优化模型参数。梯度计算基于损失函数的偏导数。
def fit(self, X, y):# 初始化参数m, n = X.shapeself.weights = np.zeros(n)self.bias = 0for i in range(self.n_iterations):# 计算线性模型的预测值linear_model = np.dot(X, self.weights) + self.biasy_pred = self.sigmoid(linear_model)# 计算梯度dw = (1 / m) * np.dot(X.T, (y_pred - y))db = (1 / m) * np.sum(y_pred - y)# 更新权重和偏差self.weights -= self.learning_rate * dwself.bias -= self.learning_rate * db# 每100次迭代打印一次损失if i % 100 == 0:cost = self.compute_cost(y, y_pred)print(f"Iteration {i}: Cost {cost}")
3.4 编写模型预测方法
逻辑回归的预测输出是一个概率值,可以通过设置一个阈值(如0.5)来决定样本属于哪个类别。
def predict(self, X):# 预测函数linear_model = np.dot(X, self.weights) + self.biasy_pred = self.sigmoid(linear_model)return [1 if i > 0.5 else 0 for i in y_pred]
4. 逻辑回归模型的性能评估
为了评估逻辑回归模型的性能,我们使用准确率和混淆矩阵。
4.1 准确率与混淆矩阵
准确率表示正确预测的比例,混淆矩阵能更清晰地展示模型的分类表现。
from sklearn.metrics import accuracy_score, confusion_matrix# 模型训练与预测
model = LogisticRegression(learning_rate=0.1, n_iterations=1000)
model.fit(X_train, y_train)
predictions = model.predict(X_test)# 计算准确率和混淆矩阵
accuracy = accuracy_score(y_test, predictions)
conf_matrix = confusion_matrix(y_test, predictions)print(f"Accuracy: {accuracy}")
print(f"Confusion Matrix:\n{conf_matrix}")
4.2
交叉验证
交叉验证通过多次数据划分和训练,得到模型更加稳定和可靠的性能评价。
5. 完整代码实现
以下是逻辑回归模型的完整代码实现:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrixclass LogisticRegression:def __init__(self, learning_rate=0.01, n_iterations=1000):self.learning_rate = learning_rateself.n_iterations = n_iterationsself.weights = Noneself.bias = Nonedef sigmoid(self, z):return 1 / (1 + np.exp(-z))def compute_cost(self, y, y_pred):m = len(y)cost = (-1 / m) * np.sum(y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))return costdef fit(self, X, y):m, n = X.shapeself.weights = np.zeros(n)self.bias = 0for i in range(self.n_iterations):linear_model = np.dot(X, self.weights) + self.biasy_pred = self.sigmoid(linear_model)dw = (1 / m) * np.dot(X.T, (y_pred - y))db = (1 / m) * np.sum(y_pred - y)self.weights -= self.learning_rate * dwself.bias -= self.learning_rate * dbif i % 100 == 0:cost = self.compute_cost(y, y_pred)print(f"Iteration {i}: Cost {cost}")def predict(self, X):linear_model = np.dot(X, self.weights) + self.biasy_pred = self.sigmoid(linear_model)return [1 if i > 0.5 else 0 for i in y_pred]# 数据准备
np.random.seed(0)
X = np.random.randn(100, 2)
y = (X[:, 0] + X[:, 1] > 0).astype(int)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)# 训练与评估
model = LogisticRegression(learning_rate=0.1, n_iterations=1000)
model.fit(X_train, y_train)
predictions = model.predict(X_test)accuracy = accuracy_score(y_test, predictions)
conf_matrix = confusion_matrix(y_test, predictions)print(f"Accuracy: {accuracy}")
print(f"Confusion Matrix:\n{conf_matrix}")
6. 实验与结果分析
在不同的学习率、迭代次数、数据分布下,我们可以调整模型参数并观察损失函数的变化。
7. 总结
本文从零开始用Python实现了逻辑回归模型,包括数学原理、代码实现和性能评估。通过这些步骤,读者可以深入理解逻辑回归的工作原理,并掌握如何从头构建和优化分类模型。
相关文章:
【人工智能】从零开始用Python实现逻辑回归模型:深入理解逻辑回归的原理与应用
解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 逻辑回归是一种经典的统计学习方法,用于分类问题尤其是二分类问题。它通过学习数据的特征和目标标签之间的…...

推荐一款功能强大的光学识别OCR软件:Readiris Dyslexic
Readiris Dyslexic是一款功能强大的光学识别OCR软件,可以扫描任何纸质文档并将其转换为完全可编辑的数字文件(Word,Excel,PDF),然后用你喜欢的编辑器进行编辑。该软件提供了一种轻松创建,修改和签名PDF的完整解决方法&…...

Python爬虫----python爬虫基础
一、python爬虫基础-爬虫简介 1、现实生活中实际爬虫有哪些? 2、什么是网络爬虫? 3、什么是通用爬虫和聚焦爬虫? 4、为什么要用python写爬虫程序 5、环境和工具 二、python爬虫基础-http协议和chrome抓包工具 1、什么是http和https协议…...

css-50 Projects in 50 Days(3)
html <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>旋转页面</title><link rel"sty…...

另外一种缓冲式图片组件的用法
文章目录 1. 概念介绍2. 使用方法2.1 基本用法2.2 缓冲原理3. 示例代码4. 内容总结我们在上一章回中介绍了"FadeInImage组件"相关的内容,本章回中将介绍CachedNetworkImage组件.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 我们在本章回中介绍的CachedNetwo…...

字节青训-小C的外卖超时判断、小C的排列询问
目录 一、小C的外卖超时判断 问题描述 测试样例 解题思路: 问题理解 数据结构选择 算法步骤 最终代码: 运行结果: 二、小C的排列询问 问题描述 测试样例 最终代码: 运行结果: 编辑 一、小C的外卖超时判断…...
PHP 伪静态详解及实现方法
概述 在现代 Web 开发中,URL 的设计对用户体验和搜索引擎优化(SEO)至关重要。动态 URL 虽然功能强大,但往往显得冗长且不友好。伪静态(URL 重写)技术通过将动态 URL 转换为静态样式,不仅提高了…...
Spring Boot 简单预览PDF例子
目录 前言 一、引入依赖 二、使用步骤 1.创建 Controller 处理 PDF 生成和预览 2.创建预览页面 总结 前言 使用 Spring Boot 创建一个生成 PDF 并进行预览的项目,你可以按以下步骤进行。我们将使用 Spring Boot、Thymeleaf、iText 等技术来完成这个任务。 一、引入…...

【魔珐有言-注册/登录安全分析报告-无验证方式导致安全隐患】
前言 由于网站注册入口容易被机器执行自动化程序攻击,存在如下风险: 暴力破解密码,造成用户信息泄露,不符合国家等级保护的要求。短信盗刷带来的拒绝服务风险 ,造成用户无法登陆、注册,大量收到垃圾短信的…...

LabVIEW 使用 Snippet
在 LabVIEW 中,Snippet(代码片段) 是一个非常有用的功能,它允许你将 一小段可重用的代码 保存为一个 图形化的代码片段,并能够在不同的 VI 中通过拖放来使用。 什么是 Snippet? Snippet 就是 LabVIEW 中的…...

单片机_day3_GPIO
目录 1. 灯如何才能亮 1.1原理图 1.2 二极管 1.3 换了一个灯和原理图 编辑 1.4 三极管 1.4.1 NPN型三极管 1.4.2 PNP型三极管 2. 基本概念 3. 输入 3.1 浮空输入 3.2 上拉输入 3.3 下拉输入 3.4 模拟输入 4. 输出 4.1 推挽输出 4.2 开漏输出 如何让开漏输出…...
Python小游戏24——小恐龙躲避游戏
首先,你需要安装Pygame库。如果你还没有安装,可以通过以下命令安装: 【bash】 pip install pygame 【python】代码 import pygame import random # 初始化Pygame pygame.init() # 设置屏幕尺寸 screen_width 800 screen_height 600 screen …...

Python 的多态笔记
Python的多态实际是通过instance 实现的 class Person:def __init__(self, name,age):self.name nameself.age agedef feed_pet(self,pet):#isinastance(obj,类)-->判断obj,是不是这个类的对象,或者判断obj是不是该类的子类的对象if isinstance(pet, Pet):sel…...

go module使用
go module介绍 go module是go官⽅⾃带的go依赖管理库,在1.13版本正式推荐使⽤ go module可以将某个项⽬(⽂件夹)下的所有依赖整理成⼀个 go.mod ⽂件,⾥⾯写⼊了依赖的版本等 使⽤ go module之后我们可不⽤将代码放置在src下了 使⽤ go module 管理依赖后会在项⽬根⽬录下⽣成…...
c ++零基础可视化——数组
c 零基础可视化 数组 一些知识: 关于给数组赋值,一个函数为memset,其在cplusplus.com中的描述如下: void * memset ( void * ptr, int value, size_t num );Sets the first num bytes of the block of memory pointed by ptr to…...

CVE-2024-2961漏洞的简单学习
简单介绍 PHP利用glibc iconv()中的一个缓冲区溢出漏洞,实现将文件读取提升为任意命令执行漏洞 在php读取文件的时候可以使用 php://filter伪协议利用 iconv 函数, 从而可以利用该漏洞进行 RCE 漏洞的利用场景 PHP的所有标准文件读取操作都受到了影响࿱…...
计算机组成原理笔记----基础篇
计算机系统硬件软件 软件 ├── 系统软件 │ ├── 操作系统 │ └── 工具软件 └── 应用软件├── 办公软件├── 媒体软件└── 浏览器软件硬件 ├── 计算机硬件 │ ├── 中央处理器(CPU) │ ├── 存储设备 │ │ ├── …...
TheadLocal出现的内存泄漏具体泄漏的是什么?弱引用在里面有什么作用?什么情景什么问题?
首先ThreadLocal是什么就不介绍了!这篇是讲讲里面的东西。 再简单说一下强引用和弱引用,举个例子,我们平常new出来的对象就是强引用的,在栈中有强引用,所以在gc的时候,堆中的实例对象不会被清除掉。 弱引…...

AI在电商平台中的创新应用:提升销售效率与用户体验的数字化转型
1. 引言 AI技术在电商平台的应用已不仅仅停留在基础的数据分析和自动化推荐上。随着人工智能的迅速发展,越来越多的电商平台开始将AI技术深度融合到用户体验、定价策略、供应链优化、客户服务等核心业务中,从而显著提升运营效率和用户满意度。在这篇文章…...
CTF-RE 从0到N:RC4
RC4加密算法简介 RC4是由Ron Rivest于1987年设计的一种流加密算法。它通过伪随机数生成器生成密钥流,并将该密钥流与明文进行异或运算来完成加密和解密。 RC4的加密流程 RC4主要包含两个阶段: 密钥调度算法 (Key Scheduling Algorithm, KSA)ÿ…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...

龙虎榜——20250610
上证指数放量收阴线,个股多数下跌,盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型,指数短线有调整的需求,大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的:御银股份、雄帝科技 驱动…...

centos 7 部署awstats 网站访问检测
一、基础环境准备(两种安装方式都要做) bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats࿰…...

STM32标准库-DMA直接存储器存取
文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...
docker 部署发现spring.profiles.active 问题
报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)
船舶制造装配管理现状:装配工作依赖人工经验,装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书,但在实际执行中,工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...

认识CMake并使用CMake构建自己的第一个项目
1.CMake的作用和优势 跨平台支持:CMake支持多种操作系统和编译器,使用同一份构建配置可以在不同的环境中使用 简化配置:通过CMakeLists.txt文件,用户可以定义项目结构、依赖项、编译选项等,无需手动编写复杂的构建脚本…...

【Post-process】【VBA】ETABS VBA FrameObj.GetNameList and write to EXCEL
ETABS API实战:导出框架元素数据到Excel 在结构工程师的日常工作中,经常需要从ETABS模型中提取框架元素信息进行后续分析。手动复制粘贴不仅耗时,还容易出错。今天我们来用简单的VBA代码实现自动化导出。 🎯 我们要实现什么? 一键点击,就能将ETABS中所有框架元素的基…...