#!/usr/bin/env python3

import argparse
import json
import tempfile
import subprocess
import os
import sys
import traceback
import importlib
from elftools.elf.elffile import ELFFile
from librw_arm64.util.logging import info

def load_analysis_cache(loader, outfile):
    with open(outfile + ".analysis_cache") as fd:
        analysis = json.load(fd)
    print("[*] Loading analysis cache")
    for func, info in analysis.items():
        for key, finfo in info.items():
            loader.container.functions[int(func)].analysis[key] = dict()
            for k, v in finfo.items():
                try:
                    addr = int(k)
                except ValueError:
                    addr = k
                loader.container.functions[int(func)].analysis[key][addr] = v

def save_analysis_cache(loader, outfile):
    analysis = dict()

    for addr, func in loader.container.functions.items():
        analysis[addr] = dict()
        analysis[addr]["free_registers"] = dict()
        for k, info in func.analysis["free_registers"].items():
            analysis[addr]["free_registers"][k] = list(info)

    with open(outfile + ".analysis_cache", "w") as fd:
        json.dump(analysis, fd)


def analyze_registers(loader, args):
    StackFrameAnalysis.analyze(loader.container)
    if args.cache:
        try:
            load_analysis_cache(loader, args.outfile)
        except IOError:
            RegisterAnalysis.analyze(loader.container)
            save_analysis_cache(loader, args.outfile)
    else:
        RegisterAnalysis.analyze(loader.container)

def asan(rw, loader, args):
    analyze_registers(loader, args)

    instrumenter = Instrument(rw)
    instrumenter.do_instrument()
    instrumenter.dump_stats()


def asank(rw, loader, args):
    StackFrameAnalysis.analyze(loader.container)

    with tempfile.NamedTemporaryFile(mode='w') as cf_file:
        with tempfile.NamedTemporaryFile(mode='r') as regs_file:
            rw.dump_cf_info(cf_file)
            cf_file.flush()

            subprocess.check_call(['cftool', cf_file.name, regs_file.name])

            analysis = json.load(regs_file)

            for func, info in analysis.items():
                for key, finfo in info.items():
                    fn = loader.container.get_function_by_name(func)
                    fn.analysis[key] = dict()
                    for k, v in finfo.items():
                        try:
                            addr = int(k)
                        except ValueError:
                            addr = k
                        fn.analysis[key][addr] = v
    return rw

def assemble(filename, outfile):
    print(f"Assembling {filename} into {outfile}")
    size = os.stat(filename).st_size
    f = open(filename, 'r')
    f.seek(max(0, size - 4096)) # we store metadata only at the end
                        # no need to read the whole file
    deps, paths, secs, cc = [], ["."], {}, "g++"
    last_lines = f.readlines()
    for line in last_lines:
        if "DEPENDENCY:" in line:
            deps += [line.split(": ")[1].strip()]
        if "LIB_PATH:" in line:
            paths += [line.split(": ")[1].strip()]
        if "CC:" in line:
            cc = line.split(": ")[1].strip()
        if "SECTION:" in line:
            name, addr = line.split(": ")[1].split(" - ")
            secs[name] = addr.strip()

    info("Will link with the following dependencies:")
    print(",".join(deps))

    info("Will link with the following sections:")
    print(secs)

    lflags = []
    for d in deps:
        if "ld-linux" in d: continue
        dep = "-l:" + d.strip()
        lflags += [dep]
    lflags = " ".join(lflags)

    wlflags = []
    for sec, addr in secs.items():
        wlflags += ["--section-start="+sec+"="+addr]
    if any("SYMBOLIC" in x for x in last_lines): 
        wlflags += ['-Bsymbolic']
    wlflags = "-Wl," + ",".join(wlflags)
    for p in paths:
        wlflags += " -L" + p
    wlflags += " -march=armv8.1-a+crc+crypto+sve"
    # wlflags += ",--no-check-sections"

    pie = "-no-pie" if any("NOPIE" in x for x in last_lines) else "-pie"
    shr = "-shared " if any("SHARED" in x for x in last_lines) else ""

    if len(shr): pie = ""

    asmline = f"{cc} {filename} {shr} {pie} -nostartfiles {lflags} {wlflags} -o {outfile}"
    info("Assembling with the following command:")
    print(asmline)
    subprocess.check_output(asmline, shell=True) 

    # f = open(f"{outfile}", "rb")
    # ff = bytearray(f.read())
    # f.close()
    # f = open(f"{outfile}", "wb")
    # change_sec_name = {
            # b".oh_frame\x00" : b".eh_frame\x00",
            # b".oh_frame_hdr\x00" : b".eh_frame_hdr\x00",
            # b".occ_except_table\x00" : b".gcc_except_table\x00",
    # }
    # print(list(change_sec_name.keys()))
    # names_offsets = list(map(lambda x: ff.find(x), list(change_sec_name.keys())))
    # print(names_offsets)
    # if all([abs(names_offsets[i+1] - names_offsets[i]) < 100 for i in range(len(names_offsets)-1)]):
        # ff[names_offsets[0]+1] = ord("e")
        # ff[names_offsets[1]+1] = ord("e")
        # ff[names_offsets[2]+1] = ord("g")

    # f.write(ff)
    # f.close()

    print("Done!\n")



if __name__ == "__main__":
    argp = argparse.ArgumentParser(description='')

    argp.add_argument("bin", type=str, help="Input binary to load")
    argp.add_argument("outfile", type=str, help="Symbolized ASM output")

    argp.add_argument("-a", "--assemble", action='store_true',
                      help="Assemble instrumented assembly file into instrumented binary")

    argp.add_argument("-A", "--asan", action='store_true',
                      help="Add binary address sanitizer instrumentation")

    argp.add_argument("-m", "--module", type=str,
                      help="Use specified instrumentation module in rwtools directory")

    argp.add_argument("-d", "--data-text", dest="data_text_support", action="store_true",
                      help="Enable support for data interleaved with code (aarch64 only, uses segfault handler, 30%+ larger binary)")
    argp.add_argument("-o", "--optimize-dt-layout", dest="optimize_dtsupport_layout", action="store_true",
                      help="Will optimize the layout to reduce calls to the segfault handler when --data-text is on. Experimental.")

    # argp.add_argument("-s", "--assembly", action="store_true",
                      # help="Generate Symbolized Assembly")
    # python3 -m librw_x64.rw </path/to/binary> <path/to/output/asm/files>
    argp.add_argument("-k", "--kernel", action='store_true',
                      help="Instrument a kernel module (x64 only)")
    argp.add_argument(
        "--kcov", action='store_true', help="Instrument the kernel module with kcov (x64 only)")
    argp.add_argument("-c", "--cache", action='store_true',
                      help="Save/load register analysis cache (only used with --asan)")
    argp.add_argument("--ignore-no-pie", dest="ignore_no_pie", action='store_true', help="Ignore position-independent-executable check (only needed for x64, use with caution)")
    argp.add_argument("--ignore-stripped", dest="ignore_stripped", action='store_true',
                      help="Ignore stripped executable check (only needed for x64, use with caution)")
    argp.add_argument("--optimize-exceptions", dest="optimize_exceptions", action="store_true",
          help="Parse and rewrite exception handler information. Will speed up the program, but use with care.")
    argp.add_argument("--no-exceptions", dest="no_exceptions", action="store_true",
          help="Avoid rewriting exception handling information.")
    argp.add_argument("--ghidra", dest="use_ghidra", action="store_true",
          help="Use ghidra to determine between data and text (aarch64 only)")
    argp.add_argument("-v", "--verbose", action="store_true",
          help="Verbose output")

    argp.set_defaults(optimize_dtsupport_layout=False)
    argp.set_defaults(data_text_support=False)
    argp.set_defaults(ignore_no_pie=False)
    argp.set_defaults(ignore_stripped=False)
    argp.set_defaults(optimize_exceptions=False)
    argp.set_defaults(no_exceptions=False)
    argp.set_defaults(use_ghidra=False)

    args = argp.parse_args()

    if args.assemble:
        assemble(args.bin, args.outfile)
        exit(0)

    print(f"Rewriting {args.bin} into {args.outfile}")

    elffile = ELFFile(open(args.bin, "rb"))
    arch = elffile.get_machine_arch()

    rwtools_path = "rwtools"

    if arch == "AArch64":
        from librw_arm64.rw import Rewriter
        from librw_arm64.analysis.register import RegisterAnalysis
        from librw_arm64.analysis.stackframe import StackFrameAnalysis
        from rwtools_arm64.asan.instrument import Instrument
        from librw_arm64.loader import Loader
        from librw_arm64.analysis import register
        import librw_arm64.util.logging

        rwtools_path = rwtools_path + "_arm64."

        if args.verbose:
            librw_arm64.util.logging.DEBUG_LOG = True
        if args.optimize_exceptions:
            Rewriter.emulate_calls = False
        if args.data_text_support: 
            Rewriter.data_text_support = True
        if args.optimize_dtsupport_layout: 
            Rewriter.optimize_dtsupport_layout = True
        if args.no_exceptions: 
            Rewriter.no_exceptions = True
        if args.use_ghidra: 
            Rewriter.use_ghidra = True

    elif arch == "x64":
        rwtools_path = rwtools_path + "_x64."
        if args.kernel:
            from librw_x64.krw import Rewriter
            from librw_x64 import krw
            from librw_x64.analysis.kregister import RegisterAnalysis
            from librw_x64.analysis.kstackframe import StackFrameAnalysis
            from librw_x64.kloader import Loader
            from librw_x64.analysis import kregister
        else:
            from librw_x64.rw import Rewriter
            from librw_x64.analysis.register import RegisterAnalysis
            from librw_x64.analysis.stackframe import StackFrameAnalysis
            from rwtools_x64.asan.instrument import Instrument
            from librw_x64.loader import Loader
            from librw_x64.analysis import register
    else:
        print(f"Architecture {arch} not supported!")
        exit(1)

    if args.asan or (args.module and args.module != "counter"):
        Rewriter.detailed_disasm = True

    loader = Loader(args.bin)

    # x64 supports only PIE binaries with symbols (for now)
    if arch == "x64" and loader.is_pie() == False and args.ignore_no_pie == False:
        print("***** The x64 version of RetroWrite requires a position-independent executable. *****")
        print("It looks like %s is not position independent" % args.bin)
        print("If you really want to continue, because you think retrowrite has made a mistake, pass --ignore-no-pie.")
        sys.exit(1)
    if arch == "x64" and loader.is_stripped() == True and args.ignore_stripped == False:
        print("The x64 version of RetroWrite requires a non-stripped executable.")
        print("It looks like %s is stripped" % args.bin)
        print("If you really want to continue, because you think retrowrite has made a mistake, pass --ignore-stripped.")
        sys.exit(1)


    slist = loader.slist_from_symtab()

    if not ".gcc_except_table" in slist:
        # if there are no exceptions, we never emulate calls
        Rewriter.emulate_calls = False

    # this if is due to small architectural implementation differences. Hopefully in the future it will be unified.
    if arch == "x64":
        loader.identify_imports()
        flist = loader.flist_from_symtab() # before loading data sections
        loader.load_functions(flist)
        if args.kernel:
            loader.load_data_sections(slist, krw.is_data_section)
        else:
            loader.load_data_sections(slist, lambda x: x in Rewriter.DATASECTIONS)
    else:
        loader.load_sections(slist, lambda x: x not in Rewriter.IGNORE_SECTIONS)
        flist = loader.flist_from_symtab() # after loading all sections
        loader.load_functions(flist, args.use_ghidra)


    reloc_list = loader.reloc_list_from_symtab()
    loader.load_relocations(reloc_list)

    global_list = loader.global_data_list_from_symtab()
    loader.load_globals_from_glist(global_list)

    loader.container.attach_loader(loader)

    rw = Rewriter(loader.container, args.outfile)
    rw.symbolize()


    if args.asan:
        if args.kernel:
            rewriter = asank(rw, loader, args)
            instrumenter = Instrument(rewriter)
            instrumenter.do_instrument()

            if args.kcov:
                kcov_instrumenter = KcovInstrument(rewriter)
                kcov_instrumenter.do_instrument()
            rewriter.dump()
        else:
            asan(rw, loader, args)
            rw.dump()
    elif args.module:
        try:
            module = importlib.import_module(rwtools_path + args.module + ".instrument")
            instrument = getattr(module, "Instrument")

            if args.module != "counter":
                analyze_registers(loader, args)
            instrumenter = instrument(rw)
            instrumenter.do_instrument()
            rw.dump()
        except Exception as e:
            traceback.print_exc()
    else:
        rw.dump()
