Fisher——Fisher投影分类
Fisher投影加速分类
线性判别分析(LDA)详解:从理论到代码实现
一、引言
线性判别分析(Linear Discriminant Analysis, LDA)是一种经典的监督学习降维与分类方法。其核心思想是通过最大化类间散度、最小化类内散度,将高维数据投影到低维空间并实现高效分类。本文以鸢尾花数据集为例,详解LDA的数学原理,并逐行解析代码实现。
二、LDA的核心原理
1. 核心目标
LDA的目标是找到一个最优投影方向 $ W $,使得以下准则成立:
- 类内散度最小:同一类样本在投影后的空间尽可能紧凑。
- 类间散度最大:不同类样本在投影后的空间尽可能分离。
数学表达式为:
$$
J(W) = \frac{W^T S_B W}{W^T S_W W}
$$
其中:
- $ S_W $:类内散度矩阵(Within-class Scatter Matrix)
- $ S_B $:类间散度矩阵(Between-class Scatter Matrix)
2. 数学推导
(1) 类均值向量
对每个类别 $ c $,计算其特征均值向量 $ \mu_c $:
$$
\mu_c = \frac{1}{N_c} \sum_{x \in C_c} x
$$
其中 $ N_c $ 是类别 $ c $ 的样本数。
(2) 类内散度矩阵 $ S_W $
衡量同一类样本的离散程度:
$$
S_W = \sum_{c} \sum_{x \in C_c} (x - \mu_c)(x - \mu_c)^T
$$
(3) 类间散度矩阵 $ S_B $
衡量不同类中心之间的距离:
$$
S_B = \sum_{c} N_c (\mu_c - \mu)(\mu_c - \mu)^T
$$
其中 $ \mu $ 是全局均值向量。
(4) 投影矩阵计算
通过广义特征值问题求解最优投影方向:
$$
S_W^{-1} S_B w = \lambda w
$$
选择最大特征值对应的特征向量 $ w $ 作为投影方向。
三、代码详解与实现
1. 数据准备与预处理
1 | from sklearn.datasets import load_iris |
- 功能:加载鸢尾花数据集并筛选前两类,构造二分类任务。
- 原理:LDA适用于多分类,但二分类更简单直观。
2. 类均值计算
1 | def class_means(X, y): |
- 功能:计算每个类别的均值向量。
- 数学对应:实现公式 $ \mu_c $ 的计算。
3. 类内散度矩阵
1 | def within_class_scatter(X, y): |
- 功能:计算类内散度矩阵 $ S_W $。
- 关键点:通过外积累加每个样本与类均值的偏差。
4. 类间散度矩阵
1 | def between_class_scatter(X, y): |
- 功能:计算类间散度矩阵 $ S_B $。
- 局限性:仅适用于二分类,多分类需扩展。
5. 投影矩阵求解
1 | def projection_matrix(X, y): |
- 功能:求解广义特征值问题,取最大特征值对应的特征向量作为投影方向。
- 数学对应:求解 $ S_W^{-1} S_B $ 的最大特征值对应的特征向量。
6. 数据投影与分类
1 | W = projection_matrix(X_train, y_train) |
- 决策规则:使用两类均值的中点作为阈值进行分类。
四、实验结果与可视化
1. 可视化设计
1 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) |
- 图形说明:
- 横轴:投影后的特征值;纵轴:辅助可视化(固定为0)。
- 绿色虚线:分类阈值。
- 颜色区分类别,散点分布反映分离效果。
2. 输出示例
1 | Train Accuracy: 0.9875 |
- 结果分析:在鸢尾花二分类任务中,LDA通常能达到95%以上的准确率。
四、关键问题与改进建议
1. 类间散度矩阵的改进
- 当前实现:仅使用两类均值差的外积,未考虑样本数比例。
- 改进建议:
对于二分类问题,正确的 $ S_B $ 应为:
$$
S_B = \frac{n_1 n_2}{n} (\mu_1 - \mu_2)(\mu_1 - \mu_2)^T
$$
2. 特征值问题的稳定性
- 潜在问题:若 $ S_W $ 奇异(如高维小样本),求逆会导致数值不稳定。
- 改进建议:添加正则化项 $ \lambda I $ 到 $ S_W $。
3. 阈值选择的优化
- 当前方法:固定阈值为均值中点。
- 改进建议:通过交叉验证选择最优阈值。
五、总结
- 核心贡献:本文从数学公式到代码实现,完整解析了LDA的原理与应用。
- 代码特点:基于NumPy实现,无需依赖 sklearn.discriminant_analysis.LinearDiscriminantAnalysis,适合理解底层机制。
- 扩展方向:可推广到多分类任务,结合非线性核函数处理复杂分布。
完整的工程请见
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 Katarina's diary!




