From ce521ae5e6999b6f920165b2fe7cb0b69fec1ae0 Mon Sep 17 00:00:00 2001 From: Keqiu Hu Date: Sat, 1 May 2021 09:12:49 -0700 Subject: [PATCH] Implement ORC dataset reader (#1383) * Implement ORC dataset reader * support double, float and string types * add sample keras unit tests * reset unintended changes * add more datatypes * fix type in macOS, add test * address comments * fix a typo in float conversion --- tensorflow_io/core/BUILD | 30 ++- tensorflow_io/core/kernels/orc/orc_kernels.cc | 246 ++++++++++++++++++ tensorflow_io/core/ops/orc_ops.cc | 60 +++++ tensorflow_io/core/python/ops/io_dataset.py | 16 ++ .../core/python/ops/orc_dataset_ops.py | 102 ++++++++ tests/test_orc.py | 97 +++++++ tests/test_orc/iris.orc | Bin 0 -> 3328 bytes 7 files changed, 538 insertions(+), 13 deletions(-) create mode 100644 tensorflow_io/core/kernels/orc/orc_kernels.cc create mode 100644 tensorflow_io/core/ops/orc_ops.cc create mode 100644 tensorflow_io/core/python/ops/orc_dataset_ops.py create mode 100644 tests/test_orc.py create mode 100644 tests/test_orc/iris.orc diff --git a/tensorflow_io/core/BUILD b/tensorflow_io/core/BUILD index f2902c8df..763840091 100644 --- a/tensorflow_io/core/BUILD +++ b/tensorflow_io/core/BUILD @@ -368,6 +368,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "orc_ops", + srcs = [ + "kernels/orc/orc_kernels.cc", + "ops/orc_ops.cc", + ], + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "//tensorflow_io/core:dataset_ops", + "@liborc", + "@local_config_tf//:libtensorflow_framework", + "@local_config_tf//:tf_header_lib", + ], + alwayslink = 1, +) + cc_library( name = "text_ops", srcs = [ @@ -531,19 +548,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "orc_ops", - srcs = [ - ], - copts = tf_io_copts(), - linkstatic = True, - deps = [ - "//tensorflow_io/core:dataset_ops", - "@liborc", - ], - alwayslink = 1, -) - cc_library( name = "numpy_ops", srcs = [ diff --git a/tensorflow_io/core/kernels/orc/orc_kernels.cc b/tensorflow_io/core/kernels/orc/orc_kernels.cc new file mode 100644 index 000000000..c3b86ce9c --- /dev/null +++ b/tensorflow_io/core/kernels/orc/orc_kernels.cc @@ -0,0 +1,246 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "orc/orc-config.hh" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow_io/core/kernels/io_interface.h" +#include "tensorflow_io/core/kernels/io_stream.h" + +namespace tensorflow { +namespace data { + +class ORCReadable : public IOReadableInterface { + public: + ORCReadable(Env* env) : env_(env) {} + ~ORCReadable() {} + Status Init(const std::vector& input, + const std::vector& metadata, const void* memory_data, + const int64 memory_size) override { + if (input.size() > 1) { + return errors::InvalidArgument("more than 1 filename is not supported"); + } + const string& filename = input[0]; + // read packet data + orc::RowReaderOptions row_reader_opts; + orc::ReaderOptions reader_opts; + std::unique_ptr reader = + orc::createReader(orc::readFile(filename), reader_opts); + + row_reader_ = reader->createRowReader(row_reader_opts); + LOG(INFO) << "ORC file schema:" << reader->getType().toString(); + + // Parse columns. We assume the orc record file is a flat array + auto row_count = reader->getNumberOfRows(); + for (uint64_t i = 0; i < reader->getType().getSubtypeCount(); ++i) { + auto field_name = reader->getType().getFieldName(i); + auto subtype = reader->getType().getSubtype(i); + DataType dtype; + switch (static_cast(subtype->getKind())) { + case orc::SHORT: + dtype = DT_INT16; + break; + case orc::INT: + dtype = DT_INT32; + break; + case orc::LONG: + dtype = DT_INT64; + break; + case orc::STRING: + dtype = DT_STRING; + break; + case orc::DOUBLE: + dtype = DT_DOUBLE; + break; + case orc::FLOAT: + dtype = DT_FLOAT; + break; + default: + return errors::InvalidArgument("data type is not supported: ", + subtype->toString()); + } + columns_.push_back(field_name); + shapes_.push_back(TensorShape({static_cast(row_count)})); + dtypes_.push_back(dtype); + columns_index_[field_name] = i; + tensors_.emplace_back( + Tensor(dtype, TensorShape({static_cast(row_count)}))); + } + // Fill in the values + std::unique_ptr batch = + row_reader_->createRowBatch(10); + auto* fields = dynamic_cast(batch.get()); + int64_t record_index = 0; +// Template type conversions between ORC and TensorFlow DT +#define PROCESS_TYPE(VTYPE, VDTYPE, TDTYPE) \ + { \ + auto* col = dynamic_cast(fields->fields[column_index]); \ + VDTYPE* buffer1 = col->data.data(); \ + tensors_[column_index].flat()(record_index) = (TDTYPE)buffer1[r]; \ + } + while (row_reader_->next(*batch)) { + for (uint32_t r = 0; r < batch->numElements; ++r) { + for (size_t column_index = 0; column_index < columns_.size(); + column_index++) { + switch (dtypes_[column_index]) { + case DT_DOUBLE: + PROCESS_TYPE(orc::DoubleVectorBatch*, double, double); + break; + case DT_FLOAT: + PROCESS_TYPE(orc::DoubleVectorBatch*, double, float); + break; + case DT_INT16: + PROCESS_TYPE(orc::LongVectorBatch*, int64, int16); + break; + case DT_INT32: + PROCESS_TYPE(orc::LongVectorBatch*, int64, int32); + break; + case DT_INT64: + PROCESS_TYPE(orc::LongVectorBatch*, int64, int64); + break; + case DT_STRING: { + auto* string_col = dynamic_cast( + fields->fields[column_index]); + char** buffer = string_col->data.data(); + int64_t* lengths = string_col->length.data(); + tensors_[column_index].flat()(record_index) = + std::string(buffer[r], lengths[r]); + break; + } + default: + return errors::InvalidArgument( + "data type is not supported: ", + DataTypeString(dtypes_[column_index])); + } + } + record_index++; + } + } + + return Status::OK(); + } + + Status Read(const int64 start, const int64 stop, const string& component, + int64* record_read, Tensor* value, Tensor* label) override { + if (columns_index_.find(component) == columns_index_.end()) { + return errors::InvalidArgument("component ", component, " is invalid"); + } + int64 column_index = columns_index_[component]; + + (*record_read) = 0; + if (start >= shapes_[column_index].dim_size(0)) { + return Status::OK(); + } + const string& column = component; + int64 element_start = start < shapes_[column_index].dim_size(0) + ? start + : shapes_[column_index].dim_size(0); + int64 element_stop = stop < shapes_[column_index].dim_size(0) + ? stop + : shapes_[column_index].dim_size(0); + if (element_start > element_stop) { + return errors::InvalidArgument("dataset ", column, + " selection is out of boundary"); + } + if (element_start == element_stop) { + return Status::OK(); + } + +#define PROCESS_VALUE(VTYPE) \ + { \ + value->flat().data()[i] = \ + tensors_[column_index].flat().data()[i]; \ + } + for (int i = element_start; i < element_stop; i++) { + switch (dtypes_[column_index]) { + case DT_DOUBLE: + PROCESS_VALUE(double); + break; + case DT_FLOAT: + PROCESS_VALUE(float); + break; + case DT_INT16: + PROCESS_VALUE(int16); + break; + case DT_INT32: + PROCESS_VALUE(int32); + break; + case DT_INT64: + PROCESS_VALUE(int64); + break; + case DT_STRING: { + PROCESS_VALUE(tstring); + break; + } + default: + return errors::InvalidArgument("data type is not supported: ", + DataTypeString(dtypes_[column_index])); + } + } + (*record_read) = element_stop - element_start; + + return Status::OK(); + } + + Status Components(std::vector* components) override { + components->clear(); + for (size_t i = 0; i < columns_.size(); i++) { + components->push_back(columns_[i]); + } + return Status::OK(); + } + + Status Spec(const string& component, PartialTensorShape* shape, + DataType* dtype, bool label) override { + if (columns_index_.find(component) == columns_index_.end()) { + return errors::InvalidArgument("component ", component, " is invalid"); + } + int64 column_index = columns_index_[component]; + *shape = shapes_[column_index]; + *dtype = dtypes_[column_index]; + return Status::OK(); + } + + string DebugString() const override { + mutex_lock l(mu_); + return strings::StrCat("ORCReadable"); + } + + private: + mutable mutex mu_; + Env* env_ TF_GUARDED_BY(mu_); + std::unique_ptr file_ TF_GUARDED_BY(mu_); + std::unique_ptr row_reader_ TF_GUARDED_BY(mu_); + std::vector tensors_; + + std::vector dtypes_; + std::vector shapes_; + std::vector columns_; + std::unordered_map columns_index_; +}; +REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableInit").Device(DEVICE_CPU), + IOInterfaceInitOp); +REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableSpec").Device(DEVICE_CPU), + IOInterfaceSpecOp); +REGISTER_KERNEL_BUILDER(Name("IO>ORCReadableRead").Device(DEVICE_CPU), + IOReadableReadOp); +} // namespace data +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_io/core/ops/orc_ops.cc b/tensorflow_io/core/ops/orc_ops.cc new file mode 100644 index 000000000..9bc292459 --- /dev/null +++ b/tensorflow_io/core/ops/orc_ops.cc @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +REGISTER_OP("IO>ORCReadableInit") + .Input("input: string") + .Output("resource: resource") + .Output("components: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("IO>ORCReadableSpec") + .Input("input: resource") + .Output("shape: int64") + .Output("dtype: int64") + .Attr("component: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->MakeShape({c->UnknownDim()})); + c->set_output(1, c->MakeShape({})); + return Status::OK(); + }); + +REGISTER_OP("IO>ORCReadableRead") + .Input("input: resource") + .Input("start: int64") + .Input("stop: int64") + .Output("value: dtype") + .Attr("component: string") + .Attr("shape: shape") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + shape_inference::ShapeHandle entry; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &entry)); + c->set_output(0, entry); + return Status::OK(); + }); +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow_io/core/python/ops/io_dataset.py b/tensorflow_io/core/python/ops/io_dataset.py index 84720994d..de5ed0f0d 100644 --- a/tensorflow_io/core/python/ops/io_dataset.py +++ b/tensorflow_io/core/python/ops/io_dataset.py @@ -26,6 +26,7 @@ from tensorflow_io.core.python.ops import parquet_dataset_ops from tensorflow_io.core.python.ops import pcap_dataset_ops from tensorflow_io.core.python.ops import mnist_dataset_ops +from tensorflow_io.core.python.ops import orc_dataset_ops class IODataset(io_dataset_ops._IODataset): # pylint: disable=protected-access @@ -308,6 +309,21 @@ def from_pcap(cls, filename, **kwargs): with tf.name_scope(kwargs.get("name", "IOFromPcap")): return pcap_dataset_ops.PcapIODataset(filename, internal=True, **kwargs) + @classmethod + def from_orc(cls, filename, **kwargs): + """Creates an `IODataset` from an ORC file. + + Args: + filename: A string, the filename of an ORC file. + name: A name prefix for the IOTensor (optional). + + Returns: + A `IODataset`. + + """ + with tf.name_scope(kwargs.get("name", "IOFromORC")): + return orc_dataset_ops.ORCIODataset(filename, internal=True, **kwargs) + class StreamIODataset( io_dataset_ops._StreamIODataset diff --git a/tensorflow_io/core/python/ops/orc_dataset_ops.py b/tensorflow_io/core/python/ops/orc_dataset_ops.py new file mode 100644 index 000000000..05425f3de --- /dev/null +++ b/tensorflow_io/core/python/ops/orc_dataset_ops.py @@ -0,0 +1,102 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ORCDataset""" + +import sys +import uuid + +import tensorflow as tf +from tensorflow_io.core.python.ops import core_ops + + +class _ORCIODatasetFunction: + def __init__(self, function, resource, component, shape, dtype): + self._function = function + self._resource = resource + self._component = component + self._shape = tf.TensorShape([None]).concatenate(shape[1:]) + self._dtype = dtype + + def __call__(self, start, stop): + return self._function( + self._resource, + start=start, + stop=stop, + component=self._component, + shape=self._shape, + dtype=self._dtype, + ) + + +class ORCIODataset(tf.data.Dataset): + """ORCIODataset""" + + def __init__(self, filename, columns=None, internal=True, **kwargs): + if not internal: + raise ValueError( + "ORCIODataset constructor is private; please use one " + "of the factory methods instead (e.g., " + "IODataset.from_orc())" + ) + with tf.name_scope("ORCIODataset") as scope: + capacity = 4096 + resource, columns_v = core_ops.io_orc_readable_init( + filename, + container=scope, + shared_name="{}/{}".format(filename, uuid.uuid4().hex), + ) + columns = columns if columns is not None else columns_v.numpy() + columns_dataset = [] + columns_function = [] + for column in columns: + shape, dtype = core_ops.io_orc_readable_spec(resource, column) + shape = tf.TensorShape([None if e < 0 else e for e in shape.numpy()]) + dtype = tf.as_dtype(dtype.numpy()) + function = _ORCIODatasetFunction( + core_ops.io_orc_readable_read, resource, column, shape, dtype + ) + columns_function.append(function) + + for (column, function) in zip(columns, columns_function): + column_dataset = tf.compat.v2.data.Dataset.range( + 0, sys.maxsize, capacity + ) + column_dataset = column_dataset.map( + lambda index: function(index, index + capacity) + ) + column_dataset = column_dataset.apply( + tf.data.experimental.take_while( + lambda v: tf.greater(tf.shape(v)[0], 0) + ) + ) + columns_dataset.append(column_dataset) + if len(columns_dataset) == 1: + dataset = columns_dataset[0] + else: + dataset = tf.compat.v2.data.Dataset.zip(tuple(columns_dataset)) + dataset = dataset.unbatch() + + self._function = columns_function + self._dataset = dataset + super().__init__( + self._dataset._variant_tensor + ) # pylint: disable=protected-access + + def _inputs(self): + return [] + + @property + def element_spec(self): + return self._dataset.element_spec diff --git a/tests/test_orc.py b/tests/test_orc.py new file mode 100644 index 000000000..d4cbd5f10 --- /dev/null +++ b/tests/test_orc.py @@ -0,0 +1,97 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +""" +Test ORCDataset +""" + +import os +import numpy as np + +import tensorflow as tf +import tensorflow_io as tfio + + +def test_orc_input(): + """test_pcap_input + """ + print("Testing ORCDataset") + orc_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_orc", "iris.orc" + ) + + dataset = tfio.IODataset.from_orc(orc_filename, capacity=15).batch(1) + packets_total = 0 + for v in dataset: + if packets_total == 0: + sepal_length, _, _, _, species = v + assert sepal_length.dtype == tf.float32 + assert species.dtype == tf.string + assert tf.math.less(tf.math.abs(sepal_length - 5.0999999), 0.0001) + assert tf.math.equal(species, "setosa") + packets_total += 1 + + assert packets_total == 150 + + +def test_orc_keras(): + """Test case for ORCDataset with Keras""" + orc_filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_orc", "iris.orc" + ) + + feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"] + label_cols = ["species"] + + feature_dataset = tfio.IODataset.from_orc(orc_filename, columns=feature_cols) + + label_dataset = tfio.IODataset.from_orc(orc_filename, columns=label_cols) + + @tf.function + def species_float_conversion(x): + if x == "virginica": + return 1.0 + if x == "versicolor": + return 2.0 + if x == "setosa": + return 3.0 + return 4.0 + + label_dataset = label_dataset.map(species_float_conversion) + dataset = tf.data.Dataset.zip((feature_dataset, label_dataset)) + dataset = dataset.batch(1) + + def pack_features_vector(features, labels): + """Pack the features into a single array.""" + features = tf.stack(list(features), axis=1) + return features, labels + + dataset = dataset.map(pack_features_vector) + + model = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + 10, activation=tf.nn.relu, input_shape=(4,) + ), # input shape required + tf.keras.layers.Dense(10, activation=tf.nn.relu), + tf.keras.layers.Dense(3), + ] + ) + + model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) + model.fit(dataset, epochs=5) + + +if __name__ == "__main__": + test.main() diff --git a/tests/test_orc/iris.orc b/tests/test_orc/iris.orc new file mode 100644 index 0000000000000000000000000000000000000000..717948c05a6a4a77dbf7e7a35aee7fb7bc594734 GIT binary patch literal 3328 zcmai%PiS0q5XXP-?VoJcMUoAbO{z3RYFqHo@=ghuEEGZy!ItJyPa5KG-4dFR4V6L< zQ6Uh_Ay5dur9vPe2t5daLZM*5Qc*TgEEEep2ueWkAVhK!>U?J2Z+BzT^f1i)=J%cN z%>3T(&Az>O`Fy^sn8~dLmt6i-K5)e$u?LQKJLk$#)LZTqV?B;fl>41qeDPxX%kn{> zK%58m0(6b|+gpbN{~k*ij|2SftN82B4gpHW`wA4UG{851Zr-W_d^;aHSA0u~9C{<4 ztydOi>r=(<8@2hXwd=L%snWx~OD>8w%G=F534bC0#AM*Ndd=|M{0v%pLjIfa5Um)iYK1_sU0B} zA3fRL(zadj6LQ)ak`JyZ`d#DcBZ_vlherIT=B`$EgyXT+WcA$B7;3b> z_XM{mHuXGDa_uP}y%JxolGFNPz1{pv zWw9qD*SJVkFgHa$NMu<3+szqWgEd){m=(ps5k++s;=sZaMU$-<{NQZga$!f&7lK_% za!e?OSaSH7DC*Ptd>;15N6q9SkJ-ixK8m75sG(nESTOGjuT$>rW)JEJpqT%jOV&S*C;iL}MaDc@|TR$E!rPC@y1~_z*W`ytRLH> zM$P~{XG0!(!ROBcygw`J4XpKJ7@nb>o3EE=;LpeI2X|y#csVP$CtE#|+nyP`n`4N~ zjWM~CYRf%T?bm2T-Us!|eHxQFq(0(^CkBkyGO^W`cTMAn^?cZV?R^#-YaN~S!0*NK z?l}0Kj_5_{>$a|K4o~Wu_l3L3JqE*`J);%Ha>wDY_bOJ5pXH5sS-5$R z@UJV*-*bECZc2`g#J{TC;EmJXL92<{Cbb58pO%DkS>kvW>!x4~3%(tLe`3@*C-=wmo3FjbkI!>?ocde& zI(cO-a34GJv}3r7+%bfE=(XU*UToc0Y{M<;{RG>NwLRV&`E6+NNDZxcu$}J?SUAC0 zE_?rce)tjgEtdDu>hZknTTbd8lxMdqF@z~(E&dCX!3q(RU9$s6u zudo}H`Fd@7c4l_I^);5y=CVDm)U7%wolzYoI+f^jqBEtF`L0~h6O?{Us}Xu5<&88( zRxp`TL3AjZbxL%9G0eFGl9s(x{$Y?$zHm>z`s4PBq-evW?ZZE955Nxpu-y-P>1Qpl z6Gwj2UO6MJ{BdmHby2?rrRBcCrQo!D1YgJ%ikToxr85JCdSz~E=EIrF^{We?47_Hm zPir5`qA*uk=-9HAZFSY>D$}(}eJIRj{wG~$*Mbo{xsmpf$Tc|pZ-F$?0XypC_QPI! paviXvPHyB?Cs$s5|ANbXANIJR&=pSq7@W<1{!c1(@$&g&{{cXQXSe_W literal 0 HcmV?d00001