Skip to content

Commit

Permalink
revise Source and Constant
Browse files Browse the repository at this point in the history
  • Loading branch information
andersbll committed Mar 21, 2016
1 parent 28adea8 commit 9fb1895
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions deeppy/expr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,28 @@ def bprop(self):
self.x.grad_array = self.grad_array


class Source(Op, NoBPropMixin, NoFPropMixin):
bpropable = False

def __init__(self, shape):
self.shape = shape
self.array = ca.zeros(shape)

@classmethod
def from_array(cls, array):
if isinstance(array, np.ndarray):
array = ca.array(array)
obj = cls(array.shape)
obj.array = array
return obj


class Constant(Op, NoBPropMixin, NoFPropMixin):
bpropable = False

def __init__(self, value):
if isinstance(value, np.ndarray):
value = ca.array(value)
self.value = value
self.array = value
if isinstance(value, (float, int)):
self.shape = (1,)
Expand Down Expand Up @@ -232,29 +247,13 @@ def setup(self):
self.inputs = [self.lhs, self.rhs]
self.rhs.setup()
except ValueError:
raise
raise ValueError('Shape mismatch: %s and %s for %s. LHS: %s RHS: '
'%s.' % (self.lhs.shape, self.rhs.shape,
self, self.lhs, self.rhs))
self.array = ca.zeros(self.shape)
self.grad_array = ca.zeros(self.shape)


class Source(Op, NoBPropMixin, NoFPropMixin):
bpropable = False

def __init__(self, shape):
self.shape = shape

def setup(self):
if not (isinstance(self.array, ca.ndarray)
and self.array.shape == self.shape):
self.array = ca.zeros(self.shape)
if not (isinstance(self.grad_array, ca.ndarray)
and self.grad_array.shape == self.shape):
self.grad_array = ca.zeros(self.shape)


class Variable(Op):
def __init__(self, parameter):
self.parameter = parameter
Expand Down

0 comments on commit 9fb1895

Please sign in to comment.