PyTorch实战:手把手教你复现HVMUNet网络(含Mamba核心模块代码详解)

张开发
2026/4/11 22:05:22 15 分钟阅读

分享文章

PyTorch实战:手把手教你复现HVMUNet网络(含Mamba核心模块代码详解)
PyTorch实战手把手教你复现HVMUNet网络含Mamba核心模块代码详解在医学图像分割领域U型网络架构一直是主流选择。最近结合状态空间模型SSM的新型网络架构HVMUNet引起了广泛关注。本文将带您从零开始实现这个融合了Mamba模块的创新架构特别针对那些希望深入理解代码实现细节的开发者。1. 环境准备与基础模块实现1.1 安装必要依赖确保您的环境已安装PyTorch 1.8和CUDA 11。推荐使用conda创建独立环境conda create -n hvmunet python3.8 conda install pytorch torchvision torchaudio cudatoolkit11.3 -c pytorch pip install einops timm1.2 深度可分离卷积实现HVMUNet的基础构建块之一是深度可分离卷积。我们先实现一个高效的版本class DepthWiseConv2d(nn.Module): def __init__(self, dim_in, dim_out, kernel_size3, padding1, stride1): super().__init__() self.conv1 nn.Conv2d(dim_in, dim_in, kernel_sizekernel_size, paddingpadding, stridestride, groupsdim_in) self.norm nn.GroupNorm(4, dim_in) self.conv2 nn.Conv2d(dim_in, dim_out, kernel_size1) def forward(self, x): return self.conv2(self.norm(self.conv1(x)))这个实现包含了三个关键部分深度卷积逐通道卷积组归一化层逐点卷积1x1卷积2. Mamba核心模块解析与实现2.1 Local_SS2D模块Local_SS2D是HVMUNet的核心创新之一结合了卷积的局部特征提取和SS2D的全局感知能力class Local_SS2D(nn.Module): def __init__(self, dim, h14, w8): super().__init__() self.dw nn.Conv2d(dim//2, dim//2, kernel_size3, padding1, groupsdim//2, biasFalse) self.complex_weight nn.Parameter( torch.randn(dim//2, h, w, 2) * 0.02) self.pre_norm LayerNorm(dim, eps1e-6, data_formatchannels_first) self.post_norm LayerNorm(dim, eps1e-6, data_formatchannels_first) self.SS2D SS2D(d_modeldim//2, d_state16) def forward(self, x): x self.pre_norm(x) x1, x2 torch.chunk(x, 2, dim1) x1 self.dw(x1) x2 x2.permute(0, 2, 3, 1) x2 self.SS2D(x2).permute(0, 3, 1, 2) x torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim2) return self.post_norm(x.reshape(x1.shape[0], -1, *x1.shape[2:]))关键点说明输入特征被均分为两部分处理一半通道走传统卷积路径另一半走SS2D路径最终结果拼接后通过LayerNorm2.2 高阶H_SS2D实现高阶视觉状态空间模块是HVMUNet的精华所在下面是5阶实现class H_SS2D(nn.Module): def __init__(self, dim, order5, gflayerNone): super().__init__() self.order order self.dims [dim // (2**i) for i in range(order)][::-1] self.proj_in nn.Conv2d(dim, 2*dim, 1) self.dwconv gflayer(sum(self.dims)) if gflayer else \ get_dwconv(sum(self.dims), 7, True) self.proj_out nn.Conv2d(dim, dim, 1) self.pws nn.ModuleList([ nn.Conv2d(self.dims[i], self.dims[i1], 1) for i in range(order-1) ]) self.ss2ds nn.ModuleList([ SS2D(d_modeld, dropout0, d_state16) for d in self.dims[1:] ]) self.ss2d_in SS2D(d_modelself.dims[0], dropout0, d_state16) def forward(self, x): B, C, H, W x.shape fused_x self.proj_in(x) pwa, abc torch.split(fused_x, (self.dims[0], sum(self.dims)), dim1) dw_abc self.dwconv(abc) dw_list torch.split(dw_abc, self.dims, dim1) x pwa * dw_list[0] x self.ss2d_in(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) for i in range(self.order-1): x self.pws[i](x) * dw_list[i1] x self.ss2ds[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return self.proj_out(x)这个模块实现了论文中描述的多阶特征交互输入特征通过proj_in扩展通道分割为不同阶次的子特征每阶特征经过深度卷积和SS2D处理通过Hadamard积实现特征交互3. 注意力桥接模块实现3.1 通道注意力桥(CAB)class Channel_Att_Bridge(nn.Module): def __init__(self, c_list): super().__init__() self.avgpool nn.AdaptiveAvgPool2d(1) self.get_all_att nn.Conv1d(1, 1, kernel_size3, padding1, biasFalse) self.att_layers nn.ModuleList([ nn.Linear(sum(c_list)-c_list[-1], c) for c in c_list[:-1] ]) self.sigmoid nn.Sigmoid() def forward(self, *features): att torch.cat([self.avgpool(f) for f in features], dim1) att self.get_all_att(att.squeeze(-1).transpose(-1, -2)) return [self.sigmoid(l(att)).transpose(-1,-2).unsqueeze(-1).expand_as(f) for l, f in zip(self.att_layers, features)]3.2 空间注意力桥(SAB)class Spatial_Att_Bridge(nn.Module): def __init__(self): super().__init__() self.shared_conv nn.Sequential( nn.Conv2d(2, 1, 7, padding3), nn.Sigmoid() ) def forward(self, *features): att_maps [] for f in features: avg torch.mean(f, dim1, keepdimTrue) max_ torch.max(f, dim1, keepdimTrue)[0] att self.shared_conv(torch.cat([avg, max_], dim1)) att_maps.append(att) return att_maps4. 完整HVMUNet网络集成4.1 编码器-解码器结构class HVMUNet(nn.Module): def __init__(self, in_ch3, num_classes1, channels[8,16,32,64,128,256]): super().__init__() # 编码器部分 self.enc1 nn.Sequential( nn.Conv2d(in_ch, channels[0], 3, padding1), nn.GroupNorm(4, channels[0]) ) self.enc2 nn.Sequential( nn.Conv2d(channels[0], channels[1], 3, padding1), nn.GroupNorm(4, channels[1]) ) self.enc3 self._make_layer(channels[1], channels[2], order2) self.enc4 self._make_layer(channels[2], channels[3], order3) self.enc5 self._make_layer(channels[3], channels[4], order4) self.enc6 self._make_layer(channels[4], channels[5], order5) # 解码器部分 self.dec1 self._make_layer(channels[5], channels[4], order5) self.dec2 self._make_layer(channels[4], channels[3], order4) self.dec3 self._make_layer(channels[3], channels[2], order3) self.dec4 self._make_layer(channels[2], channels[1], order2) self.dec5 nn.Sequential( nn.Conv2d(channels[1], channels[0], 3, padding1), nn.GroupNorm(4, channels[0]) ) self.scab SC_Att_Bridge(channels[:-1]) self.final nn.Conv2d(channels[0], num_classes, 1) def _make_layer(self, in_ch, out_ch, order): return nn.Sequential( Block(in_ch, H_SS2Dpartial(H_SS2D, orderorder)), nn.Conv2d(in_ch, out_ch, 3, padding1), nn.GroupNorm(4, out_ch) )4.2 前向传播实现def forward(self, x): # 编码过程 e1 F.gelu(F.max_pool2d(self.enc1(x), 2)) e2 F.gelu(F.max_pool2d(self.enc2(e1), 2)) e3 F.gelu(F.max_pool2d(self.enc3(e2), 2)) e4 F.gelu(F.max_pool2d(self.enc4(e3), 2)) e5 F.gelu(F.max_pool2d(self.enc5(e4), 2)) e6 F.gelu(self.enc6(e5)) # 注意力桥接 e1, e2, e3, e4, e5 self.scab(e1, e2, e3, e4, e5) # 解码过程 d1 F.gelu(self.dec1(e6)) d1 d1 e5 d2 F.gelu(F.interpolate(self.dec2(d1), scale_factor2)) d2 d2 e4 d3 F.gelu(F.interpolate(self.dec3(d2), scale_factor2)) d3 d3 e3 d4 F.gelu(F.interpolate(self.dec4(d3), scale_factor2)) d4 d4 e2 d5 F.gelu(F.interpolate(self.dec5(d4), scale_factor2)) d5 d5 e1 out self.final(F.interpolate(d5, scale_factor2)) return torch.sigmoid(out)5. 训练技巧与实战建议5.1 数据预处理配置医学图像分割通常需要特定的数据增强策略train_transform A.Compose([ A.RandomRotate90(), A.Flip(), A.ElasticTransform(alpha120, sigma120*0.05, alpha_affine120*0.03), A.RandomGamma(gamma_limit(80,120)), A.Normalize(mean0.5, std0.5), ToTensorV2() ])5.2 损失函数选择推荐使用组合损失函数class HybridLoss(nn.Module): def __init__(self, alpha0.5): super().__init__() self.alpha alpha self.bce nn.BCEWithLogitsLoss() self.dice DiceLoss() def forward(self, pred, target): return self.alpha*self.bce(pred, target) \ (1-self.alpha)*self.dice(pred, target)5.3 训练参数配置model HVMUNet().cuda() optimizer torch.optim.AdamW(model.parameters(), lr1e-4, weight_decay1e-5) scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max100) loss_fn HybridLoss(alpha0.7) # 混合精度训练 scaler torch.cuda.amp.GradScaler()5.4 常见问题排查显存不足减小batch size使用梯度累积尝试混合精度训练训练不稳定检查学习率是否合适确认数据归一化是否正确尝试添加梯度裁剪性能不佳调整H-VSS模块的阶数修改通道注意力桥的实现方式增加数据增强的多样性

更多文章