Enhanced DeepLabv3+ implementation with multiple backbone options, attention mechanisms, and dual encoder architectures.
- Multiple Backbones: MobileNetV2, Xception, Swin Transformer, EfficientNet, PVTv2, ConvNeXt
- Dual Encoder Architectures: Combine CNN + Transformer for richer features
- SAM-style Prompt Attention: Learnable prompt tokens for task-specific focus
- Attention Modules: SE, CBAM, ECA with configurable positions
- Configurable Decoder: Adjustable decoder channels (256, 512, 1024)
- Pretrained Weights: Support via
timmlibrary - 1-Channel Input: Optimized for grayscale/medical images
pip install torch torchvision timm thopfrom nets.deeplabv3_plus_v2 import deeplabv3_plus
model = deeplabv3_plus(
num_classes=4,
backbone='swin_base',
decoder_channels=1024, # NEW: configurable decoder size
attention_block='eca',
attention_position='aspp_post'
)Combines EfficientNet + PVTv2 with Gated Fusion mechanism.
from nets.deeplabv3_plus_dual_v2 import DeepLabDualV2
model = DeepLabDualV2(
num_classes=4,
efficientnet_variant='b7', # b0-b7
pvtv2_variant='b2', # b0-b5
decoder_channels=512,
low_level_channels=128,
high_level_channels=512,
pretrained=True,
in_chans=1
)Uses Cross-Attention for deeper feature interaction between backbones.
from nets.deeplabv3_plus_dual_v3 import DeepLabDualV3
model = DeepLabDualV3(
num_classes=4,
efficientnet_variant='b7',
pvtv2_variant='b2',
decoder_channels=512,
num_heads=8, # Cross-attention heads
pretrained=True,
in_chans=1
)EfficientNet + DeepLab with learnable prompt attention.
from nets.deeplabv3_plus_prompt import DeepLabPrompt
model = DeepLabPrompt(
num_classes=4,
efficientnet_variant='b7',
num_prompts=8, # Learnable prompt tokens
decoder_channels=512,
pretrained=True,
in_chans=1
)| Model | Params | Description |
|---|---|---|
| Standard + Swin-Base | ~88M | Transformer backbone |
| Dual V2 (B7 + B2) | ~110M | EfficientNet + PVTv2 with Gated Fusion |
| Dual V3 (B7 + B2) | ~120M | Cross-Attention fusion |
| DeepLabPrompt | ~100M | SAM-style prompt attention |
| Parameter | Values | Description |
|---|---|---|
decoder_channels |
256, 512, 1024 | Decoder capacity (larger = more powerful) |
downsample_factor |
8, 16 | Feature resolution (8 = slower but detailed) |
num_prompts |
4, 8, 16 | Prompt tokens for SAM-style model |
mobilenet,xceptionswin_tiny,swin_small,swin_base
- EfficientNet:
b0tob7 - PVTv2:
b0tob5 - ConvNeXt:
tiny,small,base,large
| Position | Description |
|---|---|
none |
No attention |
aspp_pre |
After each ASPP branch |
aspp_post |
After ASPP concat |
decoder |
After decoder fusion |
import torch
from nets.deeplabv3_plus_dual_v2 import DeepLabDualV2
# Create model
model = DeepLabDualV2(num_classes=4, in_chans=1)
model = model.cuda()
# Forward pass
x = torch.randn(1, 1, 512, 512).cuda()
output = model(x)
print(output.shape) # torch.Size([1, 4, 512, 512])
# Get model info
info = model.get_model_info(input_size=(1, 1, 512, 512))
print(f"Params: {info['parameters_M']:.2f}M, GFLOPs: {info['gflops']:.2f}")- OOM with Dual V3: Use V2 (Gated Fusion) or reduce batch size
- Memory Issues: Try
decoder_channels=256or smaller backbone - Training Crashes: Adjust
num_workersvalue
MIT License