(* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
                         -****************+.
               .-=++*+::*###############*-
           :=+##*+=-::*##**############=
        .=*#*=:    .+#*=+############+.
       =##+.      +#+-=############*-:
     :*#+.      =+-.=#############- =#*.
    -##-      -=..+#############*:.  ---.
   -##:     .: .+################= :+#*-:.
  .##-        :==========*#####= :*#*- -#*
  =##                   -####*:-*#*-   .##:
  +#+                  -###*--*#*-      *#=
  *#+                 +###==*#*:        *#=
  +#*                *##+=*#*-          ##-
  :##.             .*##*##*:           -##.
   +#*            -#####*:             *#=
    *#+          =####+:             .*#+
     +#*.       +###+.              :##=
      -##+.    *##+:              :+#*:
        =*#+..*#+:             .-*#*-
          :-:#+::.        .:-=*##+:
           -+..+*###****####*+-.
          :.      ..:::::..
        ____                   _
       / ___| _ __   __ _ _ __| | __
       \___ \| '_ \ / _` | '__| |/ /
        ___) | |_) | (_| | |  |   <
       |____/| .__/ \__,_|_|  |_|\_\
             |_|   Game Toolkit

Copyright  2024-present tinyBigGAMES LLC
         All Rights Reserved.

Website: https://tinybiggames.com
Email  : support@tinybiggames.com

See LICENSE for license information
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *)

unit SGT.GenAI;

{$I SGT.Defines.inc}

interface

uses
  WinApi.Windows,
  System.Generics.Collections,
  System.SysUtils,
  System.Classes,
  System.IOUtils,
  System.Math,
  VCL.Forms,
  SGT.Deps,
  SGT.Deps.Ext,
  SGT.Core;

const
  ROLE_SYSTEM = 'system';
  ROLE_USER = 'user';
  ROLE_ASSISTANT = 'assistant';
  ROLE_TOOL = 'tool';

