diff --git a/builtin/builtin.mk b/builtin/builtin.mk index 5b358e2bf..0a5c67cfa 100644 --- a/builtin/builtin.mk +++ b/builtin/builtin.mk @@ -7,6 +7,7 @@ LIBMLDB_BUILTIN_SOURCES:= \ sub_dataset.cc \ filtered_dataset.cc \ sampled_dataset.cc \ + union_dataset.cc \ LIBMLDB_BUILTIN_LINK:= mldb_core runner diff --git a/builtin/sub_dataset.cc b/builtin/sub_dataset.cc index fe5eb608e..a6e58df9e 100644 --- a/builtin/sub_dataset.cc +++ b/builtin/sub_dataset.cc @@ -457,7 +457,8 @@ querySubDataset(MldbServer * server, std::vector output = dataset ->queryStructured(select, when, where, orderBy, groupBy, - having, named, offset, limit, "" /* alias */); + having, named, offset, limit, + -1, /* unionIndex */ "" /* alias */); std::vector result; result.reserve(output.size()); diff --git a/builtin/union_dataset.cc b/builtin/union_dataset.cc new file mode 100644 index 000000000..573eba1a2 --- /dev/null +++ b/builtin/union_dataset.cc @@ -0,0 +1,381 @@ +/** -*- C++ -*- + * union_dataset.cc + * Mich, 2016-09-14 + * This file is part of MLDB. Copyright 2016 Datacratic. All rights reserved. + **/ +#include "union_dataset.h" + +#include +#include + +#include "mldb/builtin/id_hash.h" +#include "mldb/builtin/merge_hash_entries.h" +#include "mldb/types/any_impl.h" +#include "mldb/types/structure_description.h" +#include "mldb/types/vector_description.h" + +using namespace std; + + +namespace MLDB { + + +/*****************************************************************************/ +/* UNION DATASET CONFIG */ +/*****************************************************************************/ + +DEFINE_STRUCTURE_DESCRIPTION(UnionDatasetConfig); + +UnionDatasetConfigDescription:: +UnionDatasetConfigDescription() +{ + nullAccepted = true; + + addField("datasets", &UnionDatasetConfig::datasets, + "Datasets to unify together"); +} + +static RegisterDatasetType +regUnion(builtinPackage(), + "union", + "Unify together several datasets", + "datasets/UnionDataset.md.html"); + +std::shared_ptr createUnionDataset( + MldbServer * server, vector > datasets) +{ + return std::make_shared(server, datasets); +} + +struct UnionDataset::Itl + : public MatrixView, public ColumnIndex { + + Lightweight_Hash > rowIndex; + + // Datasets that it was constructed with + vector > datasets; + + Itl(MldbServer * server, vector > datasets) { + if (datasets.empty()) { + throw MLDB::Exception("Attempt to unify no datasets together"); + } + this->datasets = datasets; + int indexWidth = getIndexBinaryWidth(); + if (indexWidth > 31) { + throw MLDB::Exception("Too many datasets in the union"); + } + for (int i = 0; i < datasets.size(); ++i) { + for (const auto & rowPath: datasets[i]->getMatrixView()->getRowPaths()) { + rowIndex[RowHash(PathElement(i) + rowPath)] = + make_pair(i, RowHash(rowPath)); + } + } + } + + int getIndexBinaryWidth() const { + return ceil(log(datasets.size()) / log(2)); + } + + int getIdxFromRowPath(const RowPath & rowPath) const { + // Returns idx > -1 if the index is valid, -1 otherwise + if (rowPath.size() < 2) { + return -1; + } + int idx = static_cast(rowPath.at(0).toIndex()); + if (idx > datasets.size()) { + return -1; + } + ExcAssert(idx == -1 || idx <= datasets.size()); + return idx; + } + + struct UnionRowStream : public RowStream { + + UnionRowStream(const UnionDataset::Itl* source) : source(source) + { + cerr << "UNIMPLEMENTED " << __FILE__ << ":" << __LINE__ << endl; + //throw MLDB::Exception("Unimplemented %s : %d", __FILE__, __LINE__); + } + + virtual std::shared_ptr clone() const + { + return make_shared(source); + } + + /* set where the stream should start*/ + virtual void initAt(size_t start) + { + cerr << "UNIMPLEMENTED " << __FILE__ << ":" << __LINE__ << endl; + //throw MLDB::Exception("Unimplemented %s : %d", __FILE__, __LINE__); + } + + virtual RowPath next() + { + cerr << "UNIMPLEMENTED " << __FILE__ << ":" << __LINE__ << endl; + throw MLDB::Exception("Unimplemented %s : %d", __FILE__, __LINE__); + uint64_t hash = (*it).first; + ++it; + + return source->getRowPath(RowHash(hash)); + } + + virtual const RowPath & rowName(RowPath & storage) const override + { + cerr << "UNIMPLEMENTED " << __FILE__ << ":" << __LINE__ << endl; + throw MLDB::Exception("Unimplemented %s : %d", __FILE__, __LINE__); + uint64_t hash = (*it).first; + return storage = source->getRowPath(RowHash(hash)); + } + + const UnionDataset::Itl* source; + IdHashes::const_iterator it; + + }; + + virtual vector + getRowPaths(ssize_t start = 0, ssize_t limit = -1) const + { + // Row names are idx.rowPath where idx is the index of the dataset + // in the union and rowPath is the original rowPath. + vector result; + for (int i = 0; i < datasets.size(); ++i) { + const auto & d = datasets[i]; + for (const auto & name: d->getMatrixView()->getRowPaths()) { + result.emplace_back(PathElement(i) + name); + } + + } + return result; + } + + virtual vector + getRowHashes(ssize_t start = 0, ssize_t limit = -1) const + { + std::vector result; + for (const auto & it: rowIndex) { + result.emplace_back(it.first); + } + return result; + } + + virtual bool knownRow(const Path & rowPath) const + { + int idx = getIdxFromRowPath(rowPath); + if (idx == -1) { + return false; + } + return datasets[idx]->getMatrixView()->knownRow(rowPath.tail()); + } + + virtual bool knownRowHash(const RowHash & rowHash) const + { + // Unused ? + return rowIndex.find(rowHash) != rowIndex.end(); + //return rowIndex.getDefault(rowHash, 0) != 0; + } + + virtual RowPath getRowPath(const RowHash & rowHash) const + { + const auto & it = rowIndex.find(rowHash); + if (it == rowIndex.end()) { + throw MLDB::Exception("Row not known"); + } + const auto & idxAndHash = it->second; + return datasets[idxAndHash.first]->getMatrixView()->getRowPath(idxAndHash.second); + } + + // DEPRECATED + virtual MatrixNamedRow getRow(const RowPath & rowPath) const + { + throw MLDB::Exception("Unimplemented %s : %d", __FILE__, __LINE__); + } + + virtual bool knownColumn(const Path & column) const + { + for (const auto & d: datasets) { + if (d->getMatrixView()->knownColumn(column)) { + return true; + } + } + return false; + } + + virtual ColumnPath getColumnPath(ColumnHash columnHash) const + { + for (const auto & d: datasets) { + try { + return d->getMatrixView()->getColumnPath(columnHash); + } + catch (const MLDB::Exception & exc) { + } + } + throw MLDB::Exception("Column not known"); + } + + /** Return a list of all columns. */ + virtual vector getColumnPaths() const + { + std::set preResult; + for (const auto & d: datasets) { + auto columnPaths = d->getColumnPaths(); + preResult.insert(columnPaths.begin(), columnPaths.end()); + } + return vector(preResult.begin(), preResult.end()); + } + + virtual MatrixColumn getColumn(const ColumnPath & columnPath) const + { + MatrixColumn result; + result.columnName = columnPath; + result.columnHash = columnPath; + vector > res; + for (int i = 0; i < datasets.size(); ++i) { + const auto & d = datasets[i]; + const auto & subCol = d->getColumnIndex()->getColumn(columnPath); + for (const auto & curr: subCol.rows) { + result.rows.emplace_back(PathElement(i) + std::get<0>(curr), + std::get<1>(curr), + std::get<2>(curr)); + } + } + return result; + } + + /** Return the value of the column for all rows and timestamps. */ + virtual vector > + getColumnValues(const ColumnPath & columnPath, + const std::function & filter) const + { + vector > res; + for (int i = 0; i < datasets.size(); ++i) { + const auto & d = datasets[i]; + for (const auto curr: d->getColumnIndex()->getColumnValues(columnPath)) { + res.emplace_back( + PathElement(i) + std::get<0>(curr).toUtf8String().rawString(), + std::get<1>(curr)); + } + } + return res; + } + + virtual size_t getRowCount() const + { + size_t count = 0; + for (const auto & d: datasets) { + count += d->getRowCount(); + } + return count; + } + + virtual size_t getColumnCount() const + { + return getColumnPaths().size(); + } + + std::pair getTimestampRange() const + { + std::pair result(Date::notADate(), Date::notADate()); + bool first = true; + + for (auto & d: datasets) { + std::pair dsRange = d->getTimestampRange(); + if (!dsRange.first.isADate() || !dsRange.second.isADate()) { + continue; + } + if (first) { + result = dsRange; + first = false; + } + else { + result.first.setMin(dsRange.first); + result.second.setMax(dsRange.second); + } + } + + return result; + } +}; + + +UnionDataset:: +UnionDataset(MldbServer * owner, + PolyConfig config, + const std::function & onProgress) + : Dataset(owner) +{ + auto unionConfig = config.params.convert(); + + vector > datasets; + + for (auto & d: unionConfig.datasets) { + datasets.emplace_back(obtainDataset(owner, d, onProgress)); + } + + itl.reset(new Itl(server, datasets)); +} + +UnionDataset:: +UnionDataset(MldbServer * owner, + vector > datasetsToMerge) + : Dataset(owner) +{ + itl.reset(new Itl(server, datasetsToMerge)); +} + +UnionDataset:: +~UnionDataset() +{ +} + +Any +UnionDataset:: +getStatus() const +{ + vector result; + for (auto & d: itl->datasets) { + result.emplace_back(d->getStatus()); + } + return result; +} + +std::pair +UnionDataset:: +getTimestampRange() const +{ + return itl->getTimestampRange(); +} + +std::shared_ptr +UnionDataset:: +getMatrixView() const +{ + return itl; +} + +std::shared_ptr +UnionDataset:: +getColumnIndex() const +{ + return itl; +} + +std::shared_ptr +UnionDataset:: +getRowStream() const +{ + return make_shared(itl.get()); +} + +ExpressionValue +UnionDataset:: +getRowExpr(const RowPath & rowPath) const +{ + int idx = itl->getIdxFromRowPath(rowPath); + if (idx == -1) { + return ExpressionValue{}; + } + return itl->datasets[idx]->getRowExpr( + Path(rowPath.begin() + 1, rowPath.end())); +} + +} // namespace MLDB diff --git a/builtin/union_dataset.h b/builtin/union_dataset.h new file mode 100644 index 000000000..1b002d657 --- /dev/null +++ b/builtin/union_dataset.h @@ -0,0 +1,62 @@ +/** -*- C++ -*- + * union_dataset.h + * Mich, 2016-09-14 + * This file is part of MLDB. Copyright 2016 Datacratic. All rights reserved. + **/ + +#pragma once + +#include "mldb/core/dataset.h" +#include "mldb/types/value_description_fwd.h" + +namespace MLDB { + + +/*****************************************************************************/ +/* UNION DATASET CONFIG */ +/*****************************************************************************/ + +struct UnionDatasetConfig { + std::vector > datasets; +}; + +DECLARE_STRUCTURE_DESCRIPTION(UnionDatasetConfig); + + +/*****************************************************************************/ +/* UNION DATASET */ +/*****************************************************************************/ + +struct UnionDataset: public Dataset { + + UnionDataset(MldbServer * owner, + PolyConfig config, + const std::function & onProgress); + + /** Constructor used internally when creating a datasets */ + UnionDataset(MldbServer * owner, + std::vector > datasetsToMerge); + + virtual ~UnionDataset() override; + + virtual Any getStatus() const override; + virtual void recordRowItl(const RowPath & rowPath, + const std::vector > & vals) override + { + throw MLDB::Exception("Dataset type doesn't allow recording"); + } + + virtual std::shared_ptr getMatrixView() const override; + virtual std::shared_ptr getColumnIndex() const override; + virtual std::shared_ptr getRowStream() const override; + + virtual std::pair getTimestampRange() const override; + virtual ExpressionValue getRowExpr(const RowPath & rowPath) const override; + +private: + UnionDatasetConfig datasetConfig; + struct Itl; + std::shared_ptr itl; +}; + +} // namespace MLDB diff --git a/container_files/public_html/doc/builtin/datasets/UnionDataset.md b/container_files/public_html/doc/builtin/datasets/UnionDataset.md new file mode 100644 index 000000000..9ef91035e --- /dev/null +++ b/container_files/public_html/doc/builtin/datasets/UnionDataset.md @@ -0,0 +1,31 @@ +# Union Dataset + +The union dataset allows for rows from multiple datasets to be appended +into a single dataset. Columns that match up between the datasets will be +combined together. Row names are altered to reflect the dataset they came +from and avoid having them merged together. + +For example, the row names of to unified datasets will have the following +format. + +``` +0.dataset index 0 row name +... +n.dataset index n row name +``` + +The union is done on the fly which means it is relatively rapid to unify even +large datasets together. + +Aside from the resulting row names, creating a union dataset is equivalent to +the following SQL: + +```sql +SELECT s1.* AS *, s2.* AS * +FROM (SELECT * FROM ds1 ) AS s1 +OUTER JOIN (SELECT * FROM ds2) AS s2 ON false +``` + +## Configuration + +![](%%config dataset union) diff --git a/core/dataset.cc b/core/dataset.cc index 9be65cb43..d85504444 100644 --- a/core/dataset.cc +++ b/core/dataset.cc @@ -561,7 +561,10 @@ getTimestampRange() const TupleExpression(), SqlExpression::TRUE /* having */, SqlExpression::TRUE,/* rowName */ - 0, 1, "" /* alias */); + 0, /* limit */ + 1, /* offset */ + -1, /* unionIndex */ + "" /* alias */); std::pair result; @@ -768,6 +771,7 @@ queryStructured(const SelectExpression & select, const std::shared_ptr rowName, ssize_t offset, ssize_t limit, + int unionIndex, Utf8String alias) const { std::vector output; @@ -781,6 +785,7 @@ queryStructured(const SelectExpression & select, rowName, offset, limit, + unionIndex, alias); for (auto& r : std::get<0>(rows)) { @@ -801,6 +806,7 @@ queryStructuredExpr(const SelectExpression & select, const std::shared_ptr rowName, ssize_t offset, ssize_t limit, + int unionIndex, Utf8String alias) const { ExcAssert(having); @@ -828,7 +834,13 @@ queryStructuredExpr(const SelectExpression & select, auto processor = [&] (NamedRowValue & row_, const std::vector & calc) { - row_.rowName = getValidatedRowName(calc.at(0)); + if (unionIndex == -1) { + row_.rowName = getValidatedRowName(calc.at(0)); + } + else { + row_.rowName = PathElement(unionIndex) + + getValidatedRowName(calc.at(0)); + } row_.rowHash = row_.rowName; output.push_back(std::move(row_)); return true; @@ -2121,15 +2133,46 @@ queryString(const Utf8String & query) const ExcCheck(!stm.from, "FROM clauses are not allowed on dataset queries"); ExcAssert(stm.where && stm.having && stm.rowName); - return queryStructured( - stm.select, - stm.when, - *stm.where, - stm.orderBy, - stm.groupBy, - stm.having, - stm.rowName, - stm.offset, stm.limit); + if (stm.unionStm == nullptr) { + return queryStructured( + stm.select, + stm.when, + *stm.where, + stm.orderBy, + stm.groupBy, + stm.having, + stm.rowName, + stm.offset, + stm.limit, + stm.unionIndex); + } + auto result = queryStructured(stm.select, + stm.when, + *stm.where, + stm.orderBy, + stm.groupBy, + stm.having, + stm.rowName, + stm.offset, + stm.limit, + stm.unionIndex); + SelectStatement * stmtPtr = &stm; + while (stmtPtr->unionStm != nullptr) { + stmtPtr = stmtPtr->unionStm.get(); + auto partial = queryStructured(stmtPtr->select, + stmtPtr->when, + *stmtPtr->where, + stmtPtr->orderBy, + stmtPtr->groupBy, + stmtPtr->having, + stmtPtr->rowName, + stmtPtr->offset, + stmtPtr->limit, + stmtPtr->unionIndex); + result.insert(result.end(), partial.begin(), partial.end()); + } + return result; + } Json::Value diff --git a/core/dataset.h b/core/dataset.h index c0e15008e..f14d1181f 100644 --- a/core/dataset.h +++ b/core/dataset.h @@ -524,6 +524,7 @@ struct Dataset: public MldbEntity { const std::shared_ptr rowName, ssize_t offset, ssize_t limit, + int unionIndex, Utf8String alias = "") const; std::tuple, std::shared_ptr > @@ -536,6 +537,7 @@ struct Dataset: public MldbEntity { const std::shared_ptr rowName, ssize_t offset, ssize_t limit, + int unionIndex, Utf8String alias = "") const; /** Select from the database. */ diff --git a/plugins/continuous_dataset.cc b/plugins/continuous_dataset.cc index 52ea111ed..b3ed9c33d 100644 --- a/plugins/continuous_dataset.cc +++ b/plugins/continuous_dataset.cc @@ -513,6 +513,7 @@ getDatasetConfig(std::shared_ptr datasetsWhere, SqlExpression::parse("rowPath()") /* rowName */, 0 /* offset */, -1 /* limit */, + -1, /* unionIndex */ "" /* alias */); // TODO: diff --git a/server/analytics.cc b/server/analytics.cc index 922099996..28658b5ce 100644 --- a/server/analytics.cc +++ b/server/analytics.cc @@ -378,6 +378,9 @@ queryWithoutDatasetExpr(const SelectStatement& stm, SqlBindingScope& scope) auto boundRowName = stm.rowName->bind(scope); row.rowName = getValidatedRowName(boundRowName(context, GET_ALL)); + if (stm.unionIndex != -1) { + row.rowName = PathElement(stm.unionIndex) + row.rowName; + } row.rowHash = row.rowName; val.mergeToRowDestructive(row.columns); std::vector outputcolumns = {std::move(row)}; @@ -399,6 +402,14 @@ queryFromStatement(const SelectStatement & stm, for (auto& r : std::get<0>(rows)) { output.push_back(r.flattenDestructive()); } +// const SelectStatement * stmtPtr = &stm; +// do { +// auto rows = queryFromStatementExpr(*stmtPtr, scope, params); +// for (auto& r : std::get<0>(rows)) { +// output.push_back(r.flattenDestructive()); +// } +// stmtPtr = stmtPtr->unionStm.get(); +// } while (stmtPtr != nullptr); return output; } @@ -408,14 +419,16 @@ queryFromStatementExpr(const SelectStatement & stm, BoundParameters params) { BoundTableExpression table = stm.from->bind(scope); + tuple, shared_ptr > result; if (table.dataset) { - return table.dataset->queryStructuredExpr(stm.select, stm.when, + result = table.dataset->queryStructuredExpr(stm.select, stm.when, *stm.where, stm.orderBy, stm.groupBy, stm.having, stm.rowName, - stm.offset, stm.limit, + stm.offset, stm.limit, + stm.unionIndex, table.asName); } else if (table.table.runQuery && stm.from) { @@ -458,19 +471,33 @@ queryFromStatementExpr(const SelectStatement & stm, NamedRowValue row; // Second last element is the row name row.rowName = output->values.at(output->values.size() - 2) - .coerceToPath(); + .coerceToPath(); row.rowHash = row.rowName; output->values.back().mergeToRowDestructive(row.columns); rows.emplace_back(std::move(row)); } - return std::make_tuple, - std::shared_ptr >(std::move(rows), std::make_shared()); + result = make_tuple< + vector, + shared_ptr >(std::move(rows), make_shared()); } else { // No from at all - return queryWithoutDatasetExpr(stm, scope); + result = queryWithoutDatasetExpr(stm, scope); + } + + if (stm.unionStm != nullptr) { + auto subResult = queryFromStatementExpr(*stm.unionStm.get(), + scope, + params); + // TODO what happens with ExpressionValueInfo + auto & namedRowValues = std::get<0>(result); + auto & subNamedRowValues = std::get<0>(subResult); + namedRowValues.insert(namedRowValues.end(), + subNamedRowValues.begin(), + subNamedRowValues.end()); } + return result; } /** Select from the given statement. This will choose the most diff --git a/server/dataset_collection.cc b/server/dataset_collection.cc index 2472a121a..e02c531c1 100644 --- a/server/dataset_collection.cc +++ b/server/dataset_collection.cc @@ -570,7 +570,7 @@ queryStructured(const Dataset * dataset, { return dataset->queryStructured (selectParsed, whenParsed, *whereParsed, orderByParsed, - groupByParsed,havingParsed, rowNameParsed, offset, limit); + groupByParsed,havingParsed, rowNameParsed, offset, limit, -1); }; runHttpQuery(runQuery, connection, format, createHeaders,rowNames, rowHashes, sortColumns); diff --git a/server/forwarded_dataset.cc b/server/forwarded_dataset.cc index b1452bcf2..ef39ad8a4 100644 --- a/server/forwarded_dataset.cc +++ b/server/forwarded_dataset.cc @@ -132,11 +132,12 @@ queryStructured(const SelectExpression & select, const std::shared_ptr rowName, ssize_t offset, ssize_t limit, + int unionIndex, Utf8String alias) const { ExcAssert(underlying); return underlying->queryStructured(select, when, where, orderBy, - groupBy, having, rowName, offset, limit, alias); + groupBy, having, rowName, offset, limit, unionIndex, alias); } std::vector diff --git a/server/forwarded_dataset.h b/server/forwarded_dataset.h index c9d0b5fe7..88c12719e 100644 --- a/server/forwarded_dataset.h +++ b/server/forwarded_dataset.h @@ -67,6 +67,7 @@ struct ForwardedDataset: public Dataset { const std::shared_ptr rowName, ssize_t offset, ssize_t limit, + int unionIndex, Utf8String alias = "") const; virtual std::vector diff --git a/sql/builtin_dataset_functions.cc b/sql/builtin_dataset_functions.cc index e2b4a6ff0..b460e368e 100644 --- a/sql/builtin_dataset_functions.cc +++ b/sql/builtin_dataset_functions.cc @@ -122,8 +122,6 @@ BoundTableExpression merge(const SqlBindingScope & context, static RegisterBuiltin registerMerge(merge, "merge"); - - /*****************************************************************************/ /* SAMPLED DATASET */ /*****************************************************************************/ diff --git a/sql/path.cc b/sql/path.cc index f4efc94ad..a2d5a2edb 100644 --- a/sql/path.cc +++ b/sql/path.cc @@ -588,6 +588,16 @@ toIndex() const return val; } +size_t +PathElement:: +requireIndex() const +{ + ssize_t result = toIndex(); + if (result == -1) + throw HttpReturnException(400, "Path was not an index"); + return result; +} + bool PathElement:: hasStringView() const diff --git a/sql/sql_expression.cc b/sql/sql_expression.cc index bbc22fa83..c615c7e6f 100644 --- a/sql/sql_expression.cc +++ b/sql/sql_expression.cc @@ -3895,7 +3895,8 @@ SelectStatement() : having(SelectExpression::TRUE), rowName(SqlExpression::parse("rowPath()")), offset(0), - limit(-1) + limit(-1), + unionIndex(-1) { //TODO - avoid duplication of default values } @@ -3930,7 +3931,7 @@ parse(const char * body) } SelectStatement -SelectStatement::parse(ParseContext& context, bool acceptUtf8) +SelectStatement::parseImpl(ParseContext& context, bool acceptUtf8) { ParseContext::Hold_Token token(context); @@ -4029,6 +4030,24 @@ SelectStatement::parse(ParseContext& context, bool acceptUtf8) statement.surface = ML::trim(token.captured()); skip_whitespace(context); + return statement; +} + +SelectStatement +SelectStatement::parse(ParseContext& context, bool acceptUtf8) +{ + SelectStatement statement = parseImpl(context, acceptUtf8); + + if (matchKeyword(context, "UNION ")) { + SelectStatement * statementPtr = &statement; + statementPtr->unionIndex = 0; + do { + statementPtr->unionStm = + make_shared(parseImpl(context, acceptUtf8)); + statementPtr->unionStm->unionIndex = statementPtr->unionIndex + 1; + statementPtr = statementPtr->unionStm.get(); + } while (matchKeyword(context, "UNION ")); + } //cerr << jsonEncode(statement) << endl; return statement; @@ -4038,14 +4057,15 @@ Utf8String SelectStatement:: print() const { - return select.print() + + return select.print() + rowName->print() + from->print() + when.print() + where->print() + orderBy.print() + groupBy.print() + - having->print(); + having->print() + + (unionStm == nullptr ? "" : " UNION " + unionStm->print()); } UnboundEntities diff --git a/sql/sql_expression.h b/sql/sql_expression.h index a417adb4d..f973d57d7 100644 --- a/sql/sql_expression.h +++ b/sql/sql_expression.h @@ -1738,9 +1738,11 @@ struct SelectStatement TupleExpression groupBy; std::shared_ptr having; std::shared_ptr rowName; + std::shared_ptr unionStm; ssize_t offset; ssize_t limit; + int unionIndex; // Surface form of select statement (original string that was parsed) Utf8String surface; @@ -1749,7 +1751,7 @@ struct SelectStatement static SelectStatement parse(const char * body); static SelectStatement parse(const Utf8String& body); static SelectStatement parse(ParseContext& context, bool allowUtf8); - + static SelectStatement parseImpl(ParseContext& context, bool allowUtf8); UnboundEntities getUnbound() const; Utf8String print() const; diff --git a/testing/testing.mk b/testing/testing.mk index db5fc6ed9..8914f13e5 100644 --- a/testing/testing.mk +++ b/testing/testing.mk @@ -435,6 +435,7 @@ $(eval $(call mldb_unit_test,MLDB-1907-value-description-error.py)) $(eval $(call mldb_unit_test,test_classifier_test_proc.py)) $(eval $(call mldb_unit_test,MLDB-1937-svd-with-complex-select.py)) $(eval $(call mldb_unit_test,fetcher-function.py)) +$(eval $(call mldb_unit_test,union_dataset_test.py)) $(eval $(call mldb_unit_test,MLDB-1950-crash-in-merge.py)) $(eval $(call mldb_unit_test,MLDB-408-task-cancellation.py)) $(eval $(call mldb_unit_test,MLDB-1921_merge_ds_strings.py)) diff --git a/testing/union_dataset_test.py b/testing/union_dataset_test.py new file mode 100644 index 000000000..6d5b29288 --- /dev/null +++ b/testing/union_dataset_test.py @@ -0,0 +1,183 @@ +# +# union_dataset_test.py +# Francois-Michel L'Heureux, 2016-09-20 +# This file is part of MLDB. Copyright 2016 Datacratic. All rights reserved. +# + +mldb = mldb_wrapper.wrap(mldb) # noqa + +class UnionDatasetTest(MldbUnitTest): # noqa + + @classmethod + def setUpClass(cls): + ds = mldb.create_dataset({'id' : 'ds1', 'type' : 'sparse.mutable'}) + ds.record_row('row1', [['colA', 'A', 1]]) + ds.commit() + + ds = mldb.create_dataset({'id' : 'ds2', 'type' : 'sparse.mutable'}) + ds.record_row('row1', [['colB', 'B', 1]]) + ds.commit() + + ds = mldb.create_dataset({'id' : 'ds3', 'type' : 'sparse.mutable'}) + ds.record_row('row1', [['colA', 'AA', 1], ['colB', 'BB', 1]]) + ds.record_row('row2', [['colA', 'A', 1], ['colC', 'C', 1]]) + ds.commit() + + def test_dataset(self): + mldb.put('/v1/datasets/union_ds', { + 'type' : 'union', + 'params' : { + 'datasets' : [{'id' : 'ds1'}, {'id' : 'ds2'}] + } + }) + + res = mldb.query("SELECT colA, colB FROM union_ds ORDER BY rowName()") + self.assertTableResultEquals(res, [ + ['_rowName', 'colA', 'colB'], + ['0.row1', 'A', None], + ['1.row1', None, 'B'] + ]) + + res = mldb.query("SELECT * FROM union_ds ORDER BY rowName() LIMIT 1") + self.assertTableResultEquals(res, [ + ['_rowName', 'colA'], + ['0.row1', 'A'] + ]) + + res = mldb.query("SELECT * FROM union_ds ORDER BY rowName() OFFSET 1") + self.assertTableResultEquals(res, [ + ['_rowName', 'colB'], + ['1.row1', 'B'] + ]) + + mldb.put('/v1/datasets/union_ds2', { + 'type' : 'union', + 'params' : { + 'datasets' : [{'id' : 'ds3'}, {'id' : 'ds3'}] + } + }) + res = mldb.query( + "SELECT colA, colB, colC FROM union_ds2 ORDER BY rowName()") + self.assertTableResultEquals(res, [ + ['_rowName', 'colA', 'colB', 'colC'], + ['0.row1', 'AA', 'BB', None], + ['0.row2', 'A', None, 'C'], + ['1.row1', 'AA', 'BB', None], + ['1.row2', 'A', None, 'C'] + ]) + + def test_query_from_ds(self): + res = mldb.query(""" + SELECT colA FROM ds1 + UNION + SELECT colB FROM ds2 + """) + self.assertTableResultEquals(res, [ + ['_rowName', 'colA', 'colB'], + ['0.row1', 'A', None], + ['1.row1', None, 'B'] + ]) + + def test_query_w_and_wo_ds(self): + res = mldb.query(""" + SELECT 1 + UNION + SELECT colB FROM ds2 + """) + self.assertTableResultEquals(res, [ + ['_rowName', '1', 'colB'], + ['0.result', 1, None], + ['1.row1', None, 'B'] + ]) + + res = mldb.query(""" + SELECT colB FROM ds2 + UNION + SELECT 1 + """) + self.assertTableResultEquals(res, [ + ['_rowName', 'colB', '1'], + ['0.row1', 'B', None], + ['1.result', None, 1] + ]) + + def test_query_wo_ds(self): + res = mldb.query(""" + SELECT 1 + UNION + SELECT 2 + UNION + SELECT 1, 2 + UNION + SELECT 3 + """) + self.assertTableResultEquals(res, [ + ['_rowName', '1', '2', '3'], + ['0.result', 1, None, None], + ['1.result', None, 2, None], + ['2.result', 1, 2, None], + ['3.result', None, None, 3] + ]) + + def test_query_wo_ds_nested(self): + res = mldb.query(""" + SELECT * FROM ( + SELECT 1 + UNION + SELECT 2 + ) + """) + self.assertTableResultEquals(res, [ + ['_rowName', '1', '2'], + ['0.result', 1, None], + ['1.result', None, 2] + ]) + + def test_query_wo_ds_nested_over_union(self): + res = mldb.query(""" + SELECT * FROM ( + SELECT 1 + UNION + SELECT 2 + ) + UNION + SELECT 3 + """) + self.assertTableResultEquals(res, [ + ['_rowName', '1', '2', '3'], + ['0.0.result', 1, None, None], + ['0.1.result', None, 2, None], + ['1.result', None, None, 3] + ]) + + def test_query_over_union(self): + res = mldb.query(""" + SELECT * FROM ( + SELECT 1 + UNION + SELECT 2 + ) + WHERE rowName() = '0.result' + """) + self.assertTableResultEquals(res, [ + ['_rowName', '1'], + ['0.result', 1] + ]) + + res = mldb.query(""" + SELECT count({*}) AS * FROM ( + SELECT 1, 2, 3 + UNION + SELECT 2, 1 + UNION + SELECT 1 + ) + """) + self.assertTableResultEquals(res, [ + ['_rowName', '1', '2', '3'], + ['[]', 3, 2, 1] + ]) + + +if __name__ == '__main__': + mldb.run_tests()