#!/usr/bin/env python

import sys
import json
import onnx
import glob
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["GRPC_VERBOSITY"] = 'error'
os.environ["GLOG_minloglevel"] = '2'

import tf2onnx
import argparse
import numpy as np
import tensorflow as tf
import onnxruntime as ort
import concurrent.futures
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import (
    Input, LSTM, RepeatVector, TimeDistributed, Masking, Dense,
    Dropout, BatchNormalization
)

__author__ = '@c3rb3ru5d3d53c'
__version__ = '1.0.0'

class OnnxClassifier:
    def __init__(self, model_paths: list):
        self.sessions = [ort.InferenceSession(path) for path in model_paths]
        input_info = self.sessions[0].get_inputs()[0]
        self.input_name = input_info.name
        _, self.max_sequence_length, input_dim = input_info.shape
        self.input_dimensions = input_dim or 1

    def classify(self, feature: list, threshold: float = 0.05, scale: float = 0.01):
        padded_feature = pad_sequences(
            [feature],
            maxlen=self.max_sequence_length,
            padding='post',
            dtype='float32'
        )

        input_tensor = padded_feature.reshape(
            1, self.max_sequence_length, self.input_dimensions
        ).astype(np.float32)

        similarity_scores = []

        for session in self.sessions:
            output = session.run(None, {self.input_name: input_tensor})[0]

            mse = np.mean((input_tensor - output) ** 2)
            max_mse = np.mean(input_tensor ** 2) + 1e-6
            normalized_mse = mse / max_mse

            exponent = (normalized_mse - threshold) / scale
            anomaly_score = 1 / (1 + np.exp(-exponent))

            similarity_score = 1 - anomaly_score
            similarity_scores.append(similarity_score)

        return np.mean(similarity_scores)

