diff --git a/direct/nn/recurrent/recurrent.py b/direct/nn/recurrent/recurrent.py index 3efa9b10..f2c76f7c 100644 --- a/direct/nn/recurrent/recurrent.py +++ b/direct/nn/recurrent/recurrent.py @@ -1,6 +1,7 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors +import math from typing import List, Optional, Tuple import torch @@ -117,9 +118,9 @@ def forward( Parameters ---------- cell_input: torch.Tensor - Reconstruction input + Input tensor. previous_state: torch.Tensor - Tensor of previous states. + Tensor of previous hidden state. Returns ------- @@ -161,3 +162,119 @@ def forward( out = self.conv_blocks[self.num_layers](cell_input) return out, torch.stack(new_states, dim=-1) + + +class NormConv2dGRU(nn.Module): + """Normalized 2D Convolutional GRU Network. + + Normalization methods adapted from NormUnet of [1]_. + + References + ---------- + + .. [1] https://github.com/facebookresearch/fastMRI/blob/ + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 2, + gru_kernel_size=1, + orthogonal_initialization: bool = True, + instance_norm: bool = False, + dense_connect: int = 0, + replication_padding: bool = True, + norm_groups: int = 2, + ): + """Inits NormConv2dGRU. + + Parameters + ---------- + in_channels: int + Number of input channels. + hidden_channels: int + Number of hidden channels. + out_channels: Optional[int] + Number of output channels. If None, same as in_channels. Default: None. + num_layers: int + Number of layers. Default: 2. + gru_kernel_size: int + Size of the GRU kernel. Default: 1. + orthogonal_initialization: bool + Orthogonal initialization is used if set to True. Default: True. + instance_norm: bool + Instance norm is used if set to True. Default: False. + dense_connect: int + Number of dense connections. + replication_padding: bool + If set to true replication padding is applied. + norm_groups: int, + Number of normalization groups. + """ + super().__init__() + self.convgru = Conv2dGRU( + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + num_layers=num_layers, + gru_kernel_size=gru_kernel_size, + orthogonal_initialization=orthogonal_initialization, + instance_norm=instance_norm, + dense_connect=dense_connect, + replication_padding=replication_padding, + ) + self.norm_groups = norm_groups + + @staticmethod + def norm(input_data: torch.Tensor, num_groups: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Performs group normalization.""" + b, c, h, w = input_data.shape + input_data = input_data.reshape(b, num_groups, -1) + + mean = input_data.mean(-1, keepdim=True) + std = input_data.std(-1, keepdim=True) + + output = (input_data - mean) / std + output = output.reshape(b, c, h, w) + + return output, mean, std + + @staticmethod + def unnorm(input_data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, num_groups: int) -> torch.Tensor: + b, c, h, w = input_data.shape + input_data = input_data.reshape(b, num_groups, -1) + return (input_data * std + mean).reshape(b, c, h, w) + + def forward( + self, + cell_input: torch.Tensor, + previous_state: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes NormConv2dGRU forward pass given tensors `cell_input` and `previous_state`. + + It performs group normalization on the input before the forward pass to the Conv2dGRU. + Output of Conv2dGRU is then un-normalized. + + Parameters + ---------- + cell_input: torch.Tensor + Input tensor. + previous_state: torch.Tensor + Tensor of previous hidden state. + + Returns + ------- + out, new_states: (torch.Tensor, torch.Tensor) + Output and new states. + + """ + # Normalize + cell_input, mean, std = self.norm(cell_input, self.norm_groups) + # Pass normalized input + cell_input, previous_state = self.convgru(cell_input, previous_state) + # Unnormalize output + cell_input = self.unnorm(cell_input, mean, std, self.norm_groups) + + return cell_input, previous_state diff --git a/direct/nn/recurrentvarnet/config.py b/direct/nn/recurrentvarnet/config.py index c8df4f57..9a55b8f6 100644 --- a/direct/nn/recurrentvarnet/config.py +++ b/direct/nn/recurrentvarnet/config.py @@ -18,3 +18,4 @@ class RecurrentVarNetConfig(ModelConfig): initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64) # :math:`n_d` initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4) # :math:`p` initializer_multiscale: int = 1 + normalized: bool = False diff --git a/direct/nn/recurrentvarnet/recurrentvarnet.py b/direct/nn/recurrentvarnet/recurrentvarnet.py index a81ca2f9..9722bb9e 100644 --- a/direct/nn/recurrentvarnet/recurrentvarnet.py +++ b/direct/nn/recurrentvarnet/recurrentvarnet.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from direct.data.transforms import complex_multiplication, conjugate, expand_operator, reduce_operator -from direct.nn.recurrent.recurrent import Conv2dGRU +from direct.nn.recurrent.recurrent import Conv2dGRU, NormConv2dGRU class RecurrentInit(nn.Module): @@ -120,6 +120,7 @@ def __init__( initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), initializer_multiscale: int = 1, + normalized: bool = False, **kwargs, ): """Inits RecurrentVarNet. @@ -139,7 +140,7 @@ def __init__( recurrent_num_layers: int Number of layers for the recurrent unit of the RecurrentVarNet Block (:math:`n_l`). Default: 4. no_parameter_sharing: bool - If False, the same RecurrentVarNet Block is used for all num_steps. Default: True. + If False, the same :class:`RecurrentVarNetBlock` is used for all num_steps. Default: True. learned_initializer: bool If True an RSI module is used. Default: False. initializer_initialization: str, Optional @@ -152,6 +153,8 @@ def __init__( initializer_multiscale: int RSI module number of feature layers to aggregate for the output, if 1, multi-scale context aggregation is disabled. Default: 1. + normalized: bool + If True, :class:`NormConv2dGRU` will be used as a regularizer in the :class:`RecurrentVarNetBlocks`. Default: False. """ super(RecurrentVarNet, self).__init__() @@ -198,6 +201,7 @@ def __init__( in_channels=in_channels, hidden_channels=recurrent_hidden_channels, num_layers=recurrent_num_layers, + normalized=normalized, ) ) self.forward_operator = forward_operator @@ -317,6 +321,7 @@ def __init__( in_channels: int = 2, hidden_channels: int = 64, num_layers: int = 4, + normalized: bool = False, ): """Inits RecurrentVarNetBlock. @@ -332,18 +337,22 @@ def __init__( Hidden channels. Default: 64. num_layers: int, Number of layers of :math:`n_l` recurrent unit. Default: 4. + normalized: bool + If True, :class:`NormConv2dGRU` will be used as a regularizer. Default: False. """ super().__init__() self.forward_operator = forward_operator self.backward_operator = backward_operator self.learning_rate = nn.Parameter(torch.tensor([1.0])) # :math:`\alpha_t` - self.regularizer = Conv2dGRU( - in_channels=in_channels, - hidden_channels=hidden_channels, - num_layers=num_layers, - replication_padding=True, - ) # Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}` + regularizer_params = { + "in_channels": in_channels, + "hidden_channels": hidden_channels, + "num_layers": num_layers, + "replication_padding": True, + } + # Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}` + self.regularizer = NormConv2dGRU(**regularizer_params) if normalized else Conv2dGRU(**regularizer_params) def forward( self, @@ -369,7 +378,7 @@ def forward( sensitivity_map: torch.Tensor Coil sensitivities of shape (N, coil, height, width, complex=2). hidden_state: torch.Tensor or None - ConvGRU hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional. + Recurrent unit hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional. coil_dim: int Coil dimension. Default: 1. complex_dim: int diff --git a/direct/nn/rim/config.py b/direct/nn/rim/config.py index 9c799889..c6e53642 100644 --- a/direct/nn/rim/config.py +++ b/direct/nn/rim/config.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright (c) DIRECT Contributors + from dataclasses import dataclass from typing import Optional, Tuple @@ -23,6 +24,7 @@ class RIMConfig(ModelConfig): initializer_channels: Tuple[int, ...] = (32, 32, 64, 64) initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4) initializer_multiscale: int = 1 + normalized: bool = False @dataclass diff --git a/direct/nn/rim/rim.py b/direct/nn/rim/rim.py index fa9bf18c..2c05f83f 100644 --- a/direct/nn/rim/rim.py +++ b/direct/nn/rim/rim.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from direct.data import transforms as T -from direct.nn.recurrent.recurrent import Conv2dGRU +from direct.nn.recurrent.recurrent import Conv2dGRU, NormConv2dGRU from direct.utils.asserts import assert_positive_integer @@ -215,6 +215,7 @@ def __init__( initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64), initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4), initializer_multiscale: int = 1, + normalized: bool = False, **kwargs, ): """Inits RIM. @@ -254,6 +255,8 @@ def __init__( If "learned_initializer=False" this is ignored. Default: (1, 1, 2, 4) initializer_multiscale: int Number of initializer multiscale. If "learned_initializer=False" this is ignored. Default: 1. + normalized: bool + If True, :class:`NormConv2dGRU` will be used instead of :class:`Conv2dGRU`. Default: False. """ super().__init__() @@ -299,18 +302,18 @@ def __init__( self.cell_list = nn.ModuleList() self.no_parameter_sharing = no_parameter_sharing + + conv_unit_params = { + "in_channels": x_channels * 2, # double channels as input is concatenated image and gradient + "out_channels": x_channels, + "hidden_channels": hidden_channels, + "num_layers": depth, + "instance_norm": instance_norm, + "dense_connect": dense_connect, + "replication_padding": replication_padding, + } for _ in range(length if no_parameter_sharing else 1): - self.cell_list.append( - Conv2dGRU( - in_channels=x_channels * 2, # double channels as input is concatenated image and gradient - out_channels=x_channels, - hidden_channels=hidden_channels, - num_layers=depth, - instance_norm=instance_norm, - dense_connect=dense_connect, - replication_padding=replication_padding, - ) - ) + self.cell_list.append(NormConv2dGRU(**conv_unit_params) if normalized else Conv2dGRU(**conv_unit_params)) self.length = length self.depth = depth diff --git a/tests/tests_nn/test_recurrent.py b/tests/tests_nn/test_recurrent.py index aad78200..ead8b9c5 100644 --- a/tests/tests_nn/test_recurrent.py +++ b/tests/tests_nn/test_recurrent.py @@ -4,7 +4,7 @@ import pytest import torch -from direct.nn.recurrent.recurrent import Conv2dGRU +from direct.nn.recurrent.recurrent import Conv2dGRU, NormConv2dGRU def create_input(shape): @@ -19,14 +19,23 @@ def create_input(shape): [ [3, 2, 32, 32], [3, 2, 16, 16], + [3, 2, 15, 17], ], ) @pytest.mark.parametrize( "hidden_channels", [4, 8], ) -def test_conv2dgru(shape, hidden_channels): - model = Conv2dGRU(shape[1], hidden_channels, shape[1]) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_conv2dgru(shape, hidden_channels, normalized): + model = ( + NormConv2dGRU(shape[1], hidden_channels, shape[1]) + if normalized + else Conv2dGRU(shape[1], hidden_channels, shape[1]) + ) data = create_input(shape).cpu() out = model(data, None)[0] diff --git a/tests/tests_nn/test_recurrentvarnet.py b/tests/tests_nn/test_recurrentvarnet.py index 823a3af7..c3e1804f 100644 --- a/tests/tests_nn/test_recurrentvarnet.py +++ b/tests/tests_nn/test_recurrentvarnet.py @@ -44,6 +44,10 @@ def create_input(shape): [False, None, None, None], ], ) +@pytest.mark.parametrize( + "normalized", + [True, False], +) def test_recurrentvarnet( shape, num_steps, @@ -54,6 +58,7 @@ def test_recurrentvarnet( initializer_initialization, initializer_channels, initializer_dilations, + normalized, ): model = RecurrentVarNet( fft2, @@ -66,6 +71,7 @@ def test_recurrentvarnet( initializer_initialization=initializer_initialization, initializer_channels=initializer_channels, initializer_dilations=initializer_dilations, + normalized=normalized, ).cpu() kspace = create_input(shape + [2]).cpu() diff --git a/tests/tests_nn/test_rim.py b/tests/tests_nn/test_rim.py index 71f9cfbe..4dc8b2ea 100644 --- a/tests/tests_nn/test_rim.py +++ b/tests/tests_nn/test_rim.py @@ -61,6 +61,10 @@ def create_input(shape): "input_image_is_None", [True, False], ) +@pytest.mark.parametrize( + "normalized", + [True, False], +) def test_rim( shape, hidden_channels, @@ -73,6 +77,7 @@ def test_rim( image_init, learned_initializer, input_image_is_None, + normalized, ): model = RIM( fft2, @@ -86,6 +91,7 @@ def test_rim( skip_connections=skip_connections, image_initialization=image_init, learned_initializer=learned_initializer, + normalized=normalized, ).cpu() inputs = {