Skip to content

Commit

Permalink
Fix Value constructor to only take NDArrayView batches
Browse files Browse the repository at this point in the history
Conflicts:
	bindings/python/cntk/core.py
  • Loading branch information
wilrich-msft authored and mahilleb-msft committed Apr 3, 2017
1 parent a0d7351 commit 7661e81
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 23 deletions.
22 changes: 5 additions & 17 deletions bindings/python/cntk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,6 @@ class Value(cntk_py.Value):
Internal representation of minibatch data.
Args:
shape (tuple): shape of the value
value (None or value that can be cast to NumPy array): the value to
be converted
dtype: data type (np.float32 or np.float64)
batch: batch input for `var`.
It can be:
Expand All @@ -235,23 +231,15 @@ class Value(cntk_py.Value):
should be put on
'''

def __init__(self, shape=None, dtype=None, batch=None, seq_starts=None, device=None):
def __init__(self, batch, seq_starts=None, device=None):
if device is None:
device = use_default_device()

if shape and dtype:
# FIXME is this needed?
ndav = NDArrayView(shape, dtype, device)

elif batch:
if isinstance(batch, np.ndarray):
ndav = NDArrayView.from_dense(batch, device)
else:
ndav = batch

if isinstance(batch, np.ndarray):
ndav = NDArrayView.from_dense(batch, device)
else:
raise ValueError('either shape and dtype or batch must be'
'provided')
ndav = batch


if seq_starts:
super(Value, self).__init__(ndav, seq_starts)
Expand Down
7 changes: 1 addition & 6 deletions bindings/python/cntk/tests/misc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ def test_callstack2():
cntk.io.MinibatchSource(cntk.io.CTFDeserializer("", streams={}))
assert '[CALL STACK]' in str(excinfo.value)

def test_Value_raises():
from cntk import NDArrayView, Value
with pytest.raises(ValueError):
nd = NDArrayView.from_dense(np.asarray([[[4,5]]], dtype=np.float32))
val = Value(nd)

def test_cpu_and_gpu_devices():
device = cpu()
Expand Down Expand Up @@ -120,4 +115,4 @@ def test_set_excluded_devices():
set_excluded_devices([cpu()])
assert not try_set_default_device(cpu(), False)
set_excluded_devices([])
assert try_set_default_device(cpu(), False)
assert try_set_default_device(cpu(), False)

0 comments on commit 7661e81

Please sign in to comment.