别再让模型‘偏科’了用PyTorch实战长尾数据下的CIFAR-10分类附完整代码当你第一次拿到一个真实世界的数据集时可能会发现一个令人头疼的现象某些类别的样本多得用不完而另一些类别的样本却少得可怜。比如在植物病虫害识别中常见病虫害的图片可能有上千张而罕见病虫害可能只有几十张。这种数据分布就像一条长长的尾巴——头部类别样本丰富尾部类别样本稀少。这就是典型的长尾分布问题。在CIFAR-10这样的基准数据集上我们通常看到的都是均衡分布的数据。但在现实场景中数据往往是不均衡的。想象一下如果让一个模型在这种数据上训练它会像偏科的学生一样对主科头部类别成绩优异而对副科尾部类别一塌糊涂。本文将带你用PyTorch解决这个问题让你的模型成为全科优等生。1. 理解长尾问题与CIFAR-10-LT长尾问题在现实世界中无处不在。电商平台上的商品图片、医疗影像、安防监控等场景中常见类别的样本数量往往远多于罕见类别。这种不均衡会导致模型在测试时通常是均衡分布表现不佳尤其是对尾部类别的识别率很低。CIFAR-10-LT是研究长尾问题的经典数据集它通过对原始CIFAR-10数据集进行下采样人为制造长尾分布。例如一个IFImbalance Factor为100的CIFAR-10-LT数据集意味着最多样本的类别有最少样本类别的100倍数据量。import matplotlib.pyplot as plt import numpy as np # 模拟CIFAR-10-LT的类别分布 (IF100) classes [airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck] num_samples [5000 * (0.1)**(i/9) for i in range(10)] # 指数衰减 num_samples [int(n) for n in num_samples] plt.figure(figsize(10,5)) plt.bar(classes, num_samples) plt.xticks(rotation45) plt.title(CIFAR-10-LT类别分布 (IF≈100)) plt.ylabel(样本数量) plt.show()这段代码会生成一个展示类别不均衡分布的柱状图。可以看到第一个类别可能有5000个样本而最后一个类别可能只有50个左右。注意虽然训练集是不均衡的但测试集应该保持均衡这样才能公平评估模型在所有类别上的表现。2. 数据加载与重采样策略处理长尾问题的第一道防线是数据层面的调整。PyTorch的WeightedRandomSampler可以帮我们实现各种重采样策略。让我们看看几种常见方法2.1 类别平衡采样类别平衡采样确保每个类别在每个batch中出现的概率相同。这在样本极度不均衡时特别有用。from torch.utils.data import WeightedRandomSampler # 假设我们有一个包含每个样本类别索引的列表 # targets [0,0,0,1,1,2,...] # 计算每个类别的权重 class_counts np.bincount(targets) num_samples len(targets) class_weights num_samples / (len(class_counts) * class_counts) # 为每个样本分配权重 sample_weights class_weights[targets] # 创建采样器 sampler WeightedRandomSampler( weightssample_weights, num_samplesnum_samples, replacementTrue # 允许重复采样 )2.2 渐进式平衡采样这是一种混合策略训练初期使用实例平衡采样样本多的类别被采样的概率高随着训练进行逐渐过渡到类别平衡采样。def get_progressive_sampler(targets, epoch, max_epoch): class_counts np.bincount(targets) num_samples len(targets) # 实例平衡权重 ib_weights 1. / class_counts[targets] # 类别平衡权重 class_weights num_samples / (len(class_counts) * class_counts) cb_weights class_weights[targets] # 混合权重 alpha epoch / max_epoch # 从0到1 weights (1 - alpha) * ib_weights alpha * cb_weights return WeightedRandomSampler(weights, num_samples, replacementTrue)2.3 平方根采样平方根采样是介于实例平衡和类别平衡之间的折中方案对类别频率取平方根后再计算采样权重。def get_sqrt_sampler(targets): class_counts np.bincount(targets) sqrt_counts np.sqrt(class_counts) class_weights 1. / sqrt_counts sample_weights class_weights[targets] return WeightedRandomSampler(sample_weights, len(targets), replacementTrue)提示重采样虽然有效但也可能导致尾部类别的样本被反复使用容易引起过拟合。建议配合数据增强使用。3. 损失函数设计从Focal Loss到Balanced Softmax数据层面的调整只是解决方案的一部分我们还需要在损失函数上下功夫。传统的交叉熵损失在长尾数据上表现不佳因为它平等对待所有样本和类别。3.1 Focal LossFocal Loss通过降低易分类样本的权重使模型更关注难分类样本。import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alphaNone, gamma2.0): super().__init__() self.alpha alpha # 可以传入类别权重 self.gamma gamma def forward(self, inputs, targets): ce_loss F.cross_entropy(inputs, targets, reductionnone, weightself.alpha) pt torch.exp(-ce_loss) focal_loss ((1 - pt) ** self.gamma) * ce_loss return focal_loss.mean()3.2 Class-Balanced LossClass-Balanced Loss根据类别频率动态调整损失权重。class ClassBalancedLoss(nn.Module): def __init__(self, class_counts, beta0.9999): super().__init__() effective_num 1.0 - np.power(beta, class_counts) weights (1.0 - beta) / effective_num weights weights / np.sum(weights) * len(class_counts) self.weights torch.FloatTensor(weights) def forward(self, inputs, targets): return F.cross_entropy(inputs, targets, weightself.weights.to(inputs.device))3.3 Balanced SoftmaxBalanced Softmax考虑了测试时的均衡分布调整了softmax的计算方式。class BalancedSoftmax(nn.Module): def __init__(self, class_counts): super().__init__() self.register_buffer(class_priors, torch.tensor(class_counts / sum(class_counts))) def forward(self, inputs, targets): adjusted_inputs inputs torch.log(self.class_priors.unsqueeze(0)) return F.cross_entropy(adjusted_inputs, targets)损失函数优点缺点适用场景Focal Loss关注难样本不依赖类别频率需要调整gamma参数类别间差异不极端Class-Balanced Loss显式考虑类别频率需要知道类别分布类别频率已知Balanced Softmax理论保证最优假设测试集均衡测试集确实均衡4. 模型架构与训练技巧除了数据和损失函数模型架构和训练策略也至关重要。以下是几个实用技巧4.1 两阶段训练先在不均衡数据上训练特征提取器然后在均衡数据上微调分类器。# 第一阶段特征学习 model create_model() optimizer torch.optim.SGD(model.parameters(), lr0.1) criterion nn.CrossEntropyLoss() # 使用原始不均衡数据训练 train(model, train_loader, criterion, optimizer, epochs100) # 第二阶段分类器调整 for param in model.feature_extractor.parameters(): param.requires_grad False # 冻结特征提取器 # 使用重采样后的均衡数据微调分类器 balanced_sampler get_balanced_sampler(targets) balanced_loader DataLoader(dataset, samplerbalanced_sampler) optimizer torch.optim.SGD(model.classifier.parameters(), lr0.01) train(model, balanced_loader, criterion, optimizer, epochs50)4.2 解耦表示学习和分类近年来研究发现将特征学习和分类器学习解耦能显著提升长尾分类性能。# 特征学习阶段 model create_model() train(model, imbalanced_loader, nn.CrossEntropyLoss(), epochs100) # 分类器调整阶段 features extract_features(model, balanced_loader) # 使用均衡数据 labels get_labels(balanced_loader) # 学习一个新的分类器 classifier learn_linear_classifier(features, labels) # 替换原模型的分类器 model.classifier classifier4.3 使用记忆库增强尾部类别为尾部类别维护一个记忆库在训练时用这些样本来增强表示学习。class MemoryBank: def __init__(self, num_classes, feature_dim): self.bank torch.zeros(num_classes, feature_dim) self.count torch.zeros(num_classes) def update(self, features, labels): for feat, label in zip(features, labels): self.bank[label] feat self.count[label] 1 def get_prototypes(self): return self.bank / self.count.unsqueeze(1) # 在训练过程中 memory_bank MemoryBank(num_classes10, feature_dim512) for inputs, labels in dataloader: features model.extract_features(inputs) memory_bank.update(features, labels) # 使用记忆库中的原型增强损失 prototypes memory_bank.get_prototypes() # 计算对比损失等...5. 完整代码实现与实验分析现在让我们把这些技术整合到一个完整的PyTorch实现中。我们将使用CIFAR-10-LT数据集IF设为100。5.1 数据准备import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader # 自定义CIFAR-10-LT数据集 class CIFAR10LT(Dataset): def __init__(self, root, trainTrue, downloadTrue, imbalance_factor100): self.dataset torchvision.datasets.CIFAR10( rootroot, traintrain, downloaddownload, transformtransforms.ToTensor()) # 创建长尾分布 targets np.array(self.dataset.targets) class_counts np.bincount(targets) num_classes len(class_counts) # 计算每个类别的新样本数 (指数衰减) min_count int(min(class_counts) / imbalance_factor) new_counts [int(min_count * (imbalance_factor ** (i/(num_classes-1)))) for i in range(num_classes)] # 筛选样本 self.data [] self.targets [] for class_idx in range(num_classes): idx np.where(targets class_idx)[0] np.random.shuffle(idx) selected idx[:new_counts[class_idx]] self.data.append(self.dataset.data[selected]) self.targets.extend([class_idx] * len(selected)) self.data np.concatenate(self.data) def __len__(self): return len(self.targets) def __getitem__(self, idx): img self.data[idx] target self.targets[idx] img transforms.ToTensor()(img) return img, target # 数据增强 train_transform transforms.Compose([ transforms.RandomCrop(32, padding4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) test_transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 创建数据集 train_dataset CIFAR10LT(root./data, trainTrue, imbalance_factor100) test_dataset torchvision.datasets.CIFAR10( root./data, trainFalse, transformtest_transform) # 创建采样器 class_counts np.bincount(train_dataset.targets) class_weights 1. / class_counts sample_weights class_weights[train_dataset.targets] sampler WeightedRandomSampler(sample_weights, len(train_dataset), replacementTrue) # 创建数据加载器 train_loader DataLoader( train_dataset, batch_size128, samplersampler, num_workers2, pin_memoryTrue) test_loader DataLoader( test_dataset, batch_size128, shuffleFalse, num_workers2, pin_memoryTrue)5.2 模型定义import torch.nn as nn import torch.nn.functional as F class BasicBlock(nn.Module): expansion 1 def __init__(self, in_planes, planes, stride1): super().__init__() self.conv1 nn.Conv2d( in_planes, planes, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(planes) self.conv2 nn.Conv2d( planes, planes, kernel_size3, stride1, padding1, biasFalse) self.bn2 nn.BatchNorm2d(planes) self.shortcut nn.Sequential() if stride ! 1 or in_planes ! self.expansion*planes: self.shortcut nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size1, stridestride, biasFalse), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out self.shortcut(x) out F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes10): super().__init__() self.in_planes 64 self.conv1 nn.Conv2d(3, 64, kernel_size3, stride1, padding1, biasFalse) self.bn1 nn.BatchNorm2d(64) self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2) self.linear nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out self.layer1(out) out self.layer2(out) out self.layer3(out) out self.layer4(out) out F.avg_pool2d(out, 4) out out.view(out.size(0), -1) out self.linear(out) return out def ResNet18(): return ResNet(BasicBlock, [2,2,2,2])5.3 训练循环def train(model, train_loader, criterion, optimizer, epoch): model.train() train_loss 0 correct 0 total 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets inputs.to(device), targets.to(device) optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, targets) loss.backward() optimizer.step() train_loss loss.item() _, predicted outputs.max(1) total targets.size(0) correct predicted.eq(targets).sum().item() if batch_idx % 100 0: print(fEpoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} f| Loss: {loss.item():.3f} | Acc: {100.*correct/total:.1f}%) return train_loss/(batch_idx1), 100.*correct/total def test(model, test_loader, criterion): model.eval() test_loss 0 correct 0 total 0 class_correct [0] * 10 class_total [0] * 10 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(test_loader): inputs, targets inputs.to(device), targets.to(device) outputs model(inputs) loss criterion(outputs, targets) test_loss loss.item() _, predicted outputs.max(1) total targets.size(0) correct predicted.eq(targets).sum().item() for i in range(10): idx targets i class_correct[i] predicted[idx].eq(targets[idx]).sum().item() class_total[i] idx.sum().item() # 打印每个类别的准确率 print(\nPer-class accuracy:) for i in range(10): print(fClass {i}: {100*class_correct[i]/class_total[i]:.1f}%) return test_loss/(batch_idx1), 100.*correct/total # 初始化模型、损失函数和优化器 device cuda if torch.cuda.is_available() else cpu model ResNet18().to(device) criterion FocalLoss(gamma2.0) optimizer torch.optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay5e-4) scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max200) # 训练循环 for epoch in range(200): train_loss, train_acc train(model, train_loader, criterion, optimizer, epoch) test_loss, test_acc test(model, test_loader, criterion) scheduler.step() print(fEpoch: {epoch} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc:.1f}% f| Test Loss: {test_loss:.3f} | Test Acc: {test_acc:.1f}%)5.4 结果分析经过200个epoch的训练我们可能会得到类似下面的结果整体准确率78.5%头部类别样本多准确率85-90%尾部类别样本少准确率65-70%虽然尾部类别的准确率仍低于头部类别但相比直接使用交叉熵损失尾部类别可能只有30-40%准确率已经有了显著提升。为了进一步分析我们可以绘制混淆矩阵from sklearn.metrics import confusion_matrix import seaborn as sns def plot_confusion_matrix(model, test_loader): model.eval() all_preds [] all_targets [] with torch.no_grad(): for inputs, targets in test_loader: inputs inputs.to(device) outputs model(inputs) _, preds outputs.max(1) all_preds.extend(preds.cpu().numpy()) all_targets.extend(targets.numpy()) cm confusion_matrix(all_targets, all_preds) plt.figure(figsize(10,8)) sns.heatmap(cm, annotTrue, fmtd, cmapBlues, xticklabelsclasses, yticklabelsclasses) plt.xlabel(Predicted) plt.ylabel(True) plt.title(Confusion Matrix) plt.show() plot_confusion_matrix(model, test_loader)混淆矩阵能直观展示模型在哪些类别上容易混淆帮助我们进一步改进模型。