当前位置: 首页 > news >正文

24/11/14 算法笔记 GMM高斯混合模型

高斯混合模型(Gaussian Mixture Model,简称 GMM)是一种概率模型,用于表示具有多个子群体的数据集,其中每个子群体的数据分布可以用高斯分布(正态分布)来描述。GMM 是一种软聚类方法,意味着它为每个数据点分配一个属于每个聚类的概率分布,而不是硬聚类方法中的严格分类。

GMM 的组成

一个 GMM 由以下几个部分组成:

  1. 聚类数量(K):模型中高斯分布(聚类)的数量。
  2. 均值向量(μkμk​):每个高斯分布的均值向量,其中 kk 表示聚类索引。
  3. 协方差矩阵(ΣkΣk​):每个高斯分布的协方差矩阵,描述了数据在各个维度上的分布范围和形状。
  4. 混合系数(πkπk​):每个高斯分布的权重,表示数据属于该聚类的概率,所有混合系数之和为1。

GMM 的数学表达

GMM 的概率密度函数(PDF)可以表示为:

GMM 的学习

GMM 的参数学习通常使用 EM 算法进行,EM算法前面有将,是一个策略优化算法

24/11/14 算法笔记 EM算法期望最大化算法-CSDN博客

我们来看一下简单的GMM源代码

import numpy as np
from scipy.stats import multivariate_normalclass GaussianMixture:def __init__(self, n_components, covariance_type='full', n_iter=100, random_state=None):self.n_components = n_components  # 聚类数量self.covariance_type = covariance_type  # 协方差类型self.n_iter = n_iter  # 迭代次数self.random_state = random_state  # 随机种子self.weights_ = None  # 混合系数self.means_ = None  # 均值self.covariances_ = None  # 协方差def _initialize_parameters(self, X):"""随机初始化均值、协方差和权重"""n_samples, n_features = X.shapeself.weights_ = np.ones(self.n_components) / self.n_components  # 初始化权重random_indices = np.random.choice(n_samples, self.n_components, replace=False)self.means_ = X[random_indices]  # 随机选择均值self.covariances_ = np.array([np.eye(n_features)] * self.n_components)  # 初始化协方差为单位矩阵def _e_step(self, X):"""E步骤:计算每个数据点属于每个高斯分布的责任"""n_samples = X.shape[0]responsibilities = np.zeros((n_samples, self.n_components))for k in range(self.n_components):rv = multivariate_normal(mean=self.means_[k], cov=self.covariances_[k])responsibilities[:, k] = self.weights_[k] * rv.pdf(X)# 归一化责任responsibilities /= responsibilities.sum(axis=1, keepdims=True)return responsibilitiesdef _m_step(self, X, responsibilities):"""M步骤:更新均值、协方差和权重"""n_samples = X.shape[0]effective_n = responsibilities.sum(axis=0)  # 每个聚类的有效样本数量# 更新权重self.weights_ = effective_n / n_samples# 更新均值self.means_ = np.dot(responsibilities.T, X) / effective_n[:, np.newaxis]# 更新协方差for k in range(self.n_components):diff = X - self.means_[k]self.covariances_[k] = np.dot(responsibilities[:, k] * diff.T, diff) / effective_n[k]def fit(self, X):"""训练模型"""self._initialize_parameters(X)  # 初始化参数for _ in range(self.n_iter):  # 迭代更新responsibilities = self._e_step(X)  # E步骤self._m_step(X, responsibilities)  # M步骤def predict(self, X):"""预测数据点的聚类标签"""responsibilities = self._e_step(X)  # 计算责任return np.argmax(responsibilities, axis=1)  # 返回最大责任的聚类索引def sample(self, n_samples):"""从模型中生成新样本"""samples = np.zeros((n_samples, self.means_.shape[1]))for i in range(n_samples):k = np.random.choice(self.n_components, p=self.weights_)  # 根据权重选择聚类samples[i] = np.random.multivariate_normal(self.means_[k], self.covariances_[k])  # 生成样本return samples

接下来让我们分析下每段代码

1.初始化函数 __init__

def __init__(self, n_components, covariance_type='full', n_iter=100, random_state=None):self.n_components = n_components  # 聚类数量self.covariance_type = covariance_type  # 协方差类型self.n_iter = n_iter  # 迭代次数self.random_state = random_state  # 随机种子self.weights_ = None  # 混合系数self.means_ = None  # 均值self.covariances_ = None  # 协方差

这是类的构造函数,用于初始化GMM模型的参数:

  • n_components:模型中高斯分布(聚类)的数量。
  • covariance_type:协方差矩阵的类型,可以是'full''diag''spherical',分别表示全协方差、对角协方差和球面协方差。
  • n_iter:EM算法的最大迭代次数。
  • random_state:随机数生成器的种子,用于结果的可重复性。
  • weights_means_covariances_:这些属性将在模型训练后存储模型参数。

2.参数初始化函数 _initialize_parameters

def _initialize_parameters(self, X):"""随机初始化均值、协方差和权重"""n_samples, n_features = X.shapeself.weights_ = np.ones(self.n_components) / self.n_components  # 初始化权重random_indices = np.random.choice(n_samples, self.n_components, replace=False)self.means_ = X[random_indices]  # 随机选择均值self.covariances_ = np.array([np.eye(n_features)] * self.n_components)  # 初始化协方差为单位矩阵

这个函数用于随机初始化模型参数:

  • self.weights_:权重初始化为均等分布。
  • self.means_:均值初始化为数据集中随机选择的点。
  • self.covariances_:协方差矩阵初始化为单位矩阵,适用于全协方差情况。
  • 协方差可以告诉我们两个变量是如何一起变化的。如果两个变量的协方差是正的,那么它们倾向于朝相同的方向变化;如果协方差是负的,那么一个变量增加时,另一个变量倾向于减少。

3.E步骤函数 _e_step

def _e_step(self, X):"""E步骤:计算每个数据点属于每个高斯分布的责任"""n_samples = X.shape[0]responsibilities = np.zeros((n_samples, self.n_components))for k in range(self.n_components):#函数用于生成符合多元正态分布的随机样本。rv = multivariate_normal(mean=self.means_[k], cov=self.covariances_[k])responsibilities[:, k] = self.weights_[k] * rv.pdf(X)# 归一化责任responsibilities /= responsibilities.sum(axis=1, keepdims=True)return responsibilities

E步骤计算每个数据点属于每个高斯分布的责任(后验概率):

  • 使用multivariate_normal.pdf计算每个高斯分布的PDF值。
  • 将每个高斯分布的PDF值乘以相应的权重,得到未归一化的责任。
  • 通过将每个数据点的责任除以其总和来归一化责任,确保每个数据点的责任之和为1。

PDF值通常指的是概率密度函数(Probability Density Function)的值。概率密度函数是连续概率分布的一个核心概念,它描述了随机变量在给定区间内取值的概率密度。对于连续随机变量,其概率密度函数的图形可以告诉我们随机变量取某个特定值的可能性。

4.M步骤函数 _m_step

def _m_step(self, X, responsibilities):"""M步骤:更新均值、协方差和权重"""n_samples = X.shape[0]effective_n = responsibilities.sum(axis=0)  # 每个聚类的有效样本数量# 更新权重self.weights_ = effective_n / n_samples# 更新均值self.means_ = np.dot(responsibilities.T, X) / effective_n[:, np.newaxis]# 更新协方差for k in range(self.n_components):diff = X - self.means_[k]self.covariances_[k] = np.dot(responsibilities[:, k] * diff.T, diff) / effective_n[k]

M步骤根据E步骤计算的责任更新模型参数:

  • self.weights_:权重更新为每个聚类的有效样本数量除以总样本数量。
  • self.means_:均值更新为加权平均,权重是每个数据点对每个聚类的责任。
  • self.covariances_:协方差更新为加权的样本偏差的外积,权重是每个数据点对每个聚类的责任。

5.训练函数 fit

def fit(self, X):"""训练模型"""self._initialize_parameters(X)  # 初始化参数for _ in range(self.n_iter):  # 迭代更新responsibilities = self._e_step(X)  # E步骤self._m_step(X, responsibilities)  # M步骤

  • 首先调用_initialize_parameters函数初始化参数。
  • 然后进行指定次数的迭代,每次迭代都包括E步骤和M步骤。

6.预测函数 predict

def predict(self, X):"""预测数据点的聚类标签"""responsibilities = self._e_step(X)  # 计算责任return np.argmax(responsibilities, axis=1)  # 返回最大责任的聚类索引
  • 首先调用_e_step函数计算新数据点对每个聚类的责任。
  • 然后返回责任最大的聚类索引作为预测标签。

7.采样函数 sample

def sample(self, n_samples):"""从模型中生成新样本"""samples = np.zeros((n_samples, self.means_.shape[1]))for i in range(n_samples):k = np.random.choice(self.n_components, p=self.weights_)  # 根据权重选择聚类samples[i] = np.random.multivariate_normal(self.means_[k], self.covariances_[k])  # 生成样本return samples
  • 首先初始化一个空的样本数组。
  • 然后根据每个聚类的权重随机选择一个聚类。
  • 从选定的聚类对应的高斯分布中生成一个样本。
  • 重复上述过程,直到生成所需数量的样本。

相关文章:

24/11/14 算法笔记 GMM高斯混合模型

高斯混合模型(Gaussian Mixture Model,简称 GMM)是一种概率模型,用于表示具有多个子群体的数据集,其中每个子群体的数据分布可以用高斯分布(正态分布)来描述。GMM 是一种软聚类方法,…...

Linux下编译安装Nginx

以下是在Linux下编译安装Nginx的详细步骤: 一、安装依赖库 安装基本编译工具和库 在Debian/Ubuntu系统中,使用以下命令安装:sudo apt -y update sudo apt -y install build - essential libpcre3 - dev zlib1g - dev libssl - dev在CentOS/…...

算力100问☞第4问:算力的构成元素有哪些?

算力的构成元素是一个多维度且相互交织的体系,它融合了硬件基础设施、软件优化策略、数据处理效能以及分布式计算技术等多个层面,共同塑造了强大的计算能力。具体如下: 1、硬件基础设施 中央处理器(CPU):…...

安装paddle

网址:飞桨PaddlePaddle-源于产业实践的开源深度学习平台 或者找对应python和cuda版本的paddle下载后安装: https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 你想要安装paddlepaddle - gpu2.6.1.post112版本。在你提供的文件列表中&am…...

飞凌嵌入式RK3576核心板已适配Android 14系统

在今年3月举办的RKDC2024大会上,飞凌嵌入式FET3576-C核心板作为瑞芯微RK3576处理器的行业首秀方案重磅亮相,并于今年6月率先量产发货,为客户持续稳定地供应,得到了众多合作伙伴的认可。 FET3576-C核心板此前已提供了Linux 6.1.57…...

SpringBoot+MyBatis+MySQL的Point实现范围查找

前言 最近做了一个功能,需要通过用户当前位置点获取指定范围内的数据。由于后端存储用的是 MySQL,故选择使用 MySQL 中的 Point 实现范围查找功能。ORM 框架用的是 MyBatis,MyBatis 原生并不支持 Point 字段与 POJO 的映射,需要自…...

【Apache Paimon】-- 1 -- Apache Paimon 是什么?

目录 1、简介 2、概览 3、哪些场景可以使用 Paimon 4、周边生态 5、小结 6、参考 1、简介 我们听说过数据仓库、数据湖、数据湖仓,那你听说过流式数据仓库(Stream warehouse,简称:Streamhouse)吗?那我们今天就来解锁看看他们之中的新秀: Apache paimon 到底是什么…...

解决VsCode无法跳转问题

在settings.json中加入以下代码 { "files.associations": { "*.c":"c", "*.h":"c", "*.s":"masm" }, "includePath":[ "${workspaceFold…...

优化C++设计模式:用模板代替虚函数与多态机制

文章目录 0. 引言1. 模板编程替换虚函数和多态的必要性1.1. MISRA C对类型转换和虚函数的规定1.2. 虚函数与多态问题的影响及如何适应MISRA C要求1.3. 模板编程的优势:替代虚函数和多态机制 2. 设计模式改进2.1. 单例模式的改进与静态局部变量的对比(第二种实现) 2.…...

浪浪云轻量服务器搭建vulfocus网络安全靶场

什么是网络安全靶场 网络安全靶场是一个模拟真实网络环境的训练平台,旨在为网络安全专业人员提供一个安全的环境来测试和提高他们的技能。靶场通常包括各种网络设备、操作系统、应用程序和安全工具,允许用户在其中进行攻击和防御练习。以下是网络安全靶…...

C++builder中的人工智能(23):在现代C++ Windows上轻松录制声音

在这篇文章中,我们将探讨如何在现代C Windows上轻松录制声音。声音以波形和数字形式存在,其音量随时间变化。在C Builder中,使用Windows设备进行录音非常简单。要录制声音,在多设备应用程序中,必须使用FMX.Media.hpp头…...

避免误差!Android 中正确计算时间差的方式

在 Android 开发中,计时和计算时间差异是非常常见的需求,比如记录事件发生的间隔、统计应用启动时间、测量网络请求的响应时间等。在实现这些功能时,我们通常需要一个可靠的时间源来确保计时的准确性。那么为什么 Android 推荐使用 SystemClo…...

unity3d————Resources异步加载

知识点一:Resources异步加载是什么? 在Unity中,资源加载可以分为同步加载和异步加载两种方式。同步加载会在主线程中直接进行,如果加载的资源过大,可能会导致程序卡顿,因为从硬盘读取数据到内存并进行处理…...

YOLOv11改进,YOLOv11添加GnConv递归门控卷积,二次创新C3k2结构

摘要 视觉 Transformer 在多种任务中取得了显著的成功,这得益于基于点积自注意力的新空间建模机制。视觉 Transformer 中的关键因素——即输入自适应、长距离和高阶空间交互——也可以通过卷积框架高效实现。作者提出了递归门控卷积(Recursive Gated Convolution,简称 gnCo…...

如何选择国产化CMS来建设政务网站?

在介绍CMS之前,我们先了解国家为什么要网站为什么要完成国产化改造? 1、信创国产化网站建站响应了国家的信息安全战略,支持自主可控的信息技术产业的发展,减少对进口软硬件的依赖,保障国家信息安全。 2、国产替代&…...

C/C++语言基础--initializer_list表达式、tuple元组、pair对组简介

本专栏目的 更新C/C的基础语法,包括C的一些新特性 前言 initializer_list表达式、tuple元组、pair对组再C日常还是比较常用的,尤其是对组在刷算法还是挺好用的,这里做一个简介;这三个语法结合C17的结构化绑定会更好用&#xff…...

paddle表格识别数据制作

数据格式 其中主要数据有两个一个表格结构的检测框&#xff0c;一个是tokens&#xff0c;注意的地方是 1、只能使用双引号&#xff0c;单引号不行 2、使用带引号的地方是tokens里面 "<tr>", "<td", " colspan2", ">",&quo…...

python selenium库的使用:通过兴趣点获取坐标

通过兴趣点获取坐标 from selenium import webdriver from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys from selenium.common.exceptions import TimeoutException# 保存Cookies到文件&#xff08;可选&#xff09; import pi…...

如何优化Kafka消费者的性能

要优化 Kafka 消费者性能&#xff0c;你可以考虑以下策略&#xff1a; 并行消费&#xff1a;通过增加消费者组中的消费者数量来并行处理更多的消息&#xff0c;从而提升消费速度。 批量消费&#xff1a;配置 fetch.min.bytes 和 fetch.max.wait.ms 参数来控制批量消费的大小和…...

机器学习 决策树

决策树-分类 1 概念 1、决策节点通过条件判断而进行分支选择的节点。如&#xff1a;将某个样本中的属性值(特征值)与决策节点上的值进行比较&#xff0c;从而判断它的流向。 2、叶子节点没有子节点的节点&#xff0c;表示最终的决策结果。 3、决策树的深度所有节点的最大层…...

AI人脸生成新范式:IP-Adapter-FaceID PlusV2双重嵌入技术解析

AI人脸生成新范式&#xff1a;IP-Adapter-FaceID PlusV2双重嵌入技术解析 【免费下载链接】IP-Adapter-FaceID 项目地址: https://ai.gitcode.com/hf_mirrors/h94/IP-Adapter-FaceID 在AI人脸生成领域&#xff0c;如何在保持身份一致性的同时实现风格的灵活控制&#x…...

VINS-Mono跑EUROC数据集后,如何用evo工具包进行轨迹精度评估与可视化(附完整命令)

VINS-Mono轨迹精度评估实战&#xff1a;从EUROC数据集到evo工具包全流程解析 在完成VINS-Mono算法在EUROC数据集上的运行后&#xff0c;如何科学评估其轨迹精度成为算法优化和论文撰写的关键环节。本文将深入讲解使用evo工具包进行定量分析的完整流程&#xff0c;涵盖指标计算、…...

网易云音乐无损解析:5大核心技术构建个人高品质音乐库

网易云音乐无损解析&#xff1a;5大核心技术构建个人高品质音乐库 【免费下载链接】Netease_url 网易云无损解析 项目地址: https://gitcode.com/gh_mirrors/ne/Netease_url 在数字音乐时代&#xff0c;如何突破平台限制&#xff0c;建立个人专属的高品质音乐库&#xf…...

Kubernetes资源监控与告警:从指标到行动的完整闭环

Kubernetes资源监控与告警&#xff1a;从指标到行动的完整闭环没有监控的集群就是黑盒&#xff0c;没有告警的监控就是摆设。监控体系架构 一个完整的K8s监控体系包含三个层次&#xff1a; ┌────────────────────────────────────────…...

数据库扩展实战:如何用ShardingCore实现高性能分库分表

数据库扩展实战&#xff1a;如何用ShardingCore实现高性能分库分表 【免费下载链接】sharding-core high performance lightweight solution for efcore sharding table and sharding database support read-write-separation .一款ef-core下高性能、轻量级针对分表分库读写分离…...

RAG实战解析:如何通过检索增强生成提升知识密集型NLP任务性能

1. RAG技术为什么能改变知识密集型NLP任务格局 第一次听说RAG&#xff08;Retrieval-Augmented Generation&#xff09;这个概念时&#xff0c;我正被一个开放域问答项目折磨得焦头烂额。当时我们用纯BART模型生成的答案总是出现事实性错误&#xff0c;比如把"特斯拉创始人…...

Halcon角度计算双雄对比:orientation_region和smallest_rectangle2到底该用哪个?

Halcon角度计算双雄对比&#xff1a;orientation_region与smallest_rectangle2的实战抉择 在工业视觉检测中&#xff0c;区域角度计算是定位、对齐和测量的基础操作。Halcon作为机器视觉领域的标杆工具&#xff0c;提供了orientation_region和smallest_rectangle2两个核心算子来…...

GoLang实战:5分钟搞定Langchaingo调用DeepSeek-R1大模型(附完整代码)

GoLang实战&#xff1a;5分钟搞定Langchaingo调用DeepSeek-R1大模型&#xff08;附完整代码&#xff09; 如果你是一位Go开发者&#xff0c;正需要在项目中快速集成大语言模型能力&#xff0c;却苦于时间有限、文档繁杂&#xff0c;那么这篇文章就是为你量身定制的。我们将用最…...

OpenClaw私有化方案:Qwen3-VL:30B+飞书自动化助手实战

OpenClaw私有化方案&#xff1a;Qwen3-VL:30B飞书自动化助手实战 1. 为什么选择私有化AI助手 去年我接手了一个特殊项目&#xff1a;需要将公司内部的技术文档自动整理成知识库&#xff0c;并推送到飞书文档。这个需求看似简单&#xff0c;但涉及几个棘手问题&#xff1a;文档…...

HoloPart:当3D模型学会自我解剖,深度学习的“X光眼“如何看透一切

HoloPart&#xff1a;当3D模型学会自我解剖&#xff0c;深度学习的"X光眼"如何看透一切 【免费下载链接】HoloPart Generative 3D Part Amodal Segmentation 项目地址: https://gitcode.com/gh_mirrors/ho/HoloPart 你是否曾对着一个复杂的3D模型感到困惑——…...