diff --git a/Dockerfile b/Dockerfile index f2c7f7b787950e5d960a4203b17bee750b3fbdd1..9390111bb7c8fddfa24f7f68b366015e89efa1b3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,9 @@ FROM texlive/texlive -RUN export DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get -y upgrade && apt-get -y install openjdk-11-jdk-headless hunspell hunspell-hu hunspell-en-gb hunspell-en-us +RUN export DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get -y upgrade && apt-get -y install openjdk-11-jdk-headless hunspell hunspell-hu hunspell-en-gb hunspell-en-us python3-pip && pip3 install javalang COPY tikz-uml.sty /usr/local/texlive/2020/texmf-dist/tex/latex/tikz-uml/tikz-uml.sty COPY jsonDoclet.jar /root/jsonDoclet.jar COPY ./plab/out/plab /usr/bin/plab #ADD robinbird/build/distributions/robinbird.tar / RUN mktexlsr +COPY gen_seq_diag.py / +RUN chmod +x /gen_seq_diag.py \ No newline at end of file diff --git a/gen_seq_diag.py b/gen_seq_diag.py new file mode 100755 index 0000000000000000000000000000000000000000..9f9cc33df9b7fa9df7c4c104bb534abee02c7be2 --- /dev/null +++ b/gen_seq_diag.py @@ -0,0 +1,548 @@ +#!/usr/bin/env python3 + +# Imports +import javalang +import javalang.tree as tr +import os +import sys +import math +import traceback + +# types: the main data structure (key: name, val: a Type object) +types = {} +iden = 4 * ' ' + +# walk_dir finds and parses all java files (recursively), and stores them in +# tree +trees = [] + + +def fetch_file(p): + with open(p, 'r', encoding='utf-8') as f: + s = f.read() + tree = javalang.parse.parse(s) + return tree + + +def walk_dir(p): + for root, dirs, files in os.walk(p): + for f in files: + assert(f.endswith('.java')) + fp = os.path.join(root, f) + trees.append(fetch_file(fp)) + +# Stringify an tr.Expression (very limited, intentionally) +def expr_to_str(e): + if type(e) is tr.BinaryOperation: + return (f'{expr_to_str(e.operandl)} ' + + f'{e.operator} {expr_to_str(e.operandr)}') + elif type(e) is tr.MemberReference: + return f'{"".join(e.prefix_operators)}{e.member}' + elif type(e) is tr.Cast: + return expr_to_str(e.expression) # ignore casts + elif type(e) is tr.ReferenceType: + return e.name + elif type(e) is tr.Literal: + return e.value + else: + assert False, str(e) + +def latex_escape(s): + return (s.replace("_", "\\_") + .replace("&", "\\&") + .replace("{", "\\{") + .replace("}", "\\}")) +def assert_latex_safe(*args): + for arg in args: + for c in '_&{}': + if c in arg: + print(f"Error: '{arg}' has invalid latex characters") + assert(False) + +# Context for sequence diagram generation +class Ctx: + def __init__(self, repl): + self.objs = {} # collect any touched objects + self.s = [] # result string + self.fragments = [] # fragment stack + self.vis = [] # visited list + self.repl = repl # type replacement dict + self.args = [] # function argument stack + self.comm_edges = {} + self.comm_label_stack = [] + def append(self, lvl, a): + self.s.append(lvl * iden + a + '\n') + def comment(self, lvl, a): + self.s.append(lvl * iden + f'% {a}\n') + + def push_fragment(self): + self.fragments.append(self.s) + self.s = [] + def has_any(self): + return any([i.strip()[0] != '%' for i in self.s]) + def append_fragment(self, lvl, a): + self.fragments[-1].append(lvl * iden + a + '\n') + def pop_fragment(self): + fragment = self.fragments.pop() + fragment.extend(self.s) + self.s = fragment + + # edge is a tuple with the start and end object names + def add_comm_edge(self, edge, msg): + rev = (edge[1], edge[0]) + if edge in self.comm_edges: + self.comm_edges[edge][0].append(msg) + elif rev in self.comm_edges: + self.comm_edges[rev][1].append(msg) + else: + self.comm_edges[edge] = ([msg],[]) + def get_comm_label(self): + return '.'.join([str(i) for i in self.comm_label_stack]) + def push_comm_label(self): + self.comm_label_stack.append(1) + def inc_comm_label(self): + self.comm_label_stack[-1] += 1 + def pop_comm_label(self): + self.comm_label_stack.pop() + + +# Represents a statement +class Stmt: + def gen_seq(self, ctx, met, obj, lvl): + assert(False) + +# Represents a method invocation +class Call(Stmt): + # args: list of javalang expressions + # ret: Var representing where the returned value goes + # qual: the object the call is being made on (str) + # member: the member function we are calling (str) + def __init__(self, qual, member, args, ret): + self.qual = qual + self.member = member + self.args = args + self.ret = ret + # convert each argument to a sensible str + def clean_args(self, this_ref): + res = [] + for e in self.args: + if type(e) is tr.This: + res.append(this_ref) + else: + res.append(expr_to_str(e)) + return res + + # met: the Method object we are in the body of + # obj: the object the call is being made on (Var) + # lvl: indentation level + def gen_seq(self, ctx, met, obj, lvl): + call = self + var = None # Var we are making the call on + met2 = None # The method we are calling (Method) + if call.qual == '': # self call + var = obj + met2 = met.t.get_method(call.member) + if met2 is None: + ctx.comment(lvl, f"can't find {call.member}"); + return + else: # not self call + # find out what we are calling this method on! + member = met.t.get_member(call.qual) + if member is not None: # is it our own member variable? + var = Var(member.t, f'{obj.name}/{call.qual}') + elif call.qual in met.vars: # is it a local variable? + var = met.vars[call.qual] + elif len(ctx.args) > 0: # is it a function argument? + sargs = ctx.args[-1] + for i,(t,n) in enumerate(met.params): + if n == call.qual: + if (i >= len(sargs)): + break + var = Var(t, sargs[i]) + break + if var is None: + if call.qual in types: # is it a static method of a type? + var = Var(call.qual, f'{call.qual}') + if var is None: + ctx.comment(lvl, f'cant find var {call.qual}.{call.member}') + return + t = var.t + if t in ctx.repl: t = ctx.repl[t] + if t in types: + met2 = types[t].get_method(call.member) + else: + ctx.comment(lvl, + f'cant find type {var.t} ({call.qual}.{call.member})') + return + + # check for recursion + id_ = (met2.t.name, met2.name) + if id_ in ctx.vis: + ctx.comment(lvl, f'recur {call.qual}.{call.member}') + return + ctx.vis.append(id_) + + # track referenced objects + t = var.t + if t in ctx.repl: t = ctx.repl[t] + o = (t, var.name) + if o not in ctx.objs: + ctx.objs[o] = lvl + + edge = (obj.name, var.name) + args = call.clean_args(obj.name) + args_str = latex_escape(', '.join(args)) + + lab = ctx.get_comm_label() + ctx.add_comm_edge(edge, f'{lab}: {call.member}({args_str})') + + ret = '' + if call.ret: ret = call.ret.name + if ret == '': ret = '\\ ' + + assert_latex_safe(edge[0], edge[1]) + ctx.args.append(args) + ctx.append(lvl, '\\begin{umlcall}' + + '[' + + f'op={{{call.member}({args_str})}},' + + f'return={{{ret}}},' + + f'dt=7' + + ']' + + f'{{{edge[0]}}}{{{edge[1]}}}') + met2.gen_seq(var, lvl + 1, ctx) + ctx.append(lvl, '\\end{umlcall}') + + if call.ret is not None: + ctx.inc_comm_label() + lab = ctx.get_comm_label() + ctx.add_comm_edge((edge[1], edge[0]), f'{lab}: {call.ret.name}') + + ctx.args.pop() + ctx.vis.pop() + +class CreateCall(Stmt): + def __init__(self, e, ret): + self.e = e + self.ret = ret + def gen_seq(self, ctx, met, obj, lvl): + if self.ret is not None and self.ret.t in types: + assert_latex_safe(obj.name, self.ret.name) + ctx.append(lvl, '\\umlcreatecall' + + f'[dt=7,class={{{self.ret.t}}}]' + + f'{{{obj.name}}}{{{self.ret.name}}}') + o = (self.ret.t, self.ret.name) + ctx.objs[o] = -1 + + type_name = self.e.type.name + met2 = types[type_name].get_method(type_name) + if met2 is not None: + met2.gen_seq(obj, lvl + 1, ctx) + else: + ctx.comment(lvl, f'cant find {type_name} constructor'); + + ctx.inc_comm_label() + lab = ctx.get_comm_label() + ctx.add_comm_edge((obj.name, self.ret.name), f'{lab}: <<create>>') + +# Represents a variable declaration +class Var: + # t: str, the type of this variable + # name: str, the name of this variable + def __init__(self, t, name): + self.t = t + self.name = name + +# Represents a control flow structure +class Control(Stmt): + def __init__(self): + pass + +# Represents an if statement +class If(Control): + # body_if, body_else: list of Stmt objects + def __init__(self, cond): + self.cond = cond + self.body_if = [] + self.body_else = [] + + def gen_seq(self, ctx, met, obj, lvl): + cond = expr_to_str(self.cond) + + ctx.push_fragment() + for stmt in self.body_if: + stmt.gen_seq(ctx, met, obj, lvl + 1) + ctx.inc_comm_label() + + pre = '' if ctx.has_any() else '% ' + ctx.append_fragment(lvl, f'{pre}\\begin{{umlfragment}}' + + f'[type=alt, label={{{latex_escape(cond)}}}, inner xsep=20]') + ctx.pop_fragment() + if pre != '': + ctx.append(lvl, '% \\end{umlfragment}') + return + + ctx.push_fragment() + for stmt in self.body_else: + stmt.gen_seq(ctx, met, obj, lvl + 1) + ctx.inc_comm_label() + pre = '' if ctx.has_any() else '% ' + ctx.append_fragment(lvl, f'{pre}\\umlfpart[else]') + ctx.pop_fragment() + + ctx.append(lvl, '\\end{umlfragment}') + +# Represents a for loop +class For(Control): + def __init__(self, control): + self.control = control + self.body = [] + def gen_seq(self, ctx, met, obj, lvl): + c = self.control + if type(c) is tr.EnhancedForControl: + var = c.var.declarators[0].name + control = f'for {var} in {expr_to_str(c.iterable)}' + elif type(c) is tr.ForControl: + control = c.init.declarators[0].name + else: + assert(False) + + ctx.push_fragment() + for stmt in self.body: + stmt.gen_seq(ctx, met, obj, lvl + 1) + ctx.inc_comm_label() + pre = '' if ctx.has_any() else '% ' + ctx.append_fragment(lvl, f'{pre}\\begin{{umlfragment}}' + + f'[type=loop, label={{{control}}}, inner xsep=20]') + ctx.pop_fragment() + ctx.append(lvl, f'{pre}\\end{{umlfragment}}') + +# Represents a method declaration +class Method: + # name: name of the method + # vars: local variables declared inside (name, Var) + # body: list of Stmt objects + # t: containing class + # doc: doc string + def __init__(self, name): + self.name = name + self.vars = {} + self.params = [] + self.body = [] + self.t = None + self.doc = None + self.annotations = None + + def excl(self): + """Returns True if this methods calls should be exclued from sequence + diagrams""" + for ann in self.annotations: + if ann.name == "NoSequenceDiagramGeneration": return True + return False + + # generate sequence for method + # obj is a Var + def gen_seq(self, obj, lvl, ctx): + # generate for each statement + if not self.excl(): + ctx.push_comm_label() + for i, stmt in enumerate(self.body): + stmt.gen_seq(ctx, self, obj, lvl) + ctx.inc_comm_label() + ctx.pop_comm_label() + +# Represents a class declaration +class Type: + def __init__(self, name): + self.name = name + self.vars = {} + self.methods = {} + self.extends = None + + # Get method by name (recursively looks at super classes) + def get_method(self, name): + if name in self.methods: + return self.methods[name] + if self.extends is not None: + if self.extends in types: + return types[self.extends].get_method(name) + return None + # Get member by name (recursively looks at super classes) + def get_member(self, name): + if name in self.vars: + return self.vars[name] + if self.extends is not None: + if self.extends in types: + return types[self.extends].get_member(name) + +# Processes a parsed java file, and puts the results in the types dict +def dump_tree(t): + # dumps call Stmt objects in the expression e into the list body + # ret: a Var object + def dump_expr_calls(e, body, ret): + if type(e) is tr.MethodInvocation: + body.append(Call(e.qualifier, e.member, e.arguments, ret)) + elif type(e) is tr.ClassCreator: + body.append(CreateCall(e, ret)) + + # dump Stmt objects from the statement s into the list body + # dump variable declarations into the dictionary v + def dump_stmt_calls(s, body, v): + if s is None: + pass + elif type(s) is tr.IfStatement: + stmt = If(s.condition) + dump_stmt_calls(s.then_statement, stmt.body_if, v) + dump_stmt_calls(s.else_statement, stmt.body_else, v) + body.append(stmt) + elif type(s) is tr.WhileStatement or type(s) is tr.DoStatement: + dump_expr_calls(s.condition, body, None) + dump_stmt_calls(s.body, body, v) + elif type(s) is tr.ForStatement: + stmt = For(s.control) + dump_stmt_calls(s.body, stmt.body, v) + if type(s.control) is tr.EnhancedForControl: + dump_stmt_calls(s.control.var, stmt.body, v) # dump var decl + body.append(stmt) + elif type(s) is tr.ReturnStatement or type(s) is tr.ThrowStatement: + dump_expr_calls(s.expression, body, None) + # TODO: SynchronizedStatement + elif type(s) is tr.TryStatement: + for st in s.block: + dump_stmt_calls(st, body, v) + # TODO: catches, finally_block + elif type(s) is tr.SwitchStatement: + dump_expr_calls(s.expression, body, None) + for case in s.cases: + for st in case.statements: + dump_stmt_calls(st, body, v) + elif type(s) is tr.BlockStatement: + for st in s.statements: + dump_stmt_calls(st, body, v) + elif type(s) is tr.StatementExpression: + dump_expr_calls(s.expression, body, None) + elif issubclass(type(s), tr.VariableDeclaration): + for vd in s.declarators: + v[vd.name] = Var(s.type.name, vd.name) + dump_expr_calls(vd.initializer, body, v[vd.name]) + + # dump method into Method object met + def dump_method(method, met): + if method.body is None: return + for s in method.body: + dump_stmt_calls(s, met.body, met.vars) + + for type_ in t.types: + if type(type_) is not tr.ClassDeclaration: + continue + t = Type(type_.name) + types[type_.name] = t + if type_.extends is not None: + t.extends = type_.extends.name + for decl in type_.fields: + for vd in decl.declarators: + t.vars[vd.name] = Var(decl.type.name, vd.name) + for method in type_.body: + if isinstance(method, tr.MethodDeclaration) or isinstance(method, + tr.ConstructorDeclaration): + met = Method(method.name) + met.t = t + met.doc = method.documentation + met.annotations = method.annotations + t.methods[method.name] = met + for param in method.parameters: + tname = param.type.name + vname = param.name + met.params.append((tname, vname)) + dump_method(method, met) + +ctx = None +def main(): + global ctx + # fetch & process java files + walk_dir(sys.argv[1]) + for tree in trees: + dump_tree(tree) + + # parse arguments + method = None + repl = {} + # gen_seq_diag.py <project root> diag:Type:method:obj; a -> b; ... + def parse_args(): + nonlocal method + params = list([i.strip() for i in sys.argv[2].split(';')]) + method = params[0].split(':') + params = params[1:] + for i in params: + a = i.split('->') + repl[a[0].strip()] = a[1].strip() + parse_args() + + spec_diag = method[0] + spec_type = method[1] + spec_method = method[2] + spec_obj = method[3] + + ctx = Ctx(repl) + (types[spec_type] + .methods[spec_method] + .gen_seq(Var(spec_type, spec_obj), 1, ctx) + ) + ctx.objs[(spec_type, spec_obj)] = 0 + sorted_objs = sorted(ctx.objs.items(), key = lambda i: i[1]) + + if spec_diag == 'seq': + print('\\begin{tikzpicture}\n\\begin{umlseqdiag}') + print('\\umlobject[class=Actor]{act}') + for (t, n), lvl in sorted_objs: + if lvl < 0: continue + # if t in ctx.repl: t = ctx.repl[t] + if n in types: # static object hack + print(f'\\umlbasicobject[fill=green!20]{{{t}}}') + else: + print(f'\\umlobject[class={t}]{{{n}}}') + + args_str = latex_escape(', '.join( + [p[1] for p in types[spec_type].methods[spec_method].params])) + print(f'\\begin{{umlcall}}[op={{{spec_method}({args_str})}},' + + f'return={{\\ }}]{{act}}{{{spec_obj}}}') + print(''.join(ctx.s)) + print('\\end{umlcall}') + print('\\end{umlseqdiag}\n\\end{tikzpicture}') + elif spec_diag == 'comm': + print('\\begin{tikzpicture}') + print(r""" + \newcommand{\umlcomm}[4]{ + \draw (#1) -- node [above,align=center,sloped]{#3} node + [below,align=center,sloped]{#4} (#2); + } + """) + l = len(sorted_objs) + pos = {} + for i, ((t, n), lvl) in enumerate(sorted_objs): + ang = (0.005 + 0.5 + i / l) * math.pi * 2 + r = l * 1.5 + x = math.cos(ang) * r + y = math.sin(ang) * r + pos[n] = (x, y) + if t in ctx.repl: t = ctx.repl[t] + print(f'\\draw ({x},{y}) node[draw] ({n}) {{{n}:{t}}};') + for (a, b), (msgs, msgs_rev) in ctx.comm_edges.items(): + if a == b: continue + msg = r'\\ '.join(msgs) + msg_rev = r'\\ '.join(msgs_rev) + if pos[a][0] > pos[b][0]: msg,msg_rev = msg_rev,msg + if len(msg) > 0: + msg += r'\\ $\rightarrow$' + if len(msg_rev) > 0: + msg_rev = r'$\leftarrow$\\ ' + msg_rev + print(f'\\umlcomm{{{a}}}{{{b}}}{{{msg}}}{{{msg_rev}}}') + print('\\end{tikzpicture}') + +try: + main() +except: + s = traceback.format_exc() + print("ERROR!\n") + print('%' + s.replace('\n', '\n%')) + print('% ctx dump:') + print('%' + ''.join(ctx.s).replace('\n', '\n%')) diff --git a/plab/Makefile b/plab/Makefile index 5222e679154eef6ac30b38788ec8e800dfaf7bbb..d177cf43ab8c60da5f0ac0d5d536b5dd52caec5b 100644 --- a/plab/Makefile +++ b/plab/Makefile @@ -2,5 +2,7 @@ build: mkdir -p out go build -v -race -ldflags "-linkmode external -extldflags '-static'" -a -o out/plab +pwd = $(shell pwd) + podman: - podman run --rm -it -v $PWD:/build golang:alpine sh -c "apk add make build-base; cd /build; make" \ No newline at end of file + podman run --rm -it -v ${pwd}:/build golang:alpine sh -c "apk add make build-base; cd /build; make" \ No newline at end of file