From c7f284926382599dbb3494269d41d6ba2b972c06 Mon Sep 17 00:00:00 2001 From: Pradnya Khalate Date: Sun, 2 Feb 2025 00:56:55 -0800 Subject: [PATCH] * Saving progress Signed-off-by: Pradnya Khalate --- runtime/common/RecordParser.h | 228 +++++++++++++++++++++++++++------- 1 file changed, 183 insertions(+), 45 deletions(-) diff --git a/runtime/common/RecordParser.h b/runtime/common/RecordParser.h index b784b5b6fb..8312c63ecc 100644 --- a/runtime/common/RecordParser.h +++ b/runtime/common/RecordParser.h @@ -19,65 +19,203 @@ struct OutputRecord { std::size_t size; }; +enum struct SchemaType { LABELED, ORDERED }; +enum struct RecordType { HEADER, METADATA, OUTPUT, START, END }; +enum struct OutputType { BOOL, INT, DOUBLE }; +enum struct ContainerType { ARRAY, TUPLE }; + +struct ContainerProcessor { + ContainerType container; + std::size_t count = 0; + std::size_t index = 0; + std::vector types; + + void initialize(ContainerType c) { + container = c; + count = 0; + index = 0; + types = {}; + } +}; + struct RecordParser { private: - bool labelExpected = false; + SchemaType schema = SchemaType::ORDERED; + RecordType currentRecord; + OutputType currentOutput; + bool isInContainer = false; + ContainerProcessor containerIterator; + + template + void addPrimitiveRecord(T value) { + results.emplace_back( + OutputRecord{static_cast(new T(value)), sizeof(T)}); + } + + OutputType extractPrimitiveType(const std::string &label) { + if ('i' == label[0]) { + auto digits = std::stoi(label.substr(1)); + if (1 == digits) + return OutputType::BOOL; + return OutputType::INT; + } else if ('f' == label[0]) + return OutputType::DOUBLE; + throw std::runtime_error("Unknown datatype in label"); + } + + /// Parse string like "array<3 x i32>" + std::pair + extractArrayInfo(const std::string &label) { + auto isArray = label.find("array"); + if (isArray == std::string::npos) + throw std::runtime_error("Array label missing keyword"); + auto lessThan = label.find('<'); + if (lessThan == std::string::npos) + throw std::runtime_error("Array label missing keyword"); + auto greaterThan = label.find('>'); + if (greaterThan == std::string::npos) + throw std::runtime_error("Array label missing keyword"); + auto x = label.find('x'); + if (x == std::string::npos) + throw std::runtime_error("Array label missing keyword"); + std::size_t arrSize = + std::stoi(label.substr(lessThan + 1, x - lessThan - 2)); + OutputType arrType = + extractPrimitiveType(label.substr(x + 2, greaterThan - x - 2)); + return std::make_pair(arrSize, arrType); + } + + void prcoessSingleRecord(const std::string &recValue, + const std::string &recLabel) { + if ((!recLabel.empty()) && + (extractPrimitiveType(recLabel) != currentOutput)) + throw std::runtime_error("Type mismatch in label"); + + switch (currentOutput) { + case OutputType::BOOL: { + bool value; + if ("true" == recValue) + value = true; + else if ("false" == recValue) + value = false; + else + throw std::runtime_error("Invalid boolean value"); + addPrimitiveRecord(value); + break; + } + case OutputType::INT: { + if (isInContainer) { + if (0 == containerIterator.index) { + int *resArr = new int[containerIterator.count]; + results.emplace_back( + OutputRecord{static_cast(resArr), + sizeof(int) * containerIterator.count}); + } + static_cast(results.back().buffer)[containerIterator.index++] = + std::stoi(recValue); + if (containerIterator.index == containerIterator.count) + isInContainer = false; + } else + addPrimitiveRecord(std::stoi(recValue)); + break; + } + case OutputType::DOUBLE: { + addPrimitiveRecord(std::stod(recValue)); + break; + } + } + } public: + std::vector results; + std::vector parse(const std::string &data) { - std::vector results; std::vector lines = cudaq::split(data, '\n'); - std::size_t arrSize = 0; - int arrIdx = -1; + if (lines.empty()) + return {}; + for (auto line : lines) { std::vector entries = cudaq::split(line, '\t'); if (entries.empty()) continue; - if (entries[0] != "OUTPUT") + + if ("HEADER" == entries[0]) + currentRecord = RecordType::HEADER; + else if ("METADATA" == entries[0]) + currentRecord = RecordType::METADATA; + else if ("OUTPUT" == entries[0]) + currentRecord = RecordType::OUTPUT; + else if ("START" == entries[0]) + currentRecord = RecordType::START; + else if ("END" == entries[0]) + currentRecord = RecordType::END; + else throw std::runtime_error("Invalid data"); - /// TODO: Handle labeled records - if ("BOOL" == entries[1]) { - bool value; - if ("true" == entries[2]) - value = true; - else if ("false" == entries[2]) - value = false; - results.emplace_back( - OutputRecord{static_cast(new bool(value)), sizeof(bool)}); - } else if ("INT" == entries[1]) { - if (0 != arrSize) { - if (0 == arrIdx) { - int *resArr = new int[arrSize]; - results.emplace_back(OutputRecord{static_cast(resArr), - sizeof(int) * arrSize}); - } - static_cast(results.back().buffer)[arrIdx++] = - std::stoi(entries[2]); - if (arrIdx == arrSize) { - arrSize = 0; - arrIdx = -1; - } + switch (currentRecord) { + case RecordType::HEADER: { + if ("schema_name" == entries[1]) { + if ("labeled" == entries[2]) + schema = SchemaType::LABELED; + else if ("ordered" == entries[2]) + schema = SchemaType::ORDERED; + throw std::runtime_error("Unknown schema type"); + } + /// TODO: Check schema version + break; + } + case RecordType::METADATA: + // ignore metadata for now + break; + case RecordType::START: + // indicates start of a shot + break; + case RecordType::END: { + // indicates end of a shot + if (entries.size() < 2) + throw std::runtime_error("Missing shot status"); + if ("0" != entries[1]) + throw std::runtime_error("Cannot handle unsuccessful shot"); + break; + } + case RecordType::OUTPUT: { + if (entries.size() < 3) + throw std::runtime_error("Insufficent data in a record"); + if ((schema == SchemaType::LABELED) && (entries.size() != 4)) + throw std::runtime_error( + "Unexpected record size for a labeled record"); + std::string recType = entries[1]; + std::string recValue = entries[2]; + std::string recLabel = (entries.size() == 4) ? entries[3] : ""; + + if ("BOOL" == entries[1]) + currentOutput = OutputType::BOOL; + else if ("INT" == entries[1]) + currentOutput = OutputType::INT; + else if ("DOUBLE" == entries[1]) + currentOutput = OutputType::DOUBLE; + else if ("ARRAY" == entries[1]) { + isInContainer = true; + containerIterator.initialize(ContainerType::ARRAY); + containerIterator.count = std::stoi(recValue); + if (0 == containerIterator.count) + throw std::runtime_error("Got empty array"); + } else if ("TUPLE" == entries[1]) { + isInContainer = true; + containerIterator.initialize(ContainerType::TUPLE); + containerIterator.count = std::stoi(recValue); + if (0 == containerIterator.count) + throw std::runtime_error("Got empty tuple"); } else - results.emplace_back( - OutputRecord{static_cast(new int(std::stoi(entries[2]))), - sizeof(int)}); - } else if ("FLOAT" == entries[1]) { - results.emplace_back( - OutputRecord{static_cast(new int(std::stof(entries[2]))), - sizeof(float)}); - } else if ("DOUBLE" == entries[1]) { - results.emplace_back( - OutputRecord{static_cast(new int(std::stod(entries[2]))), - sizeof(double)}); - } else if ("ARRAY" == entries[1]) { - arrSize = std::stoi(entries[2]); - if (0 == arrSize) - throw std::runtime_error("Got empty array"); - arrIdx = 0; + throw std::runtime_error("Invalid data"); + + prcoessSingleRecord(recValue, recLabel); + break; } - /// TODO: Handle more types - } + default: + throw std::runtime_error("Unknown record type"); + } + } // for line return results; } };