支持向量机(SVM)原理详解与代码解析

一、SVM算法原理详解

1. 核心思想

支持向量机通过最大化分类间隔实现最优分类,其数学本质是求解一个凸二次优化问题。核心思想包括:

  • 最大间隔原则:寻找使类别间距离最大的分类超平面
  • 支持向量:决定分类边界的关键样本点
  • 核技巧:通过核函数将低维不可分数据映射到高维空间

2. 数学基础

(1) 线性可分情况

$$
\min_{w,b} \frac{1}{2}||w||^2 \quad \text{s.t.} \quad y_i(w^T x_i + b) \geq 1
$$

  • $ w $:超平面法向量
  • $ b $:偏置项
  • $ y_i \in {-1,1} $:类别标签

(2) 非线性情况(使用核函数)

$$
K(x_i,x_j) = \phi(x_i)^T\phi(x_j)
$$

常用核函数:

  • 线性核:$ K(x,y) = x^Ty $
  • 多项式核:$ K(x,y) = (x^Ty + c)^d $
  • RBF核:$ K(x,y) = e^{-\gamma ||x-y||^2} $

(3) 正则化参数C

$$
\min_{w,b,\xi} \frac{1}{2}||w||^2 + C\sum_{i=1}^n \xi_i
$$

  • $ C $:控制误分类惩罚强度
  • $ \xi_i $:松弛变量,允许一定程度的误分类

二、代码逐句解析

1. 数据预处理阶段

1
2
3
def Iris_label(s):
it = {b'Iris-setosa':0, b'Iris-versicolor':1, b'Iris-virginica':2}
return it[s]
  • 功能:将原始文本标签转换为数字编码
  • 原理:机器学习模型需要数值输入,将类别标签转换为0-1编码
  • 改进建议:使用LabelEncoder更通用的标签编码方式
1
2
3
df = pd.read_csv('iris.txt', header=None)
df[4] = df[4].map({'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2})
data = df.values
  • 功能:读取数据并转换标签
  • 关键点:使用pandas进行数据清洗,替代np.loadtxt更灵活
  • 注意:原始数据集包含5列(4个特征+1个标签),索引0-4,因此不存在列越界问题
1
2
3
x, y = np.split(data, (4,), axis=1)
x = x[:, :2]
y = y.ravel()
  • 功能:特征与标签分离
  • 原理:np.split按列分割数据,x[:, :2]选择前两个特征用于可视化
  • 注意:ravel()确保标签为1D数组,符合scikit-learn输入要求

2. 数据集划分

1
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1)
  • 功能:按7:3划分训练集和测试集
  • 原理:train_test_split内部执行数据打乱和分割
  • 参数说明:
    • test_size=0.3:测试集占比30%
    • random_state=1:确保结果可复现

3. SVM模型构建

1
model = svm.SVC(C=2, kernel='rbf', gamma=10, decision_function_shape='ovo')
  • 参数详解:
    • C=2:正则化参数,值越大对误分类的惩罚越强
    • kernel=’rbf’:使用径向基函数(RBF)作为核函数
    • gamma=10:核函数系数,值越大模型复杂度越高
    • decision_function_shape=’ovo’:采用一对一的多分类策略

4. 模型训练

1
model.fit(x_train, y_train.ravel())
  • 功能:训练SVM模型
  • 原理:内部调用SMO算法求解优化问题
  • 数学对应:优化目标是最大化分类间隔同时最小化分类误差

5. 模型评估

1
2
train_score = model.score(x_train, y_train)
test_score = model.score(x_test, y_test)
  • 功能:计算分类准确率
  • 原理:准确率公式 $ Accuracy = \frac{TP+TN}{TP+FP+TN+FN} $
  • 训练集准确率:98.57%; 测试集准确率:95.56%

三、关键概念与代码对应关系

代码实现 数学原理
kernel='rbf' 核技巧 $ K(x,y) = e^{-\gamma} $
C=2 正则化参数 $ \min \frac{1}{2} $
gamma=10 核函数参数 $ \gamma $ 控制模型复杂度
decision\_function\_shape='ovo' 一对一多分类策略
train\_test\_split(...) 数据集划分保证泛化性

四、模型可视化实现(补充)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def plot_decision_regions(X, y, classifier):
from matplotlib.colors import ListedColormap
X1, X2 = np.meshgrid(np.arange(X[:,0].min()-1, X[:,0].max()+1, 0.02),
np.arange(X[:,1].min()-1, X[:,1].max()+1, 0.02))
Z = classifier.predict(np.c_[X1.ravel(), X2.ravel()])
Z = Z.reshape(X1.shape)
plt.contourf(X1, X2, Z, alpha=0.4, cmap=ListedColormap(['lightcoral','lightblue','lightgreen']))
plt.scatter(X[y==0,0], X[y==0,1], c='red', label='Setosa')
plt.scatter(X[y==1,0], X[y==1,1], c='blue', label='Versicolor')
plt.scatter(X[y==2,0], X[y==2,1], c='green', label='Virginica')
plt.title('SVM Result')
plt.xlabel('Sepal Length')
plt.ylabel('Sepal Width')
plt.legend()
plt.show()

plot_decision_regions(x_test, y_test, model)
  • 功能:绘制决策边界
  • 原理:
    1. 生成网格点作为测试点
    2. 对每个点进行预测
    3. 将预测结果可视化为背景颜色
  • 可视化要素:
    • 背景色:表示模型的决策区域
    • 散点:实际测试集样本分布
    • 标题:显示模型名称

五、参数调优建议

  1. 核函数选择:
    • 线性可分数据使用linear核
    • 非线性数据使用rbf核(默认)
    • 多项式核适用于特定场景
  2. 参数调优:
    1
    2
    3
    4
    from sklearn.model_selection import GridSearchCV
    param_grid = {'C': [0.1, 1, 10], 'gamma': [0.1, 1, 10]}
    grid = GridSearchCV(svm.SVC(), param_grid, cv=5)
    grid.fit(x_train, y_train)
  3. 多分类策略:
    • decision_function_shape=’ovo’:一对一(适合小数据集)
    • decision_function_shape=’ovr’:一对多(适合大数据集)

六、完整工作流程图

text
1
数据预处理 → 特征选择 → 模型初始化 → 参数调优 → 模型训练 → 可视化 → 性能评估

七、注意事项

  1. 特征标准化:
    1
    2
    3
    from sklearn.preprocessing import StandardScaler
    sc = StandardScaler()
    x_train = sc.fit_transform(x_train)
    • 原因:SVM对特征尺度敏感
    • 方法:Z-score标准化
  2. 特征选择:
    • 仅使用前两个特征会损失部分信息
    • 建议使用PCA降维代替人工选择
  3. 过拟合预防:
    • 当gamma过大时可能导致过拟合
    • 可通过交叉验证选择最优参数

八、扩展应用

  1. 多分类问题:
    • 鸢尾花有3个类别
    • SVM原生支持多分类(非二分类扩展)
  2. 非线性分类:
    • RBF核处理非线性可分问题
    • 决策边界可以是非线性的
  3. 软间隔:
    • C=2允许部分样本位于间隔带内
    • 平衡分类精度和泛化能力