import logging
import uuid
import copy
import importlib.metadata
from .flat_model import *
from .model import *
from .model_utils import *

logger = logging.getLogger(__name__)
__version__ = importlib.metadata.version("ufsm-generate")

def _emit(f, indent=0, output=""):
    f.write(" " * indent * 4 + output + "\n")


def _nl(f):
    _emit(f, 0, "")

def _emit_sq_push(f, indent, signal, sq_length):
    _emit(f, indent, f"if(m->sq_count > {sq_length})")
    _emit(f, indent + 1, f"return -UFSM_SIGNAL_QUEUE_FULL;")
    _emit(f, indent, f"m->sq[m->sq_count++] = {signal};")

def _generate_header_head(hmodel, fmodel, fh):
    _emit(fh, 0, f"/* Autogenerated with uFSM {__version__} */")
    _emit(fh, 0, f"#ifndef UFSM_{hmodel.name.upper()}")
    _emit(fh, 0, f"#define UFSM_{hmodel.name.upper()}")
    _nl(fh)
    _emit(fh, 0, "#define UFSM_OK 0")
    _emit(fh, 0, "#define UFSM_BAD_ARGUMENT 1")
    _emit(fh, 0, "#define UFSM_SIGNAL_QUEUE_FULL 2")


def _generate_header_foot(hmodel, fmodel, fh):
    _nl(fh)
    _emit(fh, 0, f"#endif  // UFSM_{hmodel.name.upper()}")


def _gen_c_head(hmodel, fmodel, sq_length, f):
    _emit(f, 0, f"/* Autogenerated with uFSM {__version__} */")
    _nl(f)
    _emit(f, 0, f'#include "{hmodel.name}.h"')
    _nl(f)


def _gen_events(hmodel, f):
    _nl(f)
    _emit(f, 0, "/* Events */")
    _emit(f, 0, "#define UFSM_RESET 0")
    _emit(f, 0, "#define UFSM_AUTO_TRANSITION 1")
    for _, event in hmodel.events.items():
        _emit(f, 0, f"#define {event.name} {event.index}")


def _gen_signals(hmodel, f):
    _nl(f)
    _emit(f, 0, "/* Signals */")
    for _, signal in hmodel.signals.items():
        _emit(f, 0, f"#define {signal.name} {signal.index}")


def _gen_guard_protos(hmodel, f):
    _nl(f)
    _emit(f, 0, "/* Guard prototypes */")
    for _, guard in hmodel.guards.items():
        _emit(f, 0, f"int {guard.name}(void *user);")


def _gen_action_protos(hmodel, f):
    _nl(f)
    _emit(f, 0, "/* Action prototypes */")
    for _, action in hmodel.actions.items():
        _emit(f, 0, f"void {action.name}(void *user);")


def _gen_machine_struct(hmodel, sq_length, f):
    # TODO: Signal queue
    no_of_regions = len(hmodel.regions.items())
    _nl(f)
    _emit(f, 0, f"struct {hmodel.name}_machine {{")
    _emit(f, 1, f"unsigned int csv[{no_of_regions}];")
    _emit(f, 1, f"unsigned int wsv[{no_of_regions}];")
    if sq_length > 0:
        _emit(f, 1, f"unsigned int sq[{sq_length}];")
        _emit(f, 1, f"unsigned int sq_count;")
    _emit(f, 1, "void *user;")
    _emit(f, 0, "};")

    _nl(f)
    _emit(f, 0, f"int {hmodel.name}_init(struct {hmodel.name}_machine *m, void *user);")
    _emit(
        f,
        0,
        f"int {hmodel.name}_process(struct {hmodel.name}_machine *m, unsigned int event);",
    )


def _sc_expr_helper(rule):
    result = " && ".join(
        f"(m->csv[{s.parent.index}] == {s.index})" for s in rule.csv_states
    )
    result += " && ".join(
        f"(m->wsv[{s.parent.index}] == {s.index})" for s in rule.wsv_states
    )
    if rule.invert:
        result = f"!({result})"
    return result


