MPC/astgen/ASTGeneration.py
Nguyễn Anh Khoa bed5a47d55 init
2018-12-04 02:05:06 +07:00

403 lines
12 KiB
Python

from parser.MPVisitor import MPVisitor
from parser.MPParser import MPParser
from functools import reduce
# * is not a good use case
from utils.AST import (
IntType,
FloatType,
BoolType,
StringType,
ArrayType,
# VoidType,
Program,
# Decl,
VarDecl,
FuncDecl,
# Stmt,
Assign,
If,
While,
For,
Break,
Continue,
Return,
With,
CallStmt,
# Expr,
BinaryOp,
UnaryOp,
CallExpr,
# LHS,
Id,
ArrayCell,
# Literal,
IntLiteral,
FloatLiteral,
StringLiteral,
BooleanLiteral
)
def flatten(listOflist):
return reduce(lambda x, item: x + item, listOflist, [])
class ASTGeneration(MPVisitor):
def visitProgram(self, ctx: MPParser.ProgramContext):
"""
return Program(list of Decl)
where Decl:
+ VarDecl ==> var_decl
+ FuncDecl ==> func_decl
+ ProcDecl ==> proc_decl
"""
return Program(self.visit(ctx.manydecl()))
def visitManydecl(self, ctx: MPParser.ManydeclContext):
"""
return list of decl expanded
"""
decl = self.visit(ctx.decl())
if ctx.manydecl():
return decl + self.visit(ctx.manydecl())
else:
return decl
def visitDecl(self, ctx: MPParser.DeclContext):
"""
return either
+ var_decl
+ func_decl
+ proc_decl
"""
decl = self.visit(ctx.getChild(0))
if ctx.var_decl():
return decl
return [decl]
def visitVar_decl(self, ctx: MPParser.Var_declContext):
"""
return varlist
"""
return self.visit(ctx.varlist())
def visitVarlist(self, ctx: MPParser.VarlistContext):
"""
return list of VarDecl(iden, mptype)
"""
var = self.visit(ctx.var())
if ctx.varlist():
return var + self.visit(ctx.varlist())
else:
return var
def visitVar(self, ctx: MPParser.VarContext):
"""
return list of VarDecl(iden, mptype)
"""
mptype = self.visit(ctx.mptype())
idenlist = self.visit(ctx.idenlist())
# apply VarDecl(x, mptype) to idenlist where x is item in idenlist
def compose(f, arg):
def h(x):
return f(x, arg)
return h
hoo = compose(lambda x, y: VarDecl(x, y), mptype)
return list(map(hoo, idenlist))
def visitIdenlist(self, ctx: MPParser.IdenlistContext):
"""
return list of iden
"""
ident = Id(ctx.IDENT().getText())
if ctx.idenlist():
return [ident] + self.visit(ctx.idenlist())
else:
return [ident]
def visitMptype(self, ctx: MPParser.MptypeContext):
return self.visit(ctx.getChild(0))
def visitPrimitive_type(self, ctx: MPParser.Primitive_typeContext):
if ctx.INTEGER():
return IntType()
elif ctx.BOOLEAN():
return BoolType()
elif ctx.REAL():
return FloatType()
elif ctx.STRING():
return StringType()
def visitCompound_type(self, ctx: MPParser.Compound_typeContext):
"""
return ArrayType(low, high, type)
"""
low, high = self.visit(ctx.array_value())
pri_type = self.visit(ctx.primitive_type())
return ArrayType(low, high, pri_type)
def visitArray_value(self, ctx: MPParser.Array_valueContext):
"""
return low, high
"""
low = int(ctx.NUM_INT(0).getText())
high = int(ctx.NUM_INT(1).getText())
sub = len(ctx.MINUS())
if sub == 0:
pass
elif sub == 2:
low = -low
high = -high
elif ctx.getChild(1).getText() == '-':
low = -low
else:
high = -high
return low, high
def visitFunc_decl(self, ctx: MPParser.Func_declContext):
ident = Id(ctx.IDENT().getText())
param_list = self.visit(ctx.param_list()) if ctx.param_list() else []
mptype = self.visit(ctx.mptype())
var_decl = flatten(list(map(self.visit, ctx.var_decl())))
compound_statement = self.visit(ctx.compound_statement())
return FuncDecl(
ident,
param_list,
var_decl,
compound_statement,
mptype
)
def visitParam_list(self, ctx: MPParser.Param_listContext):
"""
return list of VarDecl(iden, mptype)
"""
var = self.visit(ctx.var())
# var is a list of VarDecl
if ctx.param_list():
# concat
return var + self.visit(ctx.param_list())
else:
# plain list return
return var
return
def visitProc_decl(self, ctx: MPParser.Proc_declContext):
ident = Id(ctx.IDENT().getText())
param_list = self.visit(ctx.param_list()) if ctx.param_list() else []
var_decl = flatten(list(map(self.visit, ctx.var_decl())))
compound_statement = self.visit(ctx.compound_statement())
return FuncDecl(
ident,
param_list,
var_decl,
compound_statement
)
def visitExpression(self, ctx: MPParser.ExpressionContext):
if ctx.getChildCount() == 1:
return self.visit(ctx.expression_lv1())
if ctx.AND():
op = "andthen"
else:
op = "orelse"
return BinaryOp(
op,
self.visit(ctx.expression()),
self.visit(ctx.expression_lv1())
)
def visitExpression_lv1(self, ctx: MPParser.Expression_lv1Context):
if ctx.getChildCount() == 1:
return self.visit(ctx.expression_lv2(0))
return BinaryOp(
ctx.getChild(1).getText(),
self.visit(ctx.expression_lv2(0)),
self.visit(ctx.expression_lv2(1))
)
def visitExpression_lv2(self, ctx: MPParser.Expression_lv2Context):
if ctx.getChildCount() == 1:
return self.visit(ctx.expression_lv3())
return BinaryOp(
ctx.getChild(1).getText(),
self.visit(ctx.expression_lv2()),
self.visit(ctx.expression_lv3())
)
def visitExpression_lv3(self, ctx: MPParser.Expression_lv3Context):
if ctx.getChildCount() == 1:
return self.visit(ctx.expression_lv4())
return BinaryOp(
ctx.getChild(1).getText(),
self.visit(ctx.expression_lv3()),
self.visit(ctx.expression_lv4())
)
def visitExpression_lv4(self, ctx: MPParser.Expression_lv4Context):
if ctx.getChildCount() == 1:
return self.visit(ctx.index_expression())
return UnaryOp(
ctx.getChild(0).getText(),
self.visit(ctx.expression_lv4())
)
def visitIndex_expression(self, ctx: MPParser.Index_expressionContext):
if ctx.getChildCount() == 1:
return self.visit(ctx.factor())
return ArrayCell(
self.visit(ctx.index_expression()),
self.visit(ctx.expression())
)
def visitInvocation_expression(
self, ctx: MPParser.Invocation_expressionContext):
if ctx.call_param():
return self.visit(ctx.call_param())
return []
def visitFactor(self, ctx: MPParser.FactorContext):
if ctx.expression():
return self.visit(ctx.expression())
elif ctx.invocation_expression():
return CallExpr(Id(ctx.IDENT().getText()),
self.visit(ctx.invocation_expression()))
elif ctx.literal():
return self.visit(ctx.literal())
elif ctx.IDENT():
return Id(ctx.IDENT().getText())
elif ctx.STRING_LITERAL():
return StringLiteral(ctx.STRING_LITERAL().getText())
return
def visitStatement(self, ctx: MPParser.StatementContext):
return self.visit(ctx.getChild(0))
def visitStructured_statement(
self, ctx: MPParser.Structured_statementContext):
if ctx.compound_statement():
return self.visit(ctx.getChild(0))
else:
return [self.visit(ctx.getChild(0))]
def visitNormal_statement(self, ctx: MPParser.Normal_statementContext):
if ctx.assignment_statement():
return self.visit(ctx.getChild(0))
else:
return [self.visit(ctx.getChild(0))]
def visitAssignment_statement(
self, ctx: MPParser.Assignment_statementContext):
"""
return list of Assign(lhs, exp)
"""
expression = self.visit(ctx.expression())
assignment_lhs_list = self.visit(ctx.assignment_lhs_list())
rhs_list = assignment_lhs_list[1:] + [expression]
# def compose(arg):
# def h(x):
# return Assign(x, arg)
# return h
# hoo = list(map(lambda x: compose(x), rhs_list))
return [Assign(lhs, rhs)
for lhs, rhs in zip(assignment_lhs_list, rhs_list)][::-1]
def visitAssignment_lhs_list(
self, ctx: MPParser.Assignment_lhs_listContext):
"""
return list of lhs
"""
lhs = self.visit(ctx.lhs())
if ctx.assignment_lhs_list():
return [lhs] + self.visit(ctx.assignment_lhs_list())
else:
return [lhs]
def visitLhs(self, ctx: MPParser.LhsContext):
"""
return IDENT or index_pression
"""
if ctx.IDENT():
return Id(ctx.IDENT().getText())
else:
return self.visit(ctx.index_expression())
def visitIf_statement(self, ctx: MPParser.If_statementContext):
expression = self.visit(ctx.expression())
if ctx.ELSE():
then_statement = self.visit(ctx.statement(0))
else_statement = self.visit(ctx.statement(1))
return If(expression, then_statement, else_statement)
else:
then_statement = self.visit(ctx.statement(0))
return If(expression, then_statement)
def visitWhile_statement(self, ctx: MPParser.While_statementContext):
return While(self.visit(ctx.expression()), self.visit(ctx.statement()))
def visitFor_statement(self, ctx: MPParser.For_statementContext):
up = True if ctx.TO() else False
return For(
Id(ctx.IDENT().getText()),
self.visit(ctx.expression(0)),
self.visit(ctx.expression(1)),
up,
self.visit(ctx.statement()))
def visitBreak_statement(self, ctx: MPParser.Break_statementContext):
return Break()
def visitContinue_statement(self, ctx: MPParser.Continue_statementContext):
return Continue()
def visitReturn_statement(self, ctx: MPParser.Return_statementContext):
if ctx.expression():
return Return(self.visit(ctx.expression()))
else:
return Return()
def visitCompound_statement(self, ctx: MPParser.Compound_statementContext):
if ctx.statement():
return flatten(list(map(self.visit, ctx.statement())))
else:
return []
def visitWith_statement(self, ctx: MPParser.With_statementContext):
return With(self.visit(ctx.varlist()), self.visit(ctx.statement()))
def visitCall_statement(self, ctx: MPParser.Call_statementContext):
param = self.visit(ctx.call_param()) if ctx.call_param() else []
return CallStmt(Id(ctx.IDENT().getText()), param)
def visitCall_param(self, ctx: MPParser.Call_paramContext):
expression = self.visit(ctx.expression())
if ctx.call_param():
return [expression] + self.visit(ctx.call_param())
else:
return [expression]
def visitEmpty(self, ctx: MPParser.EmptyContext):
return
def visitLiteral(self, ctx: MPParser.LiteralContext):
if ctx.number():
return self.visit(ctx.number())
elif ctx.BOOL_LIT():
if ctx.BOOL_LIT().getText().lower() == 'true':
return BooleanLiteral(True)
else:
return BooleanLiteral(False)
elif ctx.STRING_LITERAL():
return StringLiteral(ctx.STRING_LITERAL().getText())
return
def visitNumber(self, ctx: MPParser.NumberContext):
if ctx.NUM_INT():
return IntLiteral(int(ctx.NUM_INT().getText()))
elif ctx.NUM_REAL():
return FloatLiteral(float(ctx.NUM_REAL().getText()))