别再死记硬背公式了!用PyTorch代码实战FGM、PGD、FreeLB对抗训练(附避坑指南)

张开发
2026/4/19 18:04:33 15 分钟阅读

分享文章

别再死记硬背公式了!用PyTorch代码实战FGM、PGD、FreeLB对抗训练(附避坑指南)
PyTorch对抗训练实战FGM、PGD与FreeLB的工程化实现与调优指南对抗训练早已从学术论文中的数学公式变成了工业界提升模型鲁棒性的标配技术。但当你真正尝试在PyTorch项目中实现它时可能会遇到各种意想不到的问题——梯度消失、训练速度骤降、与BatchNorm冲突等。本文将用可运行的代码片段带你穿透理论迷雾掌握三种主流对抗训练方法在真实项目中的落地技巧。1. 对抗训练的工程本质从公式到代码Min-Max公式在论文中看起来优雅简洁但实际代码实现时却需要解决一系列工程问题。让我们先理解这个核心公式在PyTorch计算图中的对应关系# 理论公式的伪代码表达 def min_max_loss(model, x, y): # 内层max寻找使loss最大的扰动delta delta find_worst_perturbation(model, x, y) # 外层min用对抗样本训练模型 adv_loss model(x delta, y) return adv_loss实际实现时需要处理的关键问题扰动范围控制ε-ball约束在代码中如何体现梯度计算顺序何时清零梯度何时累加梯度计算效率如何避免重复计算带来的性能损耗提示对抗训练会使单步训练时间增加30%-300%具体取决于算法选择和实现方式2. FGM实现详解与性能优化Fast Gradient Method是最轻量级的对抗训练方案适合作为第一个试水算法。以下是经过生产环境验证的增强版实现class EnhancedFGM: def __init__(self, model, epsilon0.25, emb_layerword_embeddings): self.model model self.epsilon epsilon self.emb_layer emb_layer self.backup {} def attack(self): 生成对抗样本并备份原始参数 for name, param in self.model.named_parameters(): if param.requires_grad and self.emb_layer in name: self.backup[name] param.data.clone() norm param.grad.norm(p2) if norm 1e-8: # 防止除零错误 r self.epsilon * param.grad / (norm 1e-6) # 数值稳定 param.data.add_(r) def restore(self): 恢复原始embedding参数 for name, param in self.model.named_parameters(): if param.requires_grad and self.emb_layer in name: param.data.copy_(self.backup[name]) self.backup {}实战中的五个关键发现梯度累积问题当使用梯度累积策略时需要在每次累积步骤后调用attack()和restore()混合精度训练需在attack()前后手动管理AMP的梯度缩放器层选择策略不仅限于embedding层对CNN的卷积层同样有效ε值调参从0.15开始尝试每0.05为步长调整内存优化使用grad_fn钩子减少中间变量缓存3. PGD的多步攻击实现技巧Projected Gradient Descent相比FGM更加鲁棒但实现复杂度显著增加。以下是避免常见陷阱的实现方案class SafePGD: def __init__(self, model, epsilon0.3, alpha0.1, steps3): self.model model self.epsilon epsilon self.alpha alpha self.steps steps self.emb_backup {} self.grad_backup {} def attack(self, is_first_attackFalse): for name, param in self.model.named_parameters(): if param.requires_grad and embedding in name: if is_first_attack: self.emb_backup[name] param.data.clone() norm param.grad.norm(p2) if norm 1e-8: r self.alpha * param.grad / norm param.data.add_(r) # 投影到ε-ball内 delta param.data - self.emb_backup[name] delta_norm delta.norm(p2) if delta_norm self.epsilon: delta.mul_(self.epsilon / delta_norm) param.data.copy_(self.emb_backup[name] delta) def backup_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: self.grad_backup[name] param.grad.clone() def restore_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad and param.grad is not None: param.grad.copy_(self.grad_backup[name]) def restore_emb(self): for name, param in self.model.named_parameters(): if param.requires_grad and embedding in name: param.data.copy_(self.emb_backup[name])PGD特有的训练循环结构pgd SafePGD(model, epsilon0.3, alpha0.1, steps3) for batch in dataloader: # 正常前向传播 loss model(batch) loss.backward() pgd.backup_grad() # 多步对抗攻击 for step in range(pgd.steps): pgd.attack(is_first_attack(step0)) if step ! pgd.steps - 1: model.zero_grad() else: pgd.restore_grad() loss_adv model(batch) loss_adv.backward() # 恢复并更新 pgd.restore_emb() optimizer.step() model.zero_grad()性能优化对比表优化策略FGM训练时间PGD训练时间效果提升基础实现1.0x3.2x基准梯度检查点0.95x2.8x0.2%混合精度0.6x1.9x-0.1%选择性反向传播0.8x2.5x0.1%4. FreeLB的高级应用与调参FreeLB作为PGD的改进版本在BERT等Transformer模型中表现优异。以下是适配现代预训练模型的实现class FreeLBWrapper: def __init__(self, model, optimizer, adv_lr1e-2, adv_steps3, adv_init_mag1e-2): self.model model self.optimizer optimizer self.adv_lr adv_lr self.adv_steps adv_steps self.adv_init_mag adv_init_mag self.delta None def step(self, inputs): # 初始化扰动 embeddings self.get_embeddings(inputs) if self.delta is None: self.delta torch.zeros_like(embeddings) if self.adv_init_mag 0: self.delta.uniform_(-self.adv_init_mag, self.adv_init_mag) # 多步对抗攻击 for _ in range(self.adv_steps): self.delta.requires_grad_() inputs[inputs_embeds] embeddings self.delta inputs[input_ids] None outputs self.model(**inputs) loss outputs.loss loss loss / self.adv_steps # 梯度累积平均 loss.backward() delta_grad self.delta.grad.detach() # 更新delta denom delta_grad.norm(p2, dim(1,2), keepdimTrue).clamp(min1e-6) self.delta (self.delta self.adv_lr * delta_grad / denom).detach() # 投影到单位球 delta_norm self.delta.norm(p2, dim(1,2)) mask (delta_norm 1.0).float().unsqueeze(-1).unsqueeze(-1) self.delta (self.delta * (1 - mask) mask * self.delta / delta_norm.unsqueeze(-1).unsqueeze(-1)).detach() # 最终对抗训练 inputs[inputs_embeds] embeddings self.delta outputs self.model(**inputs) return outputs.loss def get_embeddings(self, inputs): 提取模型原始embedding return self.model.embeddings.word_embeddings(inputs[input_ids])FreeLB调参指南初始扰动大小adv_init_magBERT类模型1e-2到1e-1CNN模型1e-3到1e-2学习率比例adv_lr# 通常设为模型学习率的5-20倍 adv_lr 20 * optimizer.param_groups[0][lr]训练阶段调度前1/3训练禁用对抗训练中间1/3逐步增加adv_steps1→3最后1/3固定参数训练5. 生产环境中的避坑实践典型问题与解决方案梯度爆炸问题# 在attack方法中添加梯度裁剪 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm1.0)与BatchNorm的冲突方案一冻结BN层的统计量for module in model.modules(): if isinstance(module, torch.nn.BatchNorm1d): module.eval()方案二使用同步BNSyncBatchNorm内存不足的优化# 使用梯度检查点 from torch.utils.checkpoint import checkpoint loss checkpoint(model, batch_input, use_reentrantFalse)多GPU训练注意事项# 确保扰动在所有GPU上同步 if torch.distributed.is_initialized(): torch.distributed.all_reduce(delta, optorch.distributed.ReduceOp.AVG)性能对比数据在文本分类任务上的实测效果BERT-base方法准确率干净数据准确率对抗攻击训练时间/epoch基线92.3%65.2%1.0xFGM92.1% (-0.2%)78.5% (13.3%)1.3xPGD91.8% (-0.5%)82.1% (16.9%)3.5xFreeLB92.5% (0.2%)84.3% (19.1%)2.8x在实现对抗训练时最难调试的往往不是算法本身而是它与现有训练管道的兼容性。建议首次实现时在小型数据集上验证所有组件正常工作再扩展到全量数据。

更多文章