From 7661e81777360d3222a26dcb969973ce1d4c513f Mon Sep 17 00:00:00 2001 From: Willi Richert Date: Mon, 3 Apr 2017 06:52:43 +0200 Subject: [PATCH] Fix Value constructor to only take NDArrayView batches Conflicts: bindings/python/cntk/core.py --- bindings/python/cntk/core.py | 22 +++++----------------- bindings/python/cntk/tests/misc_test.py | 7 +------ 2 files changed, 6 insertions(+), 23 deletions(-) diff --git a/bindings/python/cntk/core.py b/bindings/python/cntk/core.py index fcec64f83bbc..08252392fc49 100644 --- a/bindings/python/cntk/core.py +++ b/bindings/python/cntk/core.py @@ -216,10 +216,6 @@ class Value(cntk_py.Value): Internal representation of minibatch data. Args: - shape (tuple): shape of the value - value (None or value that can be cast to NumPy array): the value to - be converted - dtype: data type (np.float32 or np.float64) batch: batch input for `var`. It can be: @@ -235,23 +231,15 @@ class Value(cntk_py.Value): should be put on ''' - def __init__(self, shape=None, dtype=None, batch=None, seq_starts=None, device=None): + def __init__(self, batch, seq_starts=None, device=None): if device is None: device = use_default_device() - if shape and dtype: - # FIXME is this needed? - ndav = NDArrayView(shape, dtype, device) - - elif batch: - if isinstance(batch, np.ndarray): - ndav = NDArrayView.from_dense(batch, device) - else: - ndav = batch - + if isinstance(batch, np.ndarray): + ndav = NDArrayView.from_dense(batch, device) else: - raise ValueError('either shape and dtype or batch must be' - 'provided') + ndav = batch + if seq_starts: super(Value, self).__init__(ndav, seq_starts) diff --git a/bindings/python/cntk/tests/misc_test.py b/bindings/python/cntk/tests/misc_test.py index 80a27d817d70..03e42728e920 100644 --- a/bindings/python/cntk/tests/misc_test.py +++ b/bindings/python/cntk/tests/misc_test.py @@ -36,11 +36,6 @@ def test_callstack2(): cntk.io.MinibatchSource(cntk.io.CTFDeserializer("", streams={})) assert '[CALL STACK]' in str(excinfo.value) -def test_Value_raises(): - from cntk import NDArrayView, Value - with pytest.raises(ValueError): - nd = NDArrayView.from_dense(np.asarray([[[4,5]]], dtype=np.float32)) - val = Value(nd) def test_cpu_and_gpu_devices(): device = cpu() @@ -120,4 +115,4 @@ def test_set_excluded_devices(): set_excluded_devices([cpu()]) assert not try_set_default_device(cpu(), False) set_excluded_devices([]) - assert try_set_default_device(cpu(), False) \ No newline at end of file + assert try_set_default_device(cpu(), False)