Skip to content

Commit

Permalink
Merge pull request #3 from GranLte/lukezhuz/MBBDataset_import_script
Browse files Browse the repository at this point in the history
Lukezhuz/mbb dataset import script
  • Loading branch information
9Tempest authored Nov 27, 2023
2 parents 95840e5 + 8062d15 commit 8d644f6
Show file tree
Hide file tree
Showing 16 changed files with 139,081 additions and 33 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
/requirements.txt

/compile_commands.json

.vscode
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ FROM ubuntu:22.04
RUN apt-get update && apt-get install -y clang python3 python3-pip git curl
ARG bazelisk_version=1.17.0
RUN curl -L https://github.com/bazelbuild/bazelisk/releases/download/v${bazelisk_version}/bazelisk-linux-amd64 > /usr/bin/bazelisk && chmod +x /usr/bin/bazelisk && ln -s /usr/bin/bazelisk /usr/bin/bazel
WORKDIR /gematria
WORKDIR /granlte
COPY . .
RUN pip3 install -r requirements.in

RUN pip3 install -r requirements.in
2 changes: 2 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ http_archive(
sha256 = LLVM_SHA256,
strip_prefix = "llvm-project-" + LLVM_COMMIT,
urls = ["https://github.com/llvm/llvm-project/archive/{commit}.zip".format(commit = LLVM_COMMIT)],
patches = ["//:mir_parser.patch"], # Hack to make the MIR parser work
patch_args = ["-p1"],
)

load(
Expand Down
3 changes: 3 additions & 0 deletions gematria/basic_block/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ struct Instruction {
uint64_t address = 0;
// The size of the instruction.
size_t size = 0;

// The instruction is valid or not
bool is_valid = true;
};

std::ostream& operator<<(std::ostream& os, const Instruction& instruction);
Expand Down
135 changes: 108 additions & 27 deletions gematria/datasets/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,18 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/Error.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG

#ifdef DEBUG
#define LOG(X) \
llvm::errs() << X << "\n"
#else
#define LOG(X)
#endif

namespace gematria {
namespace {
Expand All @@ -66,7 +73,8 @@ BHiveImporter::BHiveImporter(const Canonicalizer* canonicalizer)
mc_inst_printer_(target_machine_.getTarget().createMCInstPrinter(
target_machine_.getTargetTriple(), kDefaultSyntax,
*target_machine_.getMCAsmInfo(), *target_machine_.getMCInstrInfo(),
*target_machine_.getMCRegisterInfo())) {}
*target_machine_.getMCRegisterInfo())),
MMI_(dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_)) {}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMachineCode(
llvm::ArrayRef<uint8_t> machine_code, uint64_t base_address /*= 0*/) {
Expand Down Expand Up @@ -149,50 +157,127 @@ absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseBHiveCsvLine(
return proto;
}

absl::StatusOr<BasicBlockProto> BHiveImporter::BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address /*= 0*/) {
BasicBlockProto basic_block_proto;
// convert MBB_name to llvm::StringRef
llvm::StringRef MBB_name_ref(MBB_name.data(), MBB_name.size());

// lookup the MBB in the map, if not, return error
if (name_to_mbb_.find(MBB_name_ref) == name_to_mbb_.end()) {
return absl::InvalidArgumentError(
absl::StrCat("Could not find MBB with name ", MBB_name));
}

llvm::MachineBasicBlock* MBB = name_to_mbb_[MBB_name_ref];
LOG("MBB is " << *MBB);
for (llvm::MachineInstr& MI : *MBB){
// if MI is a control instruction(ret,branch,jmp), skip it
if (MI.isInlineAsm() || MI.isTerminator() || MI.isEHLabel()) {
continue;
}

// Assert MI cannot be a CALL instruction
if(MI.isCall()){
LOG("MI is a CALL instruction, abort this BB " << MI);
return absl::InvalidArgumentError(
absl::StrCat("Cannot handle CALL instruction "));
}
auto I = canonicalizer_.InstructionFromMachineInstr(MI);
if (!I.is_valid) {
LOG("MI is not valid, skipping it " << MI);
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MachineInstr "));
}
*basic_block_proto.add_canonicalized_instructions() = ProtoFromInstruction(I);
}
return basic_block_proto;
}