def _guard_expr_inner(g):
    if g.kind == UFSMM_GUARD_TRUE:
        return f"{g.guard.name}(m->user)"
    elif g.kind == UFSMM_GUARD_FALSE:
        return f"{g.guard.name}(m->user) == 0"
    elif g.kind == UFSMM_GUARD_EQ:
        return f"{g.guard.name}(m->user) == {g.value}"
    elif g.kind == UFSMM_GUARD_GT:
        return f"{g.guard.name}(m->user) > {g.value}"
    elif g.kind == UFSMM_GUARD_GTE:
        return f"{g.guard.name}(m->user) >= {g.value}"
    elif g.kind == UFSMM_GUARD_LT:
        return f"{g.guard.name}(m->user) < {g.value}"
    elif g.kind == UFSMM_GUARD_LTE:
        return f"{g.guard.name}(m->user) <= {g.value}"
    else:
        raise Exception("Unknown guard kind")


def _guard_expr_helper(guards):
    result = " && ".join(_guard_expr_inner(g) for g in guards)
    return result


def _gen_transition_exits(hmodel, fmodel, f, ft, sq_length, indent):
    _emit(f, indent, "/* Exit actions */")
    for ex in ft.exits:
        if len(ex.actions) == 0:
            continue
        if len(ex.rule) > 0:
            _emit(f, indent, f"if ({_sc_expr_helper(ex.rule)}) {{")
            indent_extra = 1
        else:
            indent_extra = 0
        for a in ex.actions:
            if isinstance(a, ActionFunction):
                _emit(f, indent + indent_extra, f"{a.action.name}(m->user);")
            elif isinstance(a, ActionSignal):
                _emit_sq_push(f, indent + indent_extra, a.signal.index, sq_length)
        if len(ex.rule):
            _emit(f, indent, "}")


def _gen_transition_actions(hmodel, fmodel, f, ft, sq_length, indent):
    _emit(f, indent, "/* Actions */")
    for a in ft.actions:
        if isinstance(a, ActionFunction):
            _emit(f, indent, f"{a.action.name}(m->user);")
        elif isinstance(a, ActionSignal):
            _emit_sq_push(f, indent, a.signal.index, sq_length)


def _gen_transition_entries(hmodel, fmodel, f, ft, sq_length, indent):
    _emit(f, indent, "/* Entry actions */")
    auto_transition = False

    for en in ft.entries:
        indent_extra = 0
        if len(en.rule) > 0:
            _emit(f, indent, f"if ({_sc_expr_helper(en.rule)}) {{")
            indent_extra = 1
        else:
            indent_extra = 0
        for t in en.targets:
            if en.rule.history:
                _emit(
                    f,
                    indent + indent_extra,
                    f"if (m->csv[{t.parent.index}] == {t.index}) {{",
                )
                indent_extra += 1
                _emit(
                    f, indent + indent_extra, f"m->wsv[{t.parent.index}] = {t.index};"
                )
            else:
                _emit(
                    f,
                    indent + indent_extra,
                    f"m->wsv[{t.parent.index}] = {t.index}; // {t.parent} = {t}",
                )
            # If the target state has an out-bound, trigger-less transition
            #  we set 'process_auto_transition' to enable the trigger-less loop.
            for trans in t.transitions:
                if isinstance(trans.trigger, AutoTransitionTrigger):
                    auto_transition = True
                if isinstance(trans.dest, Join):
                    if isinstance(trans.dest.transitions[0].trigger, AutoTransitionTrigger):
                        logger.debug("Target is join with auto-trigger")
                        auto_transition = True

        for a in en.actions:
            if isinstance(a, ActionFunction):
                _emit(f, indent + indent_extra, f"{a.action.name}(m->user);")
            elif isinstance(a, ActionSignal):
                _emit_sq_push(f, indent + indent_extra, a.signal.index, sq_length)
        if en.rule.history:
            indent_extra -= 1
            _emit(f, indent + indent_extra, "}")
        if len(en.rule) > 0:
            _emit(f, indent, "}")

    if auto_transition:
        _emit(f, indent, "process_auto_transition = 1;")


