Skip to content

Commit

Permalink
Support for GPU/sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
wilrich-msft authored and mahilleb-msft committed Nov 25, 2016
1 parent a6bc7a8 commit bf2ce56
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Source/CNTKv2LibraryDll/CompositeFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ namespace CNTK
{
// Ensure that only a subset of this function's outputs are being asked to be evaluated
if (functionOutputs.find(outputVarValuePair.first) == functionOutputs.end())
InvalidArgument("Requested output is not an Ouptut of the Function");
InvalidArgument("Requested output is not an Output of the Function");

auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first);
requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());
Expand Down
12 changes: 10 additions & 2 deletions Source/Math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,15 @@ Matrix<ElemType>& Matrix<ElemType>::DoGatherColumnsOf(ElemType beta, const Matri
{ m_CPUMatrix->DoGatherColumnsOf(beta, *idx.m_CPUMatrix, *a.m_CPUMatrix, alpha); },
{ m_GPUMatrix->DoGatherColumnsOf(beta, *idx.m_GPUMatrix, *a.m_GPUMatrix, alpha); },
{ m_CPUSparseMatrix->DoGatherColumnsOf(beta, *idx.m_CPUMatrix, *a.m_CPUSparseMatrix, alpha); },
{ NOT_IMPLEMENTED; });
{
// TODO replace by more performant version directly on GPU that does not require the round-trip over CPU.
Matrix<ElemType> tempIdx(CPUDEVICE); tempIdx.AssignValuesOf(idx);
CPUSparseMatrix<ElemType> tempA(a.GetFormat(), a.GetNumRows(), a.GetNumCols(), a.m_GPUSparseMatrix->GetNumNZElements());

a.m_GPUSparseMatrix->CopyToCPUSparseMatrix(tempA);
tempA.DoGatherColumnsOf(beta, *tempIdx.m_CPUMatrix, tempA, alpha);
m_GPUSparseMatrix->SetValue(tempA);
});