type
  TGenAI_InferenceCancelCallback = function(const AUserData: Pointer): Boolean;
  TGenAI_InferenceTokenCallback = procedure(const AToken: string; const AUserData: Pointer);
  TGenAI_InfoCallback = procedure(const ALevel: Integer; const AText: string; const AUserData: Pointer);
  TGenAI_LoadModelProgressCallback = function(const AModelName: string; const AProgress: Single; const AUserData: Pointer): Boolean;
  TGenAI_LoadModelCallback = procedure(const AModelName: string; const ASuccess: Boolean; const AUserData: Pointer);
  TGenAI_InferenceStartCallback = procedure(const AUserData: Pointer);
  TGenAI_InferenceEndCallback = procedure(const AUserData: Pointer);

  { TGenAI }
  TGenAI = class(TBaseObject)
  private type
    TStats = record
      TokenInputSpeed: Double;
      TokenOutputSpeed: Double;
      InputTokens: Int32;
      OutputTokens: Int32;
      TotalTokens: Int32;
    end;
    TCallback<T> = record
      Handler: T;
      UserData: Pointer;
    end;
    TCallbacks = record
      InferenceCancel: TCallback<TGenAI_InferenceCancelCallback>;
      InferenceToken: TCallback<TGenAI_InferenceTokenCallback>;
      Info: TCallback<TGenAI_InfoCallback>;
      LoadModelProgress: TCallback<TGenAI_LoadModelProgressCallback>;
      LoadModel: TCallback<TGenAI_LoadModelCallback>;
      InferenceStart: TCallback<TGenAI_InferenceStartCallback>;
      InferenceEnd: TCallback<TGenAI_InferenceEndCallback>;
    end;
    TConfig = record
      ModelPath: string;
      NumGPULayers: Int32;
      NumThreads: Int32;
    end;
    TMessage = record
      Role: string;
      Content: string;
    end;
    TModel = record
      Filename: string;
      Name: string;
      MaxContext: UInt32;
      Template: string;
      TemplateEnd: string;
      Stop: TArray<string>;
    end;
    TInference = record
      Active: Boolean;
      ModelName: string;
      Prompt: string;
      Response: string;
    end;
    TMessageList = TList<TMessage>;
    TModelList = TDictionary<string, TModel>;
  private
    FError: string;
    FCallbacks: TCallbacks;
    FModel: Pllama_model;
    FModelParams: llama_model_params;
    FContext: Pllama_context;
    FContextParams: llama_context_params;
    FConfig: TConfig;
    FMessageList: TMessageList;
    FModelList: TModelList;
    FStats: TStats;
    FInference: TInference;
    //FTokenResponse: TTokenResponse;
    FLastUserMessage: string;
    FModelName: string;
  private
    //procedure OnAppIdle(Sender: TObject; var Done: Boolean);
    function Tokenize(const AContext: Pllama_context; const AText: string; const AAddSpecial: Boolean; const AParseSpecial: Boolean=False): TVector<llama_token>;
    function TokenToPiece(const AContext: Pllama_context; const AToken: llama_token; const ASpecial: Boolean=True): string;
    //function ShouldAddBOSToken(const AModel: Pllama_model): Boolean;

    function  OnInferenceCancel(): Boolean;
    procedure OnInferenceToken(const AToken: string);
    procedure OnInfo(const ALevel: Integer; const AText: string);
    function  OnLoadModelProgress(const AModelName: string; const AProgress: Single): Boolean;
    procedure OnLoadModel(const AModelName: string; const ASuccess: Boolean);
    procedure OnInferenceStart();
    procedure OnInferenceEnd();

  public
    constructor Create(); override;
    destructor Destroy(); override;

    procedure ClearError();
    procedure SetError(const AMsg: string; const AArgs: array of const); overload;
    procedure SetError(const AText: string); overload;
    function  GetError(): string;

    function  GetInferenceCancelCallback(): TGenAI_InferenceCancelCallback;
    procedure SetInferenceCancelCallback(const AHandler: TGenAI_InferenceCancelCallback; const AUserData: Pointer);

    function  GetInferenceTokenCallback(): TGenAI_InferenceTokenCallback;
    procedure SetInferenceTokenCallback(const AHandler: TGenAI_InferenceTokenCallback; const AUserData: Pointer);

    function  GetInfoCallback(): TGenAI_InfoCallback;
    procedure SetInfoCallback(const AHandler: TGenAI_InfoCallback; const AUserData: Pointer);

    function  GetLoadModelProgressCallback(): TGenAI_LoadModelProgressCallback;
    procedure SetLoadModelProgressCallback(const AHandler: TGenAI_LoadModelProgressCallback; const AUserData: Pointer);

    function  GetLoadModelCallback(): TGenAI_LoadModelCallback;
    procedure SetLoadModelCallback(const AHandler: TGenAI_LoadModelCallback; const AUserData: Pointer);

    function  GetInferenceStartCallback(): TGenAI_InferenceStartCallback;
    procedure SetInferenceStartCallback(const AHandler: TGenAI_InferenceStartCallback; const AUserData: Pointer);

    function  GetInferenceEndCallback(): TGenAI_InferenceEndCallback;
    procedure SetInferenceEndCallback(const AHandler: TGenAI_InferenceEndCallback; const AUserData: Pointer);

    procedure InitConfig(const AModelPath: string; const ANumGPULayers, ANumThreads: Int32);
    function  SaveConfig(const AFilename: string): Boolean;
    function  LoadConfig(const AFilename: string): Boolean;

    procedure ClearAllMessages();
    function  AddMessage(const ARole, AContent: string): Int32;
    function  GetLastUserMessage(): string;
    function  BuildMessageInferencePrompt(const AModelName: string): string;

    procedure ClearModelDefines();
    function  DefineModel(const AModelFilename, AModelName: string; const AMaxContext: UInt32; const ATemplate, ATemplateEnd: string): Int32;
    function  SaveModelDefines(const AFilename: string): Boolean;
    function  LoadModelDefines(const AFilename: string): Boolean;
    procedure ClearModelStopSequences(const AModelName: string);
    function  AddModelStopSequence(const AModelName, AToken: string): Int32;
    function  GetModelStopSequenceCount(const AModelName: string): Int32;
    function  ResetContext(): Boolean;
    function  LoadModel(const AModelName: string): Boolean;
    function  IsModelLoaded(): Boolean;
    procedure UnloadModel();

    function  RunInference(const AModelName: string; const AMaxTokens: UInt32): Boolean;
    function  IsInferenceActive(): Boolean;
    function  GetInferenceResponse(): string;
    procedure GetInferenceStats(ATokenInputSpeed: System.PSingle; ATokenOutputSpeed: System.PSingle; AInputTokens: PInt32; AOutputTokens: PInt32; ATotalTokens: PInt32);
  end;

var
  TokenResponse: TTokenResponse;

implementation


{ TGenAI }
function TGenAI_ModelLoadProgressCallback(AProgress: single; AUserData: pointer): Boolean; cdecl;
var
  LLMEngine: TGenAI;
begin
  LLMEngine := AUserData;
  if Assigned(LLMEngine) then
    Result := LLMEngine.OnLoadModelProgress(LLMEngine.FInference.ModelName, AProgress)
  else
    Result := True;
end;

procedure TGenAI_LogCallback(ALevel: ggml_log_level; const AText: PUTF8Char; AUserData: Pointer); cdecl;
begin
  if Assigned(AUserData) then
    TGenAI(AUserData).OnInfo(ALevel, Utf8ToString(AText));
end;

procedure TGenAI_CErrCallback(const AText: PUTF8Char; AUserData: Pointer); cdecl;
begin
  if Assigned(AUserData) then
    TGenAI(AUserData).OnInfo(GGML_LOG_LEVEL_ERROR, Utf8ToString(AText));
end;

(*
procedure TGenAI.OnAppIdle(Sender: TObject; var Done: Boolean);
begin
  Done := False;
  //Application.ProcessMessages();
end;
*)

