#!/usr/bin/python3
#
# Copyright (C) 2020-2022 The opuntiaOS Project Authors.
#  + Contributed by Nikita Melekhin <nimelehin@gmail.com>
#
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import os
import code
import argparse
from time import sleep
from ProfNode import ProfNode, Profiler
from QConn import QConn


parser = argparse.ArgumentParser()
parser.add_argument("path", type=str, help="Monitor address (required)")
parser.add_argument("--symbols", type=str, help="Symbols (required)")
parser.add_argument("--duration", type=int, default=10,
                    help="Duration (default=10)")
parser.add_argument("--frequency", type=float, default=0.05,
                    help="Frequency (default=0.05)")
parser.add_argument("--nointeractive", action="store_true",
                    help="Not run interactive mode")
args = parser.parse_args()

if not os.path.exists(args.symbols):
    print("Can't open symbols: {}", args.symbols)
    exit(1)

if not os.path.exists(args.path):
    print("Can't open qemu monitor socket: {}", args.path)
    exit(1)


qconn = QConn(args.path)


def get_stacktrace():
    PAGE_SIZE = 0x1000

    stacktrace = []

    qconn.stop()
    ip = int(qconn.gpreg("eip"), 16)
    sp = int(qconn.gpreg("esp"), 16)
    bp = int(qconn.gpreg("ebp"), 16)
    stacktrace.append(ip)

    bottomsp = sp & ~(PAGE_SIZE - 1)
    memdata, err = qconn.human_cmd(
        "x/{}x ".format(PAGE_SIZE // 4) + hex(bottomsp))
    if err:
        return []

    qconn.cont()

    lines = memdata.split("\r\n")
    memmap = {}
    for i in lines:
        data = i.split(" ")
        if (len(data) != 5):
            continue
        addr = int(data[0][:-1], 16)
        memmap[addr] = int(data[1], 0)
        memmap[addr+4] = int(data[2], 0)
        memmap[addr+8] = int(data[3], 0)
        memmap[addr+12] = int(data[4], 0)

    visited = set()
    while True:
        pc = memmap.get(bp + 4)
        if pc is not None:
            stacktrace.append(pc)
        visited.add(bp)
        bp = memmap.get(bp, None)
        if bp is None:
            break
        if bp in visited:
            break

    return stacktrace


profiler = Profiler(args.symbols)
root = ProfNode(profiler)
root.name = "root"

runs = int(args.duration / args.frequency)
for i in range(runs):
    root.add_stacktrace(get_stacktrace())
    sleep(args.frequency)


root.process_node()

if not args.nointeractive:
    print("Entering interactive mode, use root object to access profile data.")
    print(" * root.trace() - prints trace")
    code.interact(local=locals())
else:
    root.trace()
