diff --git a/Source/CNTKv2LibraryDll/CompositeFunction.cpp b/Source/CNTKv2LibraryDll/CompositeFunction.cpp index 5f365e569775..64cece03526d 100644 --- a/Source/CNTKv2LibraryDll/CompositeFunction.cpp +++ b/Source/CNTKv2LibraryDll/CompositeFunction.cpp @@ -1408,7 +1408,7 @@ namespace CNTK { // Ensure that only a subset of this function's outputs are being asked to be evaluated if (functionOutputs.find(outputVarValuePair.first) == functionOutputs.end()) - InvalidArgument("Requested output is not an Ouptut of the Function"); + InvalidArgument("Requested output is not an Output of the Function"); auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first); requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); diff --git a/Source/Math/Matrix.cpp b/Source/Math/Matrix.cpp index 4b10a05d4785..a41a54ab6106 100644 --- a/Source/Math/Matrix.cpp +++ b/Source/Math/Matrix.cpp @@ -1091,7 +1091,15 @@ Matrix& Matrix::DoGatherColumnsOf(ElemType beta, const Matri { m_CPUMatrix->DoGatherColumnsOf(beta, *idx.m_CPUMatrix, *a.m_CPUMatrix, alpha); }, { m_GPUMatrix->DoGatherColumnsOf(beta, *idx.m_GPUMatrix, *a.m_GPUMatrix, alpha); }, { m_CPUSparseMatrix->DoGatherColumnsOf(beta, *idx.m_CPUMatrix, *a.m_CPUSparseMatrix, alpha); }, - { NOT_IMPLEMENTED; }); + { + // TODO replace by more performant version directly on GPU that does not require the round-trip over CPU. + Matrix tempIdx(CPUDEVICE); tempIdx.AssignValuesOf(idx); + CPUSparseMatrix tempA(a.GetFormat(), a.GetNumRows(), a.GetNumCols(), a.m_GPUSparseMatrix->GetNumNZElements()); + + a.m_GPUSparseMatrix->CopyToCPUSparseMatrix(tempA); + tempA.DoGatherColumnsOf(beta, *tempIdx.m_CPUMatrix, tempA, alpha); + m_GPUSparseMatrix->SetValue(tempA); + }); return *this; } @@ -3621,7 +3629,7 @@ void Matrix::DecideAndMoveToRightDevice(const Matrix& a, con template void Matrix::DecideAndMoveToRightDevice(const Matrix& a, const Matrix& b, const Matrix& c, const Matrix& d) { - // this function is only called for one operator, so for now we keep it imple + // this function is only called for one operator, so for now we keep it simple DecideAndMoveToRightDevice(a, b, c); d._transferToDevice(a.GetDeviceId()); // BUGBUG: Is this correct in case a,b,c share the same preferredDevice? } diff --git a/bindings/python/cntk/io/__init__.py b/bindings/python/cntk/io/__init__.py index 3a147caa45d9..ac4b365672f4 100644 --- a/bindings/python/cntk/io/__init__.py +++ b/bindings/python/cntk/io/__init__.py @@ -51,7 +51,7 @@ def shape(self): def mask(self): ''' The mask object of the minibatch. In it, `2` marks the beginning of a - sequence, `1` marks a sequence element as valid, and `0` markse it as + sequence, `1` marks a sequence element as valid, and `0` marks it as invalid. ''' return self.m_data.mask().to_numpy() @@ -296,6 +296,7 @@ class Deserializer(dict): Deserializer type Description ========================== ============ :class:`ImageDeserializer` Deserializer for images that uses OpenCV + :class:`CTFDeserializer` Deserializer for text of the `CNTKTextReader format `_ ========================== ============ Args: @@ -312,16 +313,19 @@ def __init__(self, type): class ImageDeserializer(Deserializer): ''' This class configures the image reader that reads images and corresponding - labels from a file of the form + labels from a file of the form:: - + + or:: + + sequenceId path label Args: filename (str): file name of the map file that associates images to classes See also: - https://github.com/microsoft/cntk/wiki/Image-reader + `Image reader definition `_ ''' def __init__(self, filename, streams=None): @@ -447,24 +451,22 @@ def mean(filename): # TODO color transpose -# -# CNTKTextFormatReader -# TODO get away from cntk_py.text_format_minibatch_source and set it up -# similarly to ImageDeserializer -# - -#class TextFormatDeserializer(Deserializer): # TODO: either call it CNTKTextFormat or CTF. TextFormat is confusable with plain text class CTFDeserializer(Deserializer): ''' - This class configures the text reader that reads text-encoded files from a file with lines of the form - [Sequence_Id](Sample)+ - where - Sample=|Input_Name (Value )* + This class configures the text reader that reads text-encoded files from a + file with lines of the form:: + + [Sequence_Id](Sample)+ + + where:: + + Sample=|Input_Name (Value )* + Args: filename (str): file name containing the text input See also: - https://github.com/Microsoft/CNTK/wiki/CNTKTextFormat-Reader + `CNTKTextReader format `_ ''' def __init__(self, filename, streams=None): @@ -483,8 +485,11 @@ def map_input(self, node, dim, format="dense", alias=None): ''' Maps node (either node instance or node name) to a part of the text input, either specified by the node name or the alias in the text file. - Example: for node name 'Apples' an input line could look like this: - |Apples 0 1 2 3 4 5 6 7 8 9 + + Example: for node name 'input0' an input line could look like this:: + + |input0 3 7 1 0 2 + Args: node (str or input node): node or its name dim (int): specifies the dimension of the input value vector @@ -493,7 +498,7 @@ def map_input(self, node, dim, format="dense", alias=None): format (str, default 'dense'): 'dense' or 'sparse'. Specifies the input type. alias (str, default None): None or alias name. Optional abbreviated name that is used in the text file to avoid repeating long input names. For details please - see https://github.com/Microsoft/CNTK/wiki/CNTKTextFormat-Reader + see `CNTKTextReader format `_ ''' if not isinstance(node, str): node = node.name() diff --git a/bindings/python/cntk/learner.py b/bindings/python/cntk/learner.py index c3e11e2ad4c7..62e72211ecbf 100644 --- a/bindings/python/cntk/learner.py +++ b/bindings/python/cntk/learner.py @@ -369,7 +369,8 @@ def momentum_sgd(parameters, lr, momentum, with truncation Returns: - Instance of a :class:`cntk.learner.Learner` that can be passed to the :class:`cntk.trainer.Trainer` + Instance of a :class:`~cntk.learner.Learner` that can be passed to the + :class:`~cntk.trainer.Trainer` ''' _verify_learning_rate_type(lr) _verify_momentum_type(momentum) diff --git a/bindings/python/cntk/ops/__init__.py b/bindings/python/cntk/ops/__init__.py index 6d97e4c3a1f7..204b0c9750b3 100644 --- a/bindings/python/cntk/ops/__init__.py +++ b/bindings/python/cntk/ops/__init__.py @@ -338,7 +338,7 @@ def pooling(operand, pooling_type, pooling_window_shape, strides=(1,), auto_padd Args: operand: pooling input - pooling_type: one of :const:`cntk.ops.MAX_POOLING` or :const:`cntk.ops.AVG_POOLING` + pooling_type: one of :const:`~cntk.ops.MAX_POOLING` or :const:`~cntk.ops.AVG_POOLING` pooling_window_shape: dimensions of the pooling window strides (default 1): strides. auto_padding: automatic padding flags for each input dimension. @@ -1071,7 +1071,7 @@ def softmax(x, name=''): ''' from cntk.cntk_py import softmax x = sanitize_input(x) - return softmax(x) + return softmax(x, name) @typemap @@ -1095,7 +1095,7 @@ def hardmax(x, name=''): ''' from cntk.cntk_py import hardmax x = sanitize_input(x) - return hardmax(x) + return hardmax(x, name) @typemap diff --git a/bindings/python/cntk/ops/variables.py b/bindings/python/cntk/ops/variables.py index 7a411470e0b2..c13bed3eeb89 100644 --- a/bindings/python/cntk/ops/variables.py +++ b/bindings/python/cntk/ops/variables.py @@ -203,7 +203,10 @@ def __init__(self, value=None, shape=None, dtype=None, device=None, name=''): if np.isscalar(value): super().__init__(utils.sanitize_shape(shape), sanitize_dtype_cntk(dtype), value) else: - ndav = sanitize_value(shape, value, dtype, device) + if isinstance(value, cntk_py.Value): + ndav = value.data() + else: + ndav = sanitize_value(shape, value, dtype, device) super().__init__(ndav, name) diff --git a/bindings/python/cntk/tests/trainer_test.py b/bindings/python/cntk/tests/trainer_test.py index 627a8245193d..184c1234956c 100644 --- a/bindings/python/cntk/tests/trainer_test.py +++ b/bindings/python/cntk/tests/trainer_test.py @@ -4,6 +4,7 @@ # for full license information. # ============================================================================== +import os import math import numpy as np from .. import Function @@ -65,6 +66,64 @@ def test_output_to_retain(): assert np.allclose(var_map[z_output], np.asarray(in1_value)+20) +def test_eval_sparse_dense(tmpdir, device_id): + from cntk import Axis + from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs + from cntk.device import cpu, gpu, set_default_device + from cntk.ops import input_variable, times + from scipy.sparse import csr_matrix + + input_vocab_dim = label_vocab_dim = 69 + + ctf_data = '''\ +0 |S0 3:1 |# |S1 3:1 |# +0 |S0 4:1 |# A |S1 32:1 |# ~AH +0 |S0 5:1 |# B |S1 36:1 |# ~B +0 |S0 4:1 |# A |S1 31:1 |# ~AE +0 |S0 7:1 |# D |S1 38:1 |# ~D +0 |S0 12:1 |# I |S1 47:1 |# ~IY +0 |S0 1:1 |# |S1 1:1 |# +2 |S0 60:1 |# |S1 3:1 |# +2 |S0 61:1 |# A |S1 32:1 |# ~AH +''' + ctf_file = str(tmpdir/'2seqtest.txt') + with open(ctf_file, 'w') as f: + f.write(ctf_data) + + mbs = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs( + features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True), + labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True) + )), randomize=False, epoch_size = 2) + + batch_axis = Axis.default_batch_axis() + input_seq_axis = Axis('inputAxis') + label_seq_axis = Axis('labelAxis') + + input_dynamic_axes = [batch_axis, input_seq_axis] + raw_input = input_variable( + shape=input_vocab_dim, dynamic_axes=input_dynamic_axes, + name='raw_input', is_sparse=True) + + mb_valid = mbs.next_minibatch(minibatch_size_in_samples=100, + input_map={raw_input : mbs.streams.features}) + + z = times(raw_input, np.eye(input_vocab_dim)) + e_reader = z.eval(mb_valid) + + # CSR with the raw_input encoding in ctf_data + one_hot_data = [ + [3, 4, 5, 4, 7, 12, 1], + [60, 61] + ] + data = [csr_matrix(np.eye(input_vocab_dim, dtype=np.float32)[d]) for d in + one_hot_data] + e_csr = z.eval({raw_input: data}, device=cntk_device(device_id)) + assert np.all(np.allclose(a, b) for a,b in zip(e_reader, e_csr)) + + # One-hot with the raw_input encoding in ctf_data + data = one_hot(one_hot_data, num_classes=input_vocab_dim) + e_hot = z.eval({raw_input: data}, device=cntk_device(device_id)) + assert np.all(np.allclose(a, b) for a,b in zip(e_reader, e_hot)) @pytest.mark.parametrize("batch_index_data", [ [2,3], @@ -73,93 +132,65 @@ def test_output_to_retain(): def test_eval_sparse_no_seq(batch_index_data, device_id): dim = 10 multiplier = 2 - in1 = input_variable(shape=(dim,), is_sparse=True) - z = times(in1, np.eye(dim).astype(np.float32)) - z *= multiplier - batch = (np.eye(dim)[batch_index_data]).astype(np.float32) - expected = batch * multiplier - sparse_val = csr(batch) - result = z.eval({in1: sparse_val}, device=cntk_device(device_id)) - assert np.allclose(result, [expected]) - -@pytest.mark.parametrize("batch_index_data", [ - [[2,3], [0,1,6]], - ]) -def test_eval_sparse_seq_0(batch_index_data, device_id): - if cntk_device(device_id)!=cpu(): # FIXME - pytest.skip("sparse is not yet supported on GPU") - dim = 10 - multiplier = 2 - in1 = input_variable(shape=(dim,), is_sparse=True) - z = times(in1, np.eye(dim).astype(np.float32)) - z *= multiplier - batch = [(np.eye(dim)[seq_index_data]).astype(np.float32) for - seq_index_data in batch_index_data] - expected = batch * multiplier - sparse_val = [csr(seq) for seq in batch] - result = z.eval({in1: sparse_val}, device=cntk_device(device_id)) - assert np.all(np.allclose(a,b) \ - for a,b in zip(result, expected)) + for var_is_sparse in [True, False]: + in1 = input_variable(shape=(dim,), is_sparse=var_is_sparse) + z = times(in1, multiplier*np.eye(dim)) + batch = np.eye(dim)[batch_index_data] + expected = batch * multiplier + sparse_val = csr(batch) + result = z.eval({in1: sparse_val}, device=cntk_device(device_id)) + assert np.allclose(result, [expected]) @pytest.mark.parametrize("batch", [ - #[[csr([0,1,2,0])]], - [ - [csr([0, 2, 0, 7]), csr([10, 20, 0, 0])], - [csr([0, 0, 0, 3])] + [[csr([0,1,2,0])]], + [ + [csr([0, 2, 0, 7]), csr([10, 20, 0, 0])], + [csr([0, 0, 0, 3])] + ], + # same as before, but sequence being encoded as one matrix + [ + csr([[0, 2, 0, 7], [10, 20, 0, 0]]), + csr([0, 0, 0, 3]) ] - ]) + ]) def test_eval_sparse_seq_1(batch, device_id): - if cntk_device(device_id)!=cpu(): # FIXME - pytest.skip("sparse is not yet supported on GPU") dim = 4 multiplier = 2 - # FIXME - in1 = input_variable(shape=(dim,), is_sparse=True) - # in1 = input_variable(shape=(dim,)) - z = times(in1, multiplier*np.eye(dim))#np.eye(dim).astype(np.float32)) - - expected = [[m.todense() * multiplier for m in seq] for seq in batch] - - result = z.eval({in1: batch}, device=cntk_device(device_id)) + for var_is_sparse in [True, False]: + in1 = input_variable(shape=(dim,), is_sparse=var_is_sparse) + z = times(in1, multiplier*np.eye(dim)) + expected = [[m.todense() * multiplier for m in seq] for seq in batch] + result = z.eval({in1: batch}, device=cntk_device(device_id)) - assert np.all(np.allclose(a,b) \ - for a,b in zip(result, expected)) + assert np.all(np.allclose(a,b) \ + for a,b in zip(result, expected)) @pytest.mark.parametrize("one_hot_batch", [ ([[2,5], [0,1,6]]), - ([[1], - [1],[2],[3]]), + ([[1],[1],[2],[3]]), + ([[1,5], + [4]]), ]) def test_eval_one_hot_seq(one_hot_batch, device_id): - if cntk_device(device_id)!=cpu(): # FIXME - pytest.skip("sparse is not yet supported on GPU") dim = 10 multiplier = 2 - # FIXME - # in1 = input_variable(shape=(dim,), is_sparse=True) - in1 = input_variable(shape=(dim,)) - # Convert CNTK node value to dense so that we can compare it later - z = times(in1, np.eye(dim).astype(np.float32)) - z *= multiplier - # Convert expectation to dense - expected = [np.eye(dim)[seq]*multiplier for seq in one_hot_batch] - batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id)) - assert np.all(np.allclose(a,b) \ - for a,b in zip(z.eval({in1: batch}, device=cntk_device(device_id)), expected)) + for var_is_sparse in [True, False]: + in1 = input_variable(shape=(dim,), is_sparse=var_is_sparse) + # Convert CNTK node value to dense so that we can compare it later + z = times(in1, np.eye(dim)*multiplier) + # Convert expectation to dense + expected = [np.eye(dim)[seq]*multiplier for seq in one_hot_batch] + batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id)) + assert np.all(np.allclose(a,b) \ + for a,b in zip(z.eval({in1: batch}, device=cntk_device(device_id)), expected)) @pytest.mark.parametrize("one_hot_batch, dim", [ ([[11]], 10), ([[0, 1]], 1), ]) -# FIXME -def _test_eval_one_hot_bad(one_hot_batch, dim, device_id): - in1 = input_variable(shape=dim) - # Convert CNTK node value to dense so that we can compare it later - z = times(in1, np.eye(dim).astype(np.float32)) - # Convert expectation to dense - batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id)) +def test_eval_one_hot_bad(one_hot_batch, dim, device_id): with pytest.raises(ValueError): - z.eval({in1: batch}) + batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id)) diff --git a/bindings/python/cntk/utils/__init__.py b/bindings/python/cntk/utils/__init__.py index ce0681a5c8f6..89ca8cbfa7d2 100644 --- a/bindings/python/cntk/utils/__init__.py +++ b/bindings/python/cntk/utils/__init__.py @@ -39,13 +39,13 @@ def sanitize_precision(precision): def cntk_device(device_id): ''' - Converts the legacy device ID as it was used in CNTK 1 to a :class:`cntk.device.DeviceDescriptor` instance. + Converts the legacy device ID as it was used in CNTK 1 to a :class:`~cntk.device.DeviceDescriptor` instance. Args: device_id (int): device id, -1 for CPU, 0 or higher for GPU Returns: - :class:`cntk.device.DeviceDescriptor` + :class:`~cntk.device.DeviceDescriptor` ''' if device_id == -1: return cpu() @@ -72,7 +72,7 @@ def tensors_to_text_format(sample_idx, alias_tensor_map): are assumed to have dynamic axis. Returns: - String representation in CNTKTextReader format + String representation in `CNTKTextReader format `_ ''' max_seq_length = max(len(t) for t in alias_tensor_map.values()) @@ -152,6 +152,17 @@ def one_hot(batch, num_classes, dtype=None, device=None): such that the integer data in ``batch`` is interpreted as the indices representing one-hot vectors. + Example: + >>> num_classes = 6 + >>> sparse_indices = [[1,5],[4]] + >>> C.set_default_device(C.cpu()) + >>> i0 = C.input_variable(shape=num_classes, is_sparse=True) + >>> z = C.times(i0, np.eye(num_classes)) + >>> value = C.one_hot(sparse_indices, num_classes) + >>> z.eval({i0: value}) + [array([[ 0., 1., 0., 0., 0., 0.], + [ 0., 0., 0., 0., 0., 1.]], dtype=float32), array([[ 0., 0., 0., 0., 1., 0.]], dtype=float32)] + Args: batch (list (of lists, if sequence) of index data): batch input data num_classes (int): number of classes @@ -304,6 +315,9 @@ def sanitize_input(arg, fallback_dtype=np.float32, reshape=None): if isinstance(arg, list) and not arg: raise ValueError('input is empty') + if isinstance(arg, (cntk_py.Value, cntk_py.NDArrayView)): + return constant(value=arg) + if not isinstance(arg, np.ndarray) or arg.dtype!=fallback_dtype: arg = np.asarray(arg, dtype=fallback_dtype) if arg.shape == (): @@ -322,7 +336,7 @@ def get_data_type(*args): inputs. Placeholders are ignored in the type determination. Args: - args (number, list, NumPy array, :class:`cntk.ops.variables.Variable`, or :class:`cntk.ops.functions.Function`): input + args (number, list, NumPy array, :class:`~cntk.ops.variables.Variable`, or :class:`~cntk.ops.functions.Function`): input Returns: np.float32, np.float64, or None """ @@ -460,7 +474,10 @@ def sanitize_batch(var, batch, seq_starts=None, dtype=None, device=None): Args: var (:class:`~cntk.ops.variables.Variable`): variable node for which the ``batch`` is meant - batch (list of NumPy arrays): input + batch: batch input for `var`. It can be a pure Python structure (list + of lists, ...), a combination of lists of NumPy arrays or SciPy + sparse CSR matrices. Alternatively, it can also be the output of + :func:`one_hot`. seq_starts (list of bool or None): if None, every sequence is treated as a new sequence. Otherwise, it is interpreted as a list of Booleans that tell whether a sequence is a new sequence (`True`) or a @@ -469,8 +486,7 @@ def sanitize_batch(var, batch, seq_starts=None, dtype=None, device=None): this value should be put on Returns: - :class:`Value`: converted batch that can be passed to the - core API + :class:`Value`: converted batch that can be passed to the core API ''' if isinstance(batch, cntk_py.Value): return batch @@ -659,12 +675,15 @@ def sanitize_var_map(op_arguments, arguments, precision=None, device=None): ''' Sanitizes a dictionary of `Variable` s to input data such that it can be - handed off to the evaluation methods (:meth:`cntk.ops.functions.Function.forward`, :meth:`cntk.ops.functions.Function.backward`, :meth:`cntk.Trainer.train_minibatch` and - :meth:`cntk.Trainer.test_minibatch`). + handed off to the evaluation methods + (:meth:`~cntk.ops.functions.Function.forward`, + :meth:`~cntk.ops.functions.Function.backward`, :meth:`~cntk.Trainer.train_minibatch` and + :meth:`~cntk.Trainer.test_minibatch`). Args: - op_arguments (:class:`cntk.ops.functions.Function`): arguments of the root function. In - :meth:`cntk.ops.functions.Function.forward` pass it is typically `op.arguments`, in :meth:`cntk.ops.functions.Function.backward` pass it is + op_arguments (:class:`~cntk.ops.functions.Function`): arguments of the root function. In + :meth:`~cntk.ops.functions.Function.forward` pass it is typically + `op.arguments`, in :meth:`~cntk.ops.functions.Function.backward` pass it is `op.outputs` arguments: maps variables to their input data. The interpretation depends on the input type: @@ -678,7 +697,7 @@ def sanitize_var_map(op_arguments, arguments, precision=None, be used as a list of bools, denoting whether a sequence is a new one (`True`) or a continuation of the previous one (`False`). Data should be either NumPy arrays or a - :class:`cntk.io.MinibatchData` instance. + :class:`~cntk.io.MinibatchData` instance. precision (str or `np.float32` or `np.float64`): if string it can be one of 'float' 'float32, 'double', 'float64', or None device (:class:`~cntk.device.DeviceDescriptor`, default None): device @@ -836,16 +855,14 @@ def mask(self): elements describing the mask of the element: * 2: beginning of sequence (e.g. an LSTM would be reset) * 1: valid element - # 0: invalid element + * 0: invalid element Example: - A mask of - ```[[2, 1, 1], [1, 1, 0]] - ``` - describes a batch of two sequences. The first has three elements, of - which the first element signals the beginning of a sequence. The second - sequence has two elements, which are both continuations of the first - sequence. + A mask of ``[[2, 1, 1], [1, 1, 0]]`` describes a batch of two + sequences. The first has three elements, of which the first element + (2) signals the beginning of a sequence. The second sequence has two + elements (last element marked 'invalid' by '0'), which are both + continuations of the first sequence. ''' return np.asarray(super(Value, self).mask()) @@ -891,10 +908,10 @@ def sanitize_axis(axis): Sanitizes the axis. Args: - axis (:class:`cntk.axis.Axis` or int or None): the axis to be used. + axis (:class:`~cntk.axis.Axis` or int or None): the axis to be used. - * :class:`cntk.axis.Axis`: use axis instance directly (will convert row- to - col-major in case of static axis. + * :class:`~cntk.axis.Axis`: use axis instance directly (will convert + row- to col-major in case of static axis). * int: if positive, use it as static axis. If negative, count from last to first axis * None: denote all available axes @@ -923,7 +940,7 @@ def get_train_loss(trainer): Fetch the train loss from the last minibatch and copy it to the CPU in case it is on the GPU. Args: - trainer (:class:`Trainer`): the trainer used. + trainer (:class:`~cntk.trainer.Trainer`): the trainer used. Returns: the loss value ''' @@ -997,7 +1014,7 @@ def eval(op, arguments=None, precision=None, device=None, backward_pass=False, e be used as a list of bools, denoting whether a sequence is a new one (`True`) or a continuation of the previous one (`False`). Data should be either NumPy arrays or a - :class:`cntk.io.MinibatchData` instance. + :class:`~cntk.io.MinibatchData` instance. seq_starts (list of `bool`s or None): if None, every sequence is treated as a new sequence. Otherwise, it is interpreted as a list of Booleans that tell whether a sequence is a new sequence (`True`) or a