using Microsoft.Diagnostics.Tracing;
using Microsoft.Diagnostics.Tracing.Etlx;
using Microsoft.Diagnostics.Tracing.Parsers;
using Microsoft.Diagnostics.Tracing.Parsers.MicrosoftAntimalwareAMFilter;
using Microsoft.Diagnostics.Tracing.Parsers.MicrosoftAntimalwareEngine;
using Microsoft.Diagnostics.Tracing.Stacks;
using System.Collections.Generic;
using System.Linq;

namespace PerfView
{
    public sealed class RealtimeAntimalwareComputer
    {
        private TraceLogEventSource _eventSource;
        private MutableTraceEventStackSource _stackSource;

        public RealtimeAntimalwareComputer(TraceLogEventSource eventSource, MutableTraceEventStackSource stackSource)
        {
            _eventSource = eventSource;
            _stackSource = stackSource;
        }

        public void Execute()
        {
            FileScanOperationCollection operationCollection = new FileScanOperationCollection(_eventSource.TraceLog, _stackSource);

            MicrosoftAntimalwareAMFilterTraceEventParser filterParser = new MicrosoftAntimalwareAMFilterTraceEventParser(_eventSource);
            filterParser.AMFilter_FileScan += operationCollection.CreateFileScanOperation;

            MicrosoftAntimalwareEngineTraceEventParser engineParser = new MicrosoftAntimalwareEngineTraceEventParser(_eventSource);
            engineParser.StreamscanrequestStart += operationCollection.StartScan;
            engineParser.StreamscanrequestStop += operationCollection.StopScan;
            engineParser.Skippedfile += operationCollection.MarkSkipped;

            _eventSource.Process();
        }
    }

    internal enum FileScanResult
    {
        Scanned = 0,
        Skipped = 1
    }

    internal sealed class FileScanOperation
    {
        internal string File { get; set; }
        internal string Reason { get; set; }
        internal FileScanResult Result { get; set; } = FileScanResult.Scanned;
        internal double StartTimeRelativeMSec { get; set; }
        internal double StopTimeRelativeMSec { get; set; }
        internal StackSourceCallStackIndex RequestorStack { get; set; } = StackSourceCallStackIndex.Invalid;
    }

    internal sealed class FileScanOperationCollection
    {
        private TraceLog _traceLog;
        private MutableTraceEventStackSource _stackSource;
        private StackSourceSample _sample;

        // Lookup Key 1: Process
        // Lookup Key 2: Thread
        private Dictionary<ProcessIndex, Dictionary<ThreadIndex, FileScanOperation>> _inProcessScanOperations = new Dictionary<ProcessIndex, Dictionary<ThreadIndex, FileScanOperation>>();

        private FileScanOperation[] _engineThreadToScanMap;

        internal FileScanOperationCollection(TraceLog traceLog, MutableTraceEventStackSource stackSource)
        {
            _traceLog = traceLog;
            _stackSource = stackSource;
            _sample = new StackSourceSample(_stackSource);
            _engineThreadToScanMap = new FileScanOperation[_traceLog.Threads.Count];
        }

        internal void CreateFileScanOperation(AMFilter_FileScanArgsTraceData data)
        {
            // If we can't get the process or thread index, bail.
            ProcessIndex processIndex = data.Process().ProcessIndex;
            if (processIndex == ProcessIndex.Invalid)
                return;

            ThreadIndex threadIndex = data.Thread().ThreadIndex;
            if (threadIndex == ThreadIndex.Invalid)
                return;

            // Get the process container.
            Dictionary<ThreadIndex, FileScanOperation> processContainer = GetOrCreateProcessContainer(processIndex);

            // Create a new file scan operation.
            // This happens when the scan is requested inside the user process.
            FileScanOperation scan = new FileScanOperation()
            {
                File = data.FileName,
                Reason = data.Reason,
                RequestorStack = _stackSource.GetCallStack(data.CallStackIndex(), data)
            };

            processContainer[threadIndex] = scan;
        }

        internal void StartScan(StreamscanrequestStartArgs_V1TraceData data)
        {
            // Get the requesting user process based on the PID logged inside the engine.
            TraceProcess process = _traceLog.Processes.GetProcess(data.PID, data.TimeStampRelativeMSec);
            ProcessIndex processIndex = process.ProcessIndex;
            if (processIndex == ProcessIndex.Invalid)
                return;

            // Get the file scan operation.
            Dictionary<ThreadIndex, FileScanOperation> processContainer = GetOrCreateProcessContainer(processIndex);
            FileScanOperation operation = processContainer.Values.Where(s => s.File.Equals(data.Path, System.StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
            if(operation != null)
            {
                operation.StartTimeRelativeMSec = data.TimeStampRelativeMSec;
                _engineThreadToScanMap[(int)data.Thread().ThreadIndex] = operation;
            }
        }

        internal void StopScan(StreamscanrequestStartArgs_V1TraceData data)
        {
            // Get the requesting user process based on the PID logged inside the engine.
            TraceProcess process = _traceLog.Processes.GetProcess(data.PID, data.TimeStampRelativeMSec);
            ProcessIndex processIndex = process.ProcessIndex;
            if (processIndex == ProcessIndex.Invalid)
                return;

            // Get the file scan operation.
            Dictionary<ThreadIndex, FileScanOperation> processContainer = GetOrCreateProcessContainer(processIndex);
            FileScanOperation operation = processContainer.Values.Where(s => s.File.Equals(data.Path, System.StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
            if (operation != null)
            {
                operation.StopTimeRelativeMSec = data.TimeStampRelativeMSec;

                // Create the stack.
                StackSourceFrameIndex fileNodeIndex = _stackSource.Interner.FrameIntern($"File ({operation.File})");
                StackSourceFrameIndex reasonNodeIndex = _stackSource.Interner.FrameIntern($"Reason ({(operation.Reason != null ? operation.Reason : "Unknown")})");
                StackSourceFrameIndex resultNodeIndex = _stackSource.Interner.FrameIntern($"Scan Result ({operation.Result})");
                _sample.StackIndex = _stackSource.Interner.CallStackIntern(resultNodeIndex, operation.RequestorStack);
                _sample.StackIndex = _stackSource.Interner.CallStackIntern(reasonNodeIndex, _sample.StackIndex);
                _sample.StackIndex = _stackSource.Interner.CallStackIntern(fileNodeIndex, _sample.StackIndex);
                _sample.Metric = (float)(operation.StopTimeRelativeMSec - operation.StartTimeRelativeMSec);
                _sample.TimeRelativeMSec = operation.StartTimeRelativeMSec;
                _stackSource.AddSample(_sample);
            }
        }

        internal void MarkSkipped(SkippedfileArgsTraceData data)
        {
            FileScanOperation inProcessOperation = _engineThreadToScanMap[(int)data.Thread().ThreadIndex];
            if(inProcessOperation != null)
            {
                inProcessOperation.Result = FileScanResult.Skipped;
            }
        }

        private Dictionary<ThreadIndex, FileScanOperation> GetOrCreateProcessContainer(ProcessIndex processIndex)
        {
            Dictionary<ThreadIndex, FileScanOperation> processContainer;
            if(!_inProcessScanOperations.TryGetValue(processIndex, out processContainer))
            {
                processContainer = new Dictionary<ThreadIndex, FileScanOperation>();
                _inProcessScanOperations[processIndex] = processContainer;
            }

            return processContainer;
        }
    }
}