function TGenAI.Tokenize(const AContext: Pllama_context; const AText: string; const AAddSpecial: Boolean; const AParseSpecial: Boolean): TVector<llama_token>;
var
  LNumTokens: Integer;
  LResult: TVector<llama_token>;
  LText: UTF8String;
  LTokens: TArray<llama_token>;
begin
  Result := nil;

  try
    LResult := TVector<llama_token>.Create;
    LText := UTF8Encode(AText);

    // Upper limit for the number of tokens
    LNumTokens := Length(LText) + 2 * Ord(AAddSpecial);
    SetLength(LTokens, LNumTokens);

    LNumTokens := llama_tokenize(llama_get_model(AContext), PUTF8Char(LText), Length(LText), @LTokens[0], Length(LTokens), AAddSpecial, AParseSpecial);

    if LNumTokens < 0 then
    begin
      SetLength(LTokens, -LNumTokens);
      LNumTokens := llama_tokenize(llama_get_model(AContext), PUTF8Char(LText), Length(LText), @LTokens[0], Length(LTokens), AAddSpecial, AParseSpecial);
      Assert(LNumTokens = -Length(LTokens));
    end
    else
    begin
      SetLength(LTokens, LNumTokens);
    end;

    LResult.Resize(LNumTokens);
    Move(LTokens[0], LResult.Data^, LNumTokens * SizeOf(llama_token));
    Result := LResult;
  except
    on E: Exception do
    begin
      SetError(E.Message);
      Exit;
    end;
  end;
end;

function TGenAI.TokenToPiece(const AContext: Pllama_context; const AToken: llama_token; const ASpecial: Boolean): string;
var
  LTokens: Int32;
  LCheck: Int32;
  LBuffer: TArray<UTF8Char>;
begin
  try
    SetLength(LBuffer, 9);
    LTokens := llama_token_to_piece(llama_get_model(AContext), AToken, @LBuffer[0], 8, 0, ASpecial);
    if LTokens < 0 then
      begin
        SetLength(LBuffer, (-LTokens)+1);
        LCheck := llama_token_to_piece(llama_get_model(AContext), AToken, @LBuffer[0], -LTokens, 0, ASpecial);
        Assert(LCheck = -LTokens);
        LBuffer[-LTokens] := #0;
      end
    else
      begin
        LBuffer[LTokens] := #0;
      end;
    Result := UTF8ToString(@LBuffer[0]);
  except
    on E: Exception do
    begin
      SetError(E.Message);
      Exit;
    end;
  end;
end;

(*
function TLMEngine.ShouldAddBOSToken(const AModel: Pllama_model): Boolean;
var
  LAddBOS: Integer;
begin
  LAddBOS := llama_add_bos_token(AModel);
  if LAddBOS <> -1 then
    Result := Boolean(LAddBOS)
  else
    Result := llama_vocab_type(AModel) = LLAMA_VOCAB_TYPE_SPM;
end;
*)

function TGenAI.OnInferenceCancel(): Boolean;
begin
  if Assigned(FCallbacks.InferenceCancel.Handler) then
    begin
      Result := FCallbacks.InferenceCancel.Handler(FCallbacks.InferenceCancel.UserData);
    end
  else
    begin
      Result := Console.WasKeyReleased(VK_ESCAPE);
    end;
end;

procedure TGenAI.OnInferenceToken(const AToken: string);
begin
  if Assigned(FCallbacks.InferenceToken.Handler) then
    begin
      FCallbacks.InferenceToken.Handler(AToken, FCallbacks.InferenceToken.UserData);
    end
  else
    begin
      Console.Print(AToken);
    end;
end;

procedure TGenAI.OnInfo(const ALevel: Integer; const AText: string);
begin
  if Assigned(FCallbacks.Info.Handler) then
  begin
    FCallbacks.Info.Handler(ALevel, AText, FCallbacks.Info.UserData);
  end;
end;

function  TGenAI.OnLoadModelProgress(const AModelName: string; const AProgress: Single): Boolean;
begin
  Result := True;

  if Assigned(FCallbacks.LoadModelProgress.Handler) then
    begin
      Result := FCallbacks.LoadModelProgress.Handler(AModelName, AProgress, FCallbacks.LoadModelProgress.UserData);
    end
  else
    begin
      Console.Print(Console.CR+'Loading model "%s" (%3.2f%s)...', [AModelName, AProgress*100, '%'], Console.FG_CYAN);
      if AProgress >= 1 then
      begin
        Console.ClearLine(Console.FG_WHITE);
      end;
    end;
end;

procedure TGenAI.OnLoadModel(const AModelName: string; const ASuccess: Boolean);
begin
  if Assigned(FCallbacks.LoadModel.Handler) then
  begin
    FCallbacks.LoadModel.Handler(AModelName, ASuccess, FCallbacks.LoadModel.UserData);
  end;
end;

