-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
137 lines (100 loc) · 4.82 KB
/
eval.py
File metadata and controls
137 lines (100 loc) · 4.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import time
import cv2
import numpy as np
from tqdm import tqdm
import datetime
from argparse import ArgumentParser
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from utils import PSNR, SSIM, time2file_name, AverageMeter
from data.dotadataset import make_dataset
from network.BMNet import BMNet
import einops
parser = ArgumentParser(description='BMI')
parser.add_argument('--gpu', type=str, default='3', help='gpu index')
parser.add_argument('--data_path', type=str, default="./samples", help='path to test set')
# parser.add_argument('--model_path', type=str, default="/data2/wangzhibin/trainning_ckpt/BMNet/cr-16/model_best.pth", help='trained or pre-trained model directory')
parser.add_argument('--model_path', type=str, default="/data2/wangzhibin/_trainning_ckpt/_RSSCI_v2.0/baseline-16/model_best.pth", help='trained or pre-trained model directory')
parser.add_argument('--results_path', type=str, default='./results', help='results for reconstructed images')
parser.add_argument('--image_size', type=int, nargs='+', default=[512, 512], help='image size')
parser.add_argument('--resize_size', type=int, nargs='+', default=None, help='image size')
parser.add_argument("--cs_ratio", type=int, nargs='+', default=[4, 4], help="compression ratio")
parser.add_argument("--num_show", type=int, default=1, help="number of images to show")
parser.add_argument('--seed', type=int, default=42, help='random seed')
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# load model and make model saving/log dir
date_time = str(datetime.datetime.now())
date_time = time2file_name(date_time)
results_dir = args.results_path + '/' + date_time
if not os.path.exists(results_dir):
os.makedirs(results_dir)
cr1, cr2 = args.cs_ratio
def eval():
time_avg_meter = AverageMeter()
psnr_avg_meter = AverageMeter()
ssim_avg_meter = AverageMeter()
show_test = args.num_show
model = BMNet(in_chans=1, num_stage=10, embed_dim=32, cs_ratio=args.cs_ratio).to(device)
ckpt = torch.load(args.model_path, map_location='cpu')
mask = ckpt['mask'].to(device)
model_ckpt = ckpt['state_dict']
model_ckpt = {k.replace('module.', ''): v for k, v in model_ckpt.items()}
if 'mask' in model_ckpt: del model_ckpt['mask']
model.load_state_dict(model_ckpt, strict=False)
model.to(device)
model.eval()
_, test_dataset = make_dataset(args.data_path, args.data_path, cr=max(args.cs_ratio), ycrcb=True, name=True, seed=args.seed)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)
iter = 0
print('#################Testing##################')
for [img_y, img_ycrcb], name in tqdm(test_dataloader):
bs = img_y.shape[0]
img_test = img_y.type(torch.FloatTensor).to(device)
if args.resize_size:
t = transforms.Resize(args.resize_size)
input_img = t(img_test)
else:
input_img = img_test
input_mask = mask.unsqueeze(0).expand(bs, -1, -1, -1, -1).to(device) * model.scaler
input_img = einops.rearrange(input_img, "b c (cr1 h) (cr2 w) -> b (cr1 cr2) c h w", cr1=cr1, cr2=cr2)
meas = torch.sum(input_img * input_mask, dim=1, keepdim=True)
with torch.no_grad():
torch.cuda.synchronize()
st = time.time()
out_test = model(meas, input_mask)
torch.cuda.synchronize()
ed = time.time()
time_avg_meter.update(ed - st)
if args.resize_size:
t = transforms.Resize(args.image_size)
out_test = t(out_test)
out_test = torch.clamp(out_test, 0, 1)
psnr_test = PSNR(out_test, img_y, 1.)
psnr_avg_meter.update(psnr_test)
ssim_test = SSIM(out_test, img_y, 1.)
ssim_avg_meter.update(ssim_test)
if show_test:
img_ycrcb = (img_ycrcb[0].cpu().data.numpy() * 255.).astype(np.uint8).transpose(1, 2, 0)
show1 = cv2.cvtColor(img_ycrcb, cv2.COLOR_YCrCb2BGR)
show2 = out_test[0, :]
y = (show2.squeeze(0).cpu().data.numpy() * 255.).astype(np.uint8)
show2 = img_ycrcb.copy()
show2[:, :, 0] = y
show2 = cv2.cvtColor(show2, cv2.COLOR_YCrCb2BGR)
cv2.imwrite(results_dir + f'/orig_{name}.jpg', show1)
cv2.imwrite(results_dir + f'/recon_{name}.jpg', show2)
show_test = show_test - 1
print("test psnr: %.4f" % psnr_avg_meter.avg)
print("test ssim: %.4f" % ssim_avg_meter.avg)
print("avg throughput: %.4f s/image" % time_avg_meter.avg)
if __name__ == "__main__":
eval()