本周深入学习了多类别分类任务和CNN
通过 Otto Group 产品分类案例,深入理解了多类别分类任务的关键——包括标签编码、数据类型与形状规范、分层抽样、模型输出层设计以及交叉熵损失的使用。
在 CNN 部分,学习了如何利用卷积层、池化层和残差结构来提升模型的特征提取能力和分类性能,并掌握了 ResNet 等经典网络的设计思想。
整体上,已经从简单的二分类问题,逐步迈向了复杂的多类别场景,并借助 CNN 这一强大的模型架构,将理论与实践结合,掌握了构建和优化深度神经网络解决实际分类任务的能力。
Part 1:多类别分类
项目案例:Otto Group 产品分类 (9个类别)
从二分类迈向多类别分类,是掌握分类任务的关键一步。
本项目中,遇到了三个核心挑战:标签处理、模型输出层设计和损失函数选择,并逐一攻克。
1.1 数据预处理的关键决策
1.1.1 标签编码 (Label Encoding)
- 问题: 原始目标标签是文本(如
'Class_1'
,'Class_2'
),神经网络无法直接处理。 - 解决方案: 转换为整数索引(0, 1, 2, …)。
实现代码:
1 | # --- 标签编码实现 --- |
1.1.2 标签的数据类型与形状 (for nn.CrossEntropyLoss
)
- 核心问题:
CrossEntropyLoss
对输入标签的形状和类型有严格要求。 - 辨析:
- 形状
- 二分类中我们常用
.reshape(-1, 1)
得到列向量。 - 结论: 多类别分类 不需要,标签应为一维张量
[batch_size]
。
- 二分类中我们常用
- 数据类型
- 标签必须是
torch.long
,因为损失函数会把它作为 类别索引 使用。
- 标签必须是
- 形状
实现代码:
1 | from torch.utils.data import Dataset |
1.1.3 分层抽样 (Stratified Sampling)
- 问题: Otto 数据集类别不均衡,随机划分可能导致验证集分布失衡。
- 解决方案: 使用
train_test_split(..., stratify=y)
。
实现代码:
1 | from sklearn.model_selection import train_test_split |
1.2 模型架构与损失函数
- 输出层:
nn.Linear
的输出神经元数必须等于类别数。 - 最佳实践: 动态获取类别数,而非硬编码。
1 | import numpy as np |
- 损失函数: 使用
nn.CrossEntropyLoss
。 - 注意: 不需要手动添加
Softmax
,因为它已集成在损失函数内部。
1.3 Kaggle 提交格式
- 问题: Kaggle 要求提交的是各类别概率,而不是类别索引。
- 说明: 提交样例
1,0,0,0,0,0,0,0,0
代表100%概率属于 Class_1,而不是标签。
实现代码:
1 | import torch.nn.functional as F |
1.4 Otto Group 产品分类完整代码
1 | import torch |
Part2:计算机视觉入门 (CNN)
项目案例: MNIST 手写数字识别、Fashion-MNIST 服装分类
2.1 图像数据预处理 (torchvision.transforms
)
- 核心工具:
transforms.Compose
—— 图像处理的“流水线”,将多个转换按顺序组合。
1 | transform = transforms.Compose([ |
关键转换操作
transforms.ToTensor()
- 格式转换:
PIL Image → Tensor
- 数值归一化:
[0, 255] → [0.0, 1.0]
- 维度重排:
(H, W, C) → (C, H, W)
- 格式转换:
transforms.Normalize(mean, std)
- 将数据标准化为均值0、标准差1,提高收敛速度与稳定性。
- 为什么
(0.1307,)
要写成单元素元组?
因为函数期望输入元组,每个元素对应一个通道。MNIST 是单通道灰度图,因此需要(mean,)
的形式。
数据增强 (Data Augmentation)
- 核心思想: 不增加新样本,通过对训练集随机变换(裁剪、翻转等)生成多样性数据,防止过拟合。
- 原则: 只增强训练集,不动测试集,确保评估客观。
实现代码:
1 | # --- 训练集(含数据增强) --- |
2.2 从全连接到CNN:空间结构的保留
问题: 为什么全连接网络在图像任务表现差?
结论: 因为输入需要“压平”为一维向量(
28*28=784
),破坏了像素之间的空间关系,模型无法学习边缘/形状等特征。CNN的优势: 卷积核滑动操作保留二维结构,能高效提取空间模式。
2.3 简单CNN的维度变化追踪
1 | class CNNModel(nn.Module): |
假设 batch_size = 64
:
- 输入:
[64, 1, 28, 28]
Conv2d(1, 16, kernel_size=5, padding=2)
→[64, 16, 28, 28]
MaxPool2d(2)
→[64, 16, 14, 14]
Conv2d(16, 32, kernel_size=5, padding=2)
→[64, 32, 14, 14]
MaxPool2d(2)
→[64, 32, 7, 7]
.view(-1, 32*7*7)
→[64, 1568]
Linear(1568, 10)
→[64, 10]
简单CNN实现MNIST手写数据集识别
1 | import torch |
2.4 ResNet架构:跨越深度的鸿沟
现代 CNN 的基石,解决了“深度的诅咒”。
2.4.1 深度的诅咒:梯度消失
- 问题: VGG 等深层网络,反向传播时梯度因连乘效应衰减为零。
- 结果: 底层网络学不到有效特征。
2.4.2 解决方案:构建残差块 (ResidualBlock)
核心思想:
使用 跳跃连接 (Skip Connection),公式:
$$
H(x) = F(x) + x
$$让梯度能直接反传,避免消失。
第一部分: ResidualBlock
(残差块) - ResNet的“乐高积木”
这个类定义了ResNet最基本、可重复使用的单元。它的天才之处在于同时包含了一条“普通公路”和一条“高速公路”。
1 | class ResidualBlock(nn.Module): |
__init__
(初始化 - 定义两条路)
- 主路径 (
self.conv1
,self.bn1
, etc.):- 这是一条标准的卷积路径:
卷积 -> 批量归一化 -> ReLU -> 卷积 -> 批量归一化
。 - 它负责学习输入特征
x
和期望输出H(x)
之间的“差异”或“残差”F(x)
。 nn.BatchNorm2d
(批量归一化) 是一个重要的辅助层,它可以稳定和加速训练过程。
- 这是一条标准的卷积路径:
- 快捷连接 (
self.shortcut
):- 这是ResNet的灵魂。它的目标是让输入
x
能够直接与主路径的输出out
相加。 if stride != 1 or in_channels != out_channels:
: 这是一个至关重要的判断。加法要求两个张量的维度必须完全相同。这个if
检查了两种维度可能发生变化的情况:stride != 1
: 如果步长为2,主路径的输出图像尺寸会减半。in_channels != out_channels
: 如果输入通道数和输出通道数不同。
- 如果维度不匹配,快捷连接就不能直接“跳过”,它也需要经过一个简单的变换来匹配主路径输出的维度。这个变换通常是一个1x1的卷积,它的作用就是调整通道数和尺寸。
- 这是ResNet的灵魂。它的目标是让输入
forward
(前向传播 - 数据如何在两条路上跑)
out = ...
: 输入x
先走一遍主路径,得到变换后的结果F(x)
。out += self.shortcut(x)
: 将主路径的结果F(x)
与(可能经过变换的)原始输入x
逐元素相加。这一步就是梯度的“高速公路”的入口,确保了梯度可以无损地反向传播。out = self.relu(out)
: 对相加后的结果进行最终的激活。
第二部分: ResNet
- 用“乐高积木”搭建城堡
这个类使用我们上面定义的 ResidualBlock
作为基本组件,来搭建一个完整的、分阶段的ResNet模型。
1 | class ResNet(nn.Module): |
__init__
(初始化 - 定义模型蓝图)
- 初始卷积层 (
self.conv1
,self.bn1
): 在进入残差块之前,对输入图片进行一次初步的特征提取。 - 残差层 (
self.layer1
,layer2
,layer3
):- 这是模型的主体,由
_make_layer
这个辅助函数创建。 self.layer1
:stride=1
,不改变图像尺寸,只加深通道。self.layer2
:stride=2
,将图像尺寸减半,并加深通道。self.layer3
:stride=2
,将图像尺寸再次减半,并加深通道。
- 这是模型的主体,由
- 分类器 (
self.avg_pool
,self.fc
):nn.AdaptiveAvgPool2d((1, 1))
: 这是一个非常智能的池化层。不管输入的特征图尺寸是多大(比如7x7
或5x5
),它都能将其降维成1x1
。它通过计算每个通道特征图的平均值来实现。这使得模型对输入图片尺寸的变化更具鲁棒性。self.fc
: 最后的标准全连接层,用于最终的分类。
_make_layer
(辅助函数 - “施工队”)
- 这个函数的作用是“修建”一个完整的残差阶段,比如
self.layer2
。 - 一个阶段包含
num_blocks
个残差块。 strides = [stride] + [1]*(num_blocks-1)
: 这个聪明的写法确保了在一个阶段中,只有第一个残差块可能会进行下采样(stride=2
),而后续的所有块都保持尺寸不变(stride=1
)。self.in_channels = out_channels
: 这是一个重要的状态更新,确保下一个阶段的输入通道数是正确的。
forward
(前向传播 - 数据的完整旅程)
forward
方法清晰地展示了数据是如何流经整个ResNet的:从初始卷积,到一系列的残差块堆叠,再到最后的池化和分类,一气呵成。
- 1x1 卷积的角色
- 作用:调整通道数 (
out_channels
) 和空间尺寸 (stride
)。 - 使
F(x)
和x
的维度一致,从而可相加。
- 作用:调整通道数 (
bias=False
in Conv2d- 原因:卷积后的
BatchNorm
会抵消偏置,因此可省略,减少冗余参数。
- 原因:卷积后的
inplace=True
in ReLU- 原因:节省内存,直接在输入张量上修改数据,适用于
conv → bn → relu
链。
- 原因:节省内存,直接在输入张量上修改数据,适用于
ResNet 结构图与维度变化
假设输入:[Batch_Size, 1, 28, 28]
(例如:[64, 1, 28, 28],即64张28x28的灰度图)
1 | 输入图片 (x) |
2.5 概念解惑
2.5.1 下采样(Downsampling)
核心定义:降低图像分辨率
下采样,在计算机视觉的语境下,最直观的理解就是缩小图像的尺寸。
想象一下,你有一张 100x100
像素的高清图片,你把它在画图软件里缩小到 50x50
像素。这个过程就是下采样。你为了让图片变小,丢弃了一部分像素信息,保留了最关键的视觉特征。
在卷积神经网络中,下采样指的是降低特征图 (Feature Map) 的空间维度(高度和宽度)。
CNN中实现下采样的两种主要方式
1. 池化层 (Pooling Layers),尤其是 nn.MaxPool2d
这是最直接、最常见的下采样方法。
- 工作原理:
MaxPool2d(kernel_size=2)
会将输入的特征图分割成一个个不重叠的2x2
的小方块。然后,在每个小方块中,它只保留值最大的那一个像素,并丢弃其他三个。 - 效果: 因为它把
2x2
的区域压缩成了1x1
的区域,所以特征图的高度和宽度都减半了。
一个简单的例子:
1 | 原始 4x4 区域 经过 MaxPool2d(2) 后 |
- 左上角的
[[1, 2], [3, 4]]
中最大的是4
。 - 右上角的
[[5, 6], [7, 8]]
中最大的是8
。 - 以此类推…
2. 带步长(Stride)的卷积 (nn.Conv2d(..., stride=2)
)
这是另一种更现代的下采样方法,我们在ResNet
的shortcut
连接中学到过。
- 工作原理: 普通卷积的步长
stride=1
,意味着卷积核每次在图像上移动一个像素。而当stride=2
时,卷积核会每次跳过一个像素,移动两个像素的距离。 - 效果: 因为卷积核“跳”着走,它进行计算的次数变少了,自然产生的输出特征图的高度和宽度也大约减半了。
为什么要进行下采样?(三个核心目的)
1. 减少计算量和参数数量
- 这是最直接的好处。特征图的尺寸减半后,后续卷积层需要处理的数据量就大大减少了(减少了75%!)。这使得我们可以在不耗尽内存和计算资源的情况下,构建更深、更强大的网络。
2. 扩大感受野 (Receptive Field)
- 这是最重要的 концептуальный 好处。
- 感受野指的是输出特征图上的一个像素,对应到原始输入图像上的区域大小。
- 在下采样之前,一个
3x3
的卷积核看到的是原始图像上3x3
的区域。 - 在一次
2x2
的下采样之后,新的特征图上的一个像素就代表了原始图像上一个2x2
的区域。此时,再对这个新特征图用一个3x3
的卷积核,它看到的区域实际上对应到了原始图像上一个更大的区域(比如6x6
)。 - 比喻: 就像你看地图。当你“下采样”(缩小地图)时,你屏幕上的一个点从代表一条街道,变成了代表一个街区,甚至整个城市。
- 效果: 下采样让网络中更深层的卷积核能够“看到”更大范围的特征,从而学习到更宏观、更抽象的模式(比如从“边缘”和“曲线”组合成“眼睛”或“轮胎”)。
3. 增加特征的平移不变性 (Translation Invariance)
MaxPool2d
尤其能带来这个好处。如果在一个2x2
的区域内,那个最重要的特征(最大值)稍微移动了一下位置,但仍然在这个2x2
的区域内,那么池化后的输出是完全一样的。- 效果: 这让模型对目标物体在图像中的微小位移不那么敏感。无论数字“7”的笔画稍微偏左还是偏右一点,模型都能稳定地识别出它是“7”。
总结: 下采样是CNN中的一个核心操作。它通过牺牲空间分辨率,换来了计算效率的提升、感受野的扩大和特征的鲁棒性,是CNN能够学习到从低级到高级的层次化特征的关键所在。
2.6 ResNet-Fasion-MNIST完整代码
1 | import torch |
2.7 ResNet 应用到更具挑战性的彩色图片数据集——CIFAR-10
新的挑战和学习点:
- 彩色图像 (3通道):如何处理彩色图片作为输入,以及这会如何影响你的第一个卷积层。
- 更复杂的特征: CIFAR-10 的图片比 Fashion-MNIST 更具多样性,需要模型学习更抽象的视觉特征。
- 数据增强 (Data Augmentation):为了让模型更好地泛化,我们将介绍并应用一些基本的数据增强技术。
核心修改点解析:
数据加载 (
datasets.CIFAR10
): 切换到CIFAR-10数据集。transforms.Normalize
参数:- CIFAR-10是彩色图片,所以
mean
(均值) 和std
(标准差) 都变成了包含三个值的元组,分别对应R、G、B三个通道。这些是根据整个CIFAR-10训练集的统计数据计算出的。
- CIFAR-10是彩色图片,所以
数据增强 (
train_transform
):transforms.RandomCrop(32, padding=4)
: 随机裁剪,先在图像边缘填充4个像素,然后在32x32
的区域内随机裁剪出32x32
的图像。这有助于模型学习到物体在图片中的不同位置。transforms.RandomHorizontalFlip()
: 随机水平翻转图像。这有助于模型学习到物体左右翻转后的特征。- 这些技术能有效地增加训练数据的多样性,防止模型过拟合,提高泛化能力。
ResNet
类的conv1
层:self.conv1 = nn.Conv2d(3, 16, ...)
: 最关键的修改!因为CIFAR-10是RGB三通道彩色图像,所以第一个卷积层的in_channels
必须从1
(灰度图) 变为3
。
EPOCHS
和scheduler.step_size
:- CIFAR-10比Fashion-MNIST更难,因此通常需要更多的训练周期 (
EPOCHS
)。 StepLR
的step_size
也进行了调整,表示每隔多少个 epoch 学习率衰减一次。
完整代码:
- CIFAR-10比Fashion-MNIST更难,因此通常需要更多的训练周期 (
1 | import torch |
Part3 训练技巧
在构建了强大的模型架构(如 ResNet)之后,下一步是采用更智能的训练策略来最大化其性能。
3.1 正则化 (Regularization):防止模型“死记硬背”
- 问题:过拟合 (Overfitting)
- 定义:当模型在训练数据上表现完美,但在未见过的数据上表现很差时,就发生了过拟合。
- 类比:一个学生死记硬背练习册答案,却不会解新题。
- 解决方案:Dropout
- 思想:在训练时,以概率
p
随机将部分神经元输出置为零。 - 效果:迫使网络学到更稳健的特征,减少对单一神经元组合的依赖。
- 原则:
- 仅在
model.train()
模式下生效; - 在
model.eval()
时自动关闭。
- 仅在
- 思想:在训练时,以概率
实现示例
1 | self.net = nn.Sequential( |
3.2 智能训练循环:早停法与学习率调度
- 问题:如何避免过拟合,又能自动找到最佳训练轮数与学习率?
早停法 (Early Stopping)
- 策略:
- 每个 epoch 结束后计算验证集损失;
- 若验证损失连续
patience
个 epoch 无改善,则提前终止; - 保存验证损失最低时的模型作为最终模型。
- 效果:自动找到欠拟合与过拟合之间的平衡点,节省训练时间。
学习率调度器 (Learning Rate Scheduler)
- 策略:根据验证指标的变化自动调整学习率。
- 常见调度器:
ReduceLROnPlateau
:验证损失长期不降时,学习率 × factor;StepLR
:每隔固定 epoch 衰减学习率;CosineAnnealingLR
:余弦退火曲线,常见于 ResNet、Transformer;OneCycleLR
:在 NLP、CV 任务中广泛使用。
智能训练循环框架
1 | import copy |
训练循环逻辑图
1 | 训练开始 → 每个 epoch 结束 → 计算验证损失 |