procedure TGenAI.OnInferenceStart();
begin
  if Assigned(FCallbacks.InferenceStart.Handler) then
  begin
    FCallbacks.InferenceStart.Handler(FCallbacks.InferenceStart.UserData);
  end;
end;

procedure TGenAI.OnInferenceEnd();
begin
  if Assigned(FCallbacks.InferenceEnd.Handler) then
  begin
    FCallbacks.InferenceEnd.Handler(FCallbacks.InferenceEnd.UserData);
  end;
end;

constructor TGenAI.Create();
begin
  inherited;
  FMessageList := TMessageList.Create();
  FModelList := TModelList.Create();

  //Application.Initialize;
  //Application.OnIdle := OnAppIdle;
end;

destructor TGenAI.Destroy();
begin
  UnloadModel();

  if Assigned(FModelList) then
    FModelList.Free();

  if Assigned(FMessageList) then
    FMessageList.Free();

  Application.OnIdle := nil;

  inherited;
end;

procedure TGenAI.ClearError();
begin
  FError := '';
end;

procedure TGenAI.SetError(const AMsg: string; const AArgs: array of const);
begin
  FError := Format(AMsg, AArgs);
end;

procedure TGenAI.SetError(const AText: string);
begin
  FError := AText;
end;

function  TGenAI.GetError(): string;
begin
  Result := FError;
end;

function  TGenAI.GetInferenceCancelCallback(): TGenAI_InferenceCancelCallback;
begin
  Result := FCallbacks.InferenceCancel.Handler;
end;

procedure TGenAI.SetInferenceCancelCallback(const AHandler: TGenAI_InferenceCancelCallback; const AUserData: Pointer);
begin
  FCallbacks.InferenceCancel.Handler := AHandler;
  FCallbacks.InferenceCancel.UserData := AUserData;
end;

function  TGenAI.GetInferenceTokenCallback(): TGenAI_InferenceTokenCallback;
begin
  Result := FCallbacks.InferenceToken.Handler;
end;

procedure TGenAI.SetInferenceTokenCallback(const AHandler: TGenAI_InferenceTokenCallback; const AUserData: Pointer);
begin
  FCallbacks.InferenceToken.Handler := AHandler;
  FCallbacks.InferenceToken.UserData := AUserData;
end;

function  TGenAI.GetInfoCallback(): TGenAI_InfoCallback;
begin
  Result := FCallbacks.Info.Handler;
end;

procedure TGenAI.SetInfoCallback(const AHandler: TGenAI_InfoCallback; const AUserData: Pointer);
begin
  FCallbacks.Info.Handler := AHandler;
  FCallbacks.Info.UserData := AUserData;
end;

function  TGenAI.GetLoadModelProgressCallback(): TGenAI_LoadModelProgressCallback;
begin
  Result := FCallbacks.LoadModelProgress.Handler;
end;

procedure TGenAI.SetLoadModelProgressCallback(const AHandler: TGenAI_LoadModelProgressCallback; const AUserData: Pointer);
begin
  FCallbacks.LoadModelProgress.Handler := AHandler;
  FCallbacks.LoadModelProgress.UserData := AUserData;
end;

function  TGenAI.GetLoadModelCallback(): TGenAI_LoadModelCallback;
begin
  Result := FCallbacks.LoadModel.Handler;
end;

procedure TGenAI.SetLoadModelCallback(const AHandler: TGenAI_LoadModelCallback; const AUserData: Pointer);
begin
  FCallbacks.LoadModel.Handler := AHandler;
  FCallbacks.LoadModel.UserData := AUserData;
end;

function  TGenAI.GetInferenceStartCallback(): TGenAI_InferenceStartCallback;
begin
  Result := FCallbacks.InferenceStart.Handler;
end;

procedure TGenAI.SetInferenceStartCallback(const AHandler: TGenAI_InferenceStartCallback; const AUserData: Pointer);
begin
  FCallbacks.InferenceStart.Handler := AHandler;
  FCallbacks.InferenceStart.UserData := AUserData;
end;

function  TGenAI.GetInferenceEndCallback(): TGenAI_InferenceEndCallback;
begin
  Result := FCallbacks.InferenceEnd.Handler;
end;

procedure TGenAI.SetInferenceEndCallback(const AHandler: TGenAI_InferenceEndCallback; const AUserData: Pointer);
begin
  FCallbacks.InferenceEnd.Handler := AHandler;
  FCallbacks.InferenceEnd.UserData := AUserData;
end;

procedure TGenAI.InitConfig(const AModelPath: string; const ANumGPULayers, ANumThreads: Int32);
var
  LNumGPULayers: Int32;
  LNumThreads: Int32;
begin
  FConfig.ModelPath := AModelPath;

  if ANumGPULayers < 0 then
    LNumGPULayers := MaxInt
  else
    LNumGPULayers := ANumGPULayers;

  if ANumThreads < 0 then
    LNumThreads := MaxInt
  else
    LNumThreads := ANumThreads;

  FConfig.NumGPULayers := EnsureRange(LNumGPULayers, 0, MaxInt);
  FConfig.NumThreads := EnsureRange(LNumThreads, 1, Utils.GetPhysicalProcessorCount());
