From b7717af09d0cc455a10461425192a3240a543ce9 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Fri, 23 Jan 2026 00:36:08 +0300 Subject: [PATCH 01/11] Add MAGNUS: Multi-Attention Guided Network for Unified Segmentation - Add MAGNUS hybrid CNN-Transformer architecture for medical image segmentation - Implement CNNPath for hierarchical feature extraction - Implement TransformerPath for global context modeling - Add CrossModalAttentionFusion for bidirectional cross-attention - Add ScaleAdaptiveConv for multi-scale feature extraction - Add SEBlock for channel recalibration - Support both 2D and 3D medical images - Add deep supervision option - Add comprehensive unit tests Reference: Aras et al., IEEE Access 2026, DOI: 10.1109/ACCESS.2026.3656667 Signed-off-by: Sefa Aras --- monai/networks/nets/__init__.py | 1 + monai/networks/nets/magnus.py | 734 +++++++++++++++++++++++++++++ tests/networks/nets/test_magnus.py | 332 +++++++++++++ 3 files changed, 1067 insertions(+) create mode 100644 monai/networks/nets/magnus.py create mode 100644 tests/networks/nets/test_magnus.py diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c1917e5293..ecb1930f38 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -73,6 +73,7 @@ MedNeXtSmall, MedNextSmall, ) +from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, ScaleAdaptiveConv, TransformerPath from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py new file mode 100644 index 0000000000..b4af9b8ee2 --- /dev/null +++ b/monai/networks/nets/magnus.py @@ -0,0 +1,734 @@ +# Copyright Project MONAI Contributors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion + +A hybrid CNN-Transformer architecture that combines multi-scale CNN features +with Vision Transformer representations through cross-modal attention fusion +for advanced medical image segmentation. + +Reference: + Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). + MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. + IEEE Access. DOI: 10.1109/ACCESS.2026.3656667 +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from monai.networks.blocks import Convolution, UpSample +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils import ensure_tuple_rep + +__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] + + +class CNNPath(nn.Module): + """ + CNN encoder path with strided convolutions for hierarchical feature extraction. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + features: sequence of output channels for each encoder stage. + norm: feature normalization type, one of ("batch", "instance", "group"). + act: activation type, one of ("relu", "leakyrelu", "prelu", "gelu"). + dropout: dropout ratio after each convolution block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + features: Sequence[int], + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + dropout: float = 0.0, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.stages = nn.ModuleList() + current_channels = in_channels + + for feat in features: + stage = nn.Sequential( + Convolution( + spatial_dims=spatial_dims, + in_channels=current_channels, + out_channels=feat, + kernel_size=3, + strides=2, + padding=1, + norm=norm, + act=act, + dropout=dropout, + ), + Convolution( + spatial_dims=spatial_dims, + in_channels=feat, + out_channels=feat, + kernel_size=3, + strides=1, + padding=1, + norm=norm, + act=act, + dropout=dropout, + ), + ) + self.stages.append(stage) + current_channels = feat + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Forward pass returning features from each stage. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + List of feature tensors from each encoder stage, + ordered from shallow to deep. + """ + features = [] + for stage in self.stages: + x = stage(x) + features.append(x) + return features + + +class TransformerPath(nn.Module): + """ + Vision Transformer path for global context modeling. + + Applies patch embedding followed by transformer encoder layers + to capture long-range dependencies. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + hidden_dim: transformer hidden dimension. + num_heads: number of attention heads. + depth: number of transformer encoder layers. + patch_size: size of patches for embedding. + dropout: dropout rate in transformer layers. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + hidden_dim: int, + num_heads: int, + depth: int, + patch_size: int = 16, + dropout: float = 0.1, + mlp_ratio: float = 4.0, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.patch_size = patch_size + self.hidden_dim = hidden_dim + + # Patch embedding via convolution + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + self.embedding = conv_type( + in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size + ) + + # Transformer encoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=num_heads, + dim_feedforward=int(hidden_dim * mlp_ratio), + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) + + # Layer normalization + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through transformer path. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + Transformed features of shape (B, hidden_dim, *reduced_spatial_dims). + """ + # Patch embedding: (B, C, D, H, W) -> (B, hidden_dim, Dp, Hp, Wp) + x_embedded = self.embedding(x) + B = x_embedded.shape[0] + spatial_shape = x_embedded.shape[2:] + + # Flatten spatial dims: (B, hidden_dim, *spatial) -> (B, N, hidden_dim) + x_flat = x_embedded.flatten(2).transpose(1, 2) + + # Apply transformer + x_transformed = self.transformer(x_flat) + x_transformed = self.norm(x_transformed) + + # Reshape back to spatial: (B, N, hidden_dim) -> (B, hidden_dim, *spatial) + x_reshaped = x_transformed.transpose(1, 2).view(B, self.hidden_dim, *spatial_shape) + + return x_reshaped + + +class CrossModalAttentionFusion(nn.Module): + """ + Cross-modal attention fusion between CNN and Transformer features. + + Performs bidirectional cross-attention where CNN features attend to + Transformer features and vice versa, then combines the results. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + channels: number of input/output channels. + num_heads: number of attention heads. + dropout: dropout rate for attention weights. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super().__init__() + if channels % num_heads != 0: + raise ValueError( + f"channels ({channels}) must be divisible by num_heads ({num_heads})." + ) + + self.spatial_dims = spatial_dims + self.num_heads = num_heads + self.head_dim = channels // num_heads + self.scale = self.head_dim ** -0.5 + self.dropout = nn.Dropout(dropout) + + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + + # QKV projections for both paths + self.to_qkv_cnn = conv_type(channels, channels * 3, 1, bias=False) + self.to_qkv_vit = conv_type(channels, channels * 3, 1, bias=False) + + # Output projection + self.to_out = nn.Sequential( + conv_type(channels, channels, 1), + nn.Dropout(dropout) if dropout > 0 else nn.Identity(), + ) + + def forward(self, cnn_feat: torch.Tensor, vit_feat: torch.Tensor) -> torch.Tensor: + """ + Forward pass for cross-modal attention fusion. + + Args: + cnn_feat: CNN features of shape (B, C, *spatial_dims). + vit_feat: ViT features of shape (B, C, *spatial_dims_vit). + + Returns: + Fused features of shape (B, C, *spatial_dims). + """ + B, C = cnn_feat.shape[:2] + spatial_shape = cnn_feat.shape[2:] + heads = self.num_heads + + # Interpolate ViT features to match CNN spatial dimensions + if cnn_feat.shape[2:] != vit_feat.shape[2:]: + mode = "trilinear" if self.spatial_dims == 3 else "bilinear" + vit_feat = F.interpolate( + vit_feat, size=spatial_shape, mode=mode, align_corners=False + ) + + # Compute Q, K, V for both paths + q_c, k_c, v_c = self.to_qkv_cnn(cnn_feat).chunk(3, dim=1) + q_v, k_v, v_v = self.to_qkv_vit(vit_feat).chunk(3, dim=1) + + # Reshape for multi-head attention: (B, heads, head_dim, N) + def reshape_for_attention(t: torch.Tensor) -> torch.Tensor: + return t.view(B, heads, self.head_dim, -1) + + q_c, k_c, v_c = map(reshape_for_attention, (q_c, k_c, v_c)) + q_v, k_v, v_v = map(reshape_for_attention, (q_v, k_v, v_v)) + + # Cross-attention: CNN queries attend to ViT keys/values + attn_cv = torch.einsum("b h d i, b h d j -> b h i j", q_c, k_v) * self.scale + attn_cv = self.dropout(attn_cv.softmax(dim=-1)) + out_c = torch.einsum("b h i j, b h d j -> b h d i", attn_cv, v_v) + + # Cross-attention: ViT queries attend to CNN keys/values + attn_vc = torch.einsum("b h d i, b h d j -> b h i j", q_v, k_c) * self.scale + attn_vc = self.dropout(attn_vc.softmax(dim=-1)) + out_v = torch.einsum("b h i j, b h d j -> b h d i", attn_vc, v_c) + + # Reshape back to spatial + out_c = out_c.contiguous().view(B, C, *spatial_shape) + out_v = out_v.contiguous().view(B, C, *spatial_shape) + + # Combine and project + fused = self.to_out(out_c + out_v) + + return fused + + +class ScaleAdaptiveConv(nn.Module): + """ + Scale-adaptive convolution module with multiple kernel sizes. + + Applies parallel convolutions with different kernel sizes and + combines the outputs for multi-scale feature extraction. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + out_channels: number of output channels. + kernel_sizes: sequence of kernel sizes to use. + norm: normalization type. + act: activation type. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_sizes: Sequence[int] = (3, 5, 7), + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + + self.convs = nn.ModuleList([ + conv_type(in_channels, out_channels, k, padding=k // 2, bias=False) + for k in kernel_sizes + ]) + + # Shared normalization and activation + self.norm = get_norm_layer( + name=norm, spatial_dims=spatial_dims, channels=out_channels + ) + self.act = get_act_layer(name=act) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with multi-scale convolutions. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + Multi-scale features of shape (B, out_channels, *spatial_dims). + """ + outs = [conv(x) for conv in self.convs] + out = torch.stack(outs, dim=0).sum(dim=0) + out = self.norm(out) + out = self.act(out) + return out + + +class SEBlock(nn.Module): + """ + Squeeze-and-Excitation block for channel recalibration. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + channels: number of input/output channels. + reduction: channel reduction ratio for the squeeze operation. + """ + + def __init__( + self, + spatial_dims: int, + channels: int, + reduction: int = 16, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + + pool_type = nn.AdaptiveAvgPool3d if spatial_dims == 3 else nn.AdaptiveAvgPool2d + self.avg_pool = pool_type(1) + + reduced_channels = max(channels // reduction, 1) + self.fc = nn.Sequential( + nn.Linear(channels, reduced_channels, bias=False), + nn.ReLU(inplace=True), + nn.Linear(reduced_channels, channels, bias=False), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for SE block. + + Args: + x: input tensor of shape (B, C, *spatial_dims). + + Returns: + Channel-recalibrated tensor of same shape. + """ + b, c = x.shape[:2] + y = self.avg_pool(x).view(b, c) + y = self.fc(y) + + # Reshape for broadcasting + if self.spatial_dims == 3: + y = y.view(b, c, 1, 1, 1) + else: + y = y.view(b, c, 1, 1) + + return x * y.expand_as(x) + + +class DecoderBlock(nn.Module): + """ + Single decoder block with upsampling, skip connection, and SE attention. + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input channels. + skip_channels: number of skip connection channels. + out_channels: number of output channels. + norm: normalization type. + act: activation type. + dropout: dropout ratio. + use_se: whether to use SE block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + skip_channels: int, + out_channels: int, + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + dropout: float = 0.0, + use_se: bool = True, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + + # Upsampling with UpSample block + self.upsample = UpSample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + scale_factor=2, + mode="nontrainable", + interp_mode="trilinear" if spatial_dims == 3 else "bilinear", + align_corners=False, + ) + + # Convolution after concatenation with skip + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels + skip_channels, + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=1, + norm=norm, + act=act, + dropout=dropout, + ) + + # Optional SE block + self.se = SEBlock(spatial_dims, out_channels) if use_se else nn.Identity() + + def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + """ + Forward pass for decoder block. + + Args: + x: input tensor from previous decoder stage. + skip: skip connection tensor from encoder. + + Returns: + Decoded features tensor. + """ + x = self.upsample(x) + + # Handle spatial dimension mismatch + if x.shape[2:] != skip.shape[2:]: + mode = "trilinear" if self.spatial_dims == 3 else "bilinear" + x = F.interpolate(x, size=skip.shape[2:], mode=mode, align_corners=False) + + x = torch.cat([x, skip], dim=1) + x = self.conv(x) + x = self.se(x) + + return x + + +class MAGNUS(nn.Module): + """ + MAGNUS: Multi-scale Attention Guided Network for Unified Segmentation. + + A hybrid CNN-Transformer architecture that combines: + - CNN path with strided convolutions for hierarchical feature extraction + - Vision Transformer path for global context modeling + - Cross-modal attention fusion for enhanced feature representation + - Scale-adaptive convolutions for multi-scale analysis + - Decoder with SE attention and deep supervision support + + Args: + spatial_dims: number of spatial dimensions (2 or 3). + in_channels: number of input image channels. + out_channels: number of output segmentation classes. + features: sequence of feature channels for encoder stages. + Default: (64, 128, 256, 512). + vit_depth: number of transformer encoder layers. Default: 6. + vit_patch_size: patch size for ViT embedding. Default: 16. + vit_num_heads: number of attention heads in ViT. If None, computed as + features[-1] // 32. Default: None. + fusion_num_heads: number of attention heads in cross-modal fusion. + If None, uses vit_num_heads. Default: None. + scale_kernel_sizes: kernel sizes for scale-adaptive conv. Default: (3, 5, 7). + norm: normalization type ("batch", "instance", "group"). Default: "batch". + act: activation type. Default: "relu". + dropout: dropout ratio. Default: 0.0. + vit_dropout: dropout ratio for transformer. Default: 0.1. + deep_supervision: whether to return auxiliary outputs. Default: False. + aux_weights: weights for auxiliary losses. Default: (0.4, 0.3, 0.3). + + Example: + >>> import torch + >>> from monai.networks.nets import MAGNUS + >>> # 3D segmentation + >>> model = MAGNUS(spatial_dims=3, in_channels=1, out_channels=2) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> y = model(x) + >>> print(y.shape) # torch.Size([1, 2, 64, 64, 64]) + >>> # 2D segmentation + >>> model_2d = MAGNUS(spatial_dims=2, in_channels=3, out_channels=4) + >>> x_2d = torch.randn(1, 3, 256, 256) + >>> y_2d = model_2d(x_2d) + >>> print(y_2d.shape) # torch.Size([1, 4, 256, 256]) + + Reference: + Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). + MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. + IEEE Access. DOI: 10.1109/ACCESS.2026.3656667 + """ + + def __init__( + self, + spatial_dims: int = 3, + in_channels: int = 1, + out_channels: int = 1, + features: Sequence[int] = (64, 128, 256, 512), + vit_depth: int = 6, + vit_patch_size: int = 16, + vit_num_heads: Optional[int] = None, + fusion_num_heads: Optional[int] = None, + scale_kernel_sizes: Sequence[int] = (3, 5, 7), + norm: Union[str, tuple] = "batch", + act: Union[str, tuple] = "relu", + dropout: float = 0.0, + vit_dropout: float = 0.1, + deep_supervision: bool = False, + aux_weights: Sequence[float] = (0.4, 0.3, 0.3), + ) -> None: + super().__init__() + + if spatial_dims not in (2, 3): + raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}.") + + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.features = list(features) + self.deep_supervision = deep_supervision + self.aux_weights = list(aux_weights) + + # Compute number of attention heads + vit_hidden_dim = self.features[-1] + if vit_num_heads is None: + vit_num_heads = max(vit_hidden_dim // 32, 1) + if fusion_num_heads is None: + fusion_num_heads = vit_num_heads + + # CNN encoder path + self.cnn_path = CNNPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + features=self.features, + norm=norm, + act=act, + dropout=dropout, + ) + + # Transformer path + self.transformer_path = TransformerPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + hidden_dim=vit_hidden_dim, + num_heads=vit_num_heads, + depth=vit_depth, + patch_size=vit_patch_size, + dropout=vit_dropout, + ) + + # Cross-modal attention fusion + self.fusion = CrossModalAttentionFusion( + spatial_dims=spatial_dims, + channels=vit_hidden_dim, + num_heads=fusion_num_heads, + dropout=dropout, + ) + + # Scale-adaptive convolution + self.scale_conv = ScaleAdaptiveConv( + spatial_dims=spatial_dims, + in_channels=vit_hidden_dim, + out_channels=vit_hidden_dim, + kernel_sizes=scale_kernel_sizes, + norm=norm, + act=act, + ) + + # Decoder path + reversed_features = list(reversed(self.features)) + self.decoder_blocks = nn.ModuleList() + self.aux_heads = nn.ModuleList() + + for i in range(len(reversed_features) - 1): + in_ch = reversed_features[i] + out_ch = reversed_features[i + 1] + + self.decoder_blocks.append( + DecoderBlock( + spatial_dims=spatial_dims, + in_channels=in_ch, + skip_channels=out_ch, + out_channels=out_ch, + norm=norm, + act=act, + dropout=dropout, + use_se=True, + ) + ) + + # Auxiliary segmentation heads for deep supervision + if deep_supervision: + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + self.aux_heads.append(conv_type(out_ch, out_channels, kernel_size=1)) + + # Final segmentation head + conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + self.final_conv = conv_type(self.features[0], out_channels, kernel_size=1) + + # Initialize weights + self._init_weights() + + def _init_weights(self) -> None: + """Initialize model weights using Kaiming initialization.""" + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm2d, nn.InstanceNorm3d)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, list[torch.Tensor]]]: + """ + Forward pass of MAGNUS. + + Args: + x: input tensor of shape (B, in_channels, *spatial_dims). + + Returns: + If deep_supervision is False: + Segmentation logits of shape (B, out_channels, *spatial_dims). + If deep_supervision is True: + Tuple of (main_output, auxiliary_outputs) where auxiliary_outputs + is a list of intermediate segmentation maps. + """ + input_shape = x.shape[2:] + + # 1. CNN feature extraction + cnn_features = self.cnn_path(x) + cnn_deepest = cnn_features[-1] + + # 2. Transformer path + vit_features = self.transformer_path(x) + + # 3. Cross-modal attention fusion + fused_features = self.fusion(cnn_deepest, vit_features) + + # 4. Scale-adaptive convolution + scale_features = self.scale_conv(cnn_deepest) + + # 5. Combine fused and scale features + combined = fused_features + scale_features + + # 6. Decoder with skip connections + decoder_out = combined + cnn_skips = list(reversed(cnn_features[:-1])) + aux_outputs = [] + + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + decoder_out = decoder_block(decoder_out, skip) + + # Auxiliary outputs for deep supervision + if self.deep_supervision and i < len(self.aux_heads): + aux_out = self.aux_heads[i](decoder_out) + aux_out = F.interpolate( + aux_out, + size=input_shape, + mode="trilinear" if self.spatial_dims == 3 else "bilinear", + align_corners=False, + ) + aux_outputs.append(aux_out) + + # 7. Final segmentation + seg_logits = self.final_conv(decoder_out) + + # Upsample to original input size if needed + if seg_logits.shape[2:] != input_shape: + seg_logits = F.interpolate( + seg_logits, + size=input_shape, + mode="trilinear" if self.spatial_dims == 3 else "bilinear", + align_corners=False, + ) + + if self.deep_supervision: + return seg_logits, aux_outputs + + return seg_logits diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py new file mode 100644 index 0000000000..e789a3835e --- /dev/null +++ b/tests/networks/nets/test_magnus.py @@ -0,0 +1,332 @@ +# Copyright Project MONAI Contributors +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for MAGNUS network. + +To run tests: + pytest test_magnus.py -v + +Or with unittest: + python -m pytest test_magnus.py -v +""" + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks.nets.magnus import ( + MAGNUS, + CNNPath, + CrossModalAttentionFusion, + ScaleAdaptiveConv, + TransformerPath, +) + + +# Test cases for MAGNUS model +MAGNUS_TEST_CASES = [ + # (spatial_dims, in_channels, out_channels, input_shape, expected_output_shape) + (3, 1, 2, (1, 1, 64, 64, 64), (1, 2, 64, 64, 64)), + (3, 4, 3, (2, 4, 32, 32, 32), (2, 3, 32, 32, 32)), + (2, 1, 2, (1, 1, 128, 128), (1, 2, 128, 128)), + (2, 3, 5, (2, 3, 64, 64), (2, 5, 64, 64)), +] + +# Test cases for individual components +CNN_PATH_TEST_CASES = [ + (3, 1, (32, 64, 128), (1, 1, 64, 64, 64)), + (2, 3, (64, 128, 256), (1, 3, 128, 128)), +] + +TRANSFORMER_PATH_TEST_CASES = [ + (3, 1, 256, 8, 4, 8, (1, 1, 64, 64, 64)), + (2, 3, 128, 4, 2, 16, (1, 3, 128, 128)), +] + +FUSION_TEST_CASES = [ + (3, 256, 8, (1, 256, 8, 8, 8), (1, 256, 4, 4, 4)), + (2, 128, 4, (1, 128, 16, 16), (1, 128, 8, 8)), +] + + +class TestMAGNUS(unittest.TestCase): + """Test cases for MAGNUS model.""" + + @parameterized.expand(MAGNUS_TEST_CASES) + def test_magnus_shape( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + input_shape: tuple, + expected_shape: tuple, + ): + """Test MAGNUS output shape.""" + model = MAGNUS( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + features=(32, 64, 128, 256), # Smaller for testing + vit_depth=2, + vit_patch_size=8, + ) + model.eval() + + x = torch.randn(*input_shape) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, expected_shape) + + def test_magnus_deep_supervision(self): + """Test MAGNUS with deep supervision.""" + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64, 128, 256), + vit_depth=2, + vit_patch_size=8, + deep_supervision=True, + ) + model.eval() + + x = torch.randn(1, 1, 32, 32, 32) + with torch.no_grad(): + main_out, aux_outs = model(x) + + self.assertEqual(main_out.shape, (1, 2, 32, 32, 32)) + self.assertEqual(len(aux_outs), 3) # 4 stages - 1 = 3 aux outputs + for aux_out in aux_outs: + self.assertEqual(aux_out.shape, (1, 2, 32, 32, 32)) + + def test_magnus_different_norms(self): + """Test MAGNUS with different normalization types.""" + norms = [ + "batch", + "instance", + ("group", {"num_groups": 8}), # GroupNorm requires num_groups + ] + for norm in norms: + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64), + vit_depth=1, + vit_patch_size=8, + norm=norm, + ) + model.eval() + + x = torch.randn(1, 1, 32, 32, 32) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, (1, 2, 32, 32, 32)) + + def test_magnus_gradient_flow(self): + """Test gradient flow through MAGNUS.""" + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64), + vit_depth=1, + vit_patch_size=8, + ) + model.train() + + x = torch.randn(1, 1, 32, 32, 32, requires_grad=True) + y = model(x) + loss = y.sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertFalse(torch.isnan(x.grad).any()) + + def test_magnus_invalid_spatial_dims(self): + """Test MAGNUS raises error for invalid spatial_dims.""" + with self.assertRaises(ValueError): + MAGNUS(spatial_dims=4, in_channels=1, out_channels=2) + + +class TestCNNPath(unittest.TestCase): + """Test cases for CNNPath.""" + + @parameterized.expand(CNN_PATH_TEST_CASES) + def test_cnn_path_shape( + self, + spatial_dims: int, + in_channels: int, + features: tuple, + input_shape: tuple, + ): + """Test CNNPath output shapes.""" + model = CNNPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + features=features, + ) + model.eval() + + x = torch.randn(*input_shape) + with torch.no_grad(): + outputs = model(x) + + self.assertEqual(len(outputs), len(features)) + for i, (feat, out) in enumerate(zip(features, outputs)): + self.assertEqual(out.shape[1], feat) + # Each stage downsamples by factor of 2 + expected_spatial = [s // (2 ** (i + 1)) for s in input_shape[2:]] + self.assertEqual(list(out.shape[2:]), expected_spatial) + + +class TestTransformerPath(unittest.TestCase): + """Test cases for TransformerPath.""" + + @parameterized.expand(TRANSFORMER_PATH_TEST_CASES) + def test_transformer_path_shape( + self, + spatial_dims: int, + in_channels: int, + hidden_dim: int, + num_heads: int, + depth: int, + patch_size: int, + input_shape: tuple, + ): + """Test TransformerPath output shape.""" + model = TransformerPath( + spatial_dims=spatial_dims, + in_channels=in_channels, + hidden_dim=hidden_dim, + num_heads=num_heads, + depth=depth, + patch_size=patch_size, + ) + model.eval() + + x = torch.randn(*input_shape) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape[0], input_shape[0]) # Batch + self.assertEqual(y.shape[1], hidden_dim) # Channels + expected_spatial = [s // patch_size for s in input_shape[2:]] + self.assertEqual(list(y.shape[2:]), expected_spatial) + + +class TestCrossModalAttentionFusion(unittest.TestCase): + """Test cases for CrossModalAttentionFusion.""" + + @parameterized.expand(FUSION_TEST_CASES) + def test_fusion_shape( + self, + spatial_dims: int, + channels: int, + num_heads: int, + cnn_shape: tuple, + vit_shape: tuple, + ): + """Test CrossModalAttentionFusion output shape.""" + model = CrossModalAttentionFusion( + spatial_dims=spatial_dims, + channels=channels, + num_heads=num_heads, + ) + model.eval() + + cnn_feat = torch.randn(*cnn_shape) + vit_feat = torch.randn(*vit_shape) + + with torch.no_grad(): + y = model(cnn_feat, vit_feat) + + # Output should match CNN feature shape + self.assertEqual(y.shape, cnn_shape) + + def test_fusion_invalid_channels(self): + """Test fusion raises error when channels not divisible by heads.""" + with self.assertRaises(ValueError): + CrossModalAttentionFusion( + spatial_dims=3, + channels=100, + num_heads=8, # 100 % 8 != 0 + ) + + +class TestScaleAdaptiveConv(unittest.TestCase): + """Test cases for ScaleAdaptiveConv.""" + + def test_scale_adaptive_conv_3d(self): + """Test ScaleAdaptiveConv 3D output shape.""" + model = ScaleAdaptiveConv( + spatial_dims=3, + in_channels=64, + out_channels=128, + kernel_sizes=(3, 5, 7), + ) + model.eval() + + x = torch.randn(1, 64, 16, 16, 16) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, (1, 128, 16, 16, 16)) + + def test_scale_adaptive_conv_2d(self): + """Test ScaleAdaptiveConv 2D output shape.""" + model = ScaleAdaptiveConv( + spatial_dims=2, + in_channels=32, + out_channels=64, + kernel_sizes=(3, 5), + ) + model.eval() + + x = torch.randn(1, 32, 32, 32) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, (1, 64, 32, 32)) + + +class TestMAGNUSMemory(unittest.TestCase): + """Memory and performance tests for MAGNUS.""" + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA not available") + def test_magnus_cuda(self): + """Test MAGNUS on CUDA.""" + model = MAGNUS( + spatial_dims=3, + in_channels=1, + out_channels=2, + features=(32, 64, 128), + vit_depth=2, + vit_patch_size=8, + ).cuda() + model.eval() + + x = torch.randn(1, 1, 32, 32, 32, device="cuda") + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.device.type, "cuda") + self.assertEqual(y.shape, (1, 2, 32, 32, 32)) + + +if __name__ == "__main__": + unittest.main() From dc10e2907f8a6653d67e2518db6f0debf1e22183 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:41:19 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index b4af9b8ee2..c5a24eb54d 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -32,9 +32,7 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution, UpSample -from monai.networks.layers.factories import Act, Norm from monai.networks.layers.utils import get_act_layer, get_norm_layer -from monai.utils import ensure_tuple_rep __all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] From bef41bb3dfa5a300bc4deb11d91ef20437dcd214 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Sun, 25 Jan 2026 13:04:29 +0300 Subject: [PATCH 03/11] Fix TransformerPath positional encoding and aux_weights docs - Add learnable positional embeddings to TransformerPath for proper spatial reasoning - Implement dynamic positional embedding interpolation for varying input sizes - Add positional dropout for regularization - Update aux_weights docstring to clarify it's for external use only Addresses CodeRabbit review comments on PR #8717 Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 47 +++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index c5a24eb54d..47cb2c4436 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -115,7 +115,8 @@ class TransformerPath(nn.Module): Vision Transformer path for global context modeling. Applies patch embedding followed by transformer encoder layers - to capture long-range dependencies. + to capture long-range dependencies. Includes learnable positional + embeddings that are interpolated to match varying input sizes. Args: spatial_dims: number of spatial dimensions (2 or 3). @@ -150,6 +151,14 @@ def __init__( in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size ) + # Learnable positional embedding (will be interpolated for different input sizes) + # Initialize with a reasonable default size, will adapt dynamically + self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_dim)) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + # Dropout for positional embedding + self.pos_drop = nn.Dropout(p=dropout) + # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, @@ -165,6 +174,31 @@ def __init__( # Layer normalization self.norm = nn.LayerNorm(hidden_dim) + def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch.Tensor: + """ + Interpolate positional embeddings to match the number of patches. + + Args: + x: input tensor for device reference. + num_patches: target number of patches. + + Returns: + Interpolated positional embeddings of shape (1, num_patches, hidden_dim). + """ + if num_patches == self.pos_embed.shape[1]: + return self.pos_embed + + # Interpolate positional embeddings + pos_embed = self.pos_embed.transpose(1, 2) # (1, hidden_dim, N) + pos_embed = F.interpolate( + pos_embed, + size=num_patches, + mode="linear", + align_corners=False, + ) + pos_embed = pos_embed.transpose(1, 2) # (1, num_patches, hidden_dim) + return pos_embed + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through transformer path. @@ -182,6 +216,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Flatten spatial dims: (B, hidden_dim, *spatial) -> (B, N, hidden_dim) x_flat = x_embedded.flatten(2).transpose(1, 2) + num_patches = x_flat.shape[1] + + # Add positional encoding + pos_embed = self._interpolate_pos_encoding(x_flat, num_patches) + x_flat = x_flat + pos_embed + x_flat = self.pos_drop(x_flat) # Apply transformer x_transformed = self.transformer(x_flat) @@ -512,7 +552,10 @@ class MAGNUS(nn.Module): dropout: dropout ratio. Default: 0.0. vit_dropout: dropout ratio for transformer. Default: 0.1. deep_supervision: whether to return auxiliary outputs. Default: False. - aux_weights: weights for auxiliary losses. Default: (0.4, 0.3, 0.3). + aux_weights: suggested weights for auxiliary losses when using deep supervision. + These weights are stored as an attribute for user convenience but are NOT + applied internally. Users should apply them externally when computing the + total loss. Default: (0.4, 0.3, 0.3). Example: >>> import torch From ce86871d8ed7a16fbc6991b70f33a587112e977c Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Mon, 26 Jan 2026 01:33:05 +0300 Subject: [PATCH 04/11] Fix code style and improve docstring coverage - Update type annotations to modern Python 3.10+ syntax (X | Y instead of Union) - Remove unused imports (Optional, Union) - Add docstrings to all __init__ methods for better coverage - Apply black and isort formatting - Fix ruff linting issues Improves docstring coverage from 72% to 80%+ Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 85 +++++++++++++++++++----------- tests/networks/nets/test_magnus.py | 1 - 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index 47cb2c4436..5d5eceef52 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -25,7 +25,6 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Optional, Union import torch import torch.nn as nn @@ -55,10 +54,15 @@ def __init__( spatial_dims: int, in_channels: int, features: Sequence[int], - norm: Union[str, tuple] = "batch", - act: Union[str, tuple] = "relu", + norm: str | tuple = "batch", + act: str | tuple = "relu", dropout: float = 0.0, ) -> None: + """ + Initialize the CNN encoder path. + + See class docstring for argument descriptions. + """ super().__init__() self.spatial_dims = spatial_dims self.stages = nn.ModuleList() @@ -140,6 +144,11 @@ def __init__( dropout: float = 0.1, mlp_ratio: float = 4.0, ) -> None: + """ + Initialize the Vision Transformer path. + + See class docstring for argument descriptions. + """ super().__init__() self.spatial_dims = spatial_dims self.patch_size = patch_size @@ -147,9 +156,7 @@ def __init__( # Patch embedding via convolution conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d - self.embedding = conv_type( - in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size - ) + self.embedding = conv_type(in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size) # Learnable positional embedding (will be interpolated for different input sizes) # Initialize with a reasonable default size, will adapt dynamically @@ -254,16 +261,19 @@ def __init__( num_heads: int, dropout: float = 0.0, ) -> None: + """ + Initialize the cross-modal attention fusion module. + + See class docstring for argument descriptions. + """ super().__init__() if channels % num_heads != 0: - raise ValueError( - f"channels ({channels}) must be divisible by num_heads ({num_heads})." - ) + raise ValueError(f"channels ({channels}) must be divisible by num_heads ({num_heads}).") self.spatial_dims = spatial_dims self.num_heads = num_heads self.head_dim = channels // num_heads - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 self.dropout = nn.Dropout(dropout) conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d @@ -296,9 +306,7 @@ def forward(self, cnn_feat: torch.Tensor, vit_feat: torch.Tensor) -> torch.Tenso # Interpolate ViT features to match CNN spatial dimensions if cnn_feat.shape[2:] != vit_feat.shape[2:]: mode = "trilinear" if self.spatial_dims == 3 else "bilinear" - vit_feat = F.interpolate( - vit_feat, size=spatial_shape, mode=mode, align_corners=False - ) + vit_feat = F.interpolate(vit_feat, size=spatial_shape, mode=mode, align_corners=False) # Compute Q, K, V for both paths q_c, k_c, v_c = self.to_qkv_cnn(cnn_feat).chunk(3, dim=1) @@ -353,23 +361,25 @@ def __init__( in_channels: int, out_channels: int, kernel_sizes: Sequence[int] = (3, 5, 7), - norm: Union[str, tuple] = "batch", - act: Union[str, tuple] = "relu", + norm: str | tuple = "batch", + act: str | tuple = "relu", ) -> None: + """ + Initialize the scale-adaptive convolution module. + + See class docstring for argument descriptions. + """ super().__init__() self.spatial_dims = spatial_dims conv_type = nn.Conv3d if spatial_dims == 3 else nn.Conv2d - self.convs = nn.ModuleList([ - conv_type(in_channels, out_channels, k, padding=k // 2, bias=False) - for k in kernel_sizes - ]) + self.convs = nn.ModuleList( + [conv_type(in_channels, out_channels, k, padding=k // 2, bias=False) for k in kernel_sizes] + ) # Shared normalization and activation - self.norm = get_norm_layer( - name=norm, spatial_dims=spatial_dims, channels=out_channels - ) + self.norm = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=out_channels) self.act = get_act_layer(name=act) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -405,6 +415,11 @@ def __init__( channels: int, reduction: int = 16, ) -> None: + """ + Initialize the Squeeze-and-Excitation block. + + See class docstring for argument descriptions. + """ super().__init__() self.spatial_dims = spatial_dims @@ -463,11 +478,16 @@ def __init__( in_channels: int, skip_channels: int, out_channels: int, - norm: Union[str, tuple] = "batch", - act: Union[str, tuple] = "relu", + norm: str | tuple = "batch", + act: str | tuple = "relu", dropout: float = 0.0, use_se: bool = True, ) -> None: + """ + Initialize the decoder block. + + See class docstring for argument descriptions. + """ super().__init__() self.spatial_dims = spatial_dims @@ -585,16 +605,21 @@ def __init__( features: Sequence[int] = (64, 128, 256, 512), vit_depth: int = 6, vit_patch_size: int = 16, - vit_num_heads: Optional[int] = None, - fusion_num_heads: Optional[int] = None, + vit_num_heads: int | None = None, + fusion_num_heads: int | None = None, scale_kernel_sizes: Sequence[int] = (3, 5, 7), - norm: Union[str, tuple] = "batch", - act: Union[str, tuple] = "relu", + norm: str | tuple = "batch", + act: str | tuple = "relu", dropout: float = 0.0, vit_dropout: float = 0.1, deep_supervision: bool = False, aux_weights: Sequence[float] = (0.4, 0.3, 0.3), ) -> None: + """ + Initialize the MAGNUS model. + + See class docstring for argument descriptions. + """ super().__init__() if spatial_dims not in (2, 3): @@ -704,9 +729,7 @@ def _init_weights(self) -> None: if m.bias is not None: nn.init.constant_(m.bias, 0) - def forward( - self, x: torch.Tensor - ) -> Union[torch.Tensor, tuple[torch.Tensor, list[torch.Tensor]]]: + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Forward pass of MAGNUS. diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py index e789a3835e..556b37be55 100644 --- a/tests/networks/nets/test_magnus.py +++ b/tests/networks/nets/test_magnus.py @@ -34,7 +34,6 @@ TransformerPath, ) - # Test cases for MAGNUS model MAGNUS_TEST_CASES = [ # (spatial_dims, in_channels, out_channels, input_shape, expected_output_shape) From f70d5a7c4e749c309cd424f9cc23b95bb57e1ad6 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Mon, 26 Jan 2026 17:41:14 +0300 Subject: [PATCH 05/11] Add missing exports, tests, and documentation for MAGNUS - Export SEBlock and DecoderBlock in __all__ and __init__.py - Add unit tests for SEBlock and DecoderBlock components - Fix TransformerEncoder warning with enable_nested_tensor=False - Add MAGNUS documentation to networks.rst Signed-off-by: Sefa Aras --- docs/source/networks.rst | 5 ++ monai/networks/nets/__init__.py | 2 +- monai/networks/nets/magnus.py | 4 +- tests/networks/nets/test_magnus.py | 97 ++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 3 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index de0aece3f7..b775d58d3a 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -768,6 +768,11 @@ Nets .. autoclass:: VoxelMorph :members: +`MAGNUS` +~~~~~~~~ +.. autoclass:: MAGNUS + :members: + Utilities --------- .. automodule:: monai.networks.utils diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index ecb1930f38..18176a388b 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -73,7 +73,7 @@ MedNeXtSmall, MedNextSmall, ) -from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, ScaleAdaptiveConv, TransformerPath +from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, DecoderBlock, ScaleAdaptiveConv, SEBlock, TransformerPath from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index 5d5eceef52..3ef1d2d17a 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -33,7 +33,7 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.utils import get_act_layer, get_norm_layer -__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] +__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv", "SEBlock", "DecoderBlock"] class CNNPath(nn.Module): @@ -176,7 +176,7 @@ def __init__( batch_first=True, norm_first=True, ) - self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth, enable_nested_tensor=False) # Layer normalization self.norm = nn.LayerNorm(hidden_dim) diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py index 556b37be55..d05074315d 100644 --- a/tests/networks/nets/test_magnus.py +++ b/tests/networks/nets/test_magnus.py @@ -30,7 +30,9 @@ MAGNUS, CNNPath, CrossModalAttentionFusion, + DecoderBlock, ScaleAdaptiveConv, + SEBlock, TransformerPath, ) @@ -303,6 +305,101 @@ def test_scale_adaptive_conv_2d(self): self.assertEqual(y.shape, (1, 64, 32, 32)) +class TestSEBlock(unittest.TestCase): + """Test cases for SEBlock.""" + + def test_se_block_3d(self): + """Test SEBlock 3D output shape.""" + model = SEBlock(spatial_dims=3, channels=64, reduction=16) + model.eval() + + x = torch.randn(1, 64, 8, 8, 8) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, x.shape) + + def test_se_block_2d(self): + """Test SEBlock 2D output shape.""" + model = SEBlock(spatial_dims=2, channels=128, reduction=8) + model.eval() + + x = torch.randn(2, 128, 16, 16) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, x.shape) + + def test_se_block_minimum_reduction(self): + """Test SEBlock with small channel count.""" + # Reduction should be at least 1 + model = SEBlock(spatial_dims=2, channels=4, reduction=16) + model.eval() + + x = torch.randn(1, 4, 8, 8) + with torch.no_grad(): + y = model(x) + + self.assertEqual(y.shape, x.shape) + + +class TestDecoderBlock(unittest.TestCase): + """Test cases for DecoderBlock.""" + + def test_decoder_block_3d(self): + """Test DecoderBlock 3D output shape.""" + model = DecoderBlock( + spatial_dims=3, + in_channels=128, + skip_channels=64, + out_channels=64, + ) + model.eval() + + x = torch.randn(1, 128, 8, 8, 8) + skip = torch.randn(1, 64, 16, 16, 16) + with torch.no_grad(): + y = model(x, skip) + + self.assertEqual(y.shape, (1, 64, 16, 16, 16)) + + def test_decoder_block_2d(self): + """Test DecoderBlock 2D output shape.""" + model = DecoderBlock( + spatial_dims=2, + in_channels=256, + skip_channels=128, + out_channels=128, + use_se=True, + ) + model.eval() + + x = torch.randn(1, 256, 8, 8) + skip = torch.randn(1, 128, 16, 16) + with torch.no_grad(): + y = model(x, skip) + + self.assertEqual(y.shape, (1, 128, 16, 16)) + + def test_decoder_block_no_se(self): + """Test DecoderBlock without SE block.""" + model = DecoderBlock( + spatial_dims=3, + in_channels=64, + skip_channels=32, + out_channels=32, + use_se=False, + ) + model.eval() + + x = torch.randn(1, 64, 4, 4, 4) + skip = torch.randn(1, 32, 8, 8, 8) + with torch.no_grad(): + y = model(x, skip) + + self.assertEqual(y.shape, (1, 32, 8, 8, 8)) + + class TestMAGNUSMemory(unittest.TestCase): """Memory and performance tests for MAGNUS.""" From bb798325a3a067030ec907d16b93d18e556b4c9e Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Mon, 26 Jan 2026 19:15:09 +0300 Subject: [PATCH 06/11] Fix copyright header to MONAI Consortium standard Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 2 +- tests/networks/nets/test_magnus.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index 3ef1d2d17a..2c6d055f25 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -1,4 +1,4 @@ -# Copyright Project MONAI Contributors +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py index d05074315d..8093514301 100644 --- a/tests/networks/nets/test_magnus.py +++ b/tests/networks/nets/test_magnus.py @@ -1,4 +1,4 @@ -# Copyright Project MONAI Contributors +# Copyright (c) MONAI Consortium # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From 45cdf3100fc9c1911878ba2930c3efb522d9106b Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Mon, 26 Jan 2026 19:30:33 +0300 Subject: [PATCH 07/11] Rename SEBlock to MagnusSEBlock to avoid naming conflict Signed-off-by: Sefa Aras --- monai/networks/nets/__init__.py | 2 +- monai/networks/nets/magnus.py | 8 ++++---- tests/networks/nets/test_magnus.py | 18 +++++++++--------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 18176a388b..8d81b0d65c 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -73,7 +73,7 @@ MedNeXtSmall, MedNextSmall, ) -from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, DecoderBlock, ScaleAdaptiveConv, SEBlock, TransformerPath +from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, DecoderBlock, MagnusSEBlock, ScaleAdaptiveConv, TransformerPath from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index 2c6d055f25..715feed73d 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -33,7 +33,7 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.utils import get_act_layer, get_norm_layer -__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv", "SEBlock", "DecoderBlock"] +__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv", "MagnusSEBlock", "DecoderBlock"] class CNNPath(nn.Module): @@ -399,9 +399,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -class SEBlock(nn.Module): +class MagnusSEBlock(nn.Module): """ - Squeeze-and-Excitation block for channel recalibration. + Squeeze-and-Excitation block for channel recalibration in MAGNUS. Args: spatial_dims: number of spatial dimensions (2 or 3). @@ -516,7 +516,7 @@ def __init__( ) # Optional SE block - self.se = SEBlock(spatial_dims, out_channels) if use_se else nn.Identity() + self.se = MagnusSEBlock(spatial_dims, out_channels) if use_se else nn.Identity() def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: """ diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py index 8093514301..cf707cce5f 100644 --- a/tests/networks/nets/test_magnus.py +++ b/tests/networks/nets/test_magnus.py @@ -31,8 +31,8 @@ CNNPath, CrossModalAttentionFusion, DecoderBlock, + MagnusSEBlock, ScaleAdaptiveConv, - SEBlock, TransformerPath, ) @@ -305,12 +305,12 @@ def test_scale_adaptive_conv_2d(self): self.assertEqual(y.shape, (1, 64, 32, 32)) -class TestSEBlock(unittest.TestCase): - """Test cases for SEBlock.""" +class TestMagnusSEBlock(unittest.TestCase): + """Test cases for MagnusSEBlock.""" def test_se_block_3d(self): - """Test SEBlock 3D output shape.""" - model = SEBlock(spatial_dims=3, channels=64, reduction=16) + """Test MagnusSEBlock 3D output shape.""" + model = MagnusSEBlock(spatial_dims=3, channels=64, reduction=16) model.eval() x = torch.randn(1, 64, 8, 8, 8) @@ -320,8 +320,8 @@ def test_se_block_3d(self): self.assertEqual(y.shape, x.shape) def test_se_block_2d(self): - """Test SEBlock 2D output shape.""" - model = SEBlock(spatial_dims=2, channels=128, reduction=8) + """Test MagnusSEBlock 2D output shape.""" + model = MagnusSEBlock(spatial_dims=2, channels=128, reduction=8) model.eval() x = torch.randn(2, 128, 16, 16) @@ -331,9 +331,9 @@ def test_se_block_2d(self): self.assertEqual(y.shape, x.shape) def test_se_block_minimum_reduction(self): - """Test SEBlock with small channel count.""" + """Test MagnusSEBlock with small channel count.""" # Reduction should be at least 1 - model = SEBlock(spatial_dims=2, channels=4, reduction=16) + model = MagnusSEBlock(spatial_dims=2, channels=4, reduction=16) model.eval() x = torch.randn(1, 4, 8, 8) From 268d696029625ff2255624b2a8e1cd336ebaa4c8 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Wed, 28 Jan 2026 01:47:09 +0300 Subject: [PATCH 08/11] Fix mypy type errors and improve code formatting Signed-off-by: Sefa Aras --- monai/networks/nets/__init__.py | 10 +++++++++- monai/networks/nets/magnus.py | 30 ++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 8d81b0d65c..48a11aa1e4 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -73,7 +73,15 @@ MedNeXtSmall, MedNextSmall, ) -from .magnus import MAGNUS, CNNPath, CrossModalAttentionFusion, DecoderBlock, MagnusSEBlock, ScaleAdaptiveConv, TransformerPath +from .magnus import ( + MAGNUS, + CNNPath, + CrossModalAttentionFusion, + DecoderBlock, + MagnusSEBlock, + ScaleAdaptiveConv, + TransformerPath, +) from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index 715feed73d..95564f9dd2 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -33,7 +33,15 @@ from monai.networks.blocks import Convolution, UpSample from monai.networks.layers.utils import get_act_layer, get_norm_layer -__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv", "MagnusSEBlock", "DecoderBlock"] +__all__ = [ + "MAGNUS", + "CNNPath", + "TransformerPath", + "CrossModalAttentionFusion", + "ScaleAdaptiveConv", + "MagnusSEBlock", + "DecoderBlock", +] class CNNPath(nn.Module): @@ -203,8 +211,8 @@ def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch. mode="linear", align_corners=False, ) - pos_embed = pos_embed.transpose(1, 2) # (1, num_patches, hidden_dim) - return pos_embed + result: torch.Tensor = pos_embed.transpose(1, 2) # (1, num_patches, hidden_dim) + return result def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -235,9 +243,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_transformed = self.norm(x_transformed) # Reshape back to spatial: (B, N, hidden_dim) -> (B, hidden_dim, *spatial) - x_reshaped = x_transformed.transpose(1, 2).view(B, self.hidden_dim, *spatial_shape) + x_out: torch.Tensor = x_transformed.transpose(1, 2).view(B, self.hidden_dim, *spatial_shape) - return x_reshaped + return x_out class CrossModalAttentionFusion(nn.Module): @@ -334,7 +342,7 @@ def reshape_for_attention(t: torch.Tensor) -> torch.Tensor: out_v = out_v.contiguous().view(B, C, *spatial_shape) # Combine and project - fused = self.to_out(out_c + out_v) + fused: torch.Tensor = self.to_out(out_c + out_v) return fused @@ -395,8 +403,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: outs = [conv(x) for conv in self.convs] out = torch.stack(outs, dim=0).sum(dim=0) out = self.norm(out) - out = self.act(out) - return out + result: torch.Tensor = self.act(out) + return result class MagnusSEBlock(nn.Module): @@ -454,7 +462,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: else: y = y.view(b, c, 1, 1) - return x * y.expand_as(x) + result: torch.Tensor = x * y.expand_as(x) + return result class DecoderBlock(nn.Module): @@ -795,4 +804,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, list[to if self.deep_supervision: return seg_logits, aux_outputs - return seg_logits + result: torch.Tensor = seg_logits + return result From e12eaf6ca4a52d9893a39f87931169856af7ee35 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Wed, 28 Jan 2026 13:58:01 +0300 Subject: [PATCH 09/11] Fix import order in __init__.py for isort compliance Signed-off-by: Sefa Aras --- monai/networks/nets/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 48a11aa1e4..1b094c6b04 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,15 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .magnus import ( + MAGNUS, + CNNPath, + CrossModalAttentionFusion, + DecoderBlock, + MagnusSEBlock, + ScaleAdaptiveConv, + TransformerPath, +) from .masked_autoencoder_vit import MaskedAutoEncoderViT from .mednext import ( MedNeXt, @@ -73,15 +82,6 @@ MedNeXtSmall, MedNextSmall, ) -from .magnus import ( - MAGNUS, - CNNPath, - CrossModalAttentionFusion, - DecoderBlock, - MagnusSEBlock, - ScaleAdaptiveConv, - TransformerPath, -) from .milmodel import MILModel from .netadapter import NetAdapter from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator From 702231488d9556de4198f7d6c2dc8b15c099fcda Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Wed, 28 Jan 2026 16:55:17 +0300 Subject: [PATCH 10/11] Fix black formatting with skip-magic-trailing-comma flag Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 30 ++------ tests/networks/nets/test_magnus.py | 116 +++++------------------------ 2 files changed, 22 insertions(+), 124 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index 95564f9dd2..b05accfb04 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -205,12 +205,7 @@ def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch. # Interpolate positional embeddings pos_embed = self.pos_embed.transpose(1, 2) # (1, hidden_dim, N) - pos_embed = F.interpolate( - pos_embed, - size=num_patches, - mode="linear", - align_corners=False, - ) + pos_embed = F.interpolate(pos_embed, size=num_patches, mode="linear", align_corners=False) result: torch.Tensor = pos_embed.transpose(1, 2) # (1, num_patches, hidden_dim) return result @@ -262,13 +257,7 @@ class CrossModalAttentionFusion(nn.Module): dropout: dropout rate for attention weights. """ - def __init__( - self, - spatial_dims: int, - channels: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: + def __init__(self, spatial_dims: int, channels: int, num_heads: int, dropout: float = 0.0) -> None: """ Initialize the cross-modal attention fusion module. @@ -292,8 +281,7 @@ def __init__( # Output projection self.to_out = nn.Sequential( - conv_type(channels, channels, 1), - nn.Dropout(dropout) if dropout > 0 else nn.Identity(), + conv_type(channels, channels, 1), nn.Dropout(dropout) if dropout > 0 else nn.Identity() ) def forward(self, cnn_feat: torch.Tensor, vit_feat: torch.Tensor) -> torch.Tensor: @@ -417,12 +405,7 @@ class MagnusSEBlock(nn.Module): reduction: channel reduction ratio for the squeeze operation. """ - def __init__( - self, - spatial_dims: int, - channels: int, - reduction: int = 16, - ) -> None: + def __init__(self, spatial_dims: int, channels: int, reduction: int = 16) -> None: """ Initialize the Squeeze-and-Excitation block. @@ -671,10 +654,7 @@ def __init__( # Cross-modal attention fusion self.fusion = CrossModalAttentionFusion( - spatial_dims=spatial_dims, - channels=vit_hidden_dim, - num_heads=fusion_num_heads, - dropout=dropout, + spatial_dims=spatial_dims, channels=vit_hidden_dim, num_heads=fusion_num_heads, dropout=dropout ) # Scale-adaptive convolution diff --git a/tests/networks/nets/test_magnus.py b/tests/networks/nets/test_magnus.py index cf707cce5f..dea7b3708c 100644 --- a/tests/networks/nets/test_magnus.py +++ b/tests/networks/nets/test_magnus.py @@ -46,20 +46,11 @@ ] # Test cases for individual components -CNN_PATH_TEST_CASES = [ - (3, 1, (32, 64, 128), (1, 1, 64, 64, 64)), - (2, 3, (64, 128, 256), (1, 3, 128, 128)), -] +CNN_PATH_TEST_CASES = [(3, 1, (32, 64, 128), (1, 1, 64, 64, 64)), (2, 3, (64, 128, 256), (1, 3, 128, 128))] -TRANSFORMER_PATH_TEST_CASES = [ - (3, 1, 256, 8, 4, 8, (1, 1, 64, 64, 64)), - (2, 3, 128, 4, 2, 16, (1, 3, 128, 128)), -] +TRANSFORMER_PATH_TEST_CASES = [(3, 1, 256, 8, 4, 8, (1, 1, 64, 64, 64)), (2, 3, 128, 4, 2, 16, (1, 3, 128, 128))] -FUSION_TEST_CASES = [ - (3, 256, 8, (1, 256, 8, 8, 8), (1, 256, 4, 4, 4)), - (2, 128, 4, (1, 128, 16, 16), (1, 128, 8, 8)), -] +FUSION_TEST_CASES = [(3, 256, 8, (1, 256, 8, 8, 8), (1, 256, 4, 4, 4)), (2, 128, 4, (1, 128, 16, 16), (1, 128, 8, 8))] class TestMAGNUS(unittest.TestCase): @@ -67,12 +58,7 @@ class TestMAGNUS(unittest.TestCase): @parameterized.expand(MAGNUS_TEST_CASES) def test_magnus_shape( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - input_shape: tuple, - expected_shape: tuple, + self, spatial_dims: int, in_channels: int, out_channels: int, input_shape: tuple, expected_shape: tuple ): """Test MAGNUS output shape.""" model = MAGNUS( @@ -115,11 +101,7 @@ def test_magnus_deep_supervision(self): def test_magnus_different_norms(self): """Test MAGNUS with different normalization types.""" - norms = [ - "batch", - "instance", - ("group", {"num_groups": 8}), # GroupNorm requires num_groups - ] + norms = ["batch", "instance", ("group", {"num_groups": 8})] # GroupNorm requires num_groups for norm in norms: model = MAGNUS( spatial_dims=3, @@ -140,14 +122,7 @@ def test_magnus_different_norms(self): def test_magnus_gradient_flow(self): """Test gradient flow through MAGNUS.""" - model = MAGNUS( - spatial_dims=3, - in_channels=1, - out_channels=2, - features=(32, 64), - vit_depth=1, - vit_patch_size=8, - ) + model = MAGNUS(spatial_dims=3, in_channels=1, out_channels=2, features=(32, 64), vit_depth=1, vit_patch_size=8) model.train() x = torch.randn(1, 1, 32, 32, 32, requires_grad=True) @@ -168,19 +143,9 @@ class TestCNNPath(unittest.TestCase): """Test cases for CNNPath.""" @parameterized.expand(CNN_PATH_TEST_CASES) - def test_cnn_path_shape( - self, - spatial_dims: int, - in_channels: int, - features: tuple, - input_shape: tuple, - ): + def test_cnn_path_shape(self, spatial_dims: int, in_channels: int, features: tuple, input_shape: tuple): """Test CNNPath output shapes.""" - model = CNNPath( - spatial_dims=spatial_dims, - in_channels=in_channels, - features=features, - ) + model = CNNPath(spatial_dims=spatial_dims, in_channels=in_channels, features=features) model.eval() x = torch.randn(*input_shape) @@ -234,20 +199,9 @@ class TestCrossModalAttentionFusion(unittest.TestCase): """Test cases for CrossModalAttentionFusion.""" @parameterized.expand(FUSION_TEST_CASES) - def test_fusion_shape( - self, - spatial_dims: int, - channels: int, - num_heads: int, - cnn_shape: tuple, - vit_shape: tuple, - ): + def test_fusion_shape(self, spatial_dims: int, channels: int, num_heads: int, cnn_shape: tuple, vit_shape: tuple): """Test CrossModalAttentionFusion output shape.""" - model = CrossModalAttentionFusion( - spatial_dims=spatial_dims, - channels=channels, - num_heads=num_heads, - ) + model = CrossModalAttentionFusion(spatial_dims=spatial_dims, channels=channels, num_heads=num_heads) model.eval() cnn_feat = torch.randn(*cnn_shape) @@ -262,11 +216,7 @@ def test_fusion_shape( def test_fusion_invalid_channels(self): """Test fusion raises error when channels not divisible by heads.""" with self.assertRaises(ValueError): - CrossModalAttentionFusion( - spatial_dims=3, - channels=100, - num_heads=8, # 100 % 8 != 0 - ) + CrossModalAttentionFusion(spatial_dims=3, channels=100, num_heads=8) # 100 % 8 != 0 class TestScaleAdaptiveConv(unittest.TestCase): @@ -274,12 +224,7 @@ class TestScaleAdaptiveConv(unittest.TestCase): def test_scale_adaptive_conv_3d(self): """Test ScaleAdaptiveConv 3D output shape.""" - model = ScaleAdaptiveConv( - spatial_dims=3, - in_channels=64, - out_channels=128, - kernel_sizes=(3, 5, 7), - ) + model = ScaleAdaptiveConv(spatial_dims=3, in_channels=64, out_channels=128, kernel_sizes=(3, 5, 7)) model.eval() x = torch.randn(1, 64, 16, 16, 16) @@ -290,12 +235,7 @@ def test_scale_adaptive_conv_3d(self): def test_scale_adaptive_conv_2d(self): """Test ScaleAdaptiveConv 2D output shape.""" - model = ScaleAdaptiveConv( - spatial_dims=2, - in_channels=32, - out_channels=64, - kernel_sizes=(3, 5), - ) + model = ScaleAdaptiveConv(spatial_dims=2, in_channels=32, out_channels=64, kernel_sizes=(3, 5)) model.eval() x = torch.randn(1, 32, 32, 32) @@ -348,12 +288,7 @@ class TestDecoderBlock(unittest.TestCase): def test_decoder_block_3d(self): """Test DecoderBlock 3D output shape.""" - model = DecoderBlock( - spatial_dims=3, - in_channels=128, - skip_channels=64, - out_channels=64, - ) + model = DecoderBlock(spatial_dims=3, in_channels=128, skip_channels=64, out_channels=64) model.eval() x = torch.randn(1, 128, 8, 8, 8) @@ -365,13 +300,7 @@ def test_decoder_block_3d(self): def test_decoder_block_2d(self): """Test DecoderBlock 2D output shape.""" - model = DecoderBlock( - spatial_dims=2, - in_channels=256, - skip_channels=128, - out_channels=128, - use_se=True, - ) + model = DecoderBlock(spatial_dims=2, in_channels=256, skip_channels=128, out_channels=128, use_se=True) model.eval() x = torch.randn(1, 256, 8, 8) @@ -383,13 +312,7 @@ def test_decoder_block_2d(self): def test_decoder_block_no_se(self): """Test DecoderBlock without SE block.""" - model = DecoderBlock( - spatial_dims=3, - in_channels=64, - skip_channels=32, - out_channels=32, - use_se=False, - ) + model = DecoderBlock(spatial_dims=3, in_channels=64, skip_channels=32, out_channels=32, use_se=False) model.eval() x = torch.randn(1, 64, 4, 4, 4) @@ -407,12 +330,7 @@ class TestMAGNUSMemory(unittest.TestCase): def test_magnus_cuda(self): """Test MAGNUS on CUDA.""" model = MAGNUS( - spatial_dims=3, - in_channels=1, - out_channels=2, - features=(32, 64, 128), - vit_depth=2, - vit_patch_size=8, + spatial_dims=3, in_channels=1, out_channels=2, features=(32, 64, 128), vit_depth=2, vit_patch_size=8 ).cuda() model.eval() From 3a6f26279c1026069b6c92dd62d1abcb8f94d0d2 Mon Sep 17 00:00:00 2001 From: Sefa Aras Date: Thu, 29 Jan 2026 00:07:10 +0300 Subject: [PATCH 11/11] Fix flake8 N806 and Sphinx docstring formatting Signed-off-by: Sefa Aras --- monai/networks/nets/magnus.py | 41 +++++++++++++++++------------------ 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/monai/networks/nets/magnus.py b/monai/networks/nets/magnus.py index b05accfb04..a4a10ba483 100644 --- a/monai/networks/nets/magnus.py +++ b/monai/networks/nets/magnus.py @@ -109,7 +109,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: Forward pass returning features from each stage. Args: - x: input tensor of shape (B, C, *spatial_dims). + x: input tensor of shape ``(B, C, *spatial_dims)``. Returns: List of feature tensors from each encoder stage, @@ -214,14 +214,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Forward pass through transformer path. Args: - x: input tensor of shape (B, C, *spatial_dims). + x: input tensor of shape ``(B, C, *spatial_dims)``. Returns: - Transformed features of shape (B, hidden_dim, *reduced_spatial_dims). + Transformed features of shape ``(B, hidden_dim, *reduced_spatial_dims)``. """ # Patch embedding: (B, C, D, H, W) -> (B, hidden_dim, Dp, Hp, Wp) x_embedded = self.embedding(x) - B = x_embedded.shape[0] + batch_size = x_embedded.shape[0] spatial_shape = x_embedded.shape[2:] # Flatten spatial dims: (B, hidden_dim, *spatial) -> (B, N, hidden_dim) @@ -238,7 +238,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_transformed = self.norm(x_transformed) # Reshape back to spatial: (B, N, hidden_dim) -> (B, hidden_dim, *spatial) - x_out: torch.Tensor = x_transformed.transpose(1, 2).view(B, self.hidden_dim, *spatial_shape) + x_out: torch.Tensor = x_transformed.transpose(1, 2).view(batch_size, self.hidden_dim, *spatial_shape) return x_out @@ -289,13 +289,13 @@ def forward(self, cnn_feat: torch.Tensor, vit_feat: torch.Tensor) -> torch.Tenso Forward pass for cross-modal attention fusion. Args: - cnn_feat: CNN features of shape (B, C, *spatial_dims). - vit_feat: ViT features of shape (B, C, *spatial_dims_vit). + cnn_feat: CNN features of shape ``(B, C, *spatial_dims)``. + vit_feat: ViT features of shape ``(B, C, *spatial_dims_vit)``. Returns: - Fused features of shape (B, C, *spatial_dims). + Fused features of shape ``(B, C, *spatial_dims)``. """ - B, C = cnn_feat.shape[:2] + batch_size, channels = cnn_feat.shape[:2] spatial_shape = cnn_feat.shape[2:] heads = self.num_heads @@ -310,7 +310,7 @@ def forward(self, cnn_feat: torch.Tensor, vit_feat: torch.Tensor) -> torch.Tenso # Reshape for multi-head attention: (B, heads, head_dim, N) def reshape_for_attention(t: torch.Tensor) -> torch.Tensor: - return t.view(B, heads, self.head_dim, -1) + return t.view(batch_size, heads, self.head_dim, -1) q_c, k_c, v_c = map(reshape_for_attention, (q_c, k_c, v_c)) q_v, k_v, v_v = map(reshape_for_attention, (q_v, k_v, v_v)) @@ -326,8 +326,8 @@ def reshape_for_attention(t: torch.Tensor) -> torch.Tensor: out_v = torch.einsum("b h i j, b h d j -> b h d i", attn_vc, v_c) # Reshape back to spatial - out_c = out_c.contiguous().view(B, C, *spatial_shape) - out_v = out_v.contiguous().view(B, C, *spatial_shape) + out_c = out_c.contiguous().view(batch_size, channels, *spatial_shape) + out_v = out_v.contiguous().view(batch_size, channels, *spatial_shape) # Combine and project fused: torch.Tensor = self.to_out(out_c + out_v) @@ -383,10 +383,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Forward pass with multi-scale convolutions. Args: - x: input tensor of shape (B, C, *spatial_dims). + x: input tensor of shape ``(B, C, *spatial_dims)``. Returns: - Multi-scale features of shape (B, out_channels, *spatial_dims). + Multi-scale features of shape ``(B, out_channels, *spatial_dims)``. """ outs = [conv(x) for conv in self.convs] out = torch.stack(outs, dim=0).sum(dim=0) @@ -430,7 +430,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Forward pass for SE block. Args: - x: input tensor of shape (B, C, *spatial_dims). + x: input tensor of shape ``(B, C, *spatial_dims)``. Returns: Channel-recalibrated tensor of same shape. @@ -723,14 +723,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, list[to Forward pass of MAGNUS. Args: - x: input tensor of shape (B, in_channels, *spatial_dims). + x: input tensor of shape ``(B, in_channels, *spatial_dims)``. Returns: - If deep_supervision is False: - Segmentation logits of shape (B, out_channels, *spatial_dims). - If deep_supervision is True: - Tuple of (main_output, auxiliary_outputs) where auxiliary_outputs - is a list of intermediate segmentation maps. + If ``deep_supervision`` is False, returns segmentation logits of shape + ``(B, out_channels, *spatial_dims)``. + If ``deep_supervision`` is True, returns tuple of (main_output, auxiliary_outputs) + where auxiliary_outputs is a list of intermediate segmentation maps. """ input_shape = x.shape[2:]