403 lines
12 KiB
Python
403 lines
12 KiB
Python
from parser.MPVisitor import MPVisitor
|
|
from parser.MPParser import MPParser
|
|
from functools import reduce
|
|
|
|
# * is not a good use case
|
|
from utils.AST import (
|
|
IntType,
|
|
FloatType,
|
|
BoolType,
|
|
StringType,
|
|
ArrayType,
|
|
# VoidType,
|
|
Program,
|
|
# Decl,
|
|
VarDecl,
|
|
FuncDecl,
|
|
# Stmt,
|
|
Assign,
|
|
If,
|
|
While,
|
|
For,
|
|
Break,
|
|
Continue,
|
|
Return,
|
|
With,
|
|
CallStmt,
|
|
# Expr,
|
|
BinaryOp,
|
|
UnaryOp,
|
|
CallExpr,
|
|
# LHS,
|
|
Id,
|
|
ArrayCell,
|
|
# Literal,
|
|
IntLiteral,
|
|
FloatLiteral,
|
|
StringLiteral,
|
|
BooleanLiteral
|
|
)
|
|
|
|
|
|
def flatten(listOflist):
|
|
return reduce(lambda x, item: x + item, listOflist, [])
|
|
|
|
|
|
class ASTGeneration(MPVisitor):
|
|
def visitProgram(self, ctx: MPParser.ProgramContext):
|
|
"""
|
|
return Program(list of Decl)
|
|
where Decl:
|
|
+ VarDecl ==> var_decl
|
|
+ FuncDecl ==> func_decl
|
|
+ ProcDecl ==> proc_decl
|
|
"""
|
|
return Program(self.visit(ctx.manydecl()))
|
|
|
|
def visitManydecl(self, ctx: MPParser.ManydeclContext):
|
|
"""
|
|
return list of decl expanded
|
|
"""
|
|
decl = self.visit(ctx.decl())
|
|
if ctx.manydecl():
|
|
return decl + self.visit(ctx.manydecl())
|
|
else:
|
|
return decl
|
|
|
|
def visitDecl(self, ctx: MPParser.DeclContext):
|
|
"""
|
|
return either
|
|
+ var_decl
|
|
+ func_decl
|
|
+ proc_decl
|
|
"""
|
|
decl = self.visit(ctx.getChild(0))
|
|
if ctx.var_decl():
|
|
return decl
|
|
return [decl]
|
|
|
|
def visitVar_decl(self, ctx: MPParser.Var_declContext):
|
|
"""
|
|
return varlist
|
|
"""
|
|
return self.visit(ctx.varlist())
|
|
|
|
def visitVarlist(self, ctx: MPParser.VarlistContext):
|
|
"""
|
|
return list of VarDecl(iden, mptype)
|
|
"""
|
|
var = self.visit(ctx.var())
|
|
if ctx.varlist():
|
|
return var + self.visit(ctx.varlist())
|
|
else:
|
|
return var
|
|
|
|
def visitVar(self, ctx: MPParser.VarContext):
|
|
"""
|
|
return list of VarDecl(iden, mptype)
|
|
"""
|
|
mptype = self.visit(ctx.mptype())
|
|
idenlist = self.visit(ctx.idenlist())
|
|
|
|
# apply VarDecl(x, mptype) to idenlist where x is item in idenlist
|
|
def compose(f, arg):
|
|
def h(x):
|
|
return f(x, arg)
|
|
return h
|
|
hoo = compose(lambda x, y: VarDecl(x, y), mptype)
|
|
return list(map(hoo, idenlist))
|
|
|
|
def visitIdenlist(self, ctx: MPParser.IdenlistContext):
|
|
"""
|
|
return list of iden
|
|
"""
|
|
ident = Id(ctx.IDENT().getText())
|
|
if ctx.idenlist():
|
|
return [ident] + self.visit(ctx.idenlist())
|
|
else:
|
|
return [ident]
|
|
|
|
def visitMptype(self, ctx: MPParser.MptypeContext):
|
|
return self.visit(ctx.getChild(0))
|
|
|
|
def visitPrimitive_type(self, ctx: MPParser.Primitive_typeContext):
|
|
if ctx.INTEGER():
|
|
return IntType()
|
|
elif ctx.BOOLEAN():
|
|
return BoolType()
|
|
elif ctx.REAL():
|
|
return FloatType()
|
|
elif ctx.STRING():
|
|
return StringType()
|
|
|
|
def visitCompound_type(self, ctx: MPParser.Compound_typeContext):
|
|
"""
|
|
return ArrayType(low, high, type)
|
|
"""
|
|
low, high = self.visit(ctx.array_value())
|
|
pri_type = self.visit(ctx.primitive_type())
|
|
return ArrayType(low, high, pri_type)
|
|
|
|
def visitArray_value(self, ctx: MPParser.Array_valueContext):
|
|
"""
|
|
return low, high
|
|
"""
|
|
low = int(ctx.NUM_INT(0).getText())
|
|
high = int(ctx.NUM_INT(1).getText())
|
|
sub = len(ctx.MINUS())
|
|
if sub == 0:
|
|
pass
|
|
elif sub == 2:
|
|
low = -low
|
|
high = -high
|
|
elif ctx.getChild(1).getText() == '-':
|
|
low = -low
|
|
else:
|
|
high = -high
|
|
return low, high
|
|
|
|
def visitFunc_decl(self, ctx: MPParser.Func_declContext):
|
|
ident = Id(ctx.IDENT().getText())
|
|
param_list = self.visit(ctx.param_list()) if ctx.param_list() else []
|
|
mptype = self.visit(ctx.mptype())
|
|
var_decl = flatten(list(map(self.visit, ctx.var_decl())))
|
|
compound_statement = self.visit(ctx.compound_statement())
|
|
return FuncDecl(
|
|
ident,
|
|
param_list,
|
|
var_decl,
|
|
compound_statement,
|
|
mptype
|
|
)
|
|
|
|
def visitParam_list(self, ctx: MPParser.Param_listContext):
|
|
"""
|
|
return list of VarDecl(iden, mptype)
|
|
"""
|
|
var = self.visit(ctx.var())
|
|
# var is a list of VarDecl
|
|
if ctx.param_list():
|
|
# concat
|
|
return var + self.visit(ctx.param_list())
|
|
else:
|
|
# plain list return
|
|
return var
|
|
return
|
|
|
|
def visitProc_decl(self, ctx: MPParser.Proc_declContext):
|
|
ident = Id(ctx.IDENT().getText())
|
|
param_list = self.visit(ctx.param_list()) if ctx.param_list() else []
|
|
var_decl = flatten(list(map(self.visit, ctx.var_decl())))
|
|
compound_statement = self.visit(ctx.compound_statement())
|
|
return FuncDecl(
|
|
ident,
|
|
param_list,
|
|
var_decl,
|
|
compound_statement
|
|
)
|
|
|
|
def visitExpression(self, ctx: MPParser.ExpressionContext):
|
|
if ctx.getChildCount() == 1:
|
|
return self.visit(ctx.expression_lv1())
|
|
if ctx.AND():
|
|
op = "andthen"
|
|
else:
|
|
op = "orelse"
|
|
return BinaryOp(
|
|
op,
|
|
self.visit(ctx.expression()),
|
|
self.visit(ctx.expression_lv1())
|
|
)
|
|
|
|
def visitExpression_lv1(self, ctx: MPParser.Expression_lv1Context):
|
|
if ctx.getChildCount() == 1:
|
|
return self.visit(ctx.expression_lv2(0))
|
|
return BinaryOp(
|
|
ctx.getChild(1).getText(),
|
|
self.visit(ctx.expression_lv2(0)),
|
|
self.visit(ctx.expression_lv2(1))
|
|
)
|
|
|
|
def visitExpression_lv2(self, ctx: MPParser.Expression_lv2Context):
|
|
if ctx.getChildCount() == 1:
|
|
return self.visit(ctx.expression_lv3())
|
|
return BinaryOp(
|
|
ctx.getChild(1).getText(),
|
|
self.visit(ctx.expression_lv2()),
|
|
self.visit(ctx.expression_lv3())
|
|
)
|
|
|
|
def visitExpression_lv3(self, ctx: MPParser.Expression_lv3Context):
|
|
if ctx.getChildCount() == 1:
|
|
return self.visit(ctx.expression_lv4())
|
|
return BinaryOp(
|
|
ctx.getChild(1).getText(),
|
|
self.visit(ctx.expression_lv3()),
|
|
self.visit(ctx.expression_lv4())
|
|
)
|
|
|
|
def visitExpression_lv4(self, ctx: MPParser.Expression_lv4Context):
|
|
if ctx.getChildCount() == 1:
|
|
return self.visit(ctx.index_expression())
|
|
return UnaryOp(
|
|
ctx.getChild(0).getText(),
|
|
self.visit(ctx.expression_lv4())
|
|
)
|
|
|
|
def visitIndex_expression(self, ctx: MPParser.Index_expressionContext):
|
|
if ctx.getChildCount() == 1:
|
|
return self.visit(ctx.factor())
|
|
return ArrayCell(
|
|
self.visit(ctx.index_expression()),
|
|
self.visit(ctx.expression())
|
|
)
|
|
|
|
def visitInvocation_expression(
|
|
self, ctx: MPParser.Invocation_expressionContext):
|
|
if ctx.call_param():
|
|
return self.visit(ctx.call_param())
|
|
return []
|
|
|
|
def visitFactor(self, ctx: MPParser.FactorContext):
|
|
if ctx.expression():
|
|
return self.visit(ctx.expression())
|
|
elif ctx.invocation_expression():
|
|
return CallExpr(Id(ctx.IDENT().getText()),
|
|
self.visit(ctx.invocation_expression()))
|
|
elif ctx.literal():
|
|
return self.visit(ctx.literal())
|
|
elif ctx.IDENT():
|
|
return Id(ctx.IDENT().getText())
|
|
elif ctx.STRING_LITERAL():
|
|
return StringLiteral(ctx.STRING_LITERAL().getText())
|
|
return
|
|
|
|
def visitStatement(self, ctx: MPParser.StatementContext):
|
|
return self.visit(ctx.getChild(0))
|
|
|
|
def visitStructured_statement(
|
|
self, ctx: MPParser.Structured_statementContext):
|
|
if ctx.compound_statement():
|
|
return self.visit(ctx.getChild(0))
|
|
else:
|
|
return [self.visit(ctx.getChild(0))]
|
|
|
|
def visitNormal_statement(self, ctx: MPParser.Normal_statementContext):
|
|
if ctx.assignment_statement():
|
|
return self.visit(ctx.getChild(0))
|
|
else:
|
|
return [self.visit(ctx.getChild(0))]
|
|
|
|
def visitAssignment_statement(
|
|
self, ctx: MPParser.Assignment_statementContext):
|
|
"""
|
|
return list of Assign(lhs, exp)
|
|
"""
|
|
expression = self.visit(ctx.expression())
|
|
assignment_lhs_list = self.visit(ctx.assignment_lhs_list())
|
|
|
|
rhs_list = assignment_lhs_list[1:] + [expression]
|
|
|
|
# def compose(arg):
|
|
# def h(x):
|
|
# return Assign(x, arg)
|
|
# return h
|
|
# hoo = list(map(lambda x: compose(x), rhs_list))
|
|
return [Assign(lhs, rhs)
|
|
for lhs, rhs in zip(assignment_lhs_list, rhs_list)][::-1]
|
|
|
|
def visitAssignment_lhs_list(
|
|
self, ctx: MPParser.Assignment_lhs_listContext):
|
|
"""
|
|
return list of lhs
|
|
"""
|
|
lhs = self.visit(ctx.lhs())
|
|
if ctx.assignment_lhs_list():
|
|
return [lhs] + self.visit(ctx.assignment_lhs_list())
|
|
else:
|
|
return [lhs]
|
|
|
|
def visitLhs(self, ctx: MPParser.LhsContext):
|
|
"""
|
|
return IDENT or index_pression
|
|
"""
|
|
if ctx.IDENT():
|
|
return Id(ctx.IDENT().getText())
|
|
else:
|
|
return self.visit(ctx.index_expression())
|
|
|
|
def visitIf_statement(self, ctx: MPParser.If_statementContext):
|
|
expression = self.visit(ctx.expression())
|
|
if ctx.ELSE():
|
|
then_statement = self.visit(ctx.statement(0))
|
|
else_statement = self.visit(ctx.statement(1))
|
|
return If(expression, then_statement, else_statement)
|
|
else:
|
|
then_statement = self.visit(ctx.statement(0))
|
|
return If(expression, then_statement)
|
|
|
|
def visitWhile_statement(self, ctx: MPParser.While_statementContext):
|
|
return While(self.visit(ctx.expression()), self.visit(ctx.statement()))
|
|
|
|
def visitFor_statement(self, ctx: MPParser.For_statementContext):
|
|
up = True if ctx.TO() else False
|
|
return For(
|
|
Id(ctx.IDENT().getText()),
|
|
self.visit(ctx.expression(0)),
|
|
self.visit(ctx.expression(1)),
|
|
up,
|
|
self.visit(ctx.statement()))
|
|
|
|
def visitBreak_statement(self, ctx: MPParser.Break_statementContext):
|
|
return Break()
|
|
|
|
def visitContinue_statement(self, ctx: MPParser.Continue_statementContext):
|
|
return Continue()
|
|
|
|
def visitReturn_statement(self, ctx: MPParser.Return_statementContext):
|
|
if ctx.expression():
|
|
return Return(self.visit(ctx.expression()))
|
|
else:
|
|
return Return()
|
|
|
|
def visitCompound_statement(self, ctx: MPParser.Compound_statementContext):
|
|
if ctx.statement():
|
|
return flatten(list(map(self.visit, ctx.statement())))
|
|
else:
|
|
return []
|
|
|
|
def visitWith_statement(self, ctx: MPParser.With_statementContext):
|
|
return With(self.visit(ctx.varlist()), self.visit(ctx.statement()))
|
|
|
|
def visitCall_statement(self, ctx: MPParser.Call_statementContext):
|
|
param = self.visit(ctx.call_param()) if ctx.call_param() else []
|
|
return CallStmt(Id(ctx.IDENT().getText()), param)
|
|
|
|
def visitCall_param(self, ctx: MPParser.Call_paramContext):
|
|
expression = self.visit(ctx.expression())
|
|
if ctx.call_param():
|
|
return [expression] + self.visit(ctx.call_param())
|
|
else:
|
|
return [expression]
|
|
|
|
def visitEmpty(self, ctx: MPParser.EmptyContext):
|
|
return
|
|
|
|
def visitLiteral(self, ctx: MPParser.LiteralContext):
|
|
if ctx.number():
|
|
return self.visit(ctx.number())
|
|
elif ctx.BOOL_LIT():
|
|
if ctx.BOOL_LIT().getText().lower() == 'true':
|
|
return BooleanLiteral(True)
|
|
else:
|
|
return BooleanLiteral(False)
|
|
elif ctx.STRING_LITERAL():
|
|
return StringLiteral(ctx.STRING_LITERAL().getText())
|
|
return
|
|
|
|
def visitNumber(self, ctx: MPParser.NumberContext):
|
|
if ctx.NUM_INT():
|
|
return IntLiteral(int(ctx.NUM_INT().getText()))
|
|
elif ctx.NUM_REAL():
|
|
return FloatLiteral(float(ctx.NUM_REAL().getText()))
|