def _gen_transition_inner(hmodel, fmodel, f, ft, rules, sq_length, indent):
    r = rules[0]
    _emit(f, indent, f"if ({_sc_expr_helper(r)}) {{")

    if len(rules) > 1:
        _gen_transition_inner(hmodel, fmodel, f, ft, rules[1:], sq_length, indent + 1)
    else:
        if len(ft.guards) > 0:
            indent += 1
            _emit(f, indent, f"if ({_guard_expr_helper(ft.guards)}) {{")

        _gen_transition_exits(hmodel, fmodel, f, ft, sq_length, indent + 1)
        _gen_transition_actions(hmodel, fmodel, f, ft, sq_length, indent + 1)
        _gen_transition_entries(hmodel, fmodel, f, ft, sq_length, indent + 1)

        if len(ft.guards) > 0:
            _emit(f, indent, "}")
            indent -= 1
    _emit(f, indent, "}")


def _gen_reset_vector(hmodel, fmodel, f, sq_length):
    for s in fmodel.isv.states:
        _emit(f, 3, f"m->wsv[{s.parent.index}] = {s.index}; // {s.parent} = {s}")
    for a in fmodel.isv.actions:
        if isinstance(a, ActionFunction):
            _emit(f, 3, f"{a.action.name}(m->user);")
        elif isinstance(a, ActionSignal):
            _emit_sq_push(f, 3, a.signal.index, sq_length)

def _gen_transition(hmodel, fmodel, f, ft, sq_length, indent):
    _emit(f, indent, f"/* {ft.source.name} -> {ft.dest.name} */")
    _gen_transition_inner(hmodel, fmodel, f, ft, ft.rules, sq_length, indent)


def _gen_process_func(hmodel, fmodel, sq_length, f):
    _nl(f)
    _emit(f, 0, f"int {hmodel.name}_init(struct {hmodel.name}_machine *m, void *user)")
    _emit(f, 0, "{")
    _emit(f, 1, f"for (unsigned int i = 0; i < {hmodel.no_of_regions}; i++) {{")
    _emit(f, 2, "m->wsv[i] = 0;")
    _emit(f, 2, "m->csv[i] = 0;")
    _emit(f, 1, "}")
    if sq_length > 0:
        _emit(f, 1, "m->sq_count = 0;")
        _emit(f, 1, "m->user = user;")
    _emit(f, 1, "return 0;")
    _emit(f, 0, "}")
    _nl(f)
    _emit(
        f,
        0,
        f"int {hmodel.name}_process(struct {hmodel.name}_machine *m, unsigned int event)",
    )
    _emit(f, 0, "{")
    if hmodel.no_of_auto_transitions > 0:
        _emit(f, 1, "unsigned int process_auto_transition = 0;")
        _emit(f, 0, "process_more:")
        _nl(f)

    # Events
    _emit(f, 1, f"for (unsigned int i = 0; i < {hmodel.no_of_regions}; i++)")
    _emit(f, 2, "m->wsv[i] = 0;")
    _nl(f)
    _emit(f, 1, "switch(event) {")
    _emit(f, 2, "case UFSM_RESET:")
    _gen_reset_vector(hmodel, fmodel, f, sq_length)
    _emit(f, 2, "break;")
    for _, event in hmodel.events.items():
        _emit(f, 2, f"case {event.name}:")
        # Auto-transition, reserved UUID, special event
        if isinstance(event, AutoTransitionTrigger):
            _emit(f, 3, "process_auto_transition = 0;")
        for ft in fmodel.transition_schedule:
            if isinstance(ft.trigger, Event) and ft.trigger.id == event.id:
                _gen_transition(hmodel, fmodel, f, ft, sq_length, 3)
        _emit(f, 2, f"break;")

    if hmodel.no_of_auto_transitions > 0:
        _emit(f, 2, f"case UFSM_AUTO_TRANSITION:")
        _emit(f, 3, "process_auto_transition = 0;")
        for ft in fmodel.transition_schedule:
            if isinstance(ft.trigger, AutoTransitionTrigger):
                _gen_transition(hmodel, fmodel, f, ft, sq_length, 3)
        _emit(f, 2, f"break;")

    _emit(f, 2, "default:")
    _emit(f, 3, "return -UFSM_BAD_ARGUMENT;")
    _emit(f, 1, "}")

    _emit(f, 1, f"for (unsigned int i = 0; i < {hmodel.no_of_regions}; i++)")
    _emit(f, 2, "if(m->wsv[i] != 0)")
    _emit(f, 3, "m->csv[i] = m->wsv[i];")
    _nl(f)

    # Trigger-less transitions
    if hmodel.no_of_auto_transitions > 0:
        _emit(f, 1, "if (process_auto_transition == 1) {")
        _emit(f, 2, "event = UFSM_AUTO_TRANSITION;")
        _emit(f, 2, "goto process_more;")
        _emit(f, 1, "}")

    # Signals
    if len(hmodel.signals) > 0:
        _nl(f)
        _emit(f, 1, "for (unsigned int n = 0; n < m->sq_count; n++) {")
        _emit(f, 2, f"for (unsigned int i = 0; i < {hmodel.no_of_regions}; i++)")
        _emit(f, 3, "m->wsv[i] = 0;")
        _nl(f)
        _emit(f, 2, "switch(m->sq[n]) {")
        for _, signal in hmodel.signals.items():
            _emit(f, 3, f"case {signal.name}:")
            for ft in fmodel.transition_schedule:
                if isinstance(ft.trigger, AutoTransitionTrigger):
                    continue
                if isinstance(ft.trigger, CompletionTrigger):
                    continue
                if ft.trigger.id == signal.id:
                    _gen_transition(hmodel, fmodel, f, ft, sq_length, 4)
            _emit(f, 3, f"break;")
        _emit(f, 3, "default:")
        _emit(f, 4, "return -UFSM_BAD_ARGUMENT;")
        _emit(f, 2, "}")
        _emit(f, 2, f"for (unsigned int i = 0; i < {hmodel.no_of_regions}; i++)")
        _emit(f, 3, "if(m->wsv[i] != 0)")
        _emit(f, 4, "m->csv[i] = m->wsv[i];")
        _emit(f, 1, "}")

        # Trigger-less transitions
        if hmodel.no_of_auto_transitions > 0:
            _emit(f, 1, "if (process_auto_transition == 1) {")
            _emit(f, 2, "event = UFSM_AUTO_TRANSITION;")
            _emit(f, 2, "goto process_more;")
            _emit(f, 1, "}")

    _emit(f, 1, "return 0;")
    _emit(f, 0, "}")


def c_generator(fmodel: FlatModel, hmodel: Model, c_output, h_output, args):
    with open(c_output, "w") as fc, open(h_output, "w") as fh:
        logger.debug(f"Signal queue length = {args.sq_length}")
        _generate_header_head(hmodel, fmodel, fh)
        # Generate event defines
        _gen_events(hmodel, fh)
        _gen_signals(hmodel, fh)
        _gen_guard_protos(hmodel, fh)
        _gen_action_protos(hmodel, fh)
        _gen_machine_struct(hmodel, args.sq_length, fh)
        # Generate guard and action function prototypes
        _generate_header_foot(hmodel, fmodel, fh)

        _gen_c_head(hmodel, fmodel, args.sq_length, fc)
        _gen_process_func(hmodel, fmodel, args.sq_length, fc)


def c_generator_argparser(subparser):
    parser = subparser.add_parser("c")
    parser.add_argument(
        "-l", "--sq_length", type=int, default=0, help="Signal queue length"
    )
