-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquick_test.py
More file actions
117 lines (97 loc) · 5.05 KB
/
quick_test.py
File metadata and controls
117 lines (97 loc) · 5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from monai.utils import set_determinism
import utils
from monai.transforms import Compose,EnsureChannelFirst,Activations, AsDiscrete
from glob import glob
import os
from monai.data import ImageDataset,DataLoader,decollate_batch
import config
import argparse
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
import numpy as np
def create_dataset_for_test(img_path,seg_path):
val_imtrans = Compose([EnsureChannelFirst()])
val_segtrans = Compose([EnsureChannelFirst()])
images = sorted(glob(os.path.join(img_path, "*.*")))
segs = sorted(glob(os.path.join(seg_path, "*.*")))
val_ds = ImageDataset(images, segs, transform=val_imtrans, seg_transform=val_segtrans,image_only=False)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=0)
return val_loader
def test(model, val_loader, channel_m,modalities, model_net_type,model_modalities_trained_on,model_channel_map, device, save_outputs, save_path):
cropped_input_size = [128,128,128]
#can add other metrics (here only show dice)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
with torch.no_grad():
#initialize
val_images = None
val_labels = None
val_outputs = None
steps = 0
dice_metric.reset()
dice_metrics = []
segment_pixel_vol = []
gt_pixel_vol = []
for val_data in val_loader:
roi_size = (cropped_input_size[0], cropped_input_size[1], cropped_input_size[2])
sw_batch_size = 1
#test using sliding window
if model_net_type == "UNet":
val_data[0] = utils.create_UNET_input_quicktest(val_data, modalities, channel_m, model_modalities_trained_on)
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
# compute metric for the current iteration
current_dice = dice_metric(y_pred=val_outputs, y=val_labels)
if save_outputs:
#save output with the original affine
file_save_path = save_path + str(steps)+'_'+str(current_dice) + ".nii.gz"
utils.save_nifti(val_outputs[0], file_save_path, val_data[3]["affine"])
pixels_segmented = np.count_nonzero(val_outputs[0])
gt_segmented = np.count_nonzero(val_labels[0])
segment_pixel_vol.append(pixels_segmented)
gt_pixel_vol.append(gt_segmented)
steps+=1
metric = dice_metric.aggregate().item()
print("DICE Metric:")
print(metric)
dice_metric.reset()
return dice_metrics, segment_pixel_vol, gt_pixel_vol
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device_id", help="ID of the GPU", type=int, default=0)
parser.add_argument("--img_path", help="path of the dataset for testing (image)", type=str)
parser.add_argument("--seg_path", help="path of the dataset for testing (label)", type=str)
parser.add_argument("--modalities_to_test", help="The modalities for testing (the index of the modalities for that input),using '_' to separate if 0_1_2 for BRATS it would mean test on FLAIR, T1, T1c", type=str)
parser.add_argument("--test_all_combinations", help="0 or 1 1 if testing on all possible modality combinations", type=int, default='0')
parser.add_argument("--channel_m", help="channel_map instruction can be seen in the readme", type=str, default='0')
args = parser.parse_args()
channel_m=[int(x) for x in args.channel_m.split("_")]
test_all_combinations=bool(args.test_all_combinations)
cuda_id = "cuda:" + str(args.device_id)
device = torch.device(cuda_id)
torch.cuda.set_device(cuda_id)
results = {}
Test_config=config.Test_config()
print("*************** TESTING NET " + str(Test_config.model_file_path) + " **************")
model = utils.create_net(Test_config.model_file_path,Test_config.model_net_type,Test_config.model_modalities_trained_on, device, cuda_id)
print("************** TESTING DATASET FROM " + args.img_path + " ***************")
dataloader = create_dataset_for_test(args.img_path,args.seg_path)
if test_all_combinations:
modalities = utils.create_modality_combinations([int(x) for x in args.modalities_to_test.split("_")])
else:
modalities = [[int(x) for x in args.modalities_to_test.split("_")]]
for combination in modalities:
print(combination)
dsc_scores, seg_pix_vols, gt_pix_vols = test(model,
dataloader,
channel_m,
combination,
Test_config.model_net_type,
Test_config.model_modalities_trained_on,
Test_config.model_channel_map,
device,
save_outputs= Test_config.save_segs,
save_path=Test_config.save_path,
)