多任务学习(MTL)实战:从加权策略到不确定性建模

张开发
2026/4/18 0:36:12 15 分钟阅读

分享文章

多任务学习(MTL)实战:从加权策略到不确定性建模
1. 多任务学习入门从单任务到多任务的跃迁第一次接触多任务学习MTL时我正被公司要求同时优化推荐系统的点击率和停留时长两个指标。当时傻乎乎地训练了两个独立模型结果线上部署时发现资源消耗翻倍两个模型的预测结果还经常打架。直到同事扔给我一篇MTL论文才恍然大悟原来一个模型可以同时搞定多个任务多任务学习的本质就像让一个学生同时学习数学和语文。传统单任务学习是培养偏科生而MTL要培养全能型选手。在实际工业场景中这种全能模型的优势非常明显部署成本低一个模型服务替代多个独立模型资源共享底层特征表示可以在任务间共享泛化更强任务间的相关性起到正则化作用但MTL最让人头疼的就是损失加权问题。就像老师要给不同科目分配课时一样我们需要决定每个任务在总损失中的比重。早期我试过最简单的加权平均法# 手工加权示例 total_loss 0.3 * loss1 0.7 * loss2这种粗暴方法很快就让我栽了跟头——权重稍微变化0.1线上指标就能波动5%。后来才发现优秀的MTL实现需要更精细的加权策略这正是本文要重点探讨的内容。2. 手工加权简单但危险的起点2.1 基础加权方法手工加权就像给多个任务分配固定比例的资源。假设我们要同时优化分类准确率和回归误差def manual_weighted_loss(loss1, loss2): return alpha * loss1 (1-alpha) * loss2这里的alpha就是需要人工调整的超参数。我在电商搜索业务中实践时发现这种方法的痛点非常明显敏感度过高当alpha从0.4调整到0.5时AUC可能提升2%但MAE会恶化15%任务量纲差异分类loss通常在0-1之间而回归loss可能达到几十动态适应性差不同训练阶段任务难度会变化2.2 改进方案标准化加权后来我采用了一种改进方案——先对各个loss进行标准化处理# 对loss进行标准化 normalized_loss1 loss1 / loss1.detach() normalized_loss2 loss2 / loss2.detach() total_loss w1 * normalized_loss1 w2 * normalized_loss2这种方法确实缓解了量纲问题但依然需要大量实验来确定最佳权重。在推荐系统场景下我们通常要跑数十组AB测试才能找到相对合理的权重组合。提示手工加权适合任务间关系稳定且量级相近的场景比如同时预测用户年龄和性别。对于差异大的任务建议考虑动态加权方法。3. 动态加权平均DWA让任务平衡学习3.1 DWA算法原理Dynamic Weight Averaging的核心思想很直观——根据任务的学习速度动态调整权重。就像老师会根据学生各科进步速度调整教学重点计算各任务loss在相邻epoch的变化率对变化率进行softmax归一化用归一化结果作为当前epoch的权重具体实现可以参考这个PyTorch示例class DWA(nn.Module): def __init__(self, num_tasks, temp2.0): super().__init__() self.temp temp self.register_buffer(prev_loss, torch.zeros(num_tasks)) def forward(self, losses): if self.prev_loss.sum() 0: # 第一轮平均加权 return torch.softmax(torch.ones_like(losses), dim0) loss_ratio losses / self.prev_loss weights torch.softmax(loss_ratio / self.temp, dim0) self.prev_loss losses.detach() return weights3.2 实战经验与调参技巧在商品多属性预测任务中DWA表现出色但需要注意温度系数temp控制权重分布平滑度通常1.0-3.0之间初始阶段稳定前几个epoch建议使用固定权重异常值处理单个epoch的剧烈波动需要平滑处理实测发现对于价格预测销量预测的双任务场景DWA相比手工加权能使模型收敛速度提升30%最终指标也更加均衡。4. 不确定性加权更科学的自适应方法4.1 不确定性理论基础不确定性加权方法源自论文《Multi-task learning using uncertainty to weigh losses》它将任务权重与不确定性建模相结合。这里需要区分两种不确定性认知不确定性数据不足导致可通过增加数据缓解偶然不确定性数据本身噪声导致与数据量无关该方法主要针对第二种不确定性中的同方差情况。推导后的loss函数形式非常优雅L \sum_i \frac{1}{2\sigma_i^2}L_i \log\sigma_i其中σ是任务相关的不确定性参数会被自动学习。4.2 代码实现详解以下是完整的AutomaticWeightedLoss实现我在原基础上添加了数值稳定处理class RobustAutomaticWeightedLoss(nn.Module): def __init__(self, num_tasks, eps1e-6): super().__init__() self.params nn.Parameter(torch.ones(num_tasks)) self.eps eps def forward(self, *losses): total_loss 0 for i, loss in enumerate(losses): sigma torch.clamp(self.params[i], minself.eps) total_loss 0.5/(sigma**2)*loss torch.log(1 sigma**2) return total_loss使用时需要特别注意优化器配置model MultiTaskModel() awl RobustAutomaticWeightedLoss(2) # 关键awl参数需要单独配置优化器 optimizer torch.optim.Adam([ {params: model.parameters()}, {params: awl.parameters(), weight_decay: 0} # 禁止权重衰减 ], lr1e-3)4.3 工业场景应用案例在视频推荐系统中我们同时优化点击率预测分类任务观看时长预测回归任务使用不确定性加权后模型自动学习到两个任务的σ值分别为0.8和1.2这与我们手动分析的任务噪声水平一致。最终线上AB测试显示相比DWA方法不确定性加权使时长预测的MAE降低了12%而点击率保持稳定。5. 进阶技巧与避坑指南5.1 多任务架构设计除了损失加权网络结构设计同样重要。分享几个实用技巧硬参数共享底层共享顶层独立适合相关任务软参数共享各任务有独立参数但保持相似适合差异任务任务门控学习任务特定的特征组合方式# 门控共享示例 class TaskGate(nn.Module): def __init__(self, input_dim, num_tasks): super().__init__() self.gates nn.ModuleList([ nn.Linear(input_dim, input_dim) for _ in range(num_tasks) ]) def forward(self, x, task_id): return x * torch.sigmoid(self.gates[task_id](x))5.2 常见问题排查踩过无数坑后我总结出MTL调试checklist梯度检查各任务梯度量级是否均衡权重监控动态权重的变化曲线是否合理表征分析共享层的特征是否被某些任务主导资源分配显存占用是否超出预期遇到问题时可以先用简单的加权方法建立baseline再逐步引入复杂方法。记住不是所有场景都需要花哨的加权算法有时候简单的加权平均配合好的网络结构就能取得不错的效果。

更多文章