diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 05584e2c..c829c290 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -591,6 +591,7 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) torch.Tensor : RSS of the input tensor. """ if is_complex_data(data): + return torch.sqrt((data ** 2).sum(complex_dim).sum(dim)) return torch.sqrt((data ** 2).sum(dim)) diff --git a/direct/nn/varnet/varnet.py b/direct/nn/varnet/varnet.py index a1e8fd3f..42209b6d 100644 --- a/direct/nn/varnet/varnet.py +++ b/direct/nn/varnet/varnet.py @@ -23,6 +23,7 @@ def __init__( regularizer_num_filters: int = 18, regularizer_num_pull_layers: int = 4, regularizer_dropout: float = 0.0, + in_channels: int = 2, **kwargs, ): """ @@ -126,6 +127,7 @@ def __init__( self.learning_rate = nn.Parameter(torch.tensor([1.0])) self._coil_dim = 1 + self._complex_dim = -1 self._spatial_dims = (2, 3) def forward( @@ -159,12 +161,24 @@ def forward( current_kspace - masked_kspace, ) - regularization_term = reduce_operator( - self.backward_operator(current_kspace, dim=self._spatial_dims), sensitivity_map, dim=self._coil_dim + regularization_term = torch.cat( + [ + reduce_operator( + self.backward_operator(kspace, dim=self._spatial_dims), sensitivity_map, dim=self._coil_dim + ) + for kspace in torch.split(current_kspace, 2, self._complex_dim) + ], + dim=self._complex_dim, ).permute(0, 3, 1, 2) regularization_term = self.regularizer_model(regularization_term).permute(0, 2, 3, 1) - regularization_term = self.forward_operator( - expand_operator(regularization_term, sensitivity_map, dim=self._coil_dim), dim=self._spatial_dims + regularization_term = torch.cat( + [ + self.forward_operator( + expand_operator(image, sensitivity_map, dim=self._coil_dim), dim=self._spatial_dims + ) + for image in torch.split(regularization_term, 2, self._complex_dim) + ], + dim=self._complex_dim, ) return current_kspace - self.learning_rate * kspace_error + regularization_term