absl::StatusOr<BasicBlockWithThroughputProto> BHiveImporter::ParseMIRCsvLine(
std::string_view source_name, std::string_view line,
size_t BB_name_index, size_t throughput_column_index,
double throughput_scaling /*= 1.0*/, uint64_t base_address /*= 0*/) {
const absl::InlinedVector<std::string_view, 2> columns =
absl::StrSplit(line, ',');
const int min_required_num_columns =
std::max(BB_name_index, throughput_column_index) + 1;
if (columns.size() < min_required_num_columns) {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected `line` to have at least %d columns, found %d: %s",
min_required_num_columns, columns.size(), line));
}
if (BB_name_index == throughput_column_index) {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected BB name column and throughput column indices to be "
"different, but were both %d: %s",
BB_name_index, line));
}
const std::string_view BB_unique_name =
columns[BB_name_index];
const std::string_view throughput_str = columns[throughput_column_index];

BasicBlockWithThroughputProto proto;

absl::StatusOr<BasicBlockProto> block_proto_or_status =
BasicBlockProtoFromMBBName(BB_unique_name, base_address);
if (!block_proto_or_status.ok()) return block_proto_or_status.status();
*proto.mutable_basic_block() = std::move(block_proto_or_status).value();

double throughput_cycles = 0.0;
if (!absl::SimpleAtod(throughput_str, &throughput_cycles)) {
return absl::InvalidArgumentError(
absl::StrCat("Could not parse throughput value ", throughput_str));
}

ThroughputWithSourceProto& throughput = *proto.add_inverse_throughputs();
throughput.set_source(source_name);
throughput.add_inverse_throughput_cycles(throughput_cycles *
throughput_scaling);
LOG(proto.DebugString());

return proto;
}

absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name){
// clear previous loaded module
name_to_mbb_.clear();
if (mir_module_){
for (llvm::Function &F : mir_module_->functions()) {
MMI_.deleteMachineFunctionFor(F);
}
}

// create MIR Parser and read all MBB to the map based on their unique name
llvm::LLVMContext context;
llvm::SMDiagnostic diag;

// Set attributes on functions as loaded from MIR from command line arguments.
// auto setMIRFunctionAttributes = [&CPUStr, &FeaturesStr](Function &F) {
// llvm::codegen::setFunctionAttributes(CPUStr, FeaturesStr, F);
// };

std::unique_ptr<llvm::MIRParser> mir_parser = llvm::createMIRParserFromFile(file_name, diag, context);
if (!mir_parser) {
diag.print("test ", llvm::WithColor::error(llvm::errs(), "test"));
mir_parser_ = llvm::createMIRParserFromFile(file_name, diag, llvm_context_);
if (!mir_parser_) {
return absl::InvalidArgumentError(
absl::StrCat("Could not create MIR parser for file ", file_name));
}

// Parse the LLVM IR module (if any)
std::unique_ptr<llvm::Module> mir_module = mir_parser->parseIRModule();
if (!mir_module) {
mir_module_ = mir_parser_->parseIRModule();
if (!mir_module_) {
// Handle error
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MIR module for file ", file_name));
}

// Prepare MachineModuleInfo
auto *llvmTargetMachine = dynamic_cast<const llvm::LLVMTargetMachine*>(&target_machine_);
if (llvmTargetMachine == nullptr) {
return absl::InvalidArgumentError(
absl::StrCat("Could not cast target machine for file ", file_name));
}
llvm::MachineModuleInfo MMI(llvmTargetMachine);
MMI_.initialize();

// Parse the MachineFunctions and add them to MMI
if (mir_parser->parseMachineFunctions(*mir_module, MMI)) {
if (mir_parser_->parseMachineFunctions(*mir_module_, MMI_)) {
// Handle error
return absl::InvalidArgumentError(
absl::StrCat("Could not parse MachineFunctions for file ", file_name));
}

// Now iterate over the MachineFunctions and their MachineBasicBlocks
for (auto &F : *mir_module) {
for (auto &F : *mir_module_) {
if (F.isDeclaration()) continue;
llvm::MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
llvm::MachineFunction &MF = MMI_.getOrCreateMachineFunction(F);
for (auto &MBB : MF) {
// assert name is unique
if (name_to_mbb_.find(MBB.getName()) != name_to_mbb_.end()) {
Expand All @@ -201,10 +286,6 @@ absl::StatusOr<bool> BHiveImporter::LoadMIRModule(std::string_view file_name){
} else {
name_to_mbb_[MBB.getName()] = &MBB;
}
// // Pretty print the machine block with its name
// llvm::outs() << "MachineBasicBlock: " << MBB.getName() << "\n";
// MBB.print(llvm::outs());
// llvm::outs() << "\n";
}
}

Expand Down
22 changes: 22 additions & 0 deletions gematria/datasets/bhive_importer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/CodeGen/MIRParser/MIRParser.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/CodeGen/MachineModuleInfo.h"

namespace gematria {

Expand Down Expand Up @@ -63,6 +64,9 @@ class BHiveImporter {
// corresponds to a three-byte sequence {0xAA, 0xBB, 0x11}.
absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMachineCodeHex(
std::string_view machine_code_hex, uint64_t base_address = 0);

absl::StatusOr<BasicBlockProto> BasicBlockProtoFromMBBName(
std::string_view MBB_name, uint64_t base_address = 0);

// Parses a basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{machine_code},{throughput}" where {machine_code}
Expand All @@ -80,13 +84,31 @@ class BHiveImporter {
absl::StatusOr<bool> LoadMIRModule(
std::string_view file_name
);

// Parses a MIR basic block with throughput from one BHive CSV line. Expects that
// the line has the format "{BB_name},{throughput}" where {machine_code}
// is the machine code of the basic block in the hex format accepted by
// ParseBasicBlockFromMachineCodeHex(), and {throughput} is the inverse
// throughput of the basic block in text format.
// Optionally applies `throughput_scaling` to the throughput value, and uses
// `base_address` as the address of the first instruction in the basic block.
// NOTE: YOU MUST RUN LoadMIRModule before calling this function
absl::StatusOr<BasicBlockWithThroughputProto> ParseMIRCsvLine(
std::string_view source_name, std::string_view line,
size_t BB_name_index, size_t throughput_column_index,
double throughput_scaling = 1.0, uint64_t base_address = 0);

private:
const Canonicalizer& canonicalizer_;
const llvm::TargetMachine& target_machine_;
std::unique_ptr<llvm::MCContext> context_;
std::unique_ptr<llvm::MCDisassembler> disassembler_;
std::unique_ptr<llvm::MCInstPrinter> mc_inst_printer_;
llvm::DenseMap<llvm::StringRef, llvm::MachineBasicBlock*> name_to_mbb_;
llvm::LLVMContext llvm_context_;
std::unique_ptr<llvm::Module> mir_module_;
llvm::MachineModuleInfo MMI_;
std::unique_ptr<llvm::MIRParser> mir_parser_;
};

} // namespace gematria
Expand Down
15 changes: 13 additions & 2 deletions gematria/datasets/bhive_importer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,19 @@ TEST_F(BHiveImporterTest, NonStandardColumns) {
})pb")));
}

TEST_F(BHiveImporterTest, LoadMIRModule) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("/u9/z277zhu/research/gematria/sample_dataset/data.mir"),
TEST_F(BHiveImporterTest, MIRDatasetBasicTest) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("sample_dataset/data.mir"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_13,2.37", 2,
3, kScaling),
IsOk());
}

TEST_F(BHiveImporterTest, MIRDatasetTest2) {
EXPECT_THAT(x86_bhive_importer_->LoadMIRModule("sample_dataset/native_test.mir"),
IsOk());
EXPECT_THAT(x86_bhive_importer_->ParseMIRCsvLine(kSourceName, "a,b,BB_299,2.37", 2,
3, kScaling),
IsOk());
}

Expand Down
11 changes: 11 additions & 0 deletions gematria/datasets/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,14 @@ gematria_py_binary(
"//gematria/utils/python:pybind11_abseil_status",
],
)

gematria_py_binary(
name = "import_from_mir",
srcs = ["import_from_mir.py"],
deps = [
":bhive_importer",
"//gematria/llvm/python:canonicalizer",
"//gematria/llvm/python:llvm_architecture_support",
"//gematria/utils/python:pybind11_abseil_status",
],
)
20 changes: 19 additions & 1 deletion gematria/datasets/python/bhive_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,25 @@ PYBIND11_MODULE(bhive_importer, m) {
Raises:
StatusNotOk: When parsing the CSV line or extracting data from the
machine code fails.)");
machine code fails.)")
.def( //
"LoadMIRModule",
&BHiveImporter::LoadMIRModule, py::arg("file_name"),
R"(Load a mir module given a mir file
)")
.def(
"ParseMIRCsvLine",
&BHiveImporter::ParseMIRCsvLine,
py::arg("source_name"), py::arg("line"),py::arg("BB_name_index"), py::arg("throughput_column_index"),
py::arg("throughput_scaling") = 1.0, py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockWithThroughputProto from a MIR CSV line.)"
)
.def(
"BasicBlockProtoFromMBBName",
&BHiveImporter::BasicBlockProtoFromMBBName,
py::arg("MBB_name"), py::arg("base_address") = uint64_t{0},
R"(Creates a BasicBlockProto from a MIR CSV line.)"
);
}

} // namespace gematria
Loading

0 comments on commit 8d644f6

Please sign in to comment.