#!/usr/bin/env python3
# SPDX-License-Identifier: WTFPL

import argparse
import ast
import datetime
import json
import sys


class Evaluator(ast.NodeTransformer):
    def eval_datetime(self, node):
        try:
            args = [ast.literal_eval(arg) for arg in node.args]
        except ValueError as exc:
            raise ValueError("args should be literals") from exc
        obj = datetime.datetime(*args)
        return ast.Constant(value=str(obj))

    #def eval_timezone(self, node):

    #def eval_timedelta(self, node):
    #    kwargs = {
    #        kw.arg: ast.literal_eval(kw.value)
    #        for kw in node.keywords
    #    }
    #    return datetime.timedelta(**kwargs)

    def eval_UUID(self, node):
        if len(node.args) != 1 or not isinstance(node.args[0], ast.Constant):
            raise ValueError(f"bad UUID args: {node.args!r}")
        return node.args[0]

    def eval_Decimal(self, node):
        if len(node.args) != 1 or not isinstance(node.args[0], ast.Constant):
            raise ValueError(f"bad Decimal args: {node.args!r}")
        return node.args[0]

    def eval_PosixPath(self, node):
        if len(node.args) != 1 or not isinstance(node.args[0], ast.Constant):
            raise ValueError(f"bad Path args: {node.args!r}")
        return node.args[0]

    def eval_frozenset(self, node):
        if len(node.args) == 0:
            return ast.List([])
        if len(node.args) != 1 or not isinstance(node.args[0], ast.Set):
            raise ValueError(f"bad set args: {node.args!r}")
        return ast.List(node.args[0].elts)

    def eval_namespace(self, node):
        return ast.Dict(
            keys=[ast.Constant(value=kw.arg) for kw in node.keywords],
            values=[kw.value for kw in node.keywords],
        )

    def visit_Call(self, node):
        name = None
        if isinstance(node.func, ast.Name):
            name = node.func.id
        elif isinstance(node.func, ast.Attribute):
            name = node.func.attr

        if name is None:
            raise ValueError(f"cannot find function name of {node.func}")
        elif name == "set":
            return node  # ast.literal_eval will parse it fine
        try:
            func = getattr(self, f"eval_{name}")
        except AttributeError:
            raise ValueError(f"function {name!r} is not allowed")

        if any(not kw.arg for kw in node.keywords):
            raise ValueError("starred keyword arguments are not allowed")

        return func(node)


def convert_to_jsonable(obj):
    if isinstance(obj, set):
        return list(obj)
    raise TypeError(f"Object of type {obj.__class__.__name__} cannot be converted")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Convert python literal dict/list/string to JSON."
        + " If FILE is not given, read stdin",
    )
    parser.add_argument("file", nargs="?")
    args = parser.parse_args()

    if args.file:
        with open(args.file) as fp:
            text = fp.read()
    else:
        text = sys.stdin.read()

    node = ast.parse(text, args.file or "-", "eval")
    node = Evaluator().visit(node)
    data = ast.literal_eval(node)
    print(json.dumps(data, indent=4, default=convert_to_jsonable))
