import sys
import threading
import time
from io import StringIO
from typing import Optional

from labml.internal.app import AppTracker, AppTrackDataSource, Packet

WARMUP_COMMITS = 5


class AppConsoleLogs(AppTrackDataSource):
    app_tracker: Optional[AppTracker]
    frequency: float

    def __init__(self):
        super().__init__()

        self.app_tracker = None
        self.frequency = 1
        self.last_committed = time.time()
        self.commits_count = 0
        self.data = {}
        self.lock = threading.Lock()

    def set_app_tracker(self, app_tracker: AppTracker, *, frequency: float):
        self.app_tracker = app_tracker
        self.frequency = frequency
        self.check_and_flush()

    def check_and_flush(self):
        if self.app_tracker is None:
            return
        with self.lock:
            if not self.data:
                return
        t = time.time()
        freq = self.frequency
        if self.commits_count < WARMUP_COMMITS:
            freq /= 2 ** (WARMUP_COMMITS - self.commits_count)
        if self.data.get('stderr', '') != '' or self.commits_count == 0 or t - self.last_committed > freq:
            self.commits_count += 1
            self.app_tracker.has_data(self)

    def _clean(self, data: str):
        last_newline = None
        remove = []
        for i in range(len(data)):
            if data[i] == '\r':
                if i + 1 < len(data) and data[i + 1] == '\n':
                    remove.append((i, i))
                elif last_newline is not None:
                    remove.append((last_newline + 1, i))
                last_newline = i
            elif data[i] == '\n':
                last_newline = i

        res = []
        offset = 0
        for r in remove:
            if offset < r[0]:
                res.append(data[offset: r[0]])
            offset = r[1] + 1

        res.append(data[offset:])
        return ''.join(res)

    def get_data_packet(self) -> Packet:
        with self.lock:
            self.last_committed = time.time()
            self.data['time'] = time.time()
            for type_ in ['stdout', 'logger']:
                if type_ not in self.data:
                    continue
                self.data[type_] = self._clean(self.data[type_])
            packet = Packet(self.data)
            self.data = {}
            return packet

    def outputs(self, *,
                stdout_: str = '',
                stderr_: str = '',
                logger_: str = ''):
        with self.lock:
            if stdout_ != '':
                self.data['stdout'] = self.data.get('stdout', '') + stdout_
            if stderr_ != '':
                self.data['stderr'] = self.data.get('stderr', '') + stderr_
            if logger_ != '':
                self.data['logger'] = self.data.get('logger', '') + logger_

        self.check_and_flush()


APP_CONSOLE_LOGS = AppConsoleLogs()


class OutputStream(StringIO):
    def write(self, *args, **kwargs):  # real signature unknown
        super().write(*args, **kwargs)
        save = StringIO()
        save.write(*args, **kwargs)
        APP_CONSOLE_LOGS.outputs(**{self.type_: save.getvalue()})
        self.original.write(*args, **kwargs)

    def __init__(self, original, type_):  # real signature unknown
        super().__init__()
        self.type_ = type_
        self.original = original


_original_stdout_write = sys.stdout.write
_original_stderr_write = sys.stderr.write


def _write_stdout(*args, **kwargs):
    _original_stdout_write(*args, **kwargs)
    save = StringIO()
    save.write(*args, **kwargs)
    APP_CONSOLE_LOGS.outputs(stdout_=save.getvalue())


def _write_stderr(*args, **kwargs):
    _original_stderr_write(*args, **kwargs)
    save = StringIO()
    save.write(*args, **kwargs)
    APP_CONSOLE_LOGS.outputs(stderr_=save.getvalue())


def capture():
    sys.stdout.write = _write_stdout
    sys.stderr.write = _write_stderr


capture()