end;

function  TGenAI.SaveConfig(const AFilename: string): Boolean;
var
  LJson: TJsonObject;
  LFilename: string;
begin
  Result := False;

  if AFilename.IsEmpty then
  begin
    SetError('[%s] %s', ['SaveConfig', 'Filename can not be blank']);
    Exit;
  end;

  LFilename := TPath.ChangeExtension(AFilename, 'json');

  try
    LJson := TJsonObject.Create();
    try
      LJson.S['model_path'] := FConfig.ModelPath;
      LJson.I['gpu_layers'] := FConfig.NumGPULayers;
      LJson.I['threads'] := FConfig.NumThreads;

      TFile.WriteAllText(LFilename, LJson.Format(), TEncoding.UTF8);

      Result := TFile.Exists(LFilename);
    finally
      LJson.Free();
    end;
  except
    on E: Exception do
    begin
      SetError('[%s] %s', ['SaveConfig', E.Message]);
      Result := False;
    end;
  end;
end;

function  TGenAI.LoadConfig(const AFilename: string): Boolean;
var
  LFilename: string;
  LJson: TJsonObject;
  LConfig: TConfig;
begin
  Result := False;

  LFilename := TPath.ChangeExtension(AFilename, 'json');

  if not TFile.Exists(LFilename) then
  begin
    SetError('[%s] File was not found: %s', ['LoadConfig', LFilename]);
    Exit;
  end;

  try
    LJson := TJsonObject.Parse(TFile.ReadAllText(LFilename, TEncoding.UTF8));

    try
      if LJson.Contains('model_path') then
        begin
          LConfig.ModelPath := LJson.S['model_path'];
        end
      else
        begin
          SetError('[%s] "model_path" field was not found', ['LoadConfig']);
          Exit;
        end;

      if LJson.Contains('gpu_layers') then
        begin
          LConfig.NumGPULayers := LJson.I['gpu_layers'];
        end
      else
        begin
          SetError('[%s] "gpu_layers" field was not found', ['LoadConfig']);
          Exit;
        end;

      if LJson.Contains('threads') then
        begin
          LConfig.NumGPULayers := LJson.I['threads'];
        end
      else
        begin
          SetError('[%s] "threads" field was not found', ['LoadConfig']);
          Exit;
        end;

      InitConfig(LConfig.ModelPath, LConfig.NumGPULayers, LConfig.NumThreads);

      Result := True;

    finally
      LJson.Free();
    end;
  except
    on E: Exception do
    begin
      SetError('[%s] %s', ['LoadConfig', E.Message]);
      Result := False;
    end;
  end;
end;

procedure TGenAI.ClearAllMessages();
begin
  FMessageList.Clear();
end;

function TGenAI.AddMessage(const ARole, AContent: string): Int32;
var
  LMessage: TMessage;
begin
  LMessage.Role := ARole;
  LMessage.Content := AContent;
  FMessageList.Add(LMessage);
  Result := FMessageList.Count;
  if Utils.ContainsText(ARole, 'user') then
    FLastUserMessage := AContent;
end;

function  TGenAI.GetLastUserMessage(): string;
begin
  Result := FLastUserMessage;
end;

function  TGenAI.BuildMessageInferencePrompt(const AModelName: string): string;
var
  LModel: TModel;
  LMessage: TMessage;
begin
  Result := '';

  if FModelList.TryGetValue(AModelName, LModel) then
  begin
    for LMessage in FMessageList do
    begin
      FInference.Prompt := FInference.Prompt + LModel.Template.Replace('{role}', LMessage.Role).Replace('{content}', LMessage.Content).Trim;
    end;
    FInference.Prompt := FInference.Prompt + LModel.TemplateEnd;
  end;

  Result := FInference.Prompt;
end;

procedure TGenAI.ClearModelDefines();
begin
  FModelList.Clear();
end;

function TGenAI.DefineModel(const AModelFilename, AModelName: string; const AMaxContext: UInt32; const ATemplate, ATemplateEnd: string): Int32;
var
  LModel: TModel;
begin
  LModel := Default(TModel);
  LModel.Filename := AModelFilename;
  LModel.Name := AModelName;
  LModel.MaxContext := AMaxContext;
  LModel.Template := ATemplate;
  LModel.TemplateEnd := ATemplateEnd;
  FModelList.AddOrSetValue(AModelName, LModel);
  Result := FModelList.Count
end;

function  TGenAI.SaveModelDefines(const AFilename: string): Boolean;
var
  LFilename: string;
  LJson: TJsonObject;
  LObject: TJsonObject;
  LModel: TPair<string, TModel>;
begin
  Result := False;

  if AFilename.IsEmpty then
  begin
    SetError('[%s] %s', ['SaveModelDefines', 'Filename can not be blank']);
    Exit;
  end;

  LFilename := TPath.ChangeExtension(AFilename, 'json');

  try
    LJson := TJsonObject.Create();
    try
      with LJson.AddArray('models') do
      begin
        for LModel in FModelList do
        begin
          LObject := TJsonObject.Create();
          LObject.S['filename'] := LModel.Value.Filename;
          LObject.S['name'] := LModel.Value.Name;
          LObject.I['max_context'] := LModel.Value.MaxContext;
          LObject.S['template'] := LModel.Value.Template;
          LObject.S['template_end'] := LModel.Value.TemplateEnd;
          Add(LObject);
        end;
      end;

      TFile.WriteAllText(LFilename, LJson.Format(), TEncoding.UTF8);

      Result := TFile.Exists(LFilename);

    finally
      LJson.Free();
    end;

  except
    on E: Exception do
    begin
      SetError('[%s] %s', ['SaveModelDefines', E.Message]);
      Result := False;
    end;
  end;

end;

function  TGenAI.LoadModelDefines(const AFilename: string): Boolean;
var
  LFilename: string;
  LJson: TJsonObject;
  LModel: TModel;
  I, LCount: Integer;
begin
  Result := False;
  LModel := Default(TModel);

  LFilename := TPath.ChangeExtension(AFilename, 'json');

  if not TFile.Exists(LFilename) then
  begin
    SetError('[%s] File was not found: %s', ['LoadModelDefines', LFilename]);
    Exit;
  end;

  try
    LJson := TJsonObject.Parse(TFile.ReadAllText(LFilename, TEncoding.UTF8));

    ClearModelDefines();

    try
      if not LJson.Contains('models') then
      begin
        SetError('[%s] "models" field was not found', ['LoadModelDefines']);
        Exit;
      end;

      LCount := LJson.A['models'].Count;

      for I := 0 to LCount-1 do
      begin
        if LJson.A['models'].Items[I].FindValue('filename') <> nil then
          begin
            LModel.Filename := LJson.A['models'].Items[I].FindValue('filename').Value;
          end
        else
          begin
            SetError('[%s] "filename" field was not found', ['LoadModelDefines']);
            Exit;
          end;

        if LJson.A['models'].Items[I].FindValue('name') <> nil then
          begin
            LModel.Name := LJson.A['models'].Items[I].FindValue('name').Value;
          end
        else
          begin
            SetError('[%s] "name" field was not found', ['LoadModelDefines']);
            Exit;
          end;

        if LJson.A['models'].Items[I].FindValue('max_context') <> nil then
          begin
            LModel.MaxContext := LJson.A['models'].Items[I].FindValue('max_context').Value.ToInt64;
          end
        else
          begin
            SetError('[%s] "max_context" field was not found', ['LoadModelDefines']);
            Exit;
          end;

        if LJson.A['models'].Items[I].FindValue('template') <> nil then
          begin
            LModel.Template := LJson.A['models'].Items[I].FindValue('template').Value;
          end
        else
          begin
            SetError('[%s] "template" field was not found', ['LoadModelDefines']);
            Exit;
          end;

        if LJson.A['models'].Items[I].FindValue('template_end') <> nil then
          begin
            LModel.TemplateEnd := LJson.A['models'].Items[I].FindValue('template_end').Value;
          end
        else
          begin
            SetError('[%s] "template_end" field was not found', ['LoadModelDefines']);
            Exit;
          end;

        DefineModel(LModel.Filename, LModel.Name, LModel.MaxContext, LModel.Template, LModel.TemplateEnd);
      end;

      Result := True;

    finally
      LJson.Free();
    end;
  except
    on E: Exception do
    begin
      SetError('[%s] %s', ['LoadModelDefines', E.Message]);
      Result := False;
    end;
  end;
end;

procedure TGenAI.ClearModelStopSequences(const AModelName: string);
var
  LModel: TModel;
begin
  if FModelList.TryGetValue(AModelName, LModel) then
  begin
    LModel.Stop := nil;
    FModelList.AddOrSetValue(AModelName, LModel)
  end;
end;

function  TGenAI.AddModelStopSequence(const AModelName, AToken: string): Int32;
var
  LModel: TModel;
  I: Integer;
begin
  Result := -1;
  if FModelList.TryGetValue(AModelName, LModel) then
  begin
    I := Length(LModel.Stop);
    SetLength(LModel.Stop, I + 1);
    LModel.Stop[I] := AToken;
    FModelList.AddOrSetValue(AModelName, LModel);
  end;
end;
function  TGenAI.GetModelStopSequenceCount(const AModelName: string): Int32;
var
  LModel: TModel;
begin
  Result := 0;
  if FModelList.TryGetValue(AModelName, LModel) then
  begin
    Result := Length(LModel.Stop);
  end;
end;

function TGenAI.ResetContext(): Boolean;
var
  LModel: TModel;
begin
  Result := False;
  if not IsModelLoaded() then Exit;

  llama_free(FContext);
  FContext := nil;

  // check for valid model name
  if not FModelList.TryGetValue(FModelName, LModel) then
  begin
    SetError('[%s] Model not found: "%s"', ['LoadModel', FModelName]);
    Exit;
  end;

  FContextParams := llama_context_default_params();
  //FContextParams.flash_attn := true;
  FContextParams.offload_kqv := true;
  FContextParams.seed  := 1234;
  FContextParams.n_ctx := LModel.MaxContext;
  //FContextParams.n_threads := Utils.GetPhysicalProcessorCount();
  FContextParams.n_threads := FConfig.NumThreads;
  FContextParams.n_threads_batch := FContextParams.n_threads;
  FContext := llama_new_context_with_model(FModel, FContextParams);
  if not Assigned(FContext) then
  begin
    SetError('[ResetContext] Failed to reset context for model: "%s"', [LModel.Filename]);
    Exit;
  end;
  Result := True;
end;

function  TGenAI.LoadModel(const AModelName: string): Boolean;
var
  LModel: TModel;
  LFilename: string;
begin
  Result := False;

  try
    // check for valid model name
    if not FModelList.TryGetValue(AModelName, LModel) then
    begin
      SetError('[%s] Model not found: "%s"', ['LoadModel', AModelName]);
      Exit;
    end;

    // Model already loaded
    if IsModelLoaded() then
    begin
      if SameText(LModel.Name, AModelName) then
      begin
        Result := True;
        Exit;
      end;

      // currently loaded model is not AModelName, so unload and load requested one
      UnloadModel();
    end;


    LFilename := TPath.Combine(FConfig.ModelPath, LModel.Filename);
    if not TFile.Exists(LFilename) then
    begin
      SetError('[LoadModel] Model file was not found: "%s"', [LFilename]);
      Exit;
    end;
    FInference.ModelName := AModelName;

    redirect_cerr_to_callback(TGenAI_CErrCallback, Self);
    llama_log_set(TGenAI_LogCallback, Self);
    llama_backend_init();
    llama_numa_init(GGML_NUMA_STRATEGY_DISTRIBUTE);

    FModelParams := llama_model_default_params();
    FModelParams.progress_callback_user_data := Self;
    FModelParams.progress_callback := TGenAI_ModelLoadProgressCallback;
    FModelParams.n_gpu_layers := FConfig.NumGPULayers;
    FModel := llama_load_model_from_file(Utils.AsUTF8(LFilename), FModelParams);
    if not Assigned(FModel) then
    begin
      OnLoadModel(FInference.ModelName, False);
      llama_backend_free();
      SetError('[LoadModel] Failed to load model file: "%s"', [LFilename]);
      Exit;
    end;
    OnLoadModel(FInference.ModelName, True);

    FContextParams := llama_context_default_params();
    FContextParams.flash_attn := true;
    FContextParams.offload_kqv := true;
    FContextParams.seed  := 1234;
    FContextParams.n_ctx := LModel.MaxContext;
    //FContextParams.n_threads := Utils.GetPhysicalProcessorCount();
    FContextParams.n_threads := FConfig.NumThreads;
    FContextParams.n_threads_batch := FContextParams.n_threads;
    FContext := llama_new_context_with_model(FModel, FContextParams);
    if not Assigned(FContext) then
    begin
      llama_free_model(FModel);
      llama_backend_free();
      SetError('[LoadModel] Failed to load model file: "%s"', [LFilename]);
      Exit;
    end;

    FModelName := AModelName;
    Result := True;
  except
    on E: Exception do
    begin
      SetError(E.Message);
      Exit;
    end;
  end;
end;

function  TGenAI.IsModelLoaded(): Boolean;
begin
  Result := Boolean(Assigned(FModel) and Assigned(FContext));
end;

procedure TGenAI.UnloadModel();
begin
  if not IsModelLoaded() then Exit;
  llama_free(FContext);
  FContext := nil;
  llama_free_model(FModel);
  FModel := nil;
  llama_backend_free();
  restore_cerr();
  FModelName := '';
end;

function  TGenAI.RunInference(const AModelName: string; const AMaxTokens: UInt32): Boolean;
var
  LPast: UInt32;
  LRemain: UInt32;
  LConsumed: UInt32;
  LSamplingContext: Pointer;
  I: UInt32;
  LPredict: UInt32;
  LBatch: UInt32;
  LEval: UInt32;
  LId: llama_token;
  LMaxEmbedSize: UInt32;
  LSkippedTokens: UInt32;
  LEmbedInput: TVector<llama_token>;
  LEmbed: TVector<llama_token>;
  LTimings: llama_timings;
  LTokenStr: string;
  LFirstToken: Boolean;

begin
  Result := False;

  try
    // check if inference is already runnig
    if FInference.Active then
    begin
      SetError('[%s] Inference already active', ['RunInference']);
      Exit;
    end;

    // start new inference
    FInference := Default(TInference);

    // check if model not loaded
    if not LoadModel(AModelName) then
    begin
      Exit;
    end;

    // build prompt message
    FInference.Prompt := BuildMessageInferencePrompt(AModelName);
    if FInference.Prompt.IsEmpty then
    begin
      SetError('[%s] Inference prompt was empty', ['RunInference']);
      Exit;
    end;

    FInference.Active := True;
    FInference.Response := '';

    OnInferenceStart();
    try
      LEmbedInput := tokenize(FContext, FInference.Prompt, true, true);
      try
        if LEmbedInput.empty() then
          LEmbedInput.Add(llama_token_bos(FModel));

        LMaxEmbedSize := llama_n_ctx(FContext) - 4;
        if LEmbedInput.Count() > LMaxEmbedSize then
        begin
          LSkippedTokens := LEmbedInput.count() - LMaxEmbedSize;
          SetError('[%s] Input too long: %d tokens over max context of %d', ['RunInference', LSkippedTokens, LMaxEmbedSize]);
          Exit;
        end;

        LEmbed := TVector<llama_token>.Create();
        try
          LSamplingContext := _llama_sampling_init();
          try
            LPredict := AMaxTokens;
            LBatch := FContextParams.n_ubatch;

            LPast := 0;
            LRemain := LPredict;
            LConsumed := 0;
            LFirstToken := True;

            llama_reset_timings(FContext);
            while LRemain <> 0 do
            begin
              if OnInferenceCancel() then
              begin
                Break;
              end;

              if LEmbed.Count <> 0 then
              begin
                I := 0;
                while I < LEmbed.Count do
                begin
                  LEval := LEmbed.Count - I;
                  if LEval > LBatch then
                    LEval := LBatch;

                  if llama_decode(FContext, llama_batch_get_one(@LEmbed.FItems[I], LEval, LPast, 0)) <> 0 then
                  begin
                    SetError('Error in llama_decode with Vulkan backend');
                    Break;
                  end;

                  Inc(LPast, LEval);
                  Inc(I, LBatch);
                end;
                LEmbed.Clear;
              end;

              if LEmbedInput.Count <= LConsumed then
                begin
                  LId := _llama_sampling_sample(LSamplingContext, FContext, nil);
                  if llama_token_is_eog(FModel, LId) then
                  begin
                    Break;
                  end;

                  _llama_sampling_accept(LSamplingContext, FContext, LId, True);
                  LEmbed.Add(LId);
                  Dec(LRemain);

                  LTokenStr := TokenToPiece(FContext, LId, False);
                  if LFirstToken then
                  begin
                    LFirstToken := False;
                    LTokenStr := LTokenStr.TrimLeft();
                  end;

                  FInference.Response := FInference.Response + LTokenStr;
                  OnInferenceToken(LTokenStr);

                end
              else
                begin
                  while LEmbedInput.Count > LConsumed do
                  begin
                    LEmbed.Add(LEmbedInput[LConsumed]);
                    _llama_sampling_accept(LSamplingContext, FContext, LEmbedInput[LConsumed], False);
                    Inc(LConsumed);
                    if LEmbed.Count >= LBatch then
                    begin
                      Break;
                    end;
                  end;
                end;
            end;

            // get usage
            LTimings := llama_get_timings(FContext);
            FStats.InputTokens := LTimings.n_p_eval;
            FStats.OutputTokens := LTimings.n_eval;
            FStats.TokenInputSpeed := 1e3 / LTimings.t_p_eval_ms * LTimings.n_p_eval;
            FStats.TokenOutputSpeed := 1e3 / LTimings.t_eval_ms * LTimings.n_eval;
            FStats.TotalTokens := FStats.InputTokens + FStats.OutputTokens;
            Result := True;
          finally
            _llama_sampling_free(LSamplingContext);
          end;
        finally
          LEmbed.Free();
        end;
      finally
        LEmbedInput.Free();
      end;
    finally
      FInference.Active := False;
      OnInferenceEnd();
    end;
  except
    on E: Exception do
    begin
      SetError(E.Message);
      Exit;
    end;
  end;
end;

function  TGenAI.IsInferenceActive(): Boolean;
begin
  Result := FInference.Active;
end;

function  TGenAI.GetInferenceResponse(): string;
begin
  Result := FInference.Response;
end;

procedure TGenAI.GetInferenceStats(ATokenInputSpeed: System.PSingle; ATokenOutputSpeed: System.PSingle; AInputTokens: PInt32; AOutputTokens: PInt32; ATotalTokens: PInt32);
begin
  if Assigned(ATokenInputSpeed) then
    ATokenInputSpeed^ := FStats.TokenInputSpeed;

  if Assigned(ATokenOutputSpeed) then
    ATokenOutputSpeed^ := FStats.TokenOutputSpeed;

  if Assigned(AInputTokens) then
    AInputTokens^ := FStats.InputTokens;

  if Assigned(AOutputTokens) then
    AOutputTokens^ := FStats.OutputTokens;

  if Assigned(ATotalTokens) then
    ATotalTokens^ := FStats.TotalTokens;
end;

end.
