Skip to content

Commit

Permalink
Add function to Toy DSL examples (buddy-compiler#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
linuxlonelyeagle authored Aug 1, 2022
1 parent 14170de commit 2f3677c
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 35 deletions.
23 changes: 19 additions & 4 deletions examples/ToyDSL/Toy.g4
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ module

expression
: Number
| tensorLiteral
| tensorLiteral
{
tensorDataBuffer.clear();
}
| identifierExpr
| expression Mul expression
| expression Add expression
;

identifierExpr
Expand Down Expand Up @@ -62,14 +67,16 @@ funDefine
: prototype block
;

prototype
: Def Identifier ParentheseOpen declList ParentheseClose
prototype returns [std::string idName]
: Def Identifier ParentheseOpen declList? ParentheseClose
{
$idName = $Identifier.text;
}
;

declList
: Identifier
| Identifier Comma declList
|
;

block
Expand Down Expand Up @@ -144,6 +151,14 @@ Comma
: ','
;

Add
: '+'
;

Mul
: '*'
;

WS
: [ \r\n\t] -> skip
;
Expand Down
15 changes: 15 additions & 0 deletions examples/ToyDSL/function.toy
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def fun(a) {
print(a + a);
print(a * a);
print(transpose(a));
return a * transpose(a);
}

def main() {
var a = [1, 2, 3, 4];
var b = fun(a);
print(b);
var c<2,2> = [1, 2, 3, 4];
var d = fun(c);
print(d);
}
181 changes: 150 additions & 31 deletions examples/ToyDSL/include/MLIRToyVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,9 @@ class MLIRToyVisitor : public ToyBaseVisitor {
/// The builder helps create MLIR operations when traversing the AST.
mlir::OpBuilder builder;
/// The Symbol Table
/// [TODO][LOW] make the symbol table support function prototype.
llvm::ScopedHashTable<llvm::StringRef, mlir::Value> symbolTable;
/// Return Status Flag
/// The syntax supports omitting the return expression.
bool returnFlag = false;
llvm::ScopedHashTable<llvm::StringRef, int> funSymbolTable;
llvm::StringMap<mlir::toy::FuncOp> functionMap;
// Register the filename for the string attribute in MLIR location object.
std::string fileName;

Expand All @@ -71,6 +69,15 @@ class MLIRToyVisitor : public ToyBaseVisitor {
symbolTable.insert(var, value);
return mlir::success();
}
// Declear a function in the current module
/// - Check the parameter number of the function.
mlir::LogicalResult funcDeclare(llvm::StringRef functionName,
int argsNumber) {
if (funSymbolTable.count(functionName))
return mlir::failure();
funSymbolTable.insert(functionName, argsNumber);
return mlir::success();
}

/// Location
/// Get the MLIR location object with the current line and row of the toy
Expand All @@ -89,24 +96,54 @@ class MLIRToyVisitor : public ToyBaseVisitor {

// Get the tensor value from the tensor literal node.
std::any getTensor(ToyParser::TensorLiteralContext *ctx) {
// [TODO][HIGH] find a better way to define the `dims`.
std::vector<int64_t> dims;
// get dimensions.
dims.push_back(ctx->Comma().size() + 1);
if (ctx->tensorLiteral(0)->tensorLiteral(0)) {
dims.push_back(ctx->tensorLiteral(0)->Comma().size() + 1);
ToyParser::TensorLiteralContext *list = ctx->tensorLiteral(0);
while (list) {
dims.push_back(list->Comma().size() + 1);
if (list->tensorLiteral(0) && list->tensorLiteral(0)->Comma().size())
list = list->tensorLiteral(0);
else
break;
}
}
mlir::Type elementType = builder.getF64Type();
auto type = getType(dims);
mlir::Type type = getType(dims);
auto dataType = mlir::RankedTensorType::get(dims, elementType);
auto dataAttribute =
mlir::DenseElementsAttr dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(ctx->data));
auto loaction =
mlir::Location loaction =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
mlir::Value value =
builder.create<mlir::toy::ConstantOp>(loaction, type, dataAttribute);
return value;
}
// Module Visitor
// - Visitor all function asts to get the number of function parameter.
// - Visitor childrens.
virtual std::any visitModule(ToyParser::ModuleContext *ctx) override {
llvm::ScopedHashTableScope<llvm::StringRef, int> protoTypeSymbolTable(
funSymbolTable);
for (auto &function : ctx->funDefine()) {
ToyParser::PrototypeContext *protoType = function->prototype();
std::string functionName = protoType->Identifier()->toString();
int declNumber = 0;
if (protoType->declList()) {
ToyParser::DeclListContext *list = protoType->declList();
while (list) {
declNumber++;
if (list->declList())
list = list->declList();
else
break;
}
}
funcDeclare(function->prototype()->idName, declNumber);
}
return visitChildren(ctx);
}

/// Function Definition Visitor
/// - Register the function name, argument list, and return value into the
Expand All @@ -115,35 +152,70 @@ class MLIRToyVisitor : public ToyBaseVisitor {
/// - Visit fucntion block.
/// - Process the return operation.
virtual std::any visitFunDefine(ToyParser::FunDefineContext *ctx) override {
returnFlag = false;
// [TODO] make the function support argument list and return value.
llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> varScope(
symbolTable);
builder.setInsertionPointToEnd(theModule.getBody());
// Visit function prototype.
visit(ctx->prototype());
mlir::toy::FuncOp function =
std::any_cast<mlir::toy::FuncOp>(visit(ctx->prototype()));
mlir::Block &entryBlock = function.front();

// Set the insertion point in the builder to the beginning of the function
// body, it will be used throughout the codegen to create operations in this
// function.
builder.setInsertionPointToStart(&entryBlock);

std::vector<std::string> args;
if (ctx->prototype()->declList()) {
ToyParser::DeclListContext *list = ctx->prototype()->declList();
while (list->Identifier()) {
args.push_back(list->Identifier()->toString());
if (list->declList())
list = list->declList();
else
break;
}
}
// Declare all the function arguments in the symbol table.
llvm::ArrayRef<std::string> protoArgs = args;
for (auto value : llvm::zip(protoArgs, entryBlock.getArguments())) {
declare(std::get<0>(value), std::get<1>(value));
}

// Visit fucntion block.
visit(ctx->block());
// Check the return status.
// If there is no return expression at the end of the function, it will
// generate a return operation automatically.
if (!returnFlag) {
auto location =
mlir::toy::ReturnOp returnOp;
if (!entryBlock.empty())
returnOp = llvm::dyn_cast<mlir::toy::ReturnOp>(entryBlock.back());
if (!returnOp) {
mlir::Location location =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
builder.create<mlir::toy::ReturnOp>(location,
llvm::ArrayRef<mlir::Value>());
builder.create<mlir::toy::ReturnOp>(location);
} else if (returnOp.hasOperand()) {
// Otherwise, if this return operation has an operand then add a result to
// the function.
std::vector<int64_t> shape;
function.setType(builder.getFunctionType(
function.getFunctionType().getInputs(), getType(shape)));
}
// If this function isn't main, then set the visibility to private.
if (ctx->prototype()->Identifier()->toString() != "main")
function.setPrivate();
functionMap.insert({function.getName(), function});
return 0;
}

/// Prototype Visitor
virtual std::any visitPrototype(ToyParser::PrototypeContext *ctx) override {
mlir::Location location =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
auto varNumber = 0;
int varNumber = 0;
// Get the number of arguments.
if (ctx->declList()) {
auto list = ctx->declList();
ToyParser::DeclListContext *list = ctx->declList();
while (list->Identifier()) {
varNumber++;
if (list->declList())
Expand All @@ -152,26 +224,37 @@ class MLIRToyVisitor : public ToyBaseVisitor {
break;
}
}

llvm::SmallVector<mlir::Type, 4> argTypes(
varNumber, mlir::UnrankedTensorType::get(builder.getF64Type()));
auto funType = builder.getFunctionType(argTypes, llvm::None);
mlir::FunctionType funType = builder.getFunctionType(argTypes, llvm::None);
auto func = builder.create<mlir::toy::FuncOp>(
location, ctx->Identifier()->toString(), funType);
mlir::Block &entryblock = func.front();
builder.setInsertionPointToStart(&entryblock);
return 0;
return func;
}

/// Expression Visitor
/// - If the expression is tensor literal, return the tensor MLIR value.
/// - If the expression is function call or variable, visit the identifier.
/// - If the expression is add expression or mul expression return add or mul
/// value.
virtual std::any visitExpression(ToyParser::ExpressionContext *ctx) override {
mlir::Value value;
if (ctx->tensorLiteral()) {
return getTensor(ctx->tensorLiteral());
} else if (ctx->identifierExpr()) {
return visit(ctx->identifierExpr());
} else if (ctx->Add() || ctx->Mul()) {
// Derive the operation name from the binary operator. At the moment we
// only support '+' and '*'.
mlir::Value lhs = std::any_cast<mlir::Value>(visit(ctx->expression(0)));
mlir::Value rhs = std::any_cast<mlir::Value>(visit(ctx->expression(1)));
mlir::Location loaction =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
if (ctx->Add())
value = builder.create<mlir::toy::AddOp>(loaction, lhs, rhs);
else
value = builder.create<mlir::toy::MulOp>(loaction, lhs, rhs);
return value;
}
return value;
}
Expand All @@ -188,7 +271,7 @@ class MLIRToyVisitor : public ToyBaseVisitor {
std::vector<int64_t> v0;
auto v1 = ctx->type()->Number();
for (auto i : v1) {
auto j = atoi(i->toString().c_str());
int64_t j = atoi(i->toString().c_str());
v0.push_back(j);
}
mlir::Location location =
Expand All @@ -208,28 +291,65 @@ class MLIRToyVisitor : public ToyBaseVisitor {
virtual std::any
visitIdentifierExpr(ToyParser::IdentifierExprContext *ctx) override {
mlir::Value value;
int argsNumber = 0;
mlir::Location location =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
// If the identifier is a function call, visit and register all the
// arguments. [TODO][LOW] add the semantic check (look up the symbol table)
// for the function call.
if (ctx->ParentheseOpen()) {
auto location =
mlir::Location location =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
llvm::SmallVector<mlir::Value, 4> oprands;
for (auto i : ctx->expression()) {
for (ToyParser::ExpressionContext *i : ctx->expression()) {
mlir::Value arg = std::any_cast<mlir::Value>(visit(i));
oprands.push_back(arg);
argsNumber++;
}
// If function call is a built-in operation, create the corresponding
// operation.
if (ctx->Identifier()->toString() == "print") {
auto arg = oprands[0];
if (argsNumber != 1) {
mlir::emitError(location)
<< "mismatch of function parameters 'print'";
return nullptr;
}
mlir::Value arg = oprands[0];
builder.create<mlir::toy::PrintOp>(location, arg);
return 0;
} else if (ctx->Identifier()->toString() == "transpose") {
if (argsNumber != 1) {
mlir::emitError(location)
<< "mlismatch of function parameters 'transpose'";
return nullptr;
}
mlir::Value arg = oprands[0];
value = builder.create<mlir::toy::TransposeOp>(location, arg);
return value;
}
// Otherwise this is a call to a user-defined function. Calls to
// user-defined functions are mapped to a custom call that takes the
// callee name as an attribute.
auto callee = functionMap.find(ctx->Identifier()->toString());
if (callee == functionMap.end()) {
mlir::emitError(location) << "error: no defined function '"
<< ctx->Identifier()->toString() << "'";
return nullptr;
}
int numberdecl = funSymbolTable.lookup(ctx->Identifier()->toString());
if (numberdecl != argsNumber) {
mlir::emitError(location) << "error: mismatch of function parameters '"
<< ctx->Identifier()->toString() << "'";
return nullptr;
}
// If the function call cannot be mapped to the built-in operation, create
// the GenericCallOp.
mlir::toy::FuncOp calledFunc = callee->second;
value = builder.create<mlir::toy::GenericCallOp>(
location, ctx->Identifier()->toString(), oprands);
location, calledFunc.getFunctionType().getResult(0),
mlir::SymbolRefAttr::get(builder.getContext(),
ctx->Identifier()->toString()),
oprands);
return value;
} else {
// If the identifier is a variable, return the MLIR value from the symbol
Expand All @@ -241,12 +361,11 @@ class MLIRToyVisitor : public ToyBaseVisitor {

/// Return Expression Visitor
virtual std::any visitReturnExpr(ToyParser::ReturnExprContext *ctx) override {
returnFlag = true;
auto location =
mlir::Location location =
loc(ctx->start->getLine(), ctx->start->getCharPositionInLine());
mlir::Value expr = nullptr;
if (ctx->expression()) {
expr = std::any_cast<mlir::Value>(ctx->expression());
expr = std::any_cast<mlir::Value>(visit(ctx->expression()));
}
// Generate return operation based on whether the function has the return
// value.
Expand Down
21 changes: 21 additions & 0 deletions examples/ToyDSL/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,24 @@ buddy-toy-constant-translate:

buddy-toy-constant-run:
@${BUDDY_TOY_DSL} ./constant.toy -emit=jit

toyc-function-run:
@${MLIR_TOYC} ./function.toy -emit=jit

buddy-toy-function-ast:
@${BUDDY_TOY_DSL} ./function.toy -emit=ast

buddy-toy-function-mlir:
@${BUDDY_TOY_DSL} ./function.toy -emit=mlir

buddy-toy-function-affine:
@${BUDDY_TOY_DSL} ./function.toy -emit=mlir-affine

buddy-toy-function-llvm:
@${BUDDY_TOY_DSL} ./function.toy -emit=mlir-llvm

buddy-toy-function-translate:
@${BUDDY_TOY_DSL} ./function.toy -emit=llvm

buddy-toy-function-run:
@${BUDDY_TOY_DSL} ./function.toy -emit=jit

0 comments on commit 2f3677c

Please sign in to comment.