return *this;
}
Expand Down Expand Up @@ -3621,7 +3629,7 @@ void Matrix<ElemType>::DecideAndMoveToRightDevice(const Matrix<ElemType>& a, con
template <class ElemType>
void Matrix<ElemType>::DecideAndMoveToRightDevice(const Matrix<ElemType>& a, const Matrix<ElemType>& b, const Matrix<ElemType>& c, const Matrix<ElemType>& d)
{
// this function is only called for one operator, so for now we keep it imple
// this function is only called for one operator, so for now we keep it simple
DecideAndMoveToRightDevice(a, b, c);
d._transferToDevice(a.GetDeviceId()); // BUGBUG: Is this correct in case a,b,c share the same preferredDevice?
}
Expand Down
43 changes: 24 additions & 19 deletions bindings/python/cntk/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def shape(self):
def mask(self):
'''
The mask object of the minibatch. In it, `2` marks the beginning of a
sequence, `1` marks a sequence element as valid, and `0` markse it as
sequence, `1` marks a sequence element as valid, and `0` marks it as
invalid.
'''
return self.m_data.mask().to_numpy()
Expand Down Expand Up @@ -296,6 +296,7 @@ class Deserializer(dict):
Deserializer type Description
========================== ============
:class:`ImageDeserializer` Deserializer for images that uses OpenCV
:class:`CTFDeserializer` Deserializer for text of the `CNTKTextReader format <https://github.com/microsoft/cntk/wiki/CNTKTextFormat-Reader>`_
========================== ============
Args:
Expand All @@ -312,16 +313,19 @@ def __init__(self, type):
class ImageDeserializer(Deserializer):
'''
This class configures the image reader that reads images and corresponding
labels from a file of the form
labels from a file of the form::
<full path to image><tab><numerical label (0-based class id)>
<full path to image> <tab> <numerical label (0-based class id)>
or::
sequenceId <tab> path <tab> label
Args:
filename (str): file name of the map file that associates images to
classes
See also:
https://github.com/microsoft/cntk/wiki/Image-reader
`Image reader definition <https://github.com/microsoft/cntk/wiki/Image-reader>`_
'''

def __init__(self, filename, streams=None):
Expand Down Expand Up @@ -447,24 +451,22 @@ def mean(filename):

# TODO color transpose

#
# CNTKTextFormatReader
# TODO get away from cntk_py.text_format_minibatch_source and set it up
# similarly to ImageDeserializer
#


#class TextFormatDeserializer(Deserializer): # TODO: either call it CNTKTextFormat or CTF. TextFormat is confusable with plain text
class CTFDeserializer(Deserializer):
'''
This class configures the text reader that reads text-encoded files from a file with lines of the form
[Sequence_Id](Sample)+
where
Sample=|Input_Name (Value )*
This class configures the text reader that reads text-encoded files from a
file with lines of the form::
[Sequence_Id](Sample)+
where::
Sample=|Input_Name (Value )*
Args:
filename (str): file name containing the text input
See also:
https://github.com/Microsoft/CNTK/wiki/CNTKTextFormat-Reader
`CNTKTextReader format <https://github.com/microsoft/cntk/wiki/CNTKTextFormat-Reader>`_
'''

def __init__(self, filename, streams=None):
Expand All @@ -483,8 +485,11 @@ def map_input(self, node, dim, format="dense", alias=None):
'''
Maps node (either node instance or node name) to a part of the text input,
either specified by the node name or the alias in the text file.
Example: for node name 'Apples' an input line could look like this:
|Apples 0 1 2 3 4 5 6 7 8 9
Example: for node name 'input0' an input line could look like this::
|input0 3 7 1 0 2
Args:
node (str or input node): node or its name
dim (int): specifies the dimension of the input value vector
Expand All @@ -493,7 +498,7 @@ def map_input(self, node, dim, format="dense", alias=None):
format (str, default 'dense'): 'dense' or 'sparse'. Specifies the input type.
alias (str, default None): None or alias name. Optional abbreviated name that
is used in the text file to avoid repeating long input names. For details please
see https://github.com/Microsoft/CNTK/wiki/CNTKTextFormat-Reader
see `CNTKTextReader format <https://github.com/microsoft/cntk/wiki/CNTKTextFormat-Reader>`_
'''
if not isinstance(node, str):
node = node.name()
Expand Down
3 changes: 2 additions & 1 deletion bindings/python/cntk/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def momentum_sgd(parameters, lr, momentum,
with truncation
Returns:
Instance of a :class:`cntk.learner.Learner` that can be passed to the :class:`cntk.trainer.Trainer`
Instance of a :class:`~cntk.learner.Learner` that can be passed to the
:class:`~cntk.trainer.Trainer`
'''
_verify_learning_rate_type(lr)
_verify_momentum_type(momentum)
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/cntk/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def pooling(operand, pooling_type, pooling_window_shape, strides=(1,), auto_padd
Args:
operand: pooling input
pooling_type: one of :const:`cntk.ops.MAX_POOLING` or :const:`cntk.ops.AVG_POOLING`
pooling_type: one of :const:`~cntk.ops.MAX_POOLING` or :const:`~cntk.ops.AVG_POOLING`
pooling_window_shape: dimensions of the pooling window
strides (default 1): strides.
auto_padding: automatic padding flags for each input dimension.
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def softmax(x, name=''):
'''
from cntk.cntk_py import softmax
x = sanitize_input(x)
return softmax(x)
return softmax(x, name)


@typemap
Expand All @@ -1095,7 +1095,7 @@ def hardmax(x, name=''):
'''
from cntk.cntk_py import hardmax
x = sanitize_input(x)
return hardmax(x)
return hardmax(x, name)


@typemap
Expand Down
5 changes: 4 additions & 1 deletion bindings/python/cntk/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,10 @@ def __init__(self, value=None, shape=None, dtype=None, device=None, name=''):
if np.isscalar(value):
super().__init__(utils.sanitize_shape(shape), sanitize_dtype_cntk(dtype), value)
else:
ndav = sanitize_value(shape, value, dtype, device)
if isinstance(value, cntk_py.Value):
ndav = value.data()
else:
ndav = sanitize_value(shape, value, dtype, device)
super().__init__(ndav, name)


Expand Down
165 changes: 98 additions & 67 deletions bindings/python/cntk/tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# for full license information.
# ==============================================================================

import os
import math
import numpy as np
from .. import Function
Expand Down Expand Up @@ -65,6 +66,64 @@ def test_output_to_retain():

assert np.allclose(var_map[z_output], np.asarray(in1_value)+20)

def test_eval_sparse_dense(tmpdir, device_id):
from cntk import Axis
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs
from cntk.device import cpu, gpu, set_default_device
from cntk.ops import input_variable, times
from scipy.sparse import csr_matrix

input_vocab_dim = label_vocab_dim = 69

ctf_data = '''\
0 |S0 3:1 |# <s> |S1 3:1 |# <s>
0 |S0 4:1 |# A |S1 32:1 |# ~AH
0 |S0 5:1 |# B |S1 36:1 |# ~B
0 |S0 4:1 |# A |S1 31:1 |# ~AE
0 |S0 7:1 |# D |S1 38:1 |# ~D
0 |S0 12:1 |# I |S1 47:1 |# ~IY
0 |S0 1:1 |# </s> |S1 1:1 |# </s>
2 |S0 60:1 |# <s> |S1 3:1 |# <s>
2 |S0 61:1 |# A |S1 32:1 |# ~AH
'''
ctf_file = str(tmpdir/'2seqtest.txt')
with open(ctf_file, 'w') as f:
f.write(ctf_data)

mbs = MinibatchSource(CTFDeserializer(ctf_file, StreamDefs(
features = StreamDef(field='S0', shape=input_vocab_dim, is_sparse=True),
labels = StreamDef(field='S1', shape=label_vocab_dim, is_sparse=True)
)), randomize=False, epoch_size = 2)

batch_axis = Axis.default_batch_axis()
input_seq_axis = Axis('inputAxis')
label_seq_axis = Axis('labelAxis')

input_dynamic_axes = [batch_axis, input_seq_axis]
raw_input = input_variable(
shape=input_vocab_dim, dynamic_axes=input_dynamic_axes,
name='raw_input', is_sparse=True)

mb_valid = mbs.next_minibatch(minibatch_size_in_samples=100,
input_map={raw_input : mbs.streams.features})

z = times(raw_input, np.eye(input_vocab_dim))
e_reader = z.eval(mb_valid)

# CSR with the raw_input encoding in ctf_data
one_hot_data = [
[3, 4, 5, 4, 7, 12, 1],
[60, 61]
]
data = [csr_matrix(np.eye(input_vocab_dim, dtype=np.float32)[d]) for d in
one_hot_data]
e_csr = z.eval({raw_input: data}, device=cntk_device(device_id))
assert np.all(np.allclose(a, b) for a,b in zip(e_reader, e_csr))

# One-hot with the raw_input encoding in ctf_data
data = one_hot(one_hot_data, num_classes=input_vocab_dim)
e_hot = z.eval({raw_input: data}, device=cntk_device(device_id))
assert np.all(np.allclose(a, b) for a,b in zip(e_reader, e_hot))

@pytest.mark.parametrize("batch_index_data", [
[2,3],
Expand All @@ -73,93 +132,65 @@ def test_output_to_retain():
def test_eval_sparse_no_seq(batch_index_data, device_id):
dim = 10
multiplier = 2
in1 = input_variable(shape=(dim,), is_sparse=True)
z = times(in1, np.eye(dim).astype(np.float32))
z *= multiplier
batch = (np.eye(dim)[batch_index_data]).astype(np.float32)
expected = batch * multiplier
sparse_val = csr(batch)
result = z.eval({in1: sparse_val}, device=cntk_device(device_id))
assert np.allclose(result, [expected])

@pytest.mark.parametrize("batch_index_data", [
[[2,3], [0,1,6]],
])
def test_eval_sparse_seq_0(batch_index_data, device_id):
if cntk_device(device_id)!=cpu(): # FIXME
pytest.skip("sparse is not yet supported on GPU")
dim = 10
multiplier = 2
in1 = input_variable(shape=(dim,), is_sparse=True)
z = times(in1, np.eye(dim).astype(np.float32))
z *= multiplier
batch = [(np.eye(dim)[seq_index_data]).astype(np.float32) for
seq_index_data in batch_index_data]
expected = batch * multiplier
sparse_val = [csr(seq) for seq in batch]
result = z.eval({in1: sparse_val}, device=cntk_device(device_id))
assert np.all(np.allclose(a,b) \
for a,b in zip(result, expected))
for var_is_sparse in [True, False]:
in1 = input_variable(shape=(dim,), is_sparse=var_is_sparse)
z = times(in1, multiplier*np.eye(dim))
batch = np.eye(dim)[batch_index_data]
expected = batch * multiplier
sparse_val = csr(batch)
result = z.eval({in1: sparse_val}, device=cntk_device(device_id))
assert np.allclose(result, [expected])

@pytest.mark.parametrize("batch", [
#[[csr([0,1,2,0])]],
[
[csr([0, 2, 0, 7]), csr([10, 20, 0, 0])],
[csr([0, 0, 0, 3])]
[[csr([0,1,2,0])]],
[
[csr([0, 2, 0, 7]), csr([10, 20, 0, 0])],
[csr([0, 0, 0, 3])]
],
# same as before, but sequence being encoded as one matrix
[
csr([[0, 2, 0, 7], [10, 20, 0, 0]]),
csr([0, 0, 0, 3])
]
])
])
def test_eval_sparse_seq_1(batch, device_id):
if cntk_device(device_id)!=cpu(): # FIXME
pytest.skip("sparse is not yet supported on GPU")
dim = 4
multiplier = 2
# FIXME
in1 = input_variable(shape=(dim,), is_sparse=True)
# in1 = input_variable(shape=(dim,))
z = times(in1, multiplier*np.eye(dim))#np.eye(dim).astype(np.float32))

expected = [[m.todense() * multiplier for m in seq] for seq in batch]

result = z.eval({in1: batch}, device=cntk_device(device_id))
for var_is_sparse in [True, False]:
in1 = input_variable(shape=(dim,), is_sparse=var_is_sparse)
z = times(in1, multiplier*np.eye(dim))
expected = [[m.todense() * multiplier for m in seq] for seq in batch]
result = z.eval({in1: batch}, device=cntk_device(device_id))

assert np.all(np.allclose(a,b) \
for a,b in zip(result, expected))
assert np.all(np.allclose(a,b) \
for a,b in zip(result, expected))


@pytest.mark.parametrize("one_hot_batch", [
([[2,5],
[0,1,6]]),
([[1],
[1],[2],[3]]),
([[1],[1],[2],[3]]),
([[1,5],
[4]]),
])
def test_eval_one_hot_seq(one_hot_batch, device_id):
if cntk_device(device_id)!=cpu(): # FIXME
pytest.skip("sparse is not yet supported on GPU")
dim = 10
multiplier = 2
# FIXME
# in1 = input_variable(shape=(dim,), is_sparse=True)
in1 = input_variable(shape=(dim,))
# Convert CNTK node value to dense so that we can compare it later
z = times(in1, np.eye(dim).astype(np.float32))
z *= multiplier
# Convert expectation to dense
expected = [np.eye(dim)[seq]*multiplier for seq in one_hot_batch]
batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id))
assert np.all(np.allclose(a,b) \
for a,b in zip(z.eval({in1: batch}, device=cntk_device(device_id)), expected))
for var_is_sparse in [True, False]:
in1 = input_variable(shape=(dim,), is_sparse=var_is_sparse)
# Convert CNTK node value to dense so that we can compare it later
z = times(in1, np.eye(dim)*multiplier)
# Convert expectation to dense
expected = [np.eye(dim)[seq]*multiplier for seq in one_hot_batch]
batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id))
assert np.all(np.allclose(a,b) \
for a,b in zip(z.eval({in1: batch}, device=cntk_device(device_id)), expected))

@pytest.mark.parametrize("one_hot_batch, dim", [
([[11]], 10),
([[0, 1]], 1),
])
# FIXME
def _test_eval_one_hot_bad(one_hot_batch, dim, device_id):
in1 = input_variable(shape=dim)
# Convert CNTK node value to dense so that we can compare it later
z = times(in1, np.eye(dim).astype(np.float32))
# Convert expectation to dense
batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id))
def test_eval_one_hot_bad(one_hot_batch, dim, device_id):
with pytest.raises(ValueError):
z.eval({in1: batch})
batch = one_hot(one_hot_batch, num_classes=dim, device=cntk_device(device_id))

Loading

0 comments on commit bf2ce56

Please sign in to comment.