Skip to content

Commit

Permalink
* Saving progress
Browse files Browse the repository at this point in the history
Signed-off-by: Pradnya Khalate <pkhalate@nvidia.com>
  • Loading branch information
khalatepradnya committed Feb 2, 2025
1 parent 2dd5ae8 commit c7f2849
Showing 1 changed file with 183 additions and 45 deletions.
228 changes: 183 additions & 45 deletions runtime/common/RecordParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,203 @@ struct OutputRecord {
std::size_t size;
};

enum struct SchemaType { LABELED, ORDERED };
enum struct RecordType { HEADER, METADATA, OUTPUT, START, END };
enum struct OutputType { BOOL, INT, DOUBLE };
enum struct ContainerType { ARRAY, TUPLE };

struct ContainerProcessor {
ContainerType container;
std::size_t count = 0;
std::size_t index = 0;
std::vector<OutputType> types;

void initialize(ContainerType c) {
container = c;
count = 0;
index = 0;
types = {};
}
};

struct RecordParser {
private:
bool labelExpected = false;
SchemaType schema = SchemaType::ORDERED;
RecordType currentRecord;
OutputType currentOutput;
bool isInContainer = false;
ContainerProcessor containerIterator;

template <typename T>
void addPrimitiveRecord(T value) {
results.emplace_back(
OutputRecord{static_cast<void *>(new T(value)), sizeof(T)});
}

OutputType extractPrimitiveType(const std::string &label) {
if ('i' == label[0]) {
auto digits = std::stoi(label.substr(1));
if (1 == digits)
return OutputType::BOOL;
return OutputType::INT;
} else if ('f' == label[0])
return OutputType::DOUBLE;
throw std::runtime_error("Unknown datatype in label");
}

/// Parse string like "array<3 x i32>"
std::pair<std::size_t, OutputType>
extractArrayInfo(const std::string &label) {
auto isArray = label.find("array");
if (isArray == std::string::npos)
throw std::runtime_error("Array label missing keyword");
auto lessThan = label.find('<');
if (lessThan == std::string::npos)
throw std::runtime_error("Array label missing keyword");
auto greaterThan = label.find('>');
if (greaterThan == std::string::npos)
throw std::runtime_error("Array label missing keyword");
auto x = label.find('x');
if (x == std::string::npos)
throw std::runtime_error("Array label missing keyword");
std::size_t arrSize =
std::stoi(label.substr(lessThan + 1, x - lessThan - 2));
OutputType arrType =
extractPrimitiveType(label.substr(x + 2, greaterThan - x - 2));
return std::make_pair(arrSize, arrType);
}

void prcoessSingleRecord(const std::string &recValue,
const std::string &recLabel) {
if ((!recLabel.empty()) &&
(extractPrimitiveType(recLabel) != currentOutput))
throw std::runtime_error("Type mismatch in label");

switch (currentOutput) {
case OutputType::BOOL: {
bool value;
if ("true" == recValue)
value = true;
else if ("false" == recValue)
value = false;
else
throw std::runtime_error("Invalid boolean value");
addPrimitiveRecord(value);
break;
}
case OutputType::INT: {
if (isInContainer) {
if (0 == containerIterator.index) {
int *resArr = new int[containerIterator.count];
results.emplace_back(
OutputRecord{static_cast<void *>(resArr),
sizeof(int) * containerIterator.count});
}
static_cast<int *>(results.back().buffer)[containerIterator.index++] =
std::stoi(recValue);
if (containerIterator.index == containerIterator.count)
isInContainer = false;
} else
addPrimitiveRecord(std::stoi(recValue));
break;
}
case OutputType::DOUBLE: {
addPrimitiveRecord(std::stod(recValue));
break;
}
}
}

public:
std::vector<OutputRecord> results;

std::vector<OutputRecord> parse(const std::string &data) {
std::vector<OutputRecord> results;
std::vector<std::string> lines = cudaq::split(data, '\n');
std::size_t arrSize = 0;
int arrIdx = -1;
if (lines.empty())
return {};

for (auto line : lines) {
std::vector<std::string> entries = cudaq::split(line, '\t');
if (entries.empty())
continue;
if (entries[0] != "OUTPUT")

if ("HEADER" == entries[0])
currentRecord = RecordType::HEADER;
else if ("METADATA" == entries[0])
currentRecord = RecordType::METADATA;
else if ("OUTPUT" == entries[0])
currentRecord = RecordType::OUTPUT;
else if ("START" == entries[0])
currentRecord = RecordType::START;
else if ("END" == entries[0])
currentRecord = RecordType::END;
else
throw std::runtime_error("Invalid data");

/// TODO: Handle labeled records
if ("BOOL" == entries[1]) {
bool value;
if ("true" == entries[2])
value = true;
else if ("false" == entries[2])
value = false;
results.emplace_back(
OutputRecord{static_cast<void *>(new bool(value)), sizeof(bool)});
} else if ("INT" == entries[1]) {
if (0 != arrSize) {
if (0 == arrIdx) {
int *resArr = new int[arrSize];
results.emplace_back(OutputRecord{static_cast<void *>(resArr),
sizeof(int) * arrSize});
}
static_cast<int *>(results.back().buffer)[arrIdx++] =
std::stoi(entries[2]);
if (arrIdx == arrSize) {
arrSize = 0;
arrIdx = -1;
}
switch (currentRecord) {
case RecordType::HEADER: {
if ("schema_name" == entries[1]) {
if ("labeled" == entries[2])
schema = SchemaType::LABELED;
else if ("ordered" == entries[2])
schema = SchemaType::ORDERED;
throw std::runtime_error("Unknown schema type");
}
/// TODO: Check schema version
break;
}
case RecordType::METADATA:
// ignore metadata for now
break;
case RecordType::START:
// indicates start of a shot
break;
case RecordType::END: {
// indicates end of a shot
if (entries.size() < 2)
throw std::runtime_error("Missing shot status");
if ("0" != entries[1])
throw std::runtime_error("Cannot handle unsuccessful shot");
break;
}
case RecordType::OUTPUT: {
if (entries.size() < 3)
throw std::runtime_error("Insufficent data in a record");
if ((schema == SchemaType::LABELED) && (entries.size() != 4))
throw std::runtime_error(
"Unexpected record size for a labeled record");
std::string recType = entries[1];
std::string recValue = entries[2];
std::string recLabel = (entries.size() == 4) ? entries[3] : "";

if ("BOOL" == entries[1])
currentOutput = OutputType::BOOL;
else if ("INT" == entries[1])
currentOutput = OutputType::INT;
else if ("DOUBLE" == entries[1])
currentOutput = OutputType::DOUBLE;
else if ("ARRAY" == entries[1]) {
isInContainer = true;
containerIterator.initialize(ContainerType::ARRAY);
containerIterator.count = std::stoi(recValue);
if (0 == containerIterator.count)
throw std::runtime_error("Got empty array");
} else if ("TUPLE" == entries[1]) {
isInContainer = true;
containerIterator.initialize(ContainerType::TUPLE);
containerIterator.count = std::stoi(recValue);
if (0 == containerIterator.count)
throw std::runtime_error("Got empty tuple");
} else
results.emplace_back(
OutputRecord{static_cast<void *>(new int(std::stoi(entries[2]))),
sizeof(int)});
} else if ("FLOAT" == entries[1]) {
results.emplace_back(
OutputRecord{static_cast<void *>(new int(std::stof(entries[2]))),
sizeof(float)});
} else if ("DOUBLE" == entries[1]) {
results.emplace_back(
OutputRecord{static_cast<void *>(new int(std::stod(entries[2]))),
sizeof(double)});
} else if ("ARRAY" == entries[1]) {
arrSize = std::stoi(entries[2]);
if (0 == arrSize)
throw std::runtime_error("Got empty array");
arrIdx = 0;
throw std::runtime_error("Invalid data");

prcoessSingleRecord(recValue, recLabel);
break;
}
/// TODO: Handle more types
}
default:
throw std::runtime_error("Unknown record type");
}
} // for line
return results;
}
};
Expand Down

0 comments on commit c7f2849

Please sign in to comment.