Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade pip setuptools wheel
pip install tox tox-gh-actions
- name: Test with tox
run: tox
4 changes: 3 additions & 1 deletion direct/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __init__(self, datasets: List, batch_size: int, seed: Optional[int] = None):
self.weights = np.asarray([len(_) for _ in datasets])
self.cumulative_sizes = self.cumsum(datasets)

self.logger.info(f"Sampling batches with weights {self.weights} with cumulative sizes {self.cumulative_sizes}.")
self.logger.info(
f"Sampling batches with weights {self.weights} with cumulative sizes {self.cumulative_sizes}."
)
self._batch_samplers = [
self.batch_sampler(sampler, 0 if idx == 0 else self.cumulative_sizes[idx - 1])
for idx, sampler in enumerate(self.samplers)
Expand Down
222 changes: 84 additions & 138 deletions direct/nn/rim/rim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
mixed_precision=mixed_precision,
**models,
)
self._complex_dim = -1
self._coil_dim = 1

def _do_iteration(
self,
Expand Down Expand Up @@ -98,11 +100,10 @@ def _do_iteration(
# The sensitivity map needs to be normalized such that
# So \sum_{i \in \text{coils}} S_i S_i^* = 1

complex_dim, coil_dim = -1, 1
sensitivity_map_norm = torch.sqrt(
((sensitivity_map ** 2).sum(complex_dim)).sum(coil_dim)
((sensitivity_map ** 2).sum(self._complex_dim)).sum(self._coil_dim)
) # shape (batch, [slice], height, width)
sensitivity_map_norm = sensitivity_map_norm.unsqueeze(1).unsqueeze(-1)
sensitivity_map_norm = sensitivity_map_norm.unsqueeze(self._coil_dim).unsqueeze(self._complex_dim)
data["sensitivity_map"] = T.safe_divide(sensitivity_map, sensitivity_map_norm)

if self.cfg.model.scale_loglikelihood: # type: ignore
Expand All @@ -116,8 +117,6 @@ def _do_iteration(
for _ in range(self.cfg.model.steps): # type: ignore
with autocast(enabled=self.mixed_precision):
if input_image is not None:
# TODO(gy): is this print here needed?
print(input_image.shape, input_image.names, "input_image")
input_image = input_image.permute((0, 2, 3, 4, 1) if input_image.ndim == 5 else (0, 2, 3, 1))
reconstruction_iter, hidden_state = self.model(
**data,
Expand Down Expand Up @@ -257,9 +256,27 @@ def evaluate(
loss_fns: Optional[Dict[str, Callable]],
regularizer_fns: Optional[Dict[str, Callable]] = None,
crop: Optional[str] = None,
is_validation_process=True,
is_validation_process: bool = True,
):
"""
Validation process. Assumes that each batch only contains slices of the same volume *AND* that these
are sequentially ordered.

Parameters
----------
data_loader : DataLoader
loss_fns : Dict[str, Callable], optional
regularizer_fns : Dict[str, Callable], optional
crop : str, optional
is_validation_process : bool

Returns
-------
loss_dict, all_gathered_metrics, visualize_slices, visualize_target

# TODO(jt): visualization should be a namedtuple or a dict or so

"""
self.models_to_device()
self.models_validation_mode()
torch.cuda.empty_cache()
Expand All @@ -271,6 +288,7 @@ def evaluate(

# filenames can be in the volume_indices attribute of the dataset
num_for_this_process = None
all_filenames = None
if hasattr(data_loader.dataset, "volume_indices"):
all_filenames = list(data_loader.dataset.volume_indices.keys())
num_for_this_process = len(list(data_loader.batch_sampler.sampler.volume_indices.keys()))
Expand All @@ -280,9 +298,9 @@ def evaluate(
)

filenames_seen = 0

reconstruction_output: DefaultDict = defaultdict(list)
targets_output: DefaultDict = defaultdict(list)
if is_validation_process:
targets_output: DefaultDict = defaultdict(list)
val_losses = []
val_volume_metrics: Dict[PathLike, Dict] = defaultdict(dict)
last_filename = None
Expand All @@ -298,11 +316,11 @@ def evaluate(

# Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
# splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
# that the slices are outputted from the Dataset *sequentially* for each volume one by one.
# that the slices are outputted from the Dataset *sequentially* for each volume one by one, and each batch only
# contains data from one volume.
time_start = time.time()

for iter_idx, data in enumerate(data_loader):
# data = AddNames()(data)
filenames = data.pop("filename")
if len(set(filenames)) != 1:
raise ValueError(
Expand Down Expand Up @@ -356,49 +374,6 @@ def evaluate(
if last_filename is None:
last_filename = filename # First iteration last_filename is not set.

# If the new filename is not the previous one, then we can reconstruct the volume as the sampling
# is linear.
# For the last case we need to check if we are at the last batch *and* at the last element in the batch.
is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
if filename != last_filename or is_last_element_of_last_batch:
filenames_seen += 1
# Now we can ditch the reconstruction dict by reconstructing the volume,
# will take too much memory otherwise.
# TODO: Stack does not support named tensors.
volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
if is_validation_process:

target = torch.stack([_[1] for _ in targets_output[last_filename]])
curr_metrics = {
metric_name: metric_fn(target, volume) for metric_name, metric_fn in volume_metrics.items()
}
val_volume_metrics[last_filename] = curr_metrics
# Log the center slice of the volume
if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
visualize_slices.append(volume[volume.shape[0] // 2])
visualize_target.append(target[target.shape[0] // 2])

# Delete outputs from memory, and recreate dictionary.
# This is not needed when not in validation as we are actually interested
# in the iteration output.
del targets_output
targets_output = defaultdict(list)
del reconstruction_output
reconstruction_output = defaultdict(list)

if all_filenames:
log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
else:
log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"

self.logger.info(
f"{log_prefix} {last_filename}"
f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
)
# restart timer
time_start = time.time()
last_filename = filename

curr_slice = output_abs[idx].detach()
slice_no = int(slice_nos[idx].numpy())

Expand All @@ -408,6 +383,49 @@ def evaluate(
if is_validation_process:
targets_output[filename].append((slice_no, target_abs[idx].cpu()))

# If the new filename is not the previous one, then we can reconstruct the volume as the sampling
# is linear. For the last case we need to check if we are at the last batch *and* at the last
# element in the batch.
is_last_element_of_last_batch = iter_idx + 1 == len(data_loader) and idx + 1 == len(data["target"])
reconstruction_conditions = [filename != last_filename, is_last_element_of_last_batch]
for condition in reconstruction_conditions:
if condition:
filenames_seen += 1

# Now we can ditch the reconstruction dict by reconstructing the volume,
# will take too much memory otherwise.
volume = torch.stack([_[1] for _ in reconstruction_output[last_filename]])
if is_validation_process:
target = torch.stack([_[1] for _ in targets_output[last_filename]])
curr_metrics = {
metric_name: metric_fn(target, volume)
for metric_name, metric_fn in volume_metrics.items()
}
val_volume_metrics[last_filename] = curr_metrics
# Log the center slice of the volume
if len(visualize_slices) < self.cfg.logging.tensorboard.num_images: # type: ignore
visualize_slices.append(volume[volume.shape[0] // 2])
visualize_target.append(target[target.shape[0] // 2])

# Delete outputs from memory, and recreate dictionary.
# This is not needed when not in validation as we are actually interested
# in the iteration output.
del targets_output[last_filename]
del reconstruction_output[last_filename]

if all_filenames:
log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
else:
log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"

self.logger.info(
f"{log_prefix} {last_filename}"
f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
)
# restart timer
time_start = time.time()
last_filename = filename

# Average loss dict
loss_dict = reduce_list_of_dicts(val_losses)
reduce_tensor_dict(loss_dict)
Expand All @@ -420,79 +438,8 @@ def evaluate(
if not is_validation_process:
return loss_dict, reconstruction_output

# TODO: Apply named tuples where applicable
# TODO: Several functions have multiple DoIterationOutput values, in many cases
# TODO: it would be more convenient to convert this to namedtuples.
return loss_dict, all_gathered_metrics, visualize_slices, visualize_target

# TODO: WORK ON THIS.
# def do_something_with_the_noise(self, data):
# # Seems like a better idea to compute noise in image space
# masked_kspace = data["masked_kspace"]
# sensitivity_map = data["sensitivity_map"]
# masked_image_forward = self.backward_operator(masked_kspace)
# masked_image_forward = masked_image_forward.align_to(
# *self.complex_names(add_coil=True)
# )
# noise_vector = self.compute_model_per_coil("noise_model", masked_image_forward)
#
# # Create a complex noise vector
# noise_vector = torch.view_as_complex(
# noise_vector.reshape(
# noise_vector.shape[0],
# noise_vector.shape[1],
# noise_vector.shape[-1] // 2,
# 2,
# )
# )
#
# # Apply prewhitening
# # https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.1241
# noise_int = noise_vector.reshape(
# noise_vector.shape[0], noise_vector.shape[1], -1
# )
# noise_int *= 1 / (noise_int.shape[-1] - 1)
#
# phi = T.complex_bmm(noise_int, noise_int.conj().transpose(1, 2))
# # TODO(jt): No cholesky nor inverse on GPU yet...
# new_basis = torch.inverse(torch.cholesky(phi.cpu())).to(phi.device) / np.sqrt(
# 2.0
# )
#
# # TODO(jt): Likely we need something a bit more elaborate e.g. percentile
# masked_kspace_max = masked_kspace.max()
# masked_kspace = self.view_as_complex(masked_kspace)
# prewhitened_kspace = (
# T.complex_bmm(
# new_basis,
# masked_kspace.rename(None).reshape(
# masked_kspace.shape[0], masked_kspace.shape[1], -1
# ),
# )
# .reshape(masked_kspace.shape)
# .refine_names(*masked_kspace.names)
# )
# prewhitened_kspace = self.view_as_real(prewhitened_kspace)
#
# # kspace has different values after whitening, lets map back
# prewhitened_kspace = (
# prewhitened_kspace / prewhitened_kspace.max() * masked_kspace_max
# )
# data["masked_kspace"] = prewhitened_kspace
#
# sensitivity_map = self.view_as_complex(sensitivity_map)
# prewhitened_sensitivity_map = (
# T.complex_bmm(
# new_basis,
# sensitivity_map.rename(None).reshape(
# masked_kspace.shape[0], masked_kspace.shape[1], -1
# ),
# )
# .reshape(masked_kspace.shape)
# .refine_names(*sensitivity_map.names)
# )
# sensitivity_map = self.view_as_real(prewhitened_sensitivity_map)

def process_output(self, data, scaling_factors=None, resolution=None):
# data is of shape (batch, complex=2, height, width)
if scaling_factors is not None:
Expand Down Expand Up @@ -555,11 +502,10 @@ def compute_model_per_coil(self, model_name, data):
# data is of shape (batch, coil, complex=2, [slice], height, width)
output = []

coil_index = 1
for idx in range(data.size(coil_index)):
subselected_data = data.select(coil_index, idx)
for idx in range(data.size(self._coil_dim)):
subselected_data = data.select(self._coil_dim, idx)
output.append(self.models[model_name](subselected_data))
output = torch.stack(output, dim=coil_index)
output = torch.stack(output, dim=self._coil_dim)

# output is of shape (batch, coil, complex=2, [slice], height, width)
return output
Expand Down Expand Up @@ -589,18 +535,19 @@ def __init__(
mixed_precision=mixed_precision,
**models,
)
self._slice_dim = -3

def process_output(self, data, scaling_factors=None, resolution=None):
# Data has shape (batch, complex, slice, height, width)
# TODO(gy): verify shape

slice_dim = -3
center_slice = data.size(slice_dim) // 2
self._slice_dim = -3
center_slice = data.size(self._slice_dim) // 2

if scaling_factors is not None:
data = data * scaling_factors.view(-1, *((1,) * (len(data.shape) - 1))).to(data.device)

data = T.modulus_if_complex(data).select(slice_dim, center_slice)
data = T.modulus_if_complex(data).select(self._slice_dim, center_slice)

if len(data.shape) == 3: # (batch, height, width)
data = data.unsqueeze(1) # Added channel dimension.
Expand All @@ -621,22 +568,21 @@ def cropper(self, source, target, resolution=(320, 320)):

"""
# TODO(gy): Verify target shape
slice_index = -3
self._slice_dim = -3

# TODO(gy): Why is this set to True and then have an if statement?
# TODO(jt): Because it might be the case we do it differently in say 3D. Just a placeholder really
use_center_slice = True
if use_center_slice:
# Source and target have a different number of slices when trimming in depth
source = source.select(
slice_index, source.size(slice_index) // 2
self._slice_dim, source.size(self._slice_dim) // 2
) # shape (batch, complex=2, height, width)
target = target.select(slice_index, target.size(slice_index) // 2).unsqueeze(
target = target.select(self._slice_dim, target.size(self._slice_dim) // 2).unsqueeze(
1
) # shape (batch, complex=1, height, width)
# else:
# source = source.permute(0, 2, 3, 4, 1) # shape (batch *slice, height, width, complex=2)
# source = source.flatten(0, 1).permute(0, 3, 1, 2) # shape (batch *slice, complex=2, height, width)
# target = target.flatten(0, 1).unsqueeze(1) # shape (batch*slice, 1, height, width)
else:
raise NotImplementedError("Only center slice cropping supported.")

source_abs = T.modulus(source) # shape (batch, height, width)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"scikit-image>=0.18.1",
"scikit-learn>=0.24.2",
"pyxb==1.2.6",
"ismrmrd @ git+https://git@github.com/ismrmrd/ismrmrd-python.git@v1.8.0#egg=ismrmrd",
"ismrmrd==1.9.1",
"tensorboard>=2.5.0",
],
extras_require={
Expand Down