Skip to content
121 changes: 119 additions & 2 deletions direct/nn/recurrent/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

import math
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -117,9 +118,9 @@ def forward(
Parameters
----------
cell_input: torch.Tensor
Reconstruction input
Input tensor.
previous_state: torch.Tensor
Tensor of previous states.
Tensor of previous hidden state.

Returns
-------
Expand Down Expand Up @@ -161,3 +162,119 @@ def forward(
out = self.conv_blocks[self.num_layers](cell_input)

return out, torch.stack(new_states, dim=-1)


class NormConv2dGRU(nn.Module):
"""Normalized 2D Convolutional GRU Network.

Normalization methods adapted from NormUnet of [1]_.

References
----------

.. [1] https://github.com/facebookresearch/fastMRI/blob/
"""

def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 2,
gru_kernel_size=1,
orthogonal_initialization: bool = True,
instance_norm: bool = False,
dense_connect: int = 0,
replication_padding: bool = True,
norm_groups: int = 2,
):
"""Inits NormConv2dGRU.

Parameters
----------
in_channels: int
Number of input channels.
hidden_channels: int
Number of hidden channels.
out_channels: Optional[int]
Number of output channels. If None, same as in_channels. Default: None.
num_layers: int
Number of layers. Default: 2.
gru_kernel_size: int
Size of the GRU kernel. Default: 1.
orthogonal_initialization: bool
Orthogonal initialization is used if set to True. Default: True.
instance_norm: bool
Instance norm is used if set to True. Default: False.
dense_connect: int
Number of dense connections.
replication_padding: bool
If set to true replication padding is applied.
norm_groups: int,
Number of normalization groups.
"""
super().__init__()
self.convgru = Conv2dGRU(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
num_layers=num_layers,
gru_kernel_size=gru_kernel_size,
orthogonal_initialization=orthogonal_initialization,
instance_norm=instance_norm,
dense_connect=dense_connect,
replication_padding=replication_padding,
)
self.norm_groups = norm_groups

@staticmethod
def norm(input_data: torch.Tensor, num_groups: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Performs group normalization."""
b, c, h, w = input_data.shape
input_data = input_data.reshape(b, num_groups, -1)

mean = input_data.mean(-1, keepdim=True)
std = input_data.std(-1, keepdim=True)

output = (input_data - mean) / std
output = output.reshape(b, c, h, w)

return output, mean, std

@staticmethod
def unnorm(input_data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, num_groups: int) -> torch.Tensor:
b, c, h, w = input_data.shape
input_data = input_data.reshape(b, num_groups, -1)
return (input_data * std + mean).reshape(b, c, h, w)

def forward(
self,
cell_input: torch.Tensor,
previous_state: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes NormConv2dGRU forward pass given tensors `cell_input` and `previous_state`.

It performs group normalization on the input before the forward pass to the Conv2dGRU.
Output of Conv2dGRU is then un-normalized.

Parameters
----------
cell_input: torch.Tensor
Input tensor.
previous_state: torch.Tensor
Tensor of previous hidden state.

Returns
-------
out, new_states: (torch.Tensor, torch.Tensor)
Output and new states.

"""
# Normalize
cell_input, mean, std = self.norm(cell_input, self.norm_groups)
# Pass normalized input
cell_input, previous_state = self.convgru(cell_input, previous_state)
# Unnormalize output
cell_input = self.unnorm(cell_input, mean, std, self.norm_groups)

return cell_input, previous_state
1 change: 1 addition & 0 deletions direct/nn/recurrentvarnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ class RecurrentVarNetConfig(ModelConfig):
initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64) # :math:`n_d`
initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4) # :math:`p`
initializer_multiscale: int = 1
normalized: bool = False
27 changes: 18 additions & 9 deletions direct/nn/recurrentvarnet/recurrentvarnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn.functional as F

from direct.data.transforms import complex_multiplication, conjugate, expand_operator, reduce_operator
from direct.nn.recurrent.recurrent import Conv2dGRU
from direct.nn.recurrent.recurrent import Conv2dGRU, NormConv2dGRU


class RecurrentInit(nn.Module):
Expand Down Expand Up @@ -120,6 +120,7 @@ def __init__(
initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64),
initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4),
initializer_multiscale: int = 1,
normalized: bool = False,
**kwargs,
):
"""Inits RecurrentVarNet.
Expand All @@ -139,7 +140,7 @@ def __init__(
recurrent_num_layers: int
Number of layers for the recurrent unit of the RecurrentVarNet Block (:math:`n_l`). Default: 4.
no_parameter_sharing: bool
If False, the same RecurrentVarNet Block is used for all num_steps. Default: True.
If False, the same :class:`RecurrentVarNetBlock` is used for all num_steps. Default: True.
learned_initializer: bool
If True an RSI module is used. Default: False.
initializer_initialization: str, Optional
Expand All @@ -152,6 +153,8 @@ def __init__(
initializer_multiscale: int
RSI module number of feature layers to aggregate for the output, if 1, multi-scale context aggregation
is disabled. Default: 1.
normalized: bool
If True, :class:`NormConv2dGRU` will be used as a regularizer in the :class:`RecurrentVarNetBlocks`. Default: False.
"""
super(RecurrentVarNet, self).__init__()

Expand Down Expand Up @@ -198,6 +201,7 @@ def __init__(
in_channels=in_channels,
hidden_channels=recurrent_hidden_channels,
num_layers=recurrent_num_layers,
normalized=normalized,
)
)
self.forward_operator = forward_operator
Expand Down Expand Up @@ -317,6 +321,7 @@ def __init__(
in_channels: int = 2,
hidden_channels: int = 64,
num_layers: int = 4,
normalized: bool = False,
):
"""Inits RecurrentVarNetBlock.

Expand All @@ -332,18 +337,22 @@ def __init__(
Hidden channels. Default: 64.
num_layers: int,
Number of layers of :math:`n_l` recurrent unit. Default: 4.
normalized: bool
If True, :class:`NormConv2dGRU` will be used as a regularizer. Default: False.
"""
super().__init__()
self.forward_operator = forward_operator
self.backward_operator = backward_operator

self.learning_rate = nn.Parameter(torch.tensor([1.0])) # :math:`\alpha_t`
self.regularizer = Conv2dGRU(
in_channels=in_channels,
hidden_channels=hidden_channels,
num_layers=num_layers,
replication_padding=True,
) # Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}`
regularizer_params = {
"in_channels": in_channels,
"hidden_channels": hidden_channels,
"num_layers": num_layers,
"replication_padding": True,
}
# Recurrent Unit of RecurrentVarNet Block :math:`\mathcal{H}_{\theta_t}`
self.regularizer = NormConv2dGRU(**regularizer_params) if normalized else Conv2dGRU(**regularizer_params)

def forward(
self,
Expand All @@ -369,7 +378,7 @@ def forward(
sensitivity_map: torch.Tensor
Coil sensitivities of shape (N, coil, height, width, complex=2).
hidden_state: torch.Tensor or None
ConvGRU hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional.
Recurrent unit hidden state of shape (N, hidden_channels, height, width, num_layers) if not None. Optional.
coil_dim: int
Coil dimension. Default: 1.
complex_dim: int
Expand Down
2 changes: 2 additions & 0 deletions direct/nn/rim/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass
from typing import Optional, Tuple

Expand All @@ -23,6 +24,7 @@ class RIMConfig(ModelConfig):
initializer_channels: Tuple[int, ...] = (32, 32, 64, 64)
initializer_dilations: Tuple[int, ...] = (1, 1, 2, 4)
initializer_multiscale: int = 1
normalized: bool = False


@dataclass
Expand Down
27 changes: 15 additions & 12 deletions direct/nn/rim/rim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F

from direct.data import transforms as T
from direct.nn.recurrent.recurrent import Conv2dGRU
from direct.nn.recurrent.recurrent import Conv2dGRU, NormConv2dGRU
from direct.utils.asserts import assert_positive_integer


Expand Down Expand Up @@ -215,6 +215,7 @@ def __init__(
initializer_channels: Optional[Tuple[int, ...]] = (32, 32, 64, 64),
initializer_dilations: Optional[Tuple[int, ...]] = (1, 1, 2, 4),
initializer_multiscale: int = 1,
normalized: bool = False,
**kwargs,
):
"""Inits RIM.
Expand Down Expand Up @@ -254,6 +255,8 @@ def __init__(
If "learned_initializer=False" this is ignored. Default: (1, 1, 2, 4)
initializer_multiscale: int
Number of initializer multiscale. If "learned_initializer=False" this is ignored. Default: 1.
normalized: bool
If True, :class:`NormConv2dGRU` will be used instead of :class:`Conv2dGRU`. Default: False.
"""
super().__init__()

Expand Down Expand Up @@ -299,18 +302,18 @@ def __init__(

self.cell_list = nn.ModuleList()
self.no_parameter_sharing = no_parameter_sharing

conv_unit_params = {
"in_channels": x_channels * 2, # double channels as input is concatenated image and gradient
"out_channels": x_channels,
"hidden_channels": hidden_channels,
"num_layers": depth,
"instance_norm": instance_norm,
"dense_connect": dense_connect,
"replication_padding": replication_padding,
}
for _ in range(length if no_parameter_sharing else 1):
self.cell_list.append(
Conv2dGRU(
in_channels=x_channels * 2, # double channels as input is concatenated image and gradient
out_channels=x_channels,
hidden_channels=hidden_channels,
num_layers=depth,
instance_norm=instance_norm,
dense_connect=dense_connect,
replication_padding=replication_padding,
)
)
self.cell_list.append(NormConv2dGRU(**conv_unit_params) if normalized else Conv2dGRU(**conv_unit_params))

self.length = length
self.depth = depth
Expand Down
15 changes: 12 additions & 3 deletions tests/tests_nn/test_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from direct.nn.recurrent.recurrent import Conv2dGRU
from direct.nn.recurrent.recurrent import Conv2dGRU, NormConv2dGRU


def create_input(shape):
Expand All @@ -19,14 +19,23 @@ def create_input(shape):
[
[3, 2, 32, 32],
[3, 2, 16, 16],
[3, 2, 15, 17],
],
)
@pytest.mark.parametrize(
"hidden_channels",
[4, 8],
)
def test_conv2dgru(shape, hidden_channels):
model = Conv2dGRU(shape[1], hidden_channels, shape[1])
@pytest.mark.parametrize(
"normalized",
[True, False],
)
def test_conv2dgru(shape, hidden_channels, normalized):
model = (
NormConv2dGRU(shape[1], hidden_channels, shape[1])
if normalized
else Conv2dGRU(shape[1], hidden_channels, shape[1])
)
data = create_input(shape).cpu()

out = model(data, None)[0]
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_nn/test_recurrentvarnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def create_input(shape):
[False, None, None, None],
],
)
@pytest.mark.parametrize(
"normalized",
[True, False],
)
def test_recurrentvarnet(
shape,
num_steps,
Expand All @@ -54,6 +58,7 @@ def test_recurrentvarnet(
initializer_initialization,
initializer_channels,
initializer_dilations,
normalized,
):
model = RecurrentVarNet(
fft2,
Expand All @@ -66,6 +71,7 @@ def test_recurrentvarnet(
initializer_initialization=initializer_initialization,
initializer_channels=initializer_channels,
initializer_dilations=initializer_dilations,
normalized=normalized,
).cpu()

kspace = create_input(shape + [2]).cpu()
Expand Down
6 changes: 6 additions & 0 deletions tests/tests_nn/test_rim.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def create_input(shape):
"input_image_is_None",
[True, False],
)
@pytest.mark.parametrize(
"normalized",
[True, False],
)
def test_rim(
shape,
hidden_channels,
Expand All @@ -73,6 +77,7 @@ def test_rim(
image_init,
learned_initializer,
input_image_is_None,
normalized,
):
model = RIM(
fft2,
Expand All @@ -86,6 +91,7 @@ def test_rim(
skip_connections=skip_connections,
image_initialization=image_init,
learned_initializer=learned_initializer,
normalized=normalized,
).cpu()

inputs = {
Expand Down