﻿using System.IO;
using static Vanara.PInvoke.AMSI;

namespace Vanara.PInvoke.Diagnostics;

/// <summary>The <c>ScanResult</c> enumeration specifies the types of results returned by scans.</summary>
public enum ScanResult : uint
{
	/// <summary>Known good. No detection found, and the result is likely not going to change after a future definition update.</summary>
	Clean = AMSI_RESULT.AMSI_RESULT_CLEAN,

	/// <summary>No detection found, but the result might change after a future definition update.</summary>
	NotDetected = AMSI_RESULT.AMSI_RESULT_NOT_DETECTED,

	/// <summary>A threat level less than the max was found, so there is a potential that the content is considered malware.</summary>
	PotentialDetected = AMSI_RESULT.AMSI_RESULT_NOT_DETECTED + 1,

	/// <summary>Detection found. The content is considered malware and should be blocked.</summary>
	Detected = AMSI_RESULT.AMSI_RESULT_DETECTED,
}

/// <summary>Provides scanning of strings and buffers to detect malware using either the system provider or a custom provider.</summary>
public static class AntimalwareScan
{
	private static SafeHAMSICONTEXT hCtx = SafeHAMSICONTEXT.Null;

	/// <summary>
	/// Gets or sets the provider to use for Antimalware scans. If <see langword="null"/>, the system default provider is used.
	/// </summary>
	/// <value>The Antimalware scan provider.</value>
	public static IAntimalwareProvider? Provider { get; set; }

	/// <summary>Scans a buffer-full of content for malware.</summary>
	/// <param name="buffer">The buffer from which to read the data to be scanned.</param>
	/// <param name="contentName">The filename, URL, unique script ID, or similar of the content being scanned.</param>
	/// <returns>The result of the scan.</returns>
	public static ScanResult Scan(byte[] buffer, string? contentName = null)
	{
		unsafe
		{
			fixed (byte* bufferPtr = buffer)
			{
				return Scan((IntPtr)bufferPtr, (uint)buffer.Length, contentName);
			}
		}
	}

	/// <summary>Scans a buffer-full of content for malware.</summary>
	/// <param name="buffer">The buffer from which to read the data to be scanned.</param>
	/// <param name="bufferLen">The length, in bytes, of the data to be read from <c>buffer</c>.</param>
	/// <param name="contentName">The filename, URL, unique script ID, or similar of the content being scanned.</param>
	/// <returns>The result of the scan.</returns>
	public static ScanResult Scan(IntPtr buffer, uint bufferLen, string? contentName = null)
	{
		AMSI_RESULT result;
		if (Provider is null)
		{
			EnsureContext();
			using SafeHAMSISESSION session = new(hCtx);
			AmsiScanBuffer(session.Context, buffer, bufferLen, contentName, session, out result).ThrowIfFailed();
			return result.Convert();
		}
		else
		{
			using AmsiStream stream = new(new SafeCoTaskMemHandle(buffer, bufferLen, false), false);
			Provider.Scan(stream, out result).ThrowIfFailed();
		}
		return result.Convert();
	}

	/// <summary>Scans a string for malware.</summary>
	/// <param name="str">The string to be scanned.</param>
	/// <param name="contentName">The filename, URL, unique script ID, or similar of the content being scanned.</param>
	/// <returns>The result of the scan.</returns>
	public static ScanResult Scan(string str, string? contentName = null)
	{
		AMSI_RESULT result;
		if (Provider is null)
		{
			EnsureContext();
			using SafeHAMSISESSION session = new(hCtx);
			AmsiScanString(session.Context, str, contentName, session, out result).ThrowIfFailed();
			return result.Convert();
		}
		else
		{
			using AmsiStream stream = new(new SafeCoTaskMemString(str), false);
			Provider.Scan(stream, out result).ThrowIfFailed();
		}
		return result.Convert();
	}

	/// <summary>Scans a file for malware.</summary>
	/// <param name="file">The file from which to read the data to be scanned.</param>
	/// <returns>The result of the scan.</returns>
	public static ScanResult Scan(FileInfo file) => Scan(File.ReadAllBytes(file.FullName), file.FullName);

	private static ScanResult Convert(this AMSI_RESULT result) => result switch
	{
		AMSI_RESULT.AMSI_RESULT_CLEAN => ScanResult.Clean,
		AMSI_RESULT.AMSI_RESULT_NOT_DETECTED => ScanResult.NotDetected,
		>= AMSI_RESULT.AMSI_RESULT_DETECTED => ScanResult.Detected,
		_ => ScanResult.PotentialDetected,
	};

	private static void EnsureContext()
	{
		if (hCtx.IsInvalid)
		{
			AmsiInitialize(Guid.NewGuid().ToString(), out hCtx).ThrowIfFailed();
		}
	}
}