-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathBatchIterator.lua
134 lines (106 loc) · 4.33 KB
/
BatchIterator.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
require 'image'
require 'utils'
local BatchIterator = torch.class('BatchIterator')
function BatchIterator:__init(config, train_set, test_set)
self.batch_size = config.batch_size or 128
self.pixel_means = config.pixel_means or {0, 0, 0}
self.mr = config.mr
self.train = {}
self.test = {}
self.train.data = train_set
self.test.data = test_set
if #train_set > 0 then
self.train.order = torch.randperm(#self.train.data)
else
self.train.order = torch.Tensor(0);
end
-- self.test.order = torch.randperm(#self.test.data)
self.test.order = torch.range(1,#self.test.data)
self.train.id = 1
self.test.id = 1
self.epoch = 0
end
function BatchIterator:setBatchSize(batch_size)
self.batch_size = batch_size or 128
end
function BatchIterator:nextEntry(set)
local i = self[set].i or 1
self[set].i = i
if i > #self[set].data then
if set == "train" then
self[set].order = torch.randperm(#self[set].data)
end
i = 1
self.epoch = self.epoch + 1
end
local index = self[set].order[i]
self[set].i = self[set].i + 1
return self[set].data[index]
end
function BatchIterator:currentName(set)
local i = self[set].i
local index = self[set].order[i-1]
return self[set].data[index].name
end
function BatchIterator:nextBatch(set, config)
-- print(use_photo_realistic)
-- local use_pr = use_photo_realistic or true
-- print(use_photo_realistic)
local batch = {}
batch.input = {}
batch.output = {}
batch.valid = {}
for i = 1, self.batch_size do
local entry = self:nextEntry(set)
if set == "train" then
while not (file_exists(entry.input_file) and file_exists(entry.input_valid) and file_exists(entry.output_file)) do
entry = self:nextEntry(set)
end
local output = image.load(entry.output_file)
local valid = image.load(entry.input_valid)
-- define your data process here
output = output:add(-0.5):mul(2)
output = output:index(2,torch.range(1,output:size(2),2):long())
output = output:index(3,torch.range(1,output:size(3),2):long())
valid = valid:index(2,torch.range(1,valid:size(2),2):long())
valid = valid:index(3,torch.range(1,valid:size(3),2):long())
-- end
table.insert(batch.output, output)
table.insert(batch.valid, valid)
if config.verbose then
print(string.format("output max: %f, min: %f, size: %d %d", output:max(), output:min(), output:size(2), output:size(3)))
print(string.format("valid max: %f, min: %f, size: %d %d", valid:max(), valid:min(), valid:size(2), valid:size(3)))
end
end
local input = image.load(entry.input_file)
-- process your input here
input = input[{{1,3},{},{}}]
for ch = 1, 3 do
if math.max(unpack(self.pixel_means)) < 1 then
input[{ch, {}, {}}]:add(-self.pixel_means[ch])
else
input[{ch, {}, {}}]:add(-self.pixel_means[ch] / 255)
end
end
input = input:index(2,torch.range(1,input:size(2),2):long())
input = input:index(3,torch.range(1,input:size(3),2):long())
-- end
table.insert(batch.input, input)
if config.verbose then
print(string.format("input max: %f, min: %f, size: %d %d", input:max(), input:min(), input:size(2), input:size(3)))
end
end
-- format img
local ch, h, w = batch.input[1]:size(1), batch.input[1]:size(2), batch.input[1]:size(3)
batch.input = torch.cat(batch.input, 1):view(self.batch_size, ch, h, w)
-- ch, h, w= batch.input[1]:size(1), batch.input[1]:size(2), batch.input[1]:size(3)
-- batch.input = torch.cat(batch.input):view(self.batch_size, ch, h, w)
-- print(string.format("input size: %d %d %d %d", batch.input:size()))
if set == "train" then
ch, h, w = batch.output[1]:size(1), batch.output[1]:size(2), batch.output[1]:size(3)
batch.output = torch.cat(batch.output, 1):view(self.batch_size, ch, h, w)
ch, h, w = batch.valid[1]:size(1), batch.valid[1]:size(2), batch.valid[1]:size(3)
batch.valid = torch.cat(batch.valid, 1):view(self.batch_size, ch, h, w)
end
return batch
end