diff --git a/Dockerfile b/Dockerfile
index f2c7f7b787950e5d960a4203b17bee750b3fbdd1..9390111bb7c8fddfa24f7f68b366015e89efa1b3 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,7 +1,9 @@
 FROM texlive/texlive
-RUN export DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get -y upgrade && apt-get -y install openjdk-11-jdk-headless hunspell hunspell-hu hunspell-en-gb hunspell-en-us
+RUN export DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get -y upgrade && apt-get -y install openjdk-11-jdk-headless hunspell hunspell-hu hunspell-en-gb hunspell-en-us python3-pip && pip3 install javalang
 COPY tikz-uml.sty /usr/local/texlive/2020/texmf-dist/tex/latex/tikz-uml/tikz-uml.sty
 COPY jsonDoclet.jar /root/jsonDoclet.jar
 COPY ./plab/out/plab /usr/bin/plab
 #ADD robinbird/build/distributions/robinbird.tar /
 RUN mktexlsr
+COPY gen_seq_diag.py /
+RUN chmod +x /gen_seq_diag.py
\ No newline at end of file
diff --git a/gen_seq_diag.py b/gen_seq_diag.py
new file mode 100755
index 0000000000000000000000000000000000000000..9f9cc33df9b7fa9df7c4c104bb534abee02c7be2
--- /dev/null
+++ b/gen_seq_diag.py
@@ -0,0 +1,548 @@
+#!/usr/bin/env python3
+
+# Imports
+import javalang
+import javalang.tree as tr
+import os
+import sys
+import math
+import traceback
+
+# types: the main data structure (key: name, val: a Type object)
+types = {}
+iden = 4 * ' '
+
+# walk_dir finds and parses all java files (recursively), and stores them in
+# tree
+trees = []
+
+
+def fetch_file(p):
+    with open(p, 'r', encoding='utf-8') as f:
+        s = f.read()
+    tree = javalang.parse.parse(s)
+    return tree
+
+
+def walk_dir(p):
+    for root, dirs, files in os.walk(p):
+        for f in files:
+            assert(f.endswith('.java'))
+            fp = os.path.join(root, f)
+            trees.append(fetch_file(fp))
+
+# Stringify an tr.Expression (very limited, intentionally)
+def expr_to_str(e):
+    if type(e) is tr.BinaryOperation:
+        return (f'{expr_to_str(e.operandl)} ' +
+            f'{e.operator} {expr_to_str(e.operandr)}')
+    elif type(e) is tr.MemberReference:
+        return f'{"".join(e.prefix_operators)}{e.member}'
+    elif type(e) is tr.Cast:
+        return expr_to_str(e.expression) # ignore casts
+    elif type(e) is tr.ReferenceType:
+        return e.name
+    elif type(e) is tr.Literal:
+        return e.value
+    else:
+        assert False, str(e)
+
+def latex_escape(s):
+    return (s.replace("_", "\\_")
+        .replace("&", "\\&")
+        .replace("{", "\\{")
+        .replace("}", "\\}"))
+def assert_latex_safe(*args):
+    for arg in args:
+        for c in '_&{}':
+            if c in arg:
+                print(f"Error: '{arg}' has invalid latex characters")
+                assert(False)
+
+# Context for sequence diagram generation
+class Ctx:
+    def __init__(self, repl):
+        self.objs = {} # collect any touched objects
+        self.s = [] # result string
+        self.fragments = [] # fragment stack
+        self.vis = [] # visited list
+        self.repl = repl # type replacement dict
+        self.args = [] # function argument stack
+        self.comm_edges = {}
+        self.comm_label_stack = []
+    def append(self, lvl, a):
+        self.s.append(lvl * iden + a + '\n')
+    def comment(self, lvl, a):
+        self.s.append(lvl * iden + f'% {a}\n')
+
+    def push_fragment(self):
+        self.fragments.append(self.s)
+        self.s = []
+    def has_any(self):
+        return any([i.strip()[0] != '%' for i in self.s])
+    def append_fragment(self, lvl, a):
+        self.fragments[-1].append(lvl * iden + a + '\n')
+    def pop_fragment(self):
+        fragment = self.fragments.pop()
+        fragment.extend(self.s)
+        self.s = fragment
+
+    # edge is a tuple with the start and end object names
+    def add_comm_edge(self, edge, msg):
+        rev = (edge[1], edge[0])
+        if edge in self.comm_edges:
+            self.comm_edges[edge][0].append(msg)
+        elif rev in self.comm_edges:
+            self.comm_edges[rev][1].append(msg)
+        else:
+            self.comm_edges[edge] = ([msg],[])
+    def get_comm_label(self):
+        return '.'.join([str(i) for i in self.comm_label_stack])
+    def push_comm_label(self):
+        self.comm_label_stack.append(1)
+    def inc_comm_label(self):
+        self.comm_label_stack[-1] += 1
+    def pop_comm_label(self):
+        self.comm_label_stack.pop()
+
+
+# Represents a statement
+class Stmt:
+    def gen_seq(self, ctx, met, obj, lvl):
+        assert(False)
+
+# Represents a method invocation
+class Call(Stmt):
+    # args: list of javalang expressions
+    # ret: Var representing where the returned value goes
+    # qual: the object the call is being made on (str)
+    # member: the member function we are calling (str)
+    def __init__(self, qual, member, args, ret):
+        self.qual = qual
+        self.member = member
+        self.args = args
+        self.ret = ret
+    # convert each argument to a sensible str
+    def clean_args(self, this_ref):
+        res = []
+        for e in self.args:
+            if type(e) is tr.This:
+                res.append(this_ref)
+            else:
+                res.append(expr_to_str(e))
+        return res
+
+    # met: the Method object we are in the body of
+    # obj: the object the call is being made on (Var)
+    # lvl: indentation level
+    def gen_seq(self, ctx, met, obj, lvl):
+        call = self
+        var = None # Var we are making the call on
+        met2 = None # The method we are calling (Method)
+        if call.qual == '': # self call
+            var = obj
+            met2 = met.t.get_method(call.member)
+            if met2 is None:
+                ctx.comment(lvl, f"can't find {call.member}");
+                return
+        else: # not self call
+            # find out what we are calling this method on!
+            member = met.t.get_member(call.qual)
+            if member is not None: # is it our own member variable?
+                var = Var(member.t, f'{obj.name}/{call.qual}')
+            elif call.qual in met.vars: # is it a local variable?
+                var = met.vars[call.qual]
+            elif len(ctx.args) > 0: # is it a function argument?
+                sargs = ctx.args[-1]
+                for i,(t,n) in enumerate(met.params):
+                    if n == call.qual:
+                        if (i >= len(sargs)):
+                            break
+                        var = Var(t, sargs[i])
+                        break
+            if var is None:
+                if call.qual in types: # is it a static method of a type?
+                    var = Var(call.qual, f'{call.qual}')
+            if var is None:
+                ctx.comment(lvl, f'cant find var {call.qual}.{call.member}')
+                return
+            t = var.t
+            if t in ctx.repl: t = ctx.repl[t]
+            if t in types:
+                met2 = types[t].get_method(call.member)
+            else:
+                ctx.comment(lvl,
+                    f'cant find type {var.t} ({call.qual}.{call.member})')
+                return
+
+        # check for recursion
+        id_ = (met2.t.name, met2.name)
+        if id_ in ctx.vis:
+            ctx.comment(lvl, f'recur {call.qual}.{call.member}')
+            return
+        ctx.vis.append(id_)
+
+        # track referenced objects
+        t = var.t
+        if t in ctx.repl: t = ctx.repl[t]
+        o = (t, var.name)
+        if o not in ctx.objs:
+            ctx.objs[o] = lvl
+
+        edge = (obj.name, var.name)
+        args = call.clean_args(obj.name)
+        args_str = latex_escape(', '.join(args))
+
+        lab = ctx.get_comm_label()
+        ctx.add_comm_edge(edge, f'{lab}: {call.member}({args_str})')
+
+        ret = ''
+        if call.ret: ret = call.ret.name
+        if ret == '': ret = '\\ '
+
+        assert_latex_safe(edge[0], edge[1])
+        ctx.args.append(args)
+        ctx.append(lvl, '\\begin{umlcall}' +
+                '[' +
+                    f'op={{{call.member}({args_str})}},' +
+                    f'return={{{ret}}},' +
+                    f'dt=7' +
+                ']' +
+                f'{{{edge[0]}}}{{{edge[1]}}}')
+        met2.gen_seq(var, lvl + 1, ctx)
+        ctx.append(lvl, '\\end{umlcall}')
+
+        if call.ret is not None:
+            ctx.inc_comm_label()
+            lab = ctx.get_comm_label()
+            ctx.add_comm_edge((edge[1], edge[0]), f'{lab}: {call.ret.name}')
+
+        ctx.args.pop()
+        ctx.vis.pop()
+
+class CreateCall(Stmt):
+    def __init__(self, e, ret):
+        self.e = e
+        self.ret = ret
+    def gen_seq(self, ctx, met, obj, lvl):
+        if self.ret is not None and self.ret.t in types:
+            assert_latex_safe(obj.name, self.ret.name)
+            ctx.append(lvl, '\\umlcreatecall' +
+                    f'[dt=7,class={{{self.ret.t}}}]' +
+                    f'{{{obj.name}}}{{{self.ret.name}}}')
+            o = (self.ret.t, self.ret.name)
+            ctx.objs[o] = -1
+
+            type_name = self.e.type.name
+            met2 = types[type_name].get_method(type_name)
+            if met2 is not None:
+                met2.gen_seq(obj, lvl + 1, ctx)
+            else:
+                ctx.comment(lvl, f'cant find {type_name} constructor');
+
+            ctx.inc_comm_label()
+            lab = ctx.get_comm_label()
+            ctx.add_comm_edge((obj.name, self.ret.name), f'{lab}: <<create>>')
+
+# Represents a variable declaration
+class Var:
+    # t: str, the type of this variable
+    # name: str, the name of this variable
+    def __init__(self, t, name):
+        self.t = t
+        self.name = name
+
+# Represents a control flow structure
+class Control(Stmt):
+    def __init__(self):
+        pass
+
+# Represents an if statement
+class If(Control):
+    # body_if, body_else: list of Stmt objects
+    def __init__(self, cond):
+        self.cond = cond
+        self.body_if = []
+        self.body_else = []
+
+    def gen_seq(self, ctx, met, obj, lvl):
+        cond = expr_to_str(self.cond)
+
+        ctx.push_fragment()
+        for stmt in self.body_if:
+            stmt.gen_seq(ctx, met, obj, lvl + 1)
+            ctx.inc_comm_label()
+
+        pre = '' if ctx.has_any() else '% '
+        ctx.append_fragment(lvl, f'{pre}\\begin{{umlfragment}}' +
+                f'[type=alt, label={{{latex_escape(cond)}}}, inner xsep=20]')
+        ctx.pop_fragment()
+        if pre != '':
+            ctx.append(lvl, '% \\end{umlfragment}')
+            return
+
+        ctx.push_fragment()
+        for stmt in self.body_else:
+            stmt.gen_seq(ctx, met, obj, lvl + 1)
+            ctx.inc_comm_label()
+        pre = '' if ctx.has_any() else '% '
+        ctx.append_fragment(lvl, f'{pre}\\umlfpart[else]')
+        ctx.pop_fragment()
+
+        ctx.append(lvl, '\\end{umlfragment}')
+
+# Represents a for loop
+class For(Control):
+    def __init__(self, control):
+        self.control = control
+        self.body = []
+    def gen_seq(self, ctx, met, obj, lvl):
+        c = self.control
+        if type(c) is tr.EnhancedForControl:
+            var = c.var.declarators[0].name
+            control = f'for {var} in {expr_to_str(c.iterable)}'
+        elif type(c) is tr.ForControl:
+            control = c.init.declarators[0].name
+        else:
+            assert(False)
+
+        ctx.push_fragment()
+        for stmt in self.body:
+            stmt.gen_seq(ctx, met, obj, lvl + 1)
+            ctx.inc_comm_label()
+        pre = '' if ctx.has_any() else '% '
+        ctx.append_fragment(lvl, f'{pre}\\begin{{umlfragment}}' +
+                f'[type=loop, label={{{control}}}, inner xsep=20]')
+        ctx.pop_fragment()
+        ctx.append(lvl, f'{pre}\\end{{umlfragment}}')
+
+# Represents a method declaration
+class Method:
+    # name: name of the method
+    # vars: local variables declared inside (name, Var)
+    # body: list of Stmt objects
+    # t: containing class
+    # doc: doc string
+    def __init__(self, name):
+        self.name = name
+        self.vars = {}
+        self.params = []
+        self.body = []
+        self.t = None
+        self.doc = None
+        self.annotations = None
+
+    def excl(self):
+        """Returns True if this methods calls should be exclued from sequence
+        diagrams"""
+        for ann in self.annotations:
+            if ann.name == "NoSequenceDiagramGeneration": return True
+        return False
+
+    # generate sequence for method
+    # obj is a Var
+    def gen_seq(self, obj, lvl, ctx):
+        # generate for each statement
+        if not self.excl():
+            ctx.push_comm_label()
+            for i, stmt in enumerate(self.body):
+                stmt.gen_seq(ctx, self, obj, lvl)
+                ctx.inc_comm_label()
+            ctx.pop_comm_label()
+
+# Represents a class declaration
+class Type:
+    def __init__(self, name):
+        self.name = name
+        self.vars = {}
+        self.methods = {}
+        self.extends = None
+
+    # Get method by name (recursively looks at super classes)
+    def get_method(self, name):
+        if name in self.methods:
+            return self.methods[name]
+        if self.extends is not None:
+            if self.extends in types:
+                return types[self.extends].get_method(name)
+        return None
+    # Get member by name (recursively looks at super classes)
+    def get_member(self, name):
+        if name in self.vars:
+            return self.vars[name]
+        if self.extends is not None:
+            if self.extends in types:
+                return types[self.extends].get_member(name)
+
+# Processes a parsed java file, and puts the results in the types dict
+def dump_tree(t):
+    # dumps call Stmt objects in the expression e into the list body
+    # ret: a Var object
+    def dump_expr_calls(e, body, ret):
+        if type(e) is tr.MethodInvocation:
+            body.append(Call(e.qualifier, e.member, e.arguments, ret))
+        elif type(e) is tr.ClassCreator:
+            body.append(CreateCall(e, ret))
+
+    # dump Stmt objects from the statement s into the list body
+    # dump variable declarations into the dictionary v
+    def dump_stmt_calls(s, body, v):
+        if s is None:
+            pass
+        elif type(s) is tr.IfStatement:
+            stmt = If(s.condition)
+            dump_stmt_calls(s.then_statement, stmt.body_if, v)
+            dump_stmt_calls(s.else_statement, stmt.body_else, v)
+            body.append(stmt)
+        elif type(s) is tr.WhileStatement or type(s) is tr.DoStatement:
+            dump_expr_calls(s.condition, body, None)
+            dump_stmt_calls(s.body, body, v)
+        elif type(s) is tr.ForStatement:
+            stmt = For(s.control)
+            dump_stmt_calls(s.body, stmt.body, v)
+            if type(s.control) is tr.EnhancedForControl:
+                dump_stmt_calls(s.control.var, stmt.body, v) # dump var decl
+            body.append(stmt)
+        elif type(s) is tr.ReturnStatement or type(s) is tr.ThrowStatement:
+            dump_expr_calls(s.expression, body, None)
+        # TODO: SynchronizedStatement
+        elif type(s) is tr.TryStatement:
+            for st in s.block:
+                dump_stmt_calls(st, body, v)
+            # TODO: catches, finally_block
+        elif type(s) is tr.SwitchStatement:
+            dump_expr_calls(s.expression, body, None)
+            for case in s.cases:
+                for st in case.statements:
+                    dump_stmt_calls(st, body, v)
+        elif type(s) is tr.BlockStatement:
+            for st in s.statements:
+                dump_stmt_calls(st, body, v)
+        elif type(s) is tr.StatementExpression:
+            dump_expr_calls(s.expression, body, None)
+        elif issubclass(type(s), tr.VariableDeclaration):
+            for vd in s.declarators:
+                v[vd.name] = Var(s.type.name, vd.name)
+                dump_expr_calls(vd.initializer, body, v[vd.name])
+
+    # dump method into Method object met
+    def dump_method(method, met):
+        if method.body is None: return
+        for s in method.body:
+            dump_stmt_calls(s, met.body, met.vars)
+
+    for type_ in t.types:
+        if type(type_) is not tr.ClassDeclaration:
+            continue
+        t = Type(type_.name)
+        types[type_.name] = t
+        if type_.extends is not None:
+            t.extends = type_.extends.name
+        for decl in type_.fields:
+            for vd in decl.declarators:
+                t.vars[vd.name] = Var(decl.type.name, vd.name)
+        for method in type_.body:
+            if isinstance(method, tr.MethodDeclaration) or isinstance(method,
+                    tr.ConstructorDeclaration):
+                met = Method(method.name)
+                met.t = t
+                met.doc = method.documentation
+                met.annotations = method.annotations
+                t.methods[method.name] = met
+                for param in method.parameters:
+                    tname = param.type.name
+                    vname = param.name
+                    met.params.append((tname, vname))
+                dump_method(method, met)
+
+ctx = None
+def main():
+    global ctx
+    # fetch & process java files
+    walk_dir(sys.argv[1])
+    for tree in trees:
+        dump_tree(tree)
+
+    # parse arguments
+    method = None
+    repl = {}
+    # gen_seq_diag.py <project root> diag:Type:method:obj; a -> b; ...
+    def parse_args():
+        nonlocal method
+        params = list([i.strip() for i in sys.argv[2].split(';')])
+        method = params[0].split(':')
+        params = params[1:]
+        for i in params:
+            a = i.split('->')
+            repl[a[0].strip()] = a[1].strip()
+    parse_args()
+
+    spec_diag = method[0]
+    spec_type = method[1]
+    spec_method = method[2]
+    spec_obj = method[3]
+
+    ctx = Ctx(repl)
+    (types[spec_type]
+        .methods[spec_method]
+        .gen_seq(Var(spec_type, spec_obj), 1, ctx)
+    )
+    ctx.objs[(spec_type, spec_obj)] = 0
+    sorted_objs = sorted(ctx.objs.items(), key = lambda i: i[1])
+
+    if spec_diag == 'seq':
+        print('\\begin{tikzpicture}\n\\begin{umlseqdiag}')
+        print('\\umlobject[class=Actor]{act}')
+        for (t, n), lvl in sorted_objs:
+            if lvl < 0: continue
+            # if t in ctx.repl: t = ctx.repl[t]
+            if n in types: # static object hack
+                print(f'\\umlbasicobject[fill=green!20]{{{t}}}')
+            else:
+                print(f'\\umlobject[class={t}]{{{n}}}')
+
+        args_str = latex_escape(', '.join(
+            [p[1] for p in types[spec_type].methods[spec_method].params]))
+        print(f'\\begin{{umlcall}}[op={{{spec_method}({args_str})}},' +
+            f'return={{\\ }}]{{act}}{{{spec_obj}}}')
+        print(''.join(ctx.s))
+        print('\\end{umlcall}')
+        print('\\end{umlseqdiag}\n\\end{tikzpicture}')
+    elif spec_diag == 'comm':
+        print('\\begin{tikzpicture}')
+        print(r"""
+        \newcommand{\umlcomm}[4]{
+            \draw (#1) -- node [above,align=center,sloped]{#3} node
+                [below,align=center,sloped]{#4} (#2);
+        }
+        """)
+        l = len(sorted_objs)
+        pos = {}
+        for i, ((t, n), lvl) in enumerate(sorted_objs):
+            ang = (0.005 + 0.5 + i / l) * math.pi * 2
+            r = l * 1.5
+            x = math.cos(ang) * r
+            y = math.sin(ang) * r
+            pos[n] = (x, y)
+            if t in ctx.repl: t = ctx.repl[t]
+            print(f'\\draw ({x},{y}) node[draw] ({n}) {{{n}:{t}}};')
+        for (a, b), (msgs, msgs_rev) in ctx.comm_edges.items():
+            if a == b: continue
+            msg = r'\\ '.join(msgs)
+            msg_rev = r'\\ '.join(msgs_rev)
+            if pos[a][0] > pos[b][0]: msg,msg_rev = msg_rev,msg
+            if len(msg) > 0:
+                msg += r'\\ $\rightarrow$'
+            if len(msg_rev) > 0:
+                msg_rev = r'$\leftarrow$\\ ' + msg_rev
+            print(f'\\umlcomm{{{a}}}{{{b}}}{{{msg}}}{{{msg_rev}}}')
+        print('\\end{tikzpicture}')
+
+try:
+    main()
+except:
+    s = traceback.format_exc()
+    print("ERROR!\n")
+    print('%' + s.replace('\n', '\n%'))
+    print('% ctx dump:')
+    print('%' + ''.join(ctx.s).replace('\n', '\n%'))
diff --git a/plab/Makefile b/plab/Makefile
index 5222e679154eef6ac30b38788ec8e800dfaf7bbb..d177cf43ab8c60da5f0ac0d5d536b5dd52caec5b 100644
--- a/plab/Makefile
+++ b/plab/Makefile
@@ -2,5 +2,7 @@ build:
 	mkdir -p out
 	go build -v -race -ldflags "-linkmode external -extldflags '-static'" -a -o out/plab
 
+pwd = $(shell pwd)
+
 podman:
-	podman run --rm -it -v $PWD:/build golang:alpine sh -c "apk add make build-base; cd /build; make"
\ No newline at end of file
+	podman run --rm -it -v ${pwd}:/build golang:alpine sh -c "apk add make build-base; cd /build; make"
\ No newline at end of file