-
Notifications
You must be signed in to change notification settings - Fork 30
/
theano_backend.py
923 lines (699 loc) · 28.5 KB
/
theano_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
import theano
from theano import tensor as T
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
from theano.tensor.signal import pool
from theano.tensor.nnet import conv3d2d
import inspect
import numpy as np
from .common import _FLOATX, _EPSILON
# INTERNAL UTILS
theano.config.floatX = _FLOATX
# VARIABLE MANIPULATION
def variable(value, dtype=_FLOATX, name=None):
'''Instantiate a tensor variable.
'''
value = np.asarray(value, dtype=dtype)
return theano.shared(value=value, name=name, strict=False)
def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
'''Instantiate an input data placeholder variable.
'''
if shape is None and ndim is None:
raise Exception('Specify either a shape or ndim value.')
if shape is not None:
ndim = len(shape)
broadcast = (False,) * ndim
return T.TensorType(dtype, broadcast)(name)
def shape(x):
'''Return the shape of a tensor.
Warning: type returned will be different for
Theano backend (Theano tensor type) and TF backend (TF TensorShape).
'''
return x.shape
def ndim(x):
return x.ndim
def eval(x):
'''Run a graph.
'''
return x.eval()
def zeros(shape, dtype=_FLOATX, name=None):
'''Instantiate an all-zeros variable.
'''
return variable(np.zeros(shape), dtype, name)
def ones(shape, dtype=_FLOATX, name=None):
'''Instantiate an all-ones variable.
'''
return variable(np.ones(shape), dtype, name)
def ones_like(x):
return T.ones_like(x)
def zeros_like(x):
return T.zeros_like(x)
def count_params(x):
'''Return number of scalars in a tensor.
Return: numpy integer.
'''
return np.prod(x.shape.eval())
def cast(x, dtype):
return T.cast(x, dtype)
# LINEAR ALGEBRA
'''
Assumed overridden:
+, -, /, *, +=, -=, *=, /=
'''
def dot(x, y):
return T.dot(x, y)
def batch_dot(x, y, axes=None):
if axes is None:
# behaves like tf.batch_matmul as default
axes = [(x.ndim-1,), (y.ndim-2,)]
return T.batched_tensordot(x, y, axes=axes)
def transpose(x):
return T.transpose(x)
def gather(reference, indices):
'''reference: a tensor.
indices: an int tensor of indices.
Return: a tensor of same type as reference.
'''
return reference[indices]
# ELEMENT-WISE OPERATIONS
def max(x, axis=None, keepdims=False):
return T.max(x, axis=axis, keepdims=keepdims)
def min(x, axis=None, keepdims=False):
return T.min(x, axis=axis, keepdims=keepdims)
def sum(x, axis=None, keepdims=False):
'''Sum of the values in a tensor, alongside the specified axis.
'''
return T.sum(x, axis=axis, keepdims=keepdims)
def prod(x, axis=None, keepdims=False):
'''Multiply the values in a tensor, alongside the specified axis.
'''
return T.prod(x, axis=axis, keepdims=keepdims)
def mean(x, axis=None, keepdims=False):
dtype = None
if 'int' in x.dtype:
dtype = _FLOATX
return T.mean(x, axis=axis, keepdims=keepdims, dtype=dtype)
def std(x, axis=None, keepdims=False):
return T.std(x, axis=axis, keepdims=keepdims)
def any(x, axis=None, keepdims=False):
'''Bitwise reduction (logical OR).
'''
return T.any(x, axis=axis, keepdims=keepdims)
def argmax(x, axis=-1):
return T.argmax(x, axis=axis, keepdims=False)
def argmin(x, axis=-1):
return T.argmin(x, axis=axis, keepdims=False)
def square(x):
return T.sqr(x)
def abs(x):
return T.abs_(x)
def sqrt(x):
x = T.clip(x, 0., np.inf)
return T.sqrt(x)
def exp(x):
return T.exp(x)
def log(x):
return T.log(x)
def round(x):
return T.round(x)
def pow(x, a):
return T.pow(x, a)
def clip(x, min_value, max_value):
if max_value < min_value:
max_value = min_value
return T.clip(x, min_value, max_value)
def equal(x, y):
return T.eq(x, y)
def not_equal(x, y):
return T.neq(x, y)
def maximum(x, y):
return T.maximum(x, y)
def minimum(x, y):
return T.minimum(x, y)
# SHAPE OPERATIONS
def concatenate(tensors, axis=-1):
return T.concatenate(tensors, axis=axis)
def reshape(x, shape):
return T.reshape(x, shape)
def permute_dimensions(x, pattern):
'''Transpose dimensions.
pattern should be a tuple or list of
dimension indices, e.g. [0, 2, 1].
'''
pattern = tuple(pattern)
return x.dimshuffle(pattern)
def repeat_elements(x, rep, axis):
'''Repeat the elements of a tensor along an axis, like np.repeat.
If x has shape (s1, s2, s3) and axis=1, the output
will have shape (s1, s2 * rep, s3).
'''
return T.repeat(x, rep, axis=axis)
def resize_images(X, height_factor, width_factor, dim_ordering):
'''Resize the images contained in a 4D tensor of shape
- [batch, channels, height, width] (for 'th' dim_ordering)
- [batch, height, width, channels] (for 'tf' dim_ordering)
by a factor of (height_factor, width_factor). Both factors should be
positive integers.
'''
if dim_ordering == 'th':
output = repeat_elements(X, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
elif dim_ordering == 'tf':
output = repeat_elements(X, height_factor, axis=1)
output = repeat_elements(output, width_factor, axis=2)
return output
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)
def resize_volumes(X, depth_factor, height_factor, width_factor, dim_ordering):
'''Resize the volume contained in a 5D tensor of shape
- [batch, channels, depth, height, width] (for 'th' dim_ordering)
- [batch, depth, height, width, channels] (for 'tf' dim_ordering)
by a factor of (depth_factor, height_factor, width_factor).
Both factors should be positive integers.
'''
if dim_ordering == 'th':
output = repeat_elements(X, depth_factor, axis=2)
output = repeat_elements(output, height_factor, axis=3)
output = repeat_elements(output, width_factor, axis=4)
return output
elif dim_ordering == 'tf':
output = repeat_elements(X, depth_factor, axis=1)
output = repeat_elements(output, height_factor, axis=2)
output = repeat_elements(output, width_factor, axis=3)
return output
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)
def repeat(x, n):
'''Repeat a 2D tensor.
If x has shape (samples, dim) and n=2,
the output will have shape (samples, 2, dim).
'''
assert x.ndim == 2
x = x.dimshuffle((0, 'x', 1))
return T.extra_ops.repeat(x, n, axis=1)
def tile(x, n):
return T.tile(x, n)
def flatten(x):
return T.flatten(x)
def batch_flatten(x):
'''Turn a n-D tensor into a 2D tensor where
the first dimension is conserved.
'''
x = T.reshape(x, (x.shape[0], T.prod(x.shape) // x.shape[0]))
return x
def expand_dims(x, dim=-1):
'''Add a 1-sized dimension at index "dim".
'''
pattern = [i for i in range(x.type.ndim)]
if dim < 0:
if x.type.ndim == 0:
dim = 0
else:
dim = dim % x.type.ndim + 1
pattern.insert(dim, 'x')
return x.dimshuffle(pattern)
def squeeze(x, axis):
'''Remove a 1-dimension from the tensor at index "axis".
'''
x = T.addbroadcast(x, axis)
return T.squeeze(x)
def temporal_padding(x, padding=1):
'''Pad the middle dimension of a 3D tensor
with "padding" zeros left and right.
Appologies for the inane API, but Theano makes this
really hard.
'''
input_shape = x.shape
output_shape = (input_shape[0],
input_shape[1] + 2 * padding,
input_shape[2])
output = T.zeros(output_shape)
return T.set_subtensor(output[:, padding:x.shape[1] + padding, :], x)
def spatial_2d_padding(x, padding=(1, 1), dim_ordering='th'):
'''Pad the 2nd and 3rd dimensions of a 4D tensor
with "padding[0]" and "padding[1]" (resp.) zeros left and right.
'''
input_shape = x.shape
if dim_ordering == 'th':
output_shape = (input_shape[0],
input_shape[1],
input_shape[2] + 2 * padding[0],
input_shape[3] + 2 * padding[1])
output = T.zeros(output_shape)
indices = (slice(None),
slice(None),
slice(padding[0], input_shape[2] + padding[0]),
slice(padding[1], input_shape[3] + padding[1]))
elif dim_ordering == 'tf':
output_shape = (input_shape[0],
input_shape[1] + 2 * padding[0],
input_shape[2] + 2 * padding[1],
input_shape[3])
output = T.zeros(output_shape)
indices = (slice(None),
slice(padding[0], input_shape[1] + padding[0]),
slice(padding[1], input_shape[2] + padding[1]),
slice(None))
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)
return T.set_subtensor(output[indices], x)
def custom_spatial_2d_padding(x, padding=(1, 1, 1, 1), dim_ordering='th'):
'''Custom Padding of the 2nd and 3rd dimensions of a 4D tensor
with "padding[0]" and "padding[1]" padding zeros to the 'top' and 'bottom' of the 2nd dim
and "padding[2]" and "padding[3]" padding zeros to the 'left' and 'right' of the 3rd dim
'''
input_shape = x.shape
if dim_ordering == 'th':
output_shape = (input_shape[0],
input_shape[1],
padding[0] + input_shape[2] + padding[1],
padding[2] + input_shape[3] + padding[3])
output = T.zeros(output_shape)
indices = (slice(None),
slice(None),
slice(padding[0], input_shape[2] + padding[0]),
slice(padding[2], input_shape[3] + padding[2]))
elif dim_ordering == 'tf':
output_shape = (input_shape[0],
padding[0] + input_shape[1] + padding[1],
padding[2] + input_shape[2] + padding[3],
input_shape[3])
output = T.zeros(output_shape)
indices = (slice(None),
slice(padding[0], input_shape[1] + padding[0]),
slice(padding[2], input_shape[2] + padding[2]),
slice(None))
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)
return T.set_subtensor(output[indices], x)
def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='th'):
'''Pad the 2nd, 3rd and 4th dimensions of a 5D tensor
with "padding[0]", "padding[1]" and "padding[2]" (resp.) zeros left and right.
'''
input_shape = x.shape
if dim_ordering == 'th':
output_shape = (input_shape[0],
input_shape[1],
input_shape[2] + 2 * padding[0],
input_shape[3] + 2 * padding[1],
input_shape[4] + 2 * padding[2])
output = T.zeros(output_shape)
indices = (slice(None),
slice(None),
slice(padding[0], input_shape[2] + padding[0]),
slice(padding[1], input_shape[3] + padding[1]),
slice(padding[2], input_shape[4] + padding[2]))
elif dim_ordering == 'tf':
output_shape = (input_shape[0],
input_shape[1] + 2 * padding[0],
input_shape[2] + 2 * padding[1],
input_shape[3] + 2 * padding[2],
input_shape[4])
output = T.zeros(output_shape)
indices = (slice(None),
slice(padding[0], input_shape[1] + padding[0]),
slice(padding[1], input_shape[2] + padding[1]),
slice(padding[2], input_shape[3] + padding[2]),
slice(None))
else:
raise Exception('Invalid dim_ordering: ' + dim_ordering)
return T.set_subtensor(output[indices], x)
def pack(x):
return T.stack(*x)
# VALUE MANIPULATION
def get_value(x):
if not hasattr(x, 'get_value'):
raise Exception("'get_value() can only be called on a variable. " +
"If you have an expression instead, use eval().")
return x.get_value()
def set_value(x, value):
x.set_value(np.asarray(value, dtype=x.dtype))
# GRAPH MANIPULATION
class Function(object):
def __init__(self, inputs, outputs, updates=[], **kwargs):
self.function = theano.function(inputs, outputs, updates=updates,
allow_input_downcast=True, **kwargs)
def __call__(self, inputs):
assert type(inputs) in {list, tuple}
return self.function(*inputs)
def function(inputs, outputs, updates=[], **kwargs):
if len(kwargs) > 0:
function_args = inspect.getargspec(theano.function)[0]
for key in kwargs.keys():
if key not in function_args:
msg = "Invalid argument '%s' passed to K.function" % key
raise ValueError(msg)
return Function(inputs, outputs, updates=updates, **kwargs)
def gradients(loss, variables):
return T.grad(loss, variables)
# CONTROL FLOW
def rnn(step_function, inputs, initial_states,
go_backwards=False, mask=None, constants=None):
'''Iterates over the time dimension of a tensor.
# Arguments
inputs: tensor of temporal data of shape (samples, time, ...)
(at least 3D).
step_function:
Parameters:
input: tensor with shape (samples, ...) (no time dimension),
representing input for the batch of samples at a certain
time step.
states: list of tensors.
Returns:
output: tensor with shape (samples, ...) (no time dimension),
new_states: list of tensors, same length and shapes
as 'states'.
initial_states: tensor with shape (samples, ...) (no time dimension),
containing the initial values for the states used in
the step function.
go_backwards: boolean. If True, do the iteration over
the time dimension in reverse order.
mask: binary tensor with shape (samples, time),
with a zero for every element that is masked.
constants: a list of constant values passed at each step.
# Returns
A tuple (last_output, outputs, new_states).
last_output: the latest output of the rnn, of shape (samples, ...)
outputs: tensor with shape (samples, time, ...) where each
entry outputs[s, t] is the output of the step function
at time t for sample s.
new_states: list of tensors, latest states returned by
the step function, of shape (samples, ...).
'''
ndim = inputs.ndim
assert ndim >= 3, 'Input should be at least 3D.'
axes = [1, 0] + list(range(2, ndim))
inputs = inputs.dimshuffle(axes)
if mask is not None:
if mask.ndim == ndim-1:
mask = expand_dims(mask)
assert mask.ndim == ndim
mask = mask.dimshuffle(axes)
if constants is None:
constants = []
# build an all-zero tensor of shape (samples, output_dim)
initial_output = step_function(inputs[0], initial_states + constants)[0] * 0
# Theano gets confused by broadcasting patterns in the scan op
initial_output = T.unbroadcast(initial_output, 0, 1)
def _step(input, mask, output_tm1, *states):
output, new_states = step_function(input, states)
# output previous output if masked.
output = T.switch(mask, output, output_tm1)
return_states = []
for state, new_state in zip(states, new_states):
return_states.append(T.switch(mask, new_state, state))
return [output] + return_states
results, _ = theano.scan(
_step,
sequences=[inputs, mask],
outputs_info=[initial_output] + initial_states,
non_sequences=constants,
go_backwards=go_backwards)
else:
def _step(input, *states):
output, new_states = step_function(input, states)
return [output] + new_states
results, _ = theano.scan(
_step,
sequences=inputs,
outputs_info=[None] + initial_states,
non_sequences=constants,
go_backwards=go_backwards)
# deal with Theano API inconsistency
if type(results) is list:
outputs = results[0]
states = results[1:]
else:
outputs = results
states = []
outputs = T.squeeze(outputs)
last_output = outputs[-1]
axes = [1, 0] + list(range(2, outputs.ndim))
outputs = outputs.dimshuffle(axes)
states = [T.squeeze(state[-1]) for state in states]
return last_output, outputs, states
def switch(condition, then_expression, else_expression):
'''condition: scalar tensor.
'''
return T.switch(condition, then_expression, else_expression)
# NN OPERATIONS
def relu(x, alpha=0., max_value=None):
assert hasattr(T.nnet, 'relu'), ('It looks like like your version of '
'Theano is out of date. '
'Install the latest version with:\n'
'pip install git+git://github.com/Theano/Theano.git --upgrade --no-deps')
x = T.nnet.relu(x, alpha)
if max_value is not None:
x = T.minimum(x, max_value)
return x
def softmax(x):
return T.nnet.softmax(x)
def softplus(x):
return T.nnet.softplus(x)
def categorical_crossentropy(output, target, from_logits=False):
if from_logits:
output = T.nnet.softmax(output)
else:
# scale preds so that the class probas of each sample sum to 1
output /= output.sum(axis=-1, keepdims=True)
# avoid numerical instability with _EPSILON clipping
output = T.clip(output, _EPSILON, 1.0 - _EPSILON)
return T.nnet.categorical_crossentropy(output, target)
def binary_crossentropy(output, target, from_logits=False):
if from_logits:
output = T.nnet.sigmoid(output)
# avoid numerical instability with _EPSILON clipping
output = T.clip(output, _EPSILON, 1.0 - _EPSILON)
return T.nnet.binary_crossentropy(output, target)
def sigmoid(x):
return T.nnet.sigmoid(x)
def hard_sigmoid(x):
return T.nnet.hard_sigmoid(x)
def tanh(x):
return T.tanh(x)
def dropout(x, level, seed=None):
if level < 0. or level >= 1:
raise Exception('Dropout level must be in interval [0, 1[.')
if seed is None:
seed = np.random.randint(10e6)
rng = RandomStreams(seed=seed)
retain_prob = 1. - level
x *= rng.binomial(x.shape, p=retain_prob, dtype=x.dtype)
x /= retain_prob
return x
def l2_normalize(x, axis):
norm = T.sqrt(T.sum(T.square(x), axis=axis, keepdims=True))
return x / norm
# CONVOLUTIONS
def conv2d(x, kernel, strides=(1, 1), border_mode='valid', dim_ordering='th',
image_shape=None, filter_shape=None):
'''
border_mode: string, "same" or "valid".
'''
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if dim_ordering == 'tf':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, rows, cols)
# TF input shape: (samples, rows, cols, input_depth)
# TH kernel shape: (depth, input_depth, rows, cols)
# TF kernel shape: (rows, cols, input_depth, depth)
x = x.dimshuffle((0, 3, 1, 2))
kernel = kernel.dimshuffle((3, 2, 0, 1))
if image_shape:
image_shape = (image_shape[0], image_shape[3],
image_shape[1], image_shape[2])
if filter_shape:
filter_shape = (filter_shape[3], filter_shape[2],
filter_shape[0], filter_shape[1])
if border_mode == 'same':
th_border_mode = 'half'
np_kernel = kernel.eval()
assert strides[0] <= np_kernel.shape[2], 'strides should be smaller than the convolution window.'
assert strides[1] <= np_kernel.shape[3], 'strides should be smaller than the convolution window.'
elif border_mode == 'valid':
th_border_mode = 'valid'
else:
raise Exception('Border mode not supported: ' + str(border_mode))
# Theano might not accept like longs
def int_or_none(value):
try:
return int(value)
except TypeError:
return None
if image_shape is not None:
image_shape = tuple(int_or_none(v) for v in image_shape)
if filter_shape is not None:
filter_shape = tuple(int_or_none(v) for v in filter_shape)
conv_out = T.nnet.conv2d(x, kernel,
border_mode=th_border_mode,
subsample=strides,
input_shape=image_shape,
filter_shape=filter_shape)
if border_mode == 'same':
if np_kernel.shape[2] % 2 == 0:
conv_out = conv_out[:,:,:(x.shape[2]+strides[0]-1) // strides[0],:]
if np_kernel.shape[3] % 2 == 0:
conv_out = conv_out[:,:,:,:(x.shape[3]+strides[1]-1) // strides[1]]
if dim_ordering == 'tf':
conv_out = conv_out.dimshuffle((0, 2, 3, 1))
return conv_out
def conv3d(x, kernel, strides=(1, 1, 1),
border_mode='valid', dim_ordering='th',
volume_shape=None, filter_shape=None):
'''
Run on cuDNN if available.
border_mode: string, "same" or "valid".
'''
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if border_mode not in {'same', 'valid'}:
raise Exception('Invalid border mode: ' + str(border_mode))
if dim_ordering == 'tf':
# TF uses the last dimension as channel dimension,
# instead of the 2nd one.
# TH input shape: (samples, input_depth, conv_dim1, conv_dim2, conv_dim3)
# TF input shape: (samples, conv_dim1, conv_dim2, conv_dim3, input_depth)
# TH kernel shape: (out_depth, input_depth, kernel_dim1, kernel_dim2, kernel_dim3)
# TF kernel shape: (kernel_dim1, kernel_dim2, kernel_dim3, input_depth, out_depth)
x = x.dimshuffle((0, 4, 1, 2, 3))
kernel = kernel.dimshuffle((4, 3, 0, 1, 2))
if volume_shape:
volume_shape = (volume_shape[0], volume_shape[4],
volume_shape[1], volume_shape[2], volume_shape[3])
if filter_shape:
filter_shape = (filter_shape[4], filter_shape[3],
filter_shape[0], filter_shape[1], filter_shape[2])
if border_mode == 'same':
assert(strides == (1, 1, 1))
pad_dim1 = (kernel.shape[2] - 1)
pad_dim2 = (kernel.shape[3] - 1)
pad_dim3 = (kernel.shape[4] - 1)
output_shape = (x.shape[0], x.shape[1],
x.shape[2] + pad_dim1,
x.shape[3] + pad_dim2,
x.shape[4] + pad_dim3)
output = T.zeros(output_shape)
indices = (slice(None), slice(None),
slice(pad_dim1 // 2, x.shape[2] + pad_dim1 // 2),
slice(pad_dim2 // 2, x.shape[3] + pad_dim2 // 2),
slice(pad_dim3 // 2, x.shape[4] + pad_dim3 // 2))
x = T.set_subtensor(output[indices], x)
border_mode = 'valid'
border_mode_3d = (border_mode, border_mode, border_mode)
conv_out = conv3d2d.conv3d(signals=x.dimshuffle(0, 2, 1, 3, 4),
filters=kernel.dimshuffle(0, 2, 1, 3, 4),
border_mode=border_mode_3d)
conv_out = conv_out.dimshuffle(0, 2, 1, 3, 4)
# support strides by manually slicing the output
if strides != (1, 1, 1):
conv_out = conv_out[:, :, ::strides[0], ::strides[1], ::strides[2]]
if dim_ordering == 'tf':
conv_out = conv_out.dimshuffle((0, 2, 3, 4, 1))
return conv_out
def pool2d(x, pool_size, strides=(1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
if border_mode == 'same':
w_pad = pool_size[0] - 2 if pool_size[0] % 2 == 1 else pool_size[0] - 1
h_pad = pool_size[1] - 2 if pool_size[1] % 2 == 1 else pool_size[1] - 1
padding = (w_pad, h_pad)
elif border_mode == 'valid':
padding = (0, 0)
else:
raise Exception('Invalid border mode: ' + str(border_mode))
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if dim_ordering == 'tf':
x = x.dimshuffle((0, 3, 1, 2))
if pool_mode == 'max':
pool_out = pool.pool_2d(x, ds=pool_size, st=strides,
ignore_border=True,
padding=padding,
mode='max')
elif pool_mode == 'avg':
pool_out = pool.pool_2d(x, ds=pool_size, st=strides,
ignore_border=True,
padding=padding,
mode='average_exc_pad')
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))
if border_mode == 'same':
expected_width = (x.shape[2] + strides[0] - 1) // strides[0]
expected_height = (x.shape[3] + strides[1] - 1) // strides[1]
pool_out = pool_out[:, :,
: expected_width,
: expected_height]
if dim_ordering == 'tf':
pool_out = pool_out.dimshuffle((0, 2, 3, 1))
return pool_out
def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',
dim_ordering='th', pool_mode='max'):
if border_mode == 'same':
# TODO: add implementation for border_mode="same"
raise Exception('border_mode="same" not supported with Theano.')
elif border_mode == 'valid':
ignore_border = True
padding = (0, 0)
else:
raise Exception('Invalid border mode: ' + str(border_mode))
if dim_ordering not in {'th', 'tf'}:
raise Exception('Unknown dim_ordering ' + str(dim_ordering))
if dim_ordering == 'tf':
x = x.dimshuffle((0, 4, 1, 2, 3))
if pool_mode == 'max':
# pooling over conv_dim2, conv_dim1 (last two channels)
output = pool.pool_2d(input=x.dimshuffle(0, 1, 4, 3, 2),
ds=(pool_size[1], pool_size[0]),
st=(strides[1], strides[0]),
ignore_border=ignore_border,
padding=padding,
mode='max')
# pooling over conv_dim3
pool_out = pool.pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2),
ds=(1, pool_size[2]),
st=(1, strides[2]),
ignore_border=ignore_border,
padding=padding,
mode='max')
elif pool_mode == 'avg':
# pooling over conv_dim2, conv_dim1 (last two channels)
output = pool.pool_2d(input=x.dimshuffle(0, 1, 4, 3, 2),
ds=(pool_size[1], pool_size[0]),
st=(strides[1], strides[0]),
ignore_border=ignore_border,
padding=padding,
mode='average_exc_pad')
# pooling over conv_dim3
pool_out = pool.pool_2d(input=output.dimshuffle(0, 1, 4, 3, 2),
ds=(1, pool_size[2]),
st=(1, strides[2]),
ignore_border=ignore_border,
padding=padding,
mode='average_exc_pad')
else:
raise Exception('Invalid pooling mode: ' + str(pool_mode))
if dim_ordering == 'tf':
pool_out = pool_out.dimshuffle((0, 2, 3, 4, 1))
return pool_out
# RANDOMNESS
def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
rng = RandomStreams(seed=seed)
return rng.normal(size=shape, avg=mean, std=std, dtype=dtype)
def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
rng = RandomStreams(seed=seed)
return rng.uniform(shape, low=low, high=high, dtype=dtype)
def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
rng = RandomStreams(seed=seed)
return rng.binomial(shape, p=p, dtype=dtype)
'''
more TODO:
tensordot -> soon to be introduced in TF
batched_tensordot -> reimplement
'''