大模型训练为什么选择交叉熵损失(Cross-Entropy Loss):均方误差(MSE)和交叉熵损失的深入对比
交叉熵损失:深度学习中的基石与洞见
交叉熵损失(Cross-Entropy Loss)是现代深度学习中分类任务的核心损失函数,尤其在训练大规模模型(如 transformers 等大型语言模型 LLM)时,几乎无处不在。对于深度学习研究者而言,理解交叉熵的理论基础、与其他损失函数(如均方误差 MSE)的差异,以及它为何在分类任务中占据主导地位,不仅是技术层面的必需,更是深入洞察模型优化本质的关键。本文将从数学定义、理论特性、与 MSE 的对比,以及适用于分类任务的深刻原因等方面,详细剖析交叉熵损失,并提供一些独特的洞见。
一、交叉熵损失的数学定义与直觉
交叉熵起源于信息论,衡量的是两个概率分布之间的“距离”。在深度学习中,交叉熵损失通常用于监督学习中的分类任务,形式化定义如下:
对于一个多分类问题,假设有 ( C C C ) 个类别,真值标签为 ( y = [ y 1 , y 2 , … , y C ] y = [y_1, y_2, \dots, y_C] y=[y1,y2,…,yC] )(通常是 one-hot 编码,如 ( [ 0 , 1 , 0 ] [0, 1, 0] [0,1,0] )),模型预测的概率分布为 ( y ^ = [ y ^ 1 , y ^ 2 , … , y ^ C ] \hat{y} = [\hat{y}_1, \hat{y}_2, \dots, \hat{y}_C] y^=[y^1,y^2,…,y^C] )(通常通过 softmax 函数从 logits 得到),交叉熵损失为:
L C E = − ∑ i = 1 C y i log ( y ^ i ) L_{CE} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i) LCE=−i=1∑Cyilog(y^i)
对于二分类问题,常用形式是二元交叉熵(Binary Cross-Entropy):
L B C E = − [ y log ( y ^ ) + ( 1 − y ) log ( 1 − y ^ ) ] L_{BCE} = -[y \log(\hat{y}) + (1-y) \log(1-\hat{y})] LBCE=−[ylog(y^)+(1−y)log(1−y^)]
直觉上,交叉熵损失惩罚的是模型预测分布与真实分布之间的差异。当 ( y ^ i \hat{y}_i y^i ) 接近 ( y i y_i yi ) 时,损失趋近于 0;当预测偏离真值时(例如 ( y i = 1 y_i = 1 yi=1 ) 但 ( y ^ i → 0 \hat{y}_i \to 0 y^i→0 )),损失趋于无穷大。这种“无限惩罚”的特性使得交叉熵对错误的预测非常敏感。
从信息论的角度看,交叉熵可以理解为:给定真实分布 ( P ( y ) P(y) P(y) ),用预测分布 ( Q ( y ^ ) Q(\hat{y}) Q(y^) ) 对其编码所需的额外比特数。其公式与 KL 散度(Kullback-Leibler Divergence)密切相关:
H ( P , Q ) = H ( P ) + D K L ( P ∣ ∣ Q ) H(P, Q) = H(P) + D_{KL}(P || Q) H(P,Q)=H(P)+DKL(P∣∣Q)
其中 ( H ( P ) H(P) H(P) ) 是真实分布的熵(在监督学习中为常数),( D K L ( P ∣ ∣ Q ) D_{KL}(P || Q) DKL(P∣∣Q) ) 是 KL 散度。交叉熵损失本质上优化的是 ( D K L D_{KL} DKL ),即让预测分布尽量接近真实分布。
二、交叉熵与 MSE 的核心区别
均方误差(Mean Squared Error, MSE)是另一种常见的损失函数,定义为:
L M S E = 1 C ∑ i = 1 C ( y i − y ^ i ) 2 L_{MSE} = \frac{1}{C} \sum_{i=1}^{C} (y_i - \hat{y}_i)^2 LMSE=C1i=1∑C(yi−y^i)2
MSE 直观地度量了预测值与真实值之间的欧几里得距离,适用于回归任务,但在分类任务中却表现不佳。以下是交叉熵与 MSE 的几个关键区别:
-
分布假设
- 交叉熵假定输出是概率分布(通常搭配 softmax),直接优化模型输出的概率特性。
- MSE 假定误差符合高斯分布,适用于连续值的回归问题,而分类任务的输出(类别标签)是离散的,违背了这一假设。(下文有详细解释)
-
梯度特性 (由于篇幅限制,具体推导请移步笔者的另一篇博客:MSE分类时梯度消失的问题详解和交叉熵损失的梯度推导)
- 交叉熵的梯度与误差成正比,且在预测错误时(例如 ( y ^ → 0 \hat{y} \to 0 y^→0 ) 而 ( y = 1 y = 1 y=1 ))梯度较大,有助于快速修正。
- MSE 的梯度与误差线性相关,当预测值接近 0 或 1 时,梯度趋于 0,导致“梯度消失”问题,模型学习缓慢。
-
惩罚机制
- 交叉熵对错误预测的惩罚是非线性的,尤其是对置信度过高的错误预测(confidence penalty),这与分类任务中区分清晰的需求一致。
- MSE 的惩罚是二次型的,对所有误差的处理相对均匀,无法很好地反映分类任务中“正确与否”的二元性。
三、为什么大模型训练偏爱交叉熵?
在训练大型深度学习模型(如 BERT、GPT 等)时,交叉熵几乎是默认选择,而 MSE 很少被采用。这一现象背后有深刻的理论和实践原因:
-
概率输出的天然适配
大模型通常通过 softmax 或 sigmoid 输出概率分布,交叉熵直接作用于这些概率,优化目标明确且一致。MSE 则需要将离散标签转化为连续值,引入不必要的复杂性。 -
梯度动态的优势
在深度网络中,梯度传播是优化的核心。交叉熵的梯度形式为 ( y i − y ^ i y_i - \hat{y}_i yi−y^i )(经过 softmax 后),简单且高效,即使网络很深也能保持较好的梯度流。而 MSE 的梯度在概率值接近边界时趋于 0,不利于深层网络的训练。 -
分类任务的本质需求
大模型(如 LLM)常用于语言建模或多分类任务,目标是最大化正确类别的概率并压制其他类别。交叉熵通过对数惩罚机制,天然契合这一需求,而 MSE 更适合平滑的回归预测。 -
信息论的启发
交叉熵与最大似然估计(MLE)等价,优化交叉熵等同于最大化数据似然。这种统计一致性在大规模数据集上尤为重要,而 MSE 缺乏类似的理论支撑。
四、交叉熵为何适合分类任务?
交叉熵在分类任务中的优越性可以从以下几个方面深入理解:
-
对数损失的敏感性
分类任务关心的是“对错”,而不是“差多少”。交叉熵通过对数形式放大预测错误的代价,尤其是在高置信度错误时(例如预测 99% 概率为错误类别),这与分类的决策边界需求高度吻合。 -
与 softmax 的协同效应
Softmax 函数将 logits 转化为概率分布,而交叉熵直接基于这些概率计算损失,二者结合形成端到端的优化闭环。这种协同效应在多分类问题中尤为高效。 -
避免过平滑
MSE 倾向于让预测值向均值靠拢,可能导致模型输出的概率分布过于“平滑”,无法清晰区分类别。而交叉熵鼓励模型对正确类别输出高置信度,对错误类别输出低置信度。
五、MSE 不适合分类任务吗?
MSE 在分类任务中的局限性主要体现在以下几点:
-
梯度消失问题
当预测值接近 0 或 1 时,MSE 的梯度趋于 0,导致模型难以进一步优化。这在二分类或多分类中尤其致命,因为分类任务需要明确的边界。 -
对离散标签的不适配
分类标签是离散的(如 0 或 1),而 MSE 假设输出是连续的。这种假设会导致模型无法正确捕捉类别间的“跳跃性”差异。 -
缺乏概率解释
MSE 的输出难以直接解释为概率,而分类任务通常需要概率输出(例如用于后续的决策或评估指标如 AUC)。交叉熵则天然与概率分布挂钩。
尽管如此,MSE 在某些特殊场景下并非完全无用。例如,在有序分类(ordinal classification)中,类别间存在连续性(如评分 1 到 5),MSE 可以作为一种折中选择。但对于标准离散分类任务,MSE 的表现远不及交叉熵。
六、深入洞见与扩展思考
-
交叉熵的局限性 (下文有详细解释)
尽管交叉熵强大,它对标签噪声(label noise)较敏感,因为错误的 one-hot 标签会导致损失剧增。研究者常通过标签平滑(label smoothing)或加权交叉熵来缓解这一问题。 -
与生成模型的联系
在生成式大模型(如 GAN 或扩散模型)中,交叉熵的变体(如 focal loss 或 contrastive loss)也被广泛探索,显示其适用性远超传统分类。 -
从优化角度的启发 (下文有详细解释)
交叉熵的非凸性与深度网络的非线性结合,使得优化过程更倾向于找到“尖锐”的解(sharp minima),这可能解释了大模型的高泛化能力,而 MSE 倾向于平滑解,可能限制模型表达力。 -
未来的方向
随着自监督学习和多模态任务的兴起,交叉熵的变种(如 InfoNCE)正在成为研究热点。研究者可以进一步探索如何将交叉熵的优点扩展到非分类任务中。
七、总结
交叉熵损失因其数学优雅、优化高效以及与分类任务需求的契合,成为深度学习中的基石。与 MSE 相比,交叉熵更适合处理概率分布、提供动态梯度并捕捉分类的本质特性。大模型训练中选择交叉熵,不仅是实践上的惯例,更是理论与性能的必然结果。对于深度学习研究者而言,理解交叉熵的深层机制,不仅能优化模型设计,还能启发新的损失函数创新,推动领域的前沿发展。
大模型(如 LLaMA 或 Qwen)在训练时如何使用交叉熵损失
大模型(如 LLaMA 或 Qwen)在训练时如何使用交叉熵损失,特别是分类类别是否对应词表大小,真值标签是否是 one-hot 编码,以及交叉熵的具体运作过程。下面会从大模型(尤其是语言模型)的训练流程、词表设计、损失计算和优化细节等方面详细解答。
一、大模型训练中的分类任务与词表
在像 LLaMA 或 Qwen 这样的大型语言模型(LLM)中,训练目标通常是自回归语言建模(autoregressive language modeling),即给定前文 ( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,…,xt−1 ),预测下一个词 ( x t x_t xt )。这里的“分类任务”并不是传统意义上的固定类别分类(如猫狗分类),而是将预测下一个词视为一个多分类问题,类别数量等于词表大小(vocabulary size)。
1. 词表大小与分类类别
- 词表(Vocabulary):大模型通常使用分词器(tokenizer)将文本分解为 token(如 BPE 或 WordPiece),每个 token 在词表中有一个唯一的索引。词表大小 ( C C C ) 通常在几万到几十万之间。例如:
- LLaMA 的词表大小约为 32,000(具体取决于版本)。
- Qwen 的词表大小可能在 50,000 或更高(视具体实现)。
- 分类类别:在语言建模中,模型的输出层是一个全连接层,输出维度等于词表大小 ( C C C )。对于每个时间步 ( t t t ),模型预测下一个 token 的概率分布 ( y ^ \hat{y} y^ ),其维度为 ( [ C ] [C] [C] ),表示词表中每个 token 的概率。
2. 真值标签是否是 one-hot 编码?
- 理论上是 one-hot,但在实现中通常不是:
- 在数学定义上,真值标签 ( y = [ y 1 , y 2 , … , y C ] y = [y_1, y_2, \dots, y_C] y=[y1,y2,…,yC] ) 是 one-hot 形式,例如若正确 token 是词表中的第 ( k k k ) 个,则 ( y k = 1 y_k = 1 yk=1 ),其他 ( y i = 0 y_i = 0 yi=0 )。
- 但在实际训练中,为了计算效率,不会显式构造 one-hot 向量。框架如 PyTorch 或 TensorFlow 提供的高效实现(如
torch.nn.CrossEntropyLoss)接受真实标签为标量索引(即 ( k k k )),而不是完整的 one-hot 向量。这是由于 softmax 和交叉熵的结合可以直接基于索引计算损失,避免稀疏向量的存储和运算开销。
二、交叉熵损失在大模型中的具体运作
让我们详细拆解交叉熵在大模型训练中的过程,从输入到损失计算。
1. 模型架构与输出
- 输入:给定一个序列 ( x 1 , x 2 , … , x t − 1 x_1, x_2, \dots, x_{t-1} x1,x2,…,xt−1 )(token 索引),通过分词器转换为整数序列。
- Transformer 前向传播:输入序列经过 embedding 层、多个 Transformer 层,最终在每个时间步 ( t t t) 输出一个 logits 向量 ( z t = [ z t , 1 , z t , 2 , … , z t , C ] \mathbf{z}_t = [z_{t,1}, z_{t,2}, \dots, z_{t,C}] zt=[zt,1,zt,2,…,zt,C] )(维度为词表大小 ( C C C ))。
- Softmax 转换:logits 通过 softmax 函数转化为概率分布:
y ^ t , i = e z t , i ∑ j = 1 C e z t , j \hat{y}_{t,i} = \frac{e^{z_{t,i}}}{\sum_{j=1}^C e^{z_{t,j}}} y^t,i=∑j=1Cezt,jezt,i
其中 ( y ^ t , i \hat{y}_{t,i} y^t,i ) 是时间步 ( t t t ) 对词表中第 ( i i i ) 个 token 的预测概率。
2. 真值标签的准备
- 真实 token:对于时间步 ( t t t ),下一个真实的 token 是 ( x t x_t xt ),其在词表中的索引为 ( k k k )(例如 ( x t = k x_t = k xt=k ))。
- 标签形式:理论上,( y t = [ 0 , 0 , … , 1 , … , 0 ] y_t = [0, 0, \dots, 1, \dots, 0] yt=[0,0,…,1,…,0] )(第 ( k k k ) 位为 1),但实际中 ( y t y_t yt ) 直接用索引 ( k k k ) 表示。
3. 交叉熵损失计算
-
数学形式:
L C E = − ∑ i = 1 C y t , i log ( y ^ t , i ) L_{CE} = -\sum_{i=1}^C y_{t,i} \log(\hat{y}_{t,i}) LCE=−i=1∑Cyt,ilog(y^t,i)
因为 ( y t y_t yt ) 是 one-hot,只有 ( y t , k = 1 y_{t,k} = 1 yt,k=1 ),其他为 0,所以:
L C E = − log ( y ^ t , k ) L_{CE} = -\log(\hat{y}_{t,k}) LCE=−log(y^t,k)
即损失只依赖于正确 token 的预测概率 ( y ^ t , k \hat{y}_{t,k} y^t,k )。 -
实际实现:
- 在 PyTorch 中,
torch.nn.CrossEntropyLoss内部结合了 softmax 和交叉熵计算:- 输入:logits ( z t \mathbf{z}_t zt )(未经过 softmax)。
- 目标:索引 ( k k k )(而不是 one-hot 向量)。
- 计算:直接对 ( z t \mathbf{z}_t zt) 应用 softmax,然后取 ( − log ( y ^ t , k ) -\log(\hat{y}_{t,k}) −log(y^t,k) )。
- 这避免了显式计算整个 ( y ^ t \hat{y}_t y^t ) 向量,只需计算正确索引处的概率,大幅提升效率。
- 在 PyTorch 中,
4. 序列上的总损失
- 对于一个长度为 ( T T T ) 的序列,模型会预测每个位置的下一个 token,总损失是所有时间步损失的平均:
L = 1 T ∑ t = 1 T L C E , t = − 1 T ∑ t = 1 T log ( y ^ t , k t ) L = \frac{1}{T} \sum_{t=1}^T L_{CE,t} = -\frac{1}{T} \sum_{t=1}^T \log(\hat{y}_{t, k_t}) L=T1t=1∑TLCE,t=−T1t=1∑Tlog(y^t,kt)
其中 ( k t k_t kt ) 是时间步 ( t t t ) 的真实 token 索引。
三、大模型训练中的具体流程
以 LLaMA 或 Qwen 为例,训练时的交叉熵运作过程如下:
-
数据准备:
- 预处理大规模文本语料(如 Wikipedia、Books),用分词器(如 SentencePiece)生成 token 序列。
- 构造训练样本:输入序列 ( [ x 1 , x 2 , … , x T − 1 ] [x_1, x_2, \dots, x_{T-1}] [x1,x2,…,xT−1] ) 和目标序列 ( [ x 2 , x 3 , … , x T ] [x_2, x_3, \dots, x_T] [x2,x3,…,xT] )。
-
前向传播:
- 输入序列通过 embedding 层映射为向量。
- Transformer 层逐层处理,输出每个位置的 logits ( z t \mathbf{z}_t zt )。
-
损失计算:
- 对于每个时间步 ( t t t ),从 ( z t \mathbf{z}_t zt ) 计算 softmax 概率 ( y ^ t \hat{y}_t y^t )。
- 用真实 token 索引 ( k t k_t kt ) 计算 ( − log ( y ^ t , k t ) -\log(\hat{y}_{t, k_t}) −log(y^t,kt) )。
-
反向传播:
- 计算梯度 ( ∂ L C E ∂ z t , j = y ^ t , j − y t , j \frac{\partial L_{CE}}{\partial z_{t,j}} = \hat{y}_{t,j} - y_{t,j} ∂zt,j∂LCE=y^t,j−yt,j ):
- 若 ( j = k t j = k_t j=kt ):( y ^ t , j − 1 \hat{y}_{t,j} - 1 y^t,j−1 );
- 若 ( j ≠ k t j \neq k_t j=kt ):( y ^ t , j \hat{y}_{t,j} y^t,j )。
- 梯度通过 Transformer 层反向传播,更新参数。
- 计算梯度 ( ∂ L C E ∂ z t , j = y ^ t , j − y t , j \frac{\partial L_{CE}}{\partial z_{t,j}} = \hat{y}_{t,j} - y_{t,j} ∂zt,j∂LCE=y^t,j−yt,j ):
-
优化:
- 使用优化器(如 AdamW)根据梯度更新模型权重。
- 重复迭代,直到模型在验证集上的困惑度(perplexity,( e L C E e^{L_{CE}} eLCE ))收敛。
四、为什么用交叉熵?具体特性在大模型中的体现
-
词表大小的适配:
- 词表大小 ( C C C ) 很大(例如 50,000),交叉熵直接优化正确 token 的概率,计算高效且目标明确。
-
概率分布优化:
- 语言建模需要输出合理的概率分布,交叉熵通过 ( − log ( y ^ t , k ) -\log(\hat{y}_{t,k}) −log(y^t,k) ) 惩罚低概率预测,与最大似然估计一致。
-
梯度特性:
- 当 ( y ^ t , k → 0 \hat{y}_{t,k} \to 0 y^t,k→0 )(正确 token 概率很低),梯度 ( y ^ t , k − 1 ≈ − 1 \hat{y}_{t,k} - 1 \approx -1 y^t,k−1≈−1 ),推动模型快速修正。
- 对于错误 token,若 ( y ^ t , j → 1 \hat{y}_{t,j} \to 1 y^t,j→1 )(( j ≠ k j \neq k j=k )),梯度 ( y ^ t , j \hat{y}_{t,j} y^t,j ) 推动其概率下降。
-
避免 MSE 的问题:
- 若用 MSE,误差 ( y t , k − y ^ t , k y_{t,k} - \hat{y}_{t,k} yt,k−y^t,k ) 需要将离散标签视为连续值,且梯度在 ( y ^ → 0 \hat{y} \to 0 y^→0 ) 或 1 时消失,不适合大词表的分类。
五、实际细节与优化
-
批处理(Batching):
- 真实训练中,输入是批量序列(batch size × sequence length),损失在 batch 和序列维度上平均。
-
掩码(Masking):
- 对于填充(padding)部分,使用掩码忽略损失计算,确保只对有效 token 计算交叉熵。
-
数值稳定性:
- 直接计算 ( log ( softmax ( z ) ) \log(\text{softmax}(z)) log(softmax(z)) ) 可能因指数溢出而不稳定,框架通常使用 log-sum-exp 技巧:
log ( y ^ t , k ) = z t , k − log ( ∑ j = 1 C e z t , j ) \log(\hat{y}_{t,k}) = z_{t,k} - \log\left(\sum_{j=1}^C e^{z_{t,j}}\right) log(y^t,k)=zt,k−log(j=1∑Cezt,j)
- 直接计算 ( log ( softmax ( z ) ) \log(\text{softmax}(z)) log(softmax(z)) ) 可能因指数溢出而不稳定,框架通常使用 log-sum-exp 技巧:
总结
在大模型(如 LLaMA 或 Qwen)训练中,交叉熵损失的分类类别确实是词表大小 ( C C C )(几万到几十万),真值标签理论上是 one-hot,但在实现中用索引表示以提高效率。具体过程包括:输入序列经过 Transformer 输出 logits,经 softmax 转为概率,交叉熵计算正确 token 的对数损失,梯度驱动优化。交叉熵的高效性和概率优化特性使其成为语言建模的理想选择。
MSE(均方误差)的分布假设:误差符合高斯分布
我们将深入探讨 MSE(均方误差)的分布假设,特别是“误差符合高斯分布”的含义,以及这如何影响其在回归和分类任务中的适用性。这部分内容会从统计学和概率论的角度详细展开,适合高理论水平的深度学习研究者。
MSE 的分布假设:误差符合高斯分布
均方误差(Mean Squared Error, MSE)作为损失函数,其理论基础可以追溯到统计学中的最大似然估计(Maximum Likelihood Estimation, MLE)和最小二乘法。在使用 MSE 时,隐含的假设是模型预测值与真实值之间的误差服从高斯分布(正态分布)。以下是详细解释:
1. MSE 的定义
对于一个预测任务,假设真实值为 ( y y y ),模型预测值为 ( y ^ \hat{y} y^ ),MSE 定义为:
L M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 L_{MSE} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 LMSE=n1i=1∑n(yi−y^i)2
其中 ( n n n ) 是样本数。我们关心的是误差 ( ϵ i = y i − y ^ i \epsilon_i = y_i - \hat{y}_i ϵi=yi−y^i ),MSE 实际上是对误差平方的平均。
2. 误差符合高斯分布的含义
“误差符合高斯分布”指的是,假设误差 ( ϵ = y − y ^ \epsilon = y - \hat{y} ϵ=y−y^ ) 服从均值为 0、方差为 ( σ 2 \sigma^2 σ2 ) 的正态分布,即:
ϵ ∼ N ( 0 , σ 2 ) \epsilon \sim \mathcal{N}(0, \sigma^2) ϵ∼N(0,σ2)
或者写成概率密度形式:
p ( ϵ ) = 1 2 π σ 2 exp ( − ϵ 2 2 σ 2 ) p(\epsilon) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{\epsilon^2}{2\sigma^2}\right) p(ϵ)=2πσ21exp(−2σ2ϵ2)
在监督学习中,真实值 ( y y y ) 可以看作是由模型预测 ( y ^ \hat{y} y^ ) 加上一个随机噪声 ( ϵ \epsilon ϵ ) 得到的:
y = y ^ + ϵ y = \hat{y} + \epsilon y=y^+ϵ
如果 ( ϵ \epsilon ϵ ) 服从高斯分布,那么 ( y y y ) 给定 ( y ^ \hat{y} y^) 的条件概率为:
p ( y ∣ y ^ ) = 1 2 π σ 2 exp ( − ( y − y ^ ) 2 2 σ 2 ) p(y | \hat{y}) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(y - \hat{y})^2}{2\sigma^2}\right) p(y∣y^)=2πσ21exp(−2σ2(y−y^)2)
3. 从最大似然估计推导 MSE
假设我们有 ( n n n ) 个独立同分布的样本 ( { ( y 1 , y ^ 1 ) , ( y 2 , y ^ 2 ) , … , ( y n , y ^ n ) } (y_1, \hat{y}_1), (y_2, \hat{y}_2), \dots, (y_n, \hat{y}_n) \} (y1,y^1),(y2,y^2),…,(yn,y^n)} ),其联合概率(似然函数)为:
L = ∏ i = 1 n p ( y i ∣ y ^ i ) = ∏ i = 1 n 1 2 π σ 2 exp ( − ( y i − y ^ i ) 2 2 σ 2 ) L = \prod_{i=1}^n p(y_i | \hat{y}_i) = \prod_{i=1}^n \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(y_i - \hat{y}_i)^2}{2\sigma^2}\right) L=i=1∏np(yi∣y^i)=i=1∏n2πσ21exp(−2σ2(yi−y^i)2)
为了最大化似然,通常取对数(对数似然):
log L = ∑ i = 1 n log ( 1 2 π σ 2 exp ( − ( y i − y ^ i ) 2 2 σ 2 ) ) \log L = \sum_{i=1}^n \log \left( \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(y_i - \hat{y}_i)^2}{2\sigma^2}\right) \right) logL=i=1∑nlog(2πσ21exp(−2σ2(yi−y^i)2))
= ∑ i = 1 n [ − 1 2 log ( 2 π σ 2 ) − ( y i − y ^ i ) 2 2 σ 2 ] = \sum_{i=1}^n \left[ -\frac{1}{2} \log(2\pi\sigma^2) - \frac{(y_i - \hat{y}_i)^2}{2\sigma^2} \right] =i=1∑n[−21log(2πσ2)−2σ2(yi−y^i)2]
= − n 2 log ( 2 π σ 2 ) − 1 2 σ 2 ∑ i = 1 n ( y i − y ^ i ) 2 = -\frac{n}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (y_i - \hat{y}_i)^2 =−2nlog(2πσ2)−2σ21i=1∑n(yi−y^i)2
最大化 ( log L \log L logL ) 等价于最小化其负值。忽略常数项 ( − n 2 log ( 2 π σ 2 ) -\frac{n}{2} \log(2\pi\sigma^2) −2nlog(2πσ2) )(不影响优化),目标变为:
minimize 1 2 σ 2 ∑ i = 1 n ( y i − y ^ i ) 2 \text{minimize} \quad \frac{1}{2\sigma^2} \sum_{i=1}^n (y_i - \hat{y}_i)^2 minimize2σ21i=1∑n(yi−y^i)2
由于 ( σ 2 \sigma^2 σ2 ) 是固定的方差(不随参数变化),这等价于最小化:
∑ i = 1 n ( y i − y ^ i ) 2 \sum_{i=1}^n (y_i - \hat{y}_i)^2 i=1∑n(yi−y^i)2
这就是 MSE 的形式。因此,MSE 等价于在高斯噪声假设下进行最大似然估计。
4. “误差”指的是什么?
这里的“误差”具体是指真实值 ( y y y ) 与模型预测值 ( y ^ \hat{y} y^ ) 之间的差 ( ϵ = y − y ^ \epsilon = y - \hat{y} ϵ=y−y^ )。在回归任务中,( y y y ) 是连续值,( y ^ \hat{y} y^ ) 是模型试图逼近的连续预测值,误差 ( ϵ \epsilon ϵ ) 被假设为随机噪声,符合高斯分布。这种假设在许多实际问题中是合理的,例如测量数据中的噪声通常近似正态。
MSE 分布假设的适用性
1. 适用于连续值的回归问题
高斯分布是一个连续分布,其概率密度在实数轴上有定义。MSE 的高斯假设天然适用于回归任务,因为:
- 真实值 ( y y y ) 是连续的(如温度、房价)。
- 误差 ( ϵ \epsilon ϵ ) 可以取任意实数值,且高斯分布的对称性和“钟形曲线”特性与许多自然现象的噪声分布一致。
- MSE 通过平方惩罚大误差,鼓励模型预测值尽量接近真实值,形成平滑的拟合。
2. 在分类任务中的违背
分类任务的输出(类别标签)是离散的,例如二分类 ( y ∈ { 0 , 1 } y \in \{0, 1\} y∈{0,1} ) 或多分类 ( y ∈ { 1 , 2 , … , C } y \in \{1, 2, \dots, C\} y∈{1,2,…,C} )。即使模型输出 ( y ^ \hat{y} y^ ) 被设计为概率(通过 sigmoid 或 softmax),以下问题使得 MSE 的高斯假设不适用:
-
离散标签与连续误差冲突:
如果 ( y y y ) 是离散的(如 0 或 1),而 ( \hat{y} ) 是连续概率(如 0.7),误差 ( ϵ = y − y ^ \epsilon = y - \hat{y} ϵ=y−y^ ) 仍然是连续的,但其分布不再是简单的高斯分布。例如,当 ( y = 1 y = 1 y=1 ) 时,( ϵ = 1 − y ^ \epsilon = 1 - \hat{y} ϵ=1−y^ ) 的取值范围是 ( [0, 1] ),这与高斯分布的无界性(( − ∞ , ∞ -\infty, \infty −∞,∞ ))不符。 -
分布假设不匹配:
分类任务的真实标签更适合用 categorical 分布(多分类)或 Bernoulli 分布(二分类)建模,而非高斯分布。例如,二分类标签 ( y y y ) 的分布是:
p ( y ) = y ^ y ( 1 − y ^ ) 1 − y p(y) = \hat{y}^y (1 - \hat{y})^{1-y} p(y)=y^y(1−y^)1−y
这与交叉熵的假设一致,而非 MSE 的高斯假设。 -
误差的意义不同:
在分类中,我们关心的是“类别是否正确”,而不是“预测值与标签的数值差”。MSE 假设误差是对称的、连续的,但分类任务需要的是概率分布的优化,而不是数值逼近。
MSE 与交叉熵的分布假设对比
-
交叉熵的分布假设
- 交叉熵假定模型输出 ( y ^ \hat{y} y^ ) 是概率分布(通过 softmax 或 sigmoid),真实标签 ( y y y ) 服从 categorical 或 Bernoulli 分布。
- 优化目标是让 ( y ^ \hat{y} y^ ) 接近 ( y y y ) 的分布,度量的是 KL 散度(或等价的对数似然)。
- 适用于离散类别输出,直接优化概率特性。
-
MSE 的分布假设
- MSE 假定误差 ( y − y ^ y - \hat{y} y−y^ ) 服从高斯分布,适用于连续值的回归。
- 在分类中强行使用 MSE,会导致模型试图将离散标签(如 0 或 1)当作连续值逼近,违背了任务的本质。
示例分析
- 回归:预测房价 ( y = 100 y = 100 y=100 ) 万,( y ^ = 95 \hat{y} = 95 y^=95 ) 万,误差 ( ϵ = 5 \epsilon = 5 ϵ=5 ) 万,可能是高斯噪声,MSE 合理。
- 分类:预测类别 ( y = 1 y = 1 y=1 )(正类),( y ^ = 0.2 \hat{y} = 0.2 y^=0.2 )(概率),误差 ( ϵ = 0.8 \epsilon = 0.8 ϵ=0.8 ) 不应视为高斯噪声,而是概率分布的偏离,交叉熵更合适。
总结
MSE 的分布假设是误差 ( ϵ = y − y ^ \epsilon = y - \hat{y} ϵ=y−y^ ) 服从均值为 0 的高斯分布,这来源于最大似然估计下的高斯噪声模型。它适用于连续值的回归任务,因为误差的连续性和对称性与高斯分布吻合。然而,在分类任务中,真实标签是离散的,误差不再符合高斯分布,而是与概率分布(如 Bernoulli 或 categorical)的偏离相关,因此 MSE 的假设被违背,交叉熵成为更自然的选择。
交叉熵的局限性和缓解方法详解
我们将深入探讨交叉熵损失的局限性,特别是它对标签噪声的敏感性,以及如何通过标签平滑(label smoothing)来缓解这一问题。这部分内容会从理论、数学推导和实际应用的角度详细展开,适合高理论水平的深度学习研究者。
一、交叉熵的局限性
交叉熵损失(Cross-Entropy Loss)虽然在分类任务中表现出色,但在某些场景下存在局限性,尤其是对标签噪声的敏感性。以下是其主要局限性的详细分析:
1. 对标签噪声的敏感性
-
定义与机制:
交叉熵损失形式为:
L C E = − ∑ i = 1 C y i log ( y ^ i ) L_{CE} = -\sum_{i=1}^C y_i \log(\hat{y}_i) LCE=−i=1∑Cyilog(y^i)
对于 one-hot 编码的真值标签 ( y y y )(例如 ( y k = 1 y_k = 1 yk=1 ),其他 ( y i = 0 y_i = 0 yi=0 )),损失简化为:
L C E = − log ( y ^ k ) L_{CE} = -\log(\hat{y}_k) LCE=−log(y^k)
这里 ( y ^ k \hat{y}_k y^k ) 是模型对正确类别的预测概率。如果 ( y ^ k → 0 \hat{y}_k \to 0 y^k→0 )(模型对正确类别置信度极低),损失会趋于无穷大。 -
噪声的影响:
在真实数据中,标签可能存在噪声(label noise),即标注错误。例如,正确类别应为 ( k k k ),但标签错误地标记为 ( m ≠ k m \neq k m=k )。此时:- 交叉熵会强制模型优化 ( y ^ m \hat{y}_m y^m )(错误类别)的概率,使其接近 1。
- 如果 ( y ^ m → 0 \hat{y}_m \to 0 y^m→0 )(模型正确倾向于 ( k k k ) 而非 ( m m m )),损失 ( − log ( y ^ m ) -\log(\hat{y}_m) −log(y^m) ) 剧增,导致梯度过大,迫使模型偏离正确方向。
-
后果:
- 模型过度拟合噪声标签,泛化性能下降。
- 在大模型(如 LLM)中,噪声标签可能放大为系统性偏差,尤其在数据量大但质量参差不齐时。
2. 过度自信问题
- 交叉熵鼓励模型对正确类别输出接近 1 的概率(即 ( y ^ k → 1 \hat{y}_k \to 1 y^k→1 )),这可能导致模型过于自信(overconfidence)。
- 在多分类任务中,这种特性使得模型对次优类别(non-target classes)的概率分配过低(接近 0),丧失了不确定性表达能力。
- 对于噪声数据,这种过度自信会加剧错误标签的影响,因为模型无法“怀疑”标签的正确性。
3. 不适应不平衡数据
- 当数据集类别分布不平衡时,交叉熵对少数类样本的优化力度不足,因为损失主要由多数类主导。
- 噪声标签在少数类中若出现,其影响会被放大,导致模型对少数类的预测更加不可靠。
4. 对分布偏移的脆弱性
- 交叉熵假设训练和测试数据的标签分布一致。如果测试时分布发生偏移(例如新类别出现或噪声模式改变),模型可能因过度依赖训练标签而失效。
二、标签平滑(Label Smoothing)的详细解释
为了缓解交叉熵对标签噪声的敏感性和过度自信问题,研究者提出了标签平滑(label smoothing)技术。下面从定义、数学推导、作用机制和实际应用详细介绍。
1. 标签平滑的定义
-
传统 one-hot 标签:对于类别 ( k k k ),真值标签 ( y = [ 0 , 0 , … , 1 , … , 0 ] y = [0, 0, \dots, 1, \dots, 0] y=[0,0,…,1,…,0] )(第 ( k k k ) 位为 1)。
-
平滑后的标签:将 one-hot 标签替换为一个混合分布:
y i ′ = { 1 − α + α / C , if i = k α / C , if i ≠ k y_i' = \begin{cases} 1 - \alpha + \alpha / C, & \text{if } i = k \\ \alpha / C, & \text{if } i \neq k \end{cases} yi′={1−α+α/C,α/C,if i=kif i=k
其中:- ( α \alpha α ) 是平滑参数(通常为小正数,如 0.1)。
- ( C C C ) 是类别数(例如词表大小)。
- 平滑后的 ( y ′ y' y′ ) 仍然满足 ( ∑ i = 1 C y i ′ = 1 \sum_{i=1}^C y_i' = 1 ∑i=1Cyi′=1 )。
-
直觉:不再强制正确类别概率为 1,而是分配一部分概率(( α / C \alpha / C α/C ))给其他类别,形成一个均匀分布的“软标签”。
2. 平滑后的交叉熵损失
- 原始交叉熵:
L C E = − ∑ i = 1 C y i log ( y ^ i ) = − log ( y ^ k ) L_{CE} = -\sum_{i=1}^C y_i \log(\hat{y}_i) = -\log(\hat{y}_k) LCE=−i=1∑Cyilog(y^i)=−log(y^k) - 平滑后的交叉熵:
L C E ′ = − ∑ i = 1 C y i ′ log ( y ^ i ) L_{CE}' = -\sum_{i=1}^C y_i' \log(\hat{y}_i) LCE′=−i=1∑Cyi′log(y^i)
代入平滑标签:
L C E ′ = − ( 1 − α + α / C ) log ( y ^ k ) − ∑ i ≠ k ( α / C ) log ( y ^ i ) L_{CE}' = -(1 - \alpha + \alpha / C) \log(\hat{y}_k) - \sum_{i \neq k} (\alpha / C) \log(\hat{y}_i) LCE′=−(1−α+α/C)log(y^k)−i=k∑(α/C)log(y^i)
整理:
L C E ′ = − ( 1 − α ) log ( y ^ k ) − α C ∑ i = 1 C log ( y ^ i ) L_{CE}' = -(1 - \alpha) \log(\hat{y}_k) - \frac{\alpha}{C} \sum_{i=1}^C \log(\hat{y}_i) LCE′=−(1−α)log(y^k)−Cαi=1∑Clog(y^i)- 第一项:仍然优化正确类别的概率,但权重降低为 ( 1 − α 1 - \alpha 1−α )。
- 第二项:引入对所有类别概率的对数和,鼓励均匀性。
3. 数学推导与理论意义
-
与 KL 散度关系:
平滑后的损失可以看作原始交叉熵与均匀分布之间的正则化:
L C E ′ = ( 1 − α ) L C E + α H ( u , y ^ ) L_{CE}' = (1 - \alpha) L_{CE} + \alpha H(u, \hat{y}) LCE′=(1−α)LCE+αH(u,y^)
其中 ( H ( u , y ^ ) = − 1 C ∑ i = 1 C log ( y ^ i ) H(u, \hat{y}) = -\frac{1}{C} \sum_{i=1}^C \log(\hat{y}_i) H(u,y^)=−C1∑i=1Clog(y^i) ) 是预测分布 ( y ^ \hat{y} y^ ) 与均匀分布 ( u = [ 1 / C , 1 / C , … , 1 / C ] u = [1/C, 1/C, \dots, 1/C] u=[1/C,1/C,…,1/C] ) 的交叉熵。这等价于在原始损失上加了一个 KL 散度正则项:
L C E ′ = L C E + α 1 − α D K L ( u ∣ ∣ y ^ ) L_{CE}' = L_{CE} + \frac{\alpha}{1 - \alpha} D_{KL}(u || \hat{y}) LCE′=LCE+1−ααDKL(u∣∣y^)
(推导略,基于 ( H ( p , q ) = H ( p ) + D K L ( p ∣ ∣ q ) H(p, q) = H(p) + D_{KL}(p || q) H(p,q)=H(p)+DKL(p∣∣q) ))。 -
正则化效应:
- 标签平滑通过 ( D K L ( u ∣ ∣ y ^ ) D_{KL}(u || \hat{y}) DKL(u∣∣y^) ) 惩罚过于尖锐的预测分布(即 ( y ^ i → 0 \hat{y}_i \to 0 y^i→0 ) 或 1),鼓励模型输出更平滑的概率分布。
- 这减少了模型对噪声标签的过度拟合。
4. 梯度分析
-
原始交叉熵梯度:
∂ L C E ∂ z j = y ^ j − y j \frac{\partial L_{CE}}{\partial z_j} = \hat{y}_j - y_j ∂zj∂LCE=y^j−yj
若 ( j = k j = k j=k ),( y ^ k − 1 \hat{y}_k - 1 y^k−1 );若 ( j ≠ k j \neq k j=k ),( y ^ j \hat{y}_j y^j )。 -
平滑后梯度:
∂ L C E ′ ∂ z j = y ^ j − y j ′ \frac{\partial L_{CE}'}{\partial z_j} = \hat{y}_j - y_j' ∂zj∂LCE′=y^j−yj′- 若 ( j = k j = k j=k ):( y ^ k − ( 1 − α + α / C ) \hat{y}_k - (1 - \alpha + \alpha / C) y^k−(1−α+α/C) )
- 若 ( j ≠ k j \neq k j=k ):( y ^ j − α / C \hat{y}_j - \alpha / C y^j−α/C )
当 ( y ^ k → 0 \hat{y}_k \to 0 y^k→0 ) 时,梯度幅度减小(从 -1 变为 ( − ( 1 − α + α / C ) -(1 - \alpha + \alpha / C) −(1−α+α/C) )),对错误标签的惩罚不再无限大。
5. 作用机制
- 缓解噪声敏感性:
- 噪声标签(如错误标记为 ( m m m ) 而非 ( k k k ))不再要求 ( y ^ m = 1 \hat{y}_m = 1 y^m=1 ),而是 ( y ^ m = 1 − α + α / C \hat{y}_m = 1 - \alpha + \alpha / C y^m=1−α+α/C ),降低了损失剧增的风险。
- 模型对错误标签的过度优化被抑制。
- 减少过度自信:
- 正确类别目标从 1 变为 ( 1 − α + α / C < 1 1 - \alpha + \alpha / C < 1 1−α+α/C<1 ),其他类别从 0 变为 ( α / C > 0 \alpha / C > 0 α/C>0 ),模型不会将概率推向极端。
- 提升泛化性:
- 平滑分布更接近真实世界的不确定性(例如语言中的歧义),提高模型在测试集上的鲁棒性。
6. 实际应用
- 大模型中的使用:
- 在训练 LLaMA、BERT 或 Qwen 等模型时,标签平滑常用于语言建模或预训练任务。例如,( α = 0.1 \alpha = 0.1 α=0.1 ) 是一个常见选择。
- 在词表大小很大的场景(如 ( C = 50 , 000 C = 50,000 C=50,000 )),平滑后的 ( α / C \alpha / C α/C ) 很小,但仍有效正则化。
- 实现:
- PyTorch 中,
torch.nn.CrossEntropyLoss支持label_smoothing参数,直接应用平滑。
- PyTorch 中,
三、其他缓解方法:加权交叉熵
- 定义:
对不同类别或样本赋予权重,调整损失贡献:
L W C E = − ∑ i = 1 C w i y i log ( y ^ i ) L_{WCE} = -\sum_{i=1}^C w_i y_i \log(\hat{y}_i) LWCE=−i=1∑Cwiyilog(y^i) - 作用:
- 对噪声样本降低权重(如基于置信度估计)。
- 对少数类提高权重,缓解不平衡问题。
- 局限:
- 需要额外的噪声检测或权重设计,复杂度高于标签平滑。
四、总结与洞见
交叉熵的局限性主要体现在对标签噪声的敏感性(损失剧增)和过度自信倾向,这在大模型训练中可能导致过拟合和泛化能力下降。标签平滑通过将 one-hot 标签替换为软标签(混合均匀分布),降低了噪声影响、抑制了过度自信,并通过正则化提升了模型鲁棒性。其数学本质是在交叉熵上加了一个均匀分布的 KL 散度惩罚,梯度调整也更温和。实际中,标签平滑已成为大模型训练的标准技术,而加权交叉熵等方法则提供了进一步的灵活性。
交叉熵损失: 优化过程倾向于找到“尖锐”的解(sharp minima)
我们将深入探讨第 3 点“从优化角度的启发”,特别是关于交叉熵损失的非凸性如何与深度网络的非线性结合,导致优化过程倾向于找到“尖锐”的解(sharp minima),以及这如何可能解释大模型的高泛化能力。同时,我们会对比 MSE(均方误差)倾向于平滑解的特性,并解释其对模型表达力的限制。这部分内容会从优化理论、损失函数的几何特性以及深度学习的实际现象出发,尽可能严谨且深入。
一、优化问题的背景
在深度学习中,训练目标是通过优化损失函数 ( $L(\theta) )(其中 ( θ \theta θ ) 是模型参数)来调整网络权重。损失函数通常定义在高维参数空间中,其形状(即曲面特性)决定了优化的轨迹和最终解的性质。关键概念包括:
- 凸性与非凸性:凸函数有全局最优解,而非凸函数可能有多个局部最优解。
- 解的“尖锐性”与“平坦性”:在参数空间中,局部极小值(minima)的曲率决定了其“尖锐”(高曲率)或“平坦”(低曲率)。这与 Hessian 矩阵(二阶导数)的特征值相关。
二、交叉熵的非凸性与深度网络的非线性
1. 交叉熵损失的非凸性
交叉熵损失形式为:
L C E = − ∑ i = 1 C y i log ( y ^ i ) L_{CE} = -\sum_{i=1}^C y_i \log(\hat{y}_i) LCE=−i=1∑Cyilog(y^i)
其中 ( y ^ i = softmax ( z i ) \hat{y}_i = \text{softmax}(z_i) y^i=softmax(zi) ),( z i z_i zi ) 是模型输出的 logits,依赖于参数 ( θ \theta θ )。
- 非线性来源:
- Softmax 函数 ( y ^ i = e z i ∑ j = 1 C e z j \hat{y}_i = \frac{e^{z_i}}{\sum_{j=1}^C e^{z_j}} y^i=∑j=1Cezjezi ) 是非线性的,引入指数运算和归一化。
- 在深度网络中,( z i = f ( θ , x ) z_i = f(\theta, x) zi=f(θ,x) ) 是输入 ( x x x ) 通过多层非线性变换(如 ReLU、sigmoid)得到的复杂函数。
- 非凸性证明:
- 对于线性模型(如逻辑回归),交叉熵可能是凸的(对 ( \theta ) 而言)。
- 但在深度网络中,多层非线性激活函数和参数交互使得 ( L C E ( θ ) L_{CE}(\theta) LCE(θ) ) 成为高度非凸的函数。Hessian 矩阵的特征值既有正值也有负值,表明存在鞍点和多个局部极小值。
2. 深度网络的非线性放大
- 深度网络(如 Transformer 或 CNN)的每一层都引入非线性变换(如 ReLU、LayerNorm),叠加后形成复杂的参数-损失映射。
- 这种非线性与交叉熵的非凸性结合,导致损失表面充满“山峰”和“山谷”,优化过程容易陷入局部极小值或鞍点。
3. 优化倾向于“尖锐”的解
- 尖锐解的定义:
- 在参数空间中,尖锐的局部极小值(sharp minima)周围的损失函数变化很快,即二阶导数(Hessian 的特征值)较大。
- 几何上,这对应于狭窄的“山谷”,参数稍有扰动,损失就显著增加。
- 为什么交叉熵倾向于尖锐解:
- 交叉熵的目标是让正确类别的 ( y ^ k → 1 \hat{y}_k \to 1 y^k→1 ),其他 ( y ^ i → 0 \hat{y}_i \to 0 y^i→0 )。这要求 logits ( z k z_k zk ) 远大于其他 ( z i z_i zi )(由于 softmax 的指数特性)。
- 这种极端化的优化目标使得参数 ( θ \theta θ ) 被推向损失表面上曲率较高的区域。
- 梯度下降(尤其是随机梯度下降 SGD)在非凸表面上的动态(如高噪声、步幅选择)进一步偏向尖锐解,因为平坦区域的梯度较小,难以驱动参数快速收敛。
三、尖锐解与大模型的高泛化能力
1. 传统观点的挑战
- 传统优化理论认为,平坦的局部极小值(flat minima)更可取,因为它们对参数扰动不敏感,泛化性能更好(损失在测试集上仍较低)。
- 然而,大模型(如 LLaMA、GPT)的研究表明,尖锐解并不一定导致差的泛化,反而可能与高性能相关。
2. 尖锐解的可能优势
- 高维参数空间的特性:
- 在高维空间中(大模型参数量动辄亿级),局部极小值的数量呈指数增长。即使是尖锐解,也可能分布在损失表面的大量“优质”区域。
- 尖锐解可能对应于特定的数据模式,捕捉训练数据的复杂结构。
- 正则化效应:
- 交叉熵的非凸优化结合 SGD 的随机性,隐式起到正则化作用,避免模型陷入过于简单的平坦解。
- 尖锐解可能迫使模型学习更具区分性的特征,而非过度平滑的泛化。
- 理论支持:
- Keskar 等(2017)提出“大批量训练倾向于尖锐解,小批量倾向于平坦解”,但后续研究(如 Dinh 等,2017)表明尖锐性与泛化关系的复杂性。
- 对于大模型,尖锐解可能与“过参数化”相关:模型容量足够大,即使解尖锐,也能覆盖测试分布。
3. 大模型的泛化能力解释
- 现象:LLaMA、Qwen 等模型在海量数据上训练,尽管损失表面非凸且解尖锐,却表现出惊人的零样本(zero-shot)和少样本(few-shot)性能。
- 解释:
- 尖锐解可能对应于数据的高阶模式(如语言的语法、语义),而不是简单的线性规律。
- 非凸性与深度网络的非线性使得模型探索了更丰富的解空间,尖锐解可能是这些解中的“幸运一击”。
四、MSE 的平滑解特性及其限制
1. MSE 的性质
MSE 定义为:
L M S E = 1 C ∑ i = 1 C ( y i − y ^ i ) 2 L_{MSE} = \frac{1}{C} \sum_{i=1}^C (y_i - \hat{y}_i)^2 LMSE=C1i=1∑C(yi−y^i)2
- 凸性:
- 对于线性模型,MSE 是凸函数。
- 在深度网络中,MSE 仍是非凸的,但其二次形式使其损失表面比交叉熵更平滑。
- 平滑性:
- MSE 的二阶导数(Hessian)是常数或随 ( y ^ \hat{y} y^ ) 平滑变化的函数,曲率变化较缓。
2. 倾向于平坦解
- 优化目标:
- MSE 最小化预测值与目标值的欧几里得距离,鼓励 ( y ^ i \hat{y}_i y^i ) 接近 ( y i y_i yi ),但不强制极端值(如 0 或 1)。
- 在分类中,( y ^ \hat{y} y^ ) 可能是概率,MSE 会拉平输出分布(例如预测值趋向均值)。
- 几何特性:
- MSE 的损失表面更“圆润”,局部极小值周围的曲率较低,形成平坦的“盆地”。
- 梯度 ( ∂ L M S E ∂ z j = ( y ^ j − y j ) ⋅ y ^ j ( 1 − y ^ j ) \frac{\partial L_{MSE}}{\partial z_j} = (\hat{y}_j - y_j) \cdot \hat{y}_j (1 - \hat{y}_j) ∂zj∂LMSE=(y^j−yj)⋅y^j(1−y^j) ) 在边界(( y ^ j → 0 \hat{y}_j \to 0 y^j→0 ) 或 1)时趋于 0,推动优化停留在平坦区域。
3. 对表达力的限制
- 平滑解的局限:
- 平坦解倾向于生成平滑的预测分布,难以捕捉分类任务中的清晰决策边界。
- 在语言建模中,MSE 会使概率分布过于均匀,无法突出正确 token 的尖锐概率。
- 表达力不足:
- MSE 的二次惩罚假设误差符合高斯分布,限制了模型学习复杂、非线性的数据模式。
- 对于大模型,平滑解可能导致“欠拟合”高维特征,降低其捕捉复杂语义的能力。
五、交叉熵与 MSE 的对比洞见
-
损失表面对比:
- 交叉熵:非凸、崎岖,充满尖锐极小值,鼓励极端概率。
- MSE:非凸但平滑,倾向于平坦极小值,鼓励均匀分布。
-
优化动态:
- 交叉熵结合 SGD 在非凸表面上跳跃,易陷入尖锐解。
- MSE 的梯度较温和,收敛到平坦区域。
-
泛化与表达力:
- 交叉熵的尖锐解可能挖掘数据深层结构,适合大模型的高泛化需求。
- MSE 的平滑解限制了模型的区分能力,可能更适合简单回归任务。
六、总结
交叉熵的非凸性与深度网络的非线性结合,使优化倾向于尖锐解。这种特性可能解释了大模型的高泛化能力,因为尖锐解在高维空间中捕捉了复杂的训练数据模式,而非凸优化提供了探索丰富解空间的机会。相比之下,MSE 的平滑解特性使其更适合连续回归,但限制了其在分类任务中的表达力,尤其对于需要清晰区分的大模型而言。尖锐解与平坦解的争论仍在继续,但交叉熵的成功无疑为优化理论提供了新的启发。
后记
2025年3月21日21点33分于上海,在Grok 3大模型辅助下完成。
相关文章:
大模型训练为什么选择交叉熵损失(Cross-Entropy Loss):均方误差(MSE)和交叉熵损失的深入对比
交叉熵损失:深度学习中的基石与洞见 交叉熵损失(Cross-Entropy Loss)是现代深度学习中分类任务的核心损失函数,尤其在训练大规模模型(如 transformers 等大型语言模型 LLM)时,几乎无处不在。对…...
C++|GLog开源库的使用 如何实现自定义类型消息日志
参考: C glog使用教程与代码演示 C第三方日志库Glog的安装与使用超详解 GLOG从入门到入门 glog 设置日志级别_glog C版本代码分析 文章目录 日志等级自定义消息创建使用宏定义 日志等级 在 glog 中,日志的严重性是通过 LogSeverity 来区分的,…...
cursor常用快捷键(JetBrains Darcula主题风格)
一、基础操作速查 打开/创建项目 打开项目:Ctrl Shift O(选择文件夹)新建文件:Ctrl N保存文件:Ctrl S关闭当前标签页:Ctrl F4 代码编辑 复制当前行:Ctrl D删除当前行:Ctrl …...
区块链学习总结
Hardhat 是一个用于 Ethereum 智能合约开发 的开发环境,专为 Solidity 语言编写的智能合约提供工具支持。它能够帮助开发者 编译、部署、测试和调试 智能合约,并提供一个本地的以太坊测试网络。 Hardhat 的核心功能 本地开发网络(Hardhat Ne…...
《深入剖析鸿蒙生态原生应用:一次开发多端部署的技术革新》
在数字化时代飞速发展的浪潮中,鸿蒙生态以其独特的技术理念和强大的创新能力,为开发者和用户带来了全新的体验。其中,“一次开发多端部署”作为鸿蒙生态原生应用开发的核心技术之一,不仅是技术上的重大突破,更是对未来…...
知识蒸馏:让大模型“瘦身“而不失智慧的魔术
引言:当AI模型需要"减肥" 在人工智能领域,一个有趣的悖论正在上演:大模型的参数规模每年以10倍速度增长,而移动设备的算力却始终受限。GPT-4的1750亿参数需要价值500万美元的GPU集群运行,但现实中的智能设备…...
JavaScript 获取 URL 中参数值的详解
JavaScript 获取 URL 中参数值的详解 1. 了解 URL 参数2. 使用 URLSearchParams 获取参数值2.1 什么是 URLSearchParams?2.2 示例代码2.3 优缺点 3. 使用正则表达式获取参数值3.1 示例代码3.2 分析 4. 自定义解析函数4.1 示例代码4.2 分析 5. 小结与注意事项 在开发…...
识别并脱敏上传到deepseek/chatgpt的文本文件中的身份证/手机号
本文将介绍一种简单高效的方法解决用户在上传文件到DeepSeek、ChatGPT,文心一言,AI等大语言模型平台过程中的身份证号以及手机号等敏感数据识别和脱敏问题。 DeepSeek、ChatGPT,Qwen,Claude等AI平台工具快速的被接受和使用,用户每天上传的文本数据中潜藏着大量敏感信息,…...
ruoyi-vue部署4
1.jdk-linux安装 2.tomcat-linux安装 3.ruoy后台部署 4.nginx-linux安装5.ruoyi前端部署...
【秣厉科技】LabVIEW工具包——OpenCV 教程(12):机器学习
文章目录 前言机器学习例1:支持向量机(SVM)做平面向量二分类例2: K邻近算法(KNearest)实现分类 总结 前言 需要下载安装OpenCV工具包的朋友,请前往 此处 ;系统要求:Wind…...
分布式事务解决方案简介
一、分布式事务的挑战 在分布式系统中,多个服务协同完成一个业务操作时,可能会遇到数据一致性问题。传统单体应用的ACID事务无法直接扩展到分布式环境,主要矛盾在于: • 网络不可靠:服务间通信可能失败。 • 并发冲突…...
【leetcode hot 100 17】电话号码的字母组合
分析:当设计关键字“所有组合”时,要考虑深度优先遍历、广度优先遍历(层次遍历),其中: 深度优先搜索: 自顶向下的递归实现深搜定义子问题在当前递归层结合子问题结果解决原问题 广度优先搜索 利…...
UI数据处理新隐私保护:确保用户新信息安全
hello宝子们...我们是艾斯视觉擅长ui设计和前端数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! 在这个数字时代,我们的个人信息似乎无处不在。从社交媒体上的点滴分享,到在线…...
【Javascrip】Javascript练习01 REST API using Express.js.
针对该问题的项目路径 要求部分 what you need to doReview the tasks provided in the section below.Obtain the boilerplate code.Use your local development environment to implement a solution.Upload your solution for marking via Gradescope. There is no attempt…...
分析K8S中Node状态为`NotReady`问题
在Kubernetes(k8s)集群中,Node状态为NotReady通常意味着节点上存在某些问题,下面为你分析正常情况下节点应运行的容器以及解决NotReady状态的方法。 正常情况下Node节点应运行的容器 1. kubelet kubelet是节点上的核心组件&…...
小样本学习综述
小样本学习综述 📕[1]潘雪玲,李国和,郑艺峰. 面向深度网络的小样本学习综述 [J]. 计算机应用研究, 2023, 40 (10): 2881-28882895. DOI:10.19734/j.issn.1001-3695.2023.02.0074. 主要是该论文的一些摘要。 小样本学习旨在利用较少目标数据训练模型快速学习的。 …...
挂谷问题与挂谷猜想:从平面转针到高维拓扑
挂谷问题与挂谷猜想:从平面转针到高维拓扑 目录 挂谷问题的起源数学定义与基本性质研究进展挂谷集合与挂谷猜想王虹与Joshua Zahl的突破意义与影响 挂谷问题的起源 1917年,日本数学家挂谷宗一(かけや そういち Soichi Kakeya,1886-1947)提…...
火语言RPA--表格数据导出
表格数据导出 🚩【组件功能】:导出表格内数据到指定的文件 配置预览 配置说明 导出格式 Excel:导出Excel文档格式,CSV:导出CSV数据格式。 导出文件夹 支持T或# 导出文件需要保存的文件夹路径。 导出文件名支持T或# 导出文…...
数学建模:MATLAB卷积神经网络
一、简述 卷积神经网络是一种处理具有网格结构数据的深度学习模型,由输入层、卷积层、池化层、全连接层、输出层组成。 输出层:将图像转换为其对应的由像素值构成的二维矩阵,并存储二维矩阵 卷积层:提取图像的底层特征…...
Vue3 基础语法指南:响应式系统与 Ref 应用
1、Reactive 的深度响应式 1.1、基本用法 vue <script setup> import { reactive } from vueconst state reactive({count: 0,user: {name: Alice,age: 30} })const increment () > state.count const updateName () > state.user.name Bob </script>1…...
学习笔记:黑马程序员JavaWeb开发教程(2025.3.21)
10.10 案例-员工管理-删除员工 前端中有两个删除按键,一个是删除员工,一个是批量删除,我们只需要将删除员工作为特殊的批量删除,就是只删除一个,开发一个接口就行 用id in ()来批量删除&…...
xLua_003 Lua访问C#
1、new C# 对象(创建游戏物体) LuaCallCSharp.cs using UnityEngine; using XLua;public class LuaCallCSharp : MonoBehaviour {public LuaEnv env null;void Start(){LuaEnv env new LuaEnv();env.DoString("requireLuaCallCSharp");}pr…...
mysql 磐维(opengauss)tidb误删数据之高级恢复
Mysql参考: Mysql 8.0 XtraBackupMysqlbinlog 完全恢复 - 墨天轮 Mysql 8.0 XtraBackupMysqlbinlog 完全恢复[TOC]# 一、安装mysql 8.0.19## 1.1https://www.modb.pro/db/509223MySQL 的全量备份、增量备份与 Binlog 时间点恢复_mysqlbinlog自动备份吗-CSDN博客文章…...
区块链技术在供应链管理中的应用与创新
在当今全球化的商业环境中,供应链管理的复杂性与日俱增。从原材料采购到最终产品交付,涉及众多环节和参与者,信息的透明度、准确性和安全性至关重要。区块链技术的出现,为供应链管理带来了全新的解决方案,正在逐步改变…...
字符指针的三道例题+算法改进
目录 一.杨氏矩阵 1.初级 2.想把下标带回来 二.字符串左旋 算法改进 三.判断是否为字符串旋转结果 算法改进 四. 3个字符函数 1.strcat 2.strncat 3.strstr 一.杨氏矩阵 数字矩阵,每行从左到右递增,每列从上到下递增,编写程序在矩…...
PostgreSQL用SQL实现俄罗斯方块
📢📢📢📣📣📣 作者:IT邦德 中国DBA联盟(ACDU)成员,10余年DBA工作经验 Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主,全网粉丝10万 擅长主流Oracle、MySQL、PG、高斯…...
如何构建简单有效的AI Agents代理?
工程技术 在过去的一年里,我们与数十个跨行业的团队合作,构建基于大型语言模型(LLM)的代理。我们发现,最成功的实现并不是使用复杂的框架或专门的库,而是采用简单、可组合的模式。 在本文中,我…...
【虚幻引擎UE5】SpawnActor生成Character实例不执行AI Move To,未初始化AIController的原因和解决方法
虚幻引擎版本:5.5.4 问题描述 刚创建的Third Person项目里,定义一个BP_Enemy蓝图,拖拽到场景中产生的实例会追随玩家,但SpawnActor产生的实例会固定不动。BP_Enemy蓝图具体设计如下: BP_Enemy的Event Graph 又定义…...
查看GPU型号、大小;CPU型号、个数、核数、内存
GPU型号、大小 nvidia-smiCPU型号 cat /proc/cpuinfo | grep model name | uniqCPU个数 cat /proc/cpuinfo | grep "physical id" | uniq | wc -lCPU核数 cat /proc/cpuinfo | grep "cpu cores" | uniqCPU内存 cat /proc/meminfo | grep MemTotal参考…...
xcode中移除安装的package dependency
有的依赖包安装之后,没有用,所以就需要把这个依赖项去掉,找了好久没有找到在哪里,最后发现在项目详情里面: 选中这一项,然后删除就可以了...