class OnnxTrainer:
    def __init__(
        self,
        features: list,
        batch_size: int = 64,
        epochs: int = 10,
        neurons: int = 256,
        opset: int = 13
    ):
        self.features = features
        self.batch_size = batch_size
        self.epochs = epochs
        self.neurons = neurons
        self.opset = opset

        self.input_dimensions = 1
        self.max_sequence_length = min(100, max(len(f) for f in features))
        self.padded_features = pad_sequences(
            features,
            maxlen=self.max_sequence_length,
            padding='post',
            dtype='float32'
        ).reshape(-1, self.max_sequence_length, self.input_dimensions)

    def get_dataset(self):
        dataset = tf.data.Dataset.from_tensor_slices((self.padded_features, self.padded_features))
        dataset = dataset.batch(self.batch_size).prefetch(tf.data.AUTOTUNE)
        return dataset

    def train(self, output: str):
        timesteps = self.max_sequence_length
        inputs = Input(shape=(timesteps, self.input_dimensions), dtype=tf.float32)
        masked_inputs = Masking(mask_value=0.0)(inputs)

        encoded = LSTM(self.neurons, return_sequences=True)(masked_inputs)
        encoded = Dropout(0.2)(encoded)
        encoded = LSTM(self.neurons // 2)(encoded)
        encoded = Dropout(0.2)(encoded)
        encoded = BatchNormalization()(encoded)

        decoded = RepeatVector(timesteps)(encoded)
        decoded = LSTM(self.neurons // 2, return_sequences=True)(decoded)
        decoded = Dropout(0.2)(decoded)
        decoded = LSTM(self.neurons, return_sequences=True)(decoded)
        decoded = Dropout(0.2)(decoded)
        decoded = BatchNormalization()(decoded)
        outputs = TimeDistributed(Dense(1))(decoded)

        autoencoder = Model(inputs, outputs)

        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, clipnorm=1.0)
        autoencoder.compile(optimizer=optimizer, loss='mean_squared_error')

        if np.isnan(self.padded_features).any() or np.isinf(self.padded_features).any():
            print("model training failed with NaN or Inf values.", file=sys.stderr)
            sys.exit(1)

        autoencoder.fit(
            self.get_dataset(),
            epochs=self.epochs
        )

        spec = (tf.TensorSpec((None, timesteps, self.input_dimensions), tf.float32, name='input'),)
        tf2onnx.convert.from_keras(
            autoencoder,
            input_signature=spec,
            opset=self.opset,
            output_path=output
        )

def classify_line(line, onnx, threshold, scale):
    try:
        data = json.loads(line)
        if data['chromosome'] is None: return None
        if data['chromosome']['feature'] is None: return None
        feature = data['chromosome']['feature']
        classification_value = onnx.classify(
            feature,
            threshold=threshold,
            scale=scale
        )
        return classification_value
    except json.JSONDecodeError as error:
        print(f"JSON decode error: {error}", file=sys.stderr)
    except Exception as error:
        print(f"Error: {error}", file=sys.stderr)
    return None

def main():
    parser = argparse.ArgumentParser(
        prog=f'bltensor v{__version__}',
        description='A Tensorflow Binlex Training and Filtering Tool',
        epilog=f'Author: {__author__}'
    )
    parser.add_argument('--version', action='version', version=f'v{__version__}')
    subparsers = parser.add_subparsers(dest="mode", required=True)

    parser_filter = subparsers.add_parser('filter', help='Filter Mode')
    parser_filter.add_argument('-t', '--threshold', type=float, required=True, help='Threshold')
    parser_filter.add_argument('-s', '--scale', type=float, default=0.01, help='Scale for anomaly score calculation')
    parser_filter.add_argument('-fm', '--filter-mode', choices=['gte', 'lte', 'gt', 'lt'], default='gte', help="Filter Mode")
    parser_filter.add_argument('-f', '--filter', type=float, required=True, help="Filter by Score")
    parser_filter.add_argument('-i', '--input', type=str, required=True, help='Input ONNX Model Directory')

    parser_train = subparsers.add_parser('train', help='Train ONNX Model')
    parser_train.add_argument('-t', '--threads', type=int, default=1, help='Threads')
    parser_train.add_argument('-e', '--epochs', type=int, default=10, help='Epochs')
    parser_train.add_argument('-b', '--batch-size', type=int, default=64, help='Batch Size')
    parser_train.add_argument('-n', '--neurons', type=int, default=256, help='Number of neurons')
    parser_train.add_argument('-o', '--output', type=str, required=True, help='Output model path')

    parser_classify = subparsers.add_parser('classify', help='Classify Sample Mode')
    parser_classify.add_argument('--threshold', type=float, required=True, help='Threshold')
    parser_classify.add_argument('--threads', type=int, default=1, help='Threads')
    parser_classify.add_argument('--scale', type=float, default=0.01, help='Scale for anomaly score calculation')
    parser_classify.add_argument('--input', type=str, required=True, help='Input ONNX Model Directory')

    args = parser.parse_args()

    if args.mode == 'classify':
        model_paths = glob.glob(os.path.join(args.input, '**', '*.onnx'), recursive=True)

        onnx = OnnxClassifier(model_paths=model_paths)

        classification_values = []

        with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
            futures = [executor.submit(classify_line, line, onnx, args.threshold, args.scale) for line in sys.stdin]
            for future in concurrent.futures.as_completed(futures):
                classification_value = future.result()
                if classification_value is not None:
                    classification_values.append(classification_value)

        similarity = sum(classification_values) / len(classification_values)
        print(f'Similarity: {similarity}')

    elif args.mode == 'filter':

        model_paths = glob.glob(os.path.join(args.input, '**', '*.onnx'), recursive=True)

        onnx = OnnxClassifier(model_paths=model_paths)

        threshold_check = {
            'gte': lambda x: x >= args.filter,
            'lte': lambda x: x <= args.filter,
            'gt': lambda x: x > args.filter,
            'lt': lambda x: x < args.filter
        }[args.filter_mode]

        for line in sys.stdin:
            classification_value = classify_line(line, onnx, args.threshold, args.scale)
            if classification_value is None: continue

            if threshold_check(classification_value):
                print(line.strip())

    elif args.mode == 'train':
        tf.config.threading.set_intra_op_parallelism_threads(args.threads)
        tf.config.threading.set_inter_op_parallelism_threads(args.threads)
        tf.config.optimizer.set_jit(False)

        features = []
        for line in sys.stdin:
            try:
                data = json.loads(line)
                if data['chromosome'] is None: continue
                if data['chromosome']['feature'] is None: continue
                features.append(data['chromosome']['feature'])
            except json.JSONDecodeError as error:
                print(f"JSON decode error: {error}", file=sys.stderr)
            except Exception as error:
                print(f"Error: {error}", file=sys.stderr)

        if not features:
            print('No features found for training.', file=sys.stderr)
            sys.exit(1)

        trainer = OnnxTrainer(
            features=features,
            batch_size=args.batch_size,
            epochs=args.epochs,
            neurons=args.neurons
        )

        trainer.train(args.output)

if __name__ == '__main__':
    main()
