别再死磕有监督了!用PyTorch复现Mean Teacher,让你的小样本数据集也能训出好模型

张开发
2026/4/10 5:29:11 15 分钟阅读

分享文章

别再死磕有监督了!用PyTorch复现Mean Teacher,让你的小样本数据集也能训出好模型
半监督学习实战用PyTorch实现Mean Teacher突破小样本训练瓶颈当你的标注数据只有几百张医疗影像或几十份工业质检图片时传统有监督学习往往捉襟见肘。上周我接手一个皮肤病变分类项目客户仅提供了387张标注图像——这个数量连ResNet的浅层都难以充分训练。经过多次实验Mean Teacher框架最终将模型准确率从68%提升到82%而今天我要分享的正是这套实战方案。1. Mean Teacher核心原理与项目适配半监督学习中的Mean Teacher框架本质上构建了一个动态知识蒸馏系统。与静态的师生模型不同这里的教师模型会随着学生模型的进化而持续更新——通过指数移动平均(EMA)机制。在医疗影像场景中这种设计尤其宝贵教师模型提供的伪标签既保持了时间维度上的稳定性又能渐进式地吸收学生模型学到的新特征。关键组件实现要点# 初始化学生和教师模型结构相同 student_model create_custom_resnet(dropout_rate0.5) teacher_model create_custom_resnet(dropout_rate0.5) # 教师模型设置为eval模式且不更新梯度 teacher_model.eval() for param in teacher_model.parameters(): param.requires_grad FalseEMA更新策略需要特别注意衰减率(decay)的设定。在工业质检这类特征稳定的场景我通常使用0.99这样的高衰减率而对于风格多变的自然图像0.95可能更合适。这个参数实际上控制着教师模型对新知识的吸收速度场景类型推荐EMA衰减率更新频率医疗影像0.995每个batch工业缺陷检测0.99每个batch自然图像分类0.95-0.97每个batch实际项目中发现当标注数据少于500样本时EMA衰减率不宜超过0.999否则教师模型参数更新过慢会导致伪标签质量下降2. 数据流设计与增强策略优化在小样本场景下数据增强不再只是提升泛化的手段更成为构建一致性约束的关键。我们的dataloader需要同时处理标注和未标注数据class SemiSupervisedDataset(Dataset): def __init__(self, labeled_data, unlabeled_data): self.labeled labeled_data self.unlabeled unlabeled_data def __getitem__(self, idx): if idx len(self.labeled): img, label self.labeled[idx] return img, label, True # 标注数据标记 else: img self.unlabeled[idx - len(self.labeled)] return img, None, False # 未标注数据对于医疗影像我推荐使用以下增强组合弱增强随机水平翻转小角度旋转(±15°)轻度色彩抖动强增强CutOut(最大遮挡30%)GridDistortion弹性变换# 一致性损失计算示例 weak_aug weak_transform(unlabeled_imgs) # 学生模型输入 strong_aug strong_transform(unlabeled_imgs) # 教师模型输入 student_pred student_model(weak_aug) with torch.no_grad(): teacher_pred teacher_model(strong_aug) consistency_loss F.mse_loss(student_pred, teacher_pred)3. 损失函数动态平衡策略单纯的MSE一致性损失在小样本场景往往效果不佳。经过多个项目验证我发现动态调整的损失权重方案最为可靠def get_consistency_weight(epoch, max_epochs300): 余弦退火调整权重 return 10 * (math.cos(math.pi * epoch / max_epochs) 1) / 2训练过程中的典型损失组成监督损失标注数据交叉熵一致性损失未标注数据MSE正则化损失如权重衰减在训练初期监督损失应占主导随着模型逐渐稳定逐步提高一致性损失的权重。这个过渡过程对最终性能影响显著训练阶段监督损失权重一致性损失权重0-50轮1.00.5-2.050-150轮0.83.0-5.0150轮后0.58.0-10.0重要提示当标注数据极度稀缺(100样本)时建议先用监督损失预训练20轮再引入一致性损失4. 实战调优技巧与性能对比在皮肤病变分类项目中我们对比了三种训练方案纯监督训练仅使用387张标注图像伪标签法固定教师模型生成伪标签Mean Teacher动态EMA更新教师模型实验结果令人印象深刻方法验证集准确率过拟合程度纯监督68.2%严重静态伪标签75.1%中等Mean Teacher82.3%轻微实现这一提升的关键调参经验包括使用AdamW优化器(初始lr3e-4)配合余弦学习率衰减每轮增加未标注数据的batch size初始为标注数据的2倍在最后50轮冻结教师模型参数# 优化器配置示例 optimizer AdamW([ {params: student_model.parameters(), lr: 3e-4}, ], weight_decay0.05) scheduler CosineAnnealingLR(optimizer, T_max300, eta_min1e-6)当处理特别敏感的小样本数据时如某些罕见病影像我会在训练后期添加标签平滑(label smoothing)技术这能进一步提升模型对边界病例的判别能力。

更多文章