-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathopt_nf_diff_ms.py
More file actions
197 lines (162 loc) · 8.84 KB
/
opt_nf_diff_ms.py
File metadata and controls
197 lines (162 loc) · 8.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as nnf
import argparse
import os
from torch import optim
import sys
import torch
from utils.csv_logger import infer_csv
from data import datasets
from models.dnvf.networks.siren import Siren, VecInt
from utils.train_utils import get_mgrid
import utils.losses as losses
import utils.utils as utils
from utils.train_utils import (adjust_learning_rate,
save_checkpoint)
'''
parse command line arg
'''
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='OASIS')
parser.add_argument('--save_dir', type=str, default='./infer_results/', help='The directory to save infer results')
parser.add_argument('--weight_jdet', type=int, default=50)
parser.add_argument('--val_iter', type=int, default=5)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--hidden_layers', type=int, default=5)
parser.add_argument('--hidden_features', type=int, default=512)
parser.add_argument(
"--downsample_factors",
nargs="+", # accept one or more arguments
type=int, # specify type
required=True,
help="list of the downsample factors. coarest to finest"
)
parser.add_argument(
"--optimization_epochs",
nargs="+", # accept one or more arguments
type=int, # specify type
required=True,
help="list of the optimization epoches. coarest to finest"
)
args = parser.parse_args()
def main():
csv_name = 'diff_nf_jDet{}_numlayer_{}_numfeature_{}_ms_{}_epoch_{}_{}'.format(args.weight_jdet, "_".join([str(d) for d in args.downsample_factors]), args.hidden_layers, args.hidden_features, "_".join([str(e) for e in args.optimization_epochs]), args.dataset) + '.csv'
downsample_factors = args.downsample_factors
optimization_epochs = args.optimization_epochs
val_iter = args.val_iter
weight_jdet = args.weight_jdet
lr = args.lr
hidden_layers = args.hidden_layers
hidden_features = args.hidden_features
save_dir = args.save_dir + args.dataset + '/' + 'nf_{}_{}'.format(args.hidden_layers, args.hidden_features) + '/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
'''
initialize dataset
'''
test_dataset, img_size = datasets.load_test_dataset(args.dataset)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False, drop_last=False)
'''
initialize spatial transformation function
'''
reg_model = utils.register_model(img_size, 'nearest')
reg_model.cuda()
reg_model_img = utils.register_model(img_size, 'bilinear')
reg_model_img.cuda()
'''
initialize optimizer and loss functions
'''
criterion_ncc = losses.NCC_vxm()
criterion_reg = losses.Grad3d(penalty='l2')
criterion_jdet = losses.NegJDet(img_size).cuda()
"""start inferring"""
eval_dsc_def = utils.AverageMeter()
eval_dsc_raw = utils.AverageMeter()
eval_det = utils.AverageMeter()
"""Prepare Input Grid"""
norm_siren_grids = []
for downsample_factor in downsample_factors:
siren_input_shape = np.array(img_size) // downsample_factor
siren_grid = get_mgrid(siren_input_shape).permute(0, 2, 3, 4, 1)
norm_siren_grid = siren_grid.clone().float()
norm_siren_grid[..., 0] = 2 * (siren_grid[..., 0] / (siren_input_shape[0]-1) - 0.5)
norm_siren_grid[..., 1] = 2 * (siren_grid[..., 1] / (siren_input_shape[1]-1) - 0.5)
norm_siren_grid[..., 2] = 2 * (siren_grid[..., 2] / (siren_input_shape[2]-1) - 0.5)
norm_siren_grid = norm_siren_grid.cuda()
print("flow field input shape: ", norm_siren_grid.shape)
norm_siren_grids.append(norm_siren_grid)
for data_idx, data in enumerate(test_loader):
data = [t.cuda() for t in data]
x, y, x_seg, y_seg = data
loss_all = utils.AverageMeter()
print('Processing pair: {}'.format(data_idx))
"""Initialize model"""
vec_field = Siren(in_features=3, hidden_features=hidden_features, hidden_layers=hidden_layers, out_features=3, outermost_linear=True)
vec_field = vec_field.cuda()
adam = optim.Adam(vec_field.parameters(), lr=lr, weight_decay=0, amsgrad=True)
vec_int = VecInt(img_size).cuda()
for idx, (downsample_factor, optimization_epoch, norm_siren_grid) in enumerate(zip(downsample_factors, optimization_epochs, norm_siren_grids)):
print('Current scale: {}x downsample, {} optimization epochs'.format(downsample_factor, optimization_epoch))
for epoch in range(optimization_epoch):
vec_field.train()
adjust_learning_rate(adam, epoch, optimization_epoch, lr)
input_norm_siren_grid = norm_siren_grid
output_norm_siren_vec = vec_field(input_norm_siren_grid)
output_norm_siren_vec = nnf.interpolate(output_norm_siren_vec.permute(0, 4, 1, 2, 3), size=img_size, mode='trilinear', align_corners=True)
output_full_vec = output_norm_siren_vec
output_full_vec[:, 0] = output_full_vec[:, 0] / 2 * (img_size[0] - 1)
output_full_vec[:, 1] = output_full_vec[:, 1] / 2 * (img_size[1] - 1)
output_full_vec[:, 2] = output_full_vec[:, 2] / 2 * (img_size[2] - 1)
output_full_flow = vec_int(output_full_vec)
warped_x = reg_model_img([x, output_full_flow])
'''use ncc loss'''
loss_ncc = criterion_ncc(warped_x, y)
loss_reg = criterion_reg(output_full_flow)
loss_jdet = criterion_jdet(output_full_flow)
loss = loss_ncc + loss_reg + loss_jdet * weight_jdet
loss_vals = [loss_ncc, loss_reg, loss_jdet]
loss_all.update(loss.item(), y.numel())
adam.zero_grad()
loss.backward()
adam.step()
current_lr = adam.state_dict()['param_groups'][0]['lr']
sys.stdout.write('\r Pair {}: Epoch [{}/{}] - loss {:.4f}, Img Sim: {:.6f}, Reg: {:.6f}, JDet: {:.6f}, lr: {:.6f}'.format(
data_idx, epoch, optimization_epoch, loss.item(), loss_vals[0].item(), loss_vals[1].item(), loss_vals[2].item(), current_lr))
sys.stdout.flush()
if epoch % val_iter == 0 or epoch == optimization_epoch - 1:
with torch.no_grad():
def_out = reg_model([x_seg.cuda().float(), output_full_flow.cuda()])
'''update DSC'''
if args.dataset == "OASIS":
dsc_trans = utils.dice_OASIS(def_out.long(), y_seg.long())
dsc_raw = utils.dice_OASIS(x_seg.long(), y_seg.long())
elif args.dataset == "IXI":
dsc_trans = utils.dice_IXI(def_out.long(), y_seg.long())
dsc_raw = utils.dice_IXI(x_seg.long(), y_seg.long())
elif args.dataset == "MindBoggle":
dsc_trans = utils.dice_MindBoggle(def_out.long(), y_seg.long())
dsc_raw = utils.dice_MindBoggle(x_seg.long(), y_seg.long())
elif args.dataset == "LPBA":
dsc_trans = utils.dice_LPBA(def_out.long(), y_seg.long())
dsc_raw = utils.dice_LPBA(x_seg.long(), y_seg.long())
elif args.dataset == "AbdomenCT":
dsc_trans = utils.dice_AbdomenCT(def_out.long(), y_seg.long())
dsc_raw = utils.dice_AbdomenCT(x_seg.long(), y_seg.long())
'''update Jdet'''
jac_det = utils.jacobian_determinant_vxm(output_full_flow.detach().cpu().numpy()[0, :, :, :, :])
tar = y.detach().cpu().numpy()[0, 0, :, :, :]
Jdet = np.sum(jac_det <= 0) / np.prod(tar.shape)
print('\n Val: Pair {}: Epoch [{}/{}] - DSC: {:.6f}, Jdet: {:.8f}, lr: {:.6f}\n'.format(
data_idx, epoch, optimization_epoch, dsc_trans, Jdet, current_lr))
torch.cuda.empty_cache()
eval_det.update(Jdet, x.size(0))
eval_dsc_raw.update(dsc_raw, x.size(0))
eval_dsc_def.update(dsc_trans, x.size(0))
infer_csv(save_dir, csv_name, data_idx, dsc_raw, dsc_trans, Jdet)
print()
infer_csv(save_dir, csv_name, 'avg', eval_dsc_raw.avg, eval_dsc_def.avg, eval_det.avg)
infer_csv(save_dir, csv_name, 'std', eval_dsc_raw.std, eval_dsc_def.std, eval_det.std)
if __name__ == '__main__':
main()