-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathevaluate.py
More file actions
123 lines (106 loc) · 4.46 KB
/
evaluate.py
File metadata and controls
123 lines (106 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import numpy as np
import sys
import getopt
import os
import shutil
import matplotlib.pyplot as plt
import datetime
from Network import TOFlow
import warnings
warnings.filterwarnings("ignore", module="matplotlib.pyplot")
# ------------------------------
# I don't know whether you have a GPU.
plt.switch_backend('agg')
# Static
task = ''
dataset_dir = ''
pathlistfile = ''
model_path = ''
gpuID = None
if sys.argv[1] in ['-h', '--help']:
print("""pytoflow version 1.0
usage: python3 train.py [[option] [value]]...
options:
--task training task, like interp, denoising, super-resolution
valid values:[interp, denoise, denoising, sr, super-resolution]
--dataDir the directory of the input image dataset(Vimeo-90K, Vimeo-90K with noise, blurred Vimeo-90K)
--pathlist the text file records which are the images for train.
--model the path of the model used.
--gpuID the No. of the GPU you want to use.
--help get help.""")
exit(0)
for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [strParameter[2:] + '=' for strParameter in sys.argv[1::2]])[0]:
if strOption == '--task': # task
task = strArgument
elif strOption == '--dataDir': # dataset_dir
dataset_dir = strArgument
elif strOption == '--pathlist': # pathlist file
pathlistfile = strArgument
elif strOption == '--model': # model path
model_path = strArgument
elif strOption == '--gpuID': # gpu id
gpuID = int(strArgument)
if task == '':
raise ValueError('Missing [--task].\nPlease enter the training task.')
elif task not in ['interp', 'denoise', 'denoising', 'sr', 'super-resolution']:
raise ValueError('Invalid [--task].\nOnly support: [interp, denoise/denoising, sr/super-resolution]')
if dataset_dir == '':
raise ValueError('Missing [--dataDir].\nPlease provide the directory of the dataset. (Vimeo-90K)')
if pathlistfile == '':
raise ValueError('Missing [--pathlist].\nPlease provide the pathlist index file for test.')
if model_path == '':
raise ValueError('Missing [--model model_path].\nPlease provide the path of the toflow model.')
if gpuID == None:
cuda_flag = False
else:
cuda_flag = True
torch.cuda.set_device(gpuID)
# --------------------------------------------------------------
def mkdir_if_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
def vimeo_evaluate(input_dir, out_img_dir, test_codelistfile, task='', cuda_flag=True):
mkdir_if_not_exist(out_img_dir)
net = TOFlow(256, 448, cuda_flag=cuda_flag, task=task)
net.load_state_dict(torch.load(model_path))
if cuda_flag:
net.cuda().eval()
else:
net.eval()
fp = open(test_codelistfile)
test_img_list = fp.read().splitlines()
fp.close()
if task == 'interp':
process_index = [1, 3]
str_format = 'im%d.png'
elif task in ['interp', 'denoise', 'denoising', 'sr', 'super-resolution']:
process_index = [1, 2, 3, 4, 5, 6, 7]
str_format = 'im%04d.png'
else:
raise ValueError('Invalid [--task].\nOnly support: [interp, denoise/denoising, sr/super-resolution]')
total_count = len(test_img_list)
count = 0
pre = datetime.datetime.now()
for code in test_img_list:
# print('Processing %s...' % code)
count += 1
video = code.split('/')[0]
sep = code.split('/')[1]
mkdir_if_not_exist(os.path.join(out_img_dir, video))
mkdir_if_not_exist(os.path.join(out_img_dir, video, sep))
input_frames = []
for i in process_index:
input_frames.append(plt.imread(os.path.join(input_dir, code, str_format % i)))
input_frames = np.transpose(np.array(input_frames), (0, 3, 1, 2))
if cuda_flag:
input_frames = torch.from_numpy(input_frames).cuda()
else:
input_frames = torch.from_numpy(input_frames)
input_frames = input_frames.view(1, input_frames.size(0), input_frames.size(1), input_frames.size(2), input_frames.size(3))
predicted_img = net(input_frames)[0, :, :, :]
plt.imsave(os.path.join(out_img_dir, video, sep, 'out.png'),predicted_img.permute(1, 2, 0).cpu().detach().numpy())
cur = datetime.datetime.now()
processing_time = (cur - pre).seconds / count
print('%.2fs per frame.\t%.2fs left.' % (processing_time, processing_time * (total_count - count)))
vimeo_evaluate(dataset_dir, './evaluate', pathlistfile, task=task, cuda_flag=cuda_flag)