-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtestCNN.lua
More file actions
126 lines (119 loc) · 4.13 KB
/
testCNN.lua
File metadata and controls
126 lines (119 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
--
-- Copyright (c) 2014, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
local testDataIterator = function()
testLoader:reset()
return function() return testLoader:get_batch(false) end
end
matio = require 'matio'
batchNum = math.ceil(nTest/opt.batchSize)
if opt.testNum~=-1 then -- only a few batches
batchNum=opt.testNum
end
local top1_center, loss
local timer = torch.Timer()
function test()
print('==> doing epoch on validation data:')
print("==> online epoch # " .. epoch)
cutorch.synchronize()
timer:reset()
-- set the dropouts to evaluate mode
-- model:training()
model:evaluate()
top1_center = 0
top1_count = 0
loss = 0
local saveSuf='.mat'
if opt.testSave=='f7' then
saveSuf='_f7.mat'
end
local i;
-- opt.testReverse=-1
for ii=1,batchNum,math.abs(opt.testReverse) do -- nTest is set in 1_data.lua
if opt.testReverse<0 then
i=batchNum+1-ii
else
i=ii;
end
if not paths.filep(opt.save..'test_'..i..saveSuf) then
-- print('do:'..i)
local indexStart = (i-1) * opt.batchSize + 1
local indexEnd = math.min(indexStart + opt.batchSize - 1, nTest)
if string.match(opt.netType, "RCNN") then
donkeys:addjob(
function()
local inputs, labels = testLoader:get(indexStart, indexEnd)
return inputs[1],inputs[2], labels, i
end,
testBatchRCNN
)
else
donkeys:addjob(
function()
local inputs, labels = testLoader:get(indexStart, indexEnd)
return inputs, labels, i
end,
testBatchCNN
)
end
end
end
donkeys:synchronize()
cutorch.synchronize()
top1_center = top1_center * 100 / top1_count
loss = loss / nTest
testLogger:add{
['% top1 accuracy (test set) (center crop)'] = top1_center,
['avg loss (test set)'] = loss
}
print(string.format('Epoch: [%d][TESTING SUMMARY] Total Time(s): %.2f \t'
.. 'average loss (per batch): %.2f \t '
.. 'accuracy [Center](%%):\t top-1 %.2f\t ',
epoch, timer:time().real, loss, top1_center))
print('\n')
end -- of test()
-----------------------------------------------------------------------------
local inputs = torch.CudaTensor()
local inputs_roi = torch.CudaTensor()
local labels = torch.CudaTensor()
function testBatchCNN(inputsCPU, labelsCPU, batchNumber)
inputs:resize(inputsCPU:size()):copy(inputsCPU)
labels:resize(labelsCPU:size()):copy(labelsCPU)
local outputs = model:forward(inputs)
local pred = outputs:float()
cutorch.synchronize()
if opt.testDisplay==1 then
local err = criterion:forward(outputs, labels)
loss = loss + err*outputs:size(1)
local top1 = 0
local pred_sorted
if outputs:size(2)==1 then
pred_sorted = outputs:float():gt(0.5)
else
_,pred_sorted = outputs:float():sort(2, true) -- descending
end
for i=1,pred:size(1) do
local g = labelsCPU[i]
if pred_sorted[i][1] == g then
top1_center = top1_center + 1
top1 = top1 + 1
end
end
top1 = top1 * 100 / pred:size(1);
top1_count = top1_count + pred:size(1)
print(('Epoch: [%d][%d/%d]\t Err %.4f Top1-%%: %.2f'):format(epoch, batchNumber, batchNum, err, top1))
else
print(('Epoch: [%d][%d/%d]\t' ):format(epoch, batchNumber, batchNum))
end
if opt.testSave=='f8' then
matio.save(opt.save..'test_'..batchNumber..'.mat',{pred=pred})
elseif opt.testSave=='f7' then
matio.save(opt.save..'test_'..batchNumber..'_f7.mat',{pred=pred,f7=model.modules[2].modules[opt.testSaveId].output:float()})
end
end