#!/bin/bash

# Default values
model_family=""
threads=8
# Number of tokens to predict (made it larger than default because we want a long interaction)
n_predict=512

EXTRA_ARGS=('--color')

llm_dir="$(dirname "$(python -c "import ipex_llm;print(ipex_llm.__file__)")")"
lib_dir="$llm_dir/libs"
prompts_dir="$llm_dir/cli/prompts"

# Function to display help message
function display_help {
  echo "usage: ./llm-chat -x MODEL_FAMILY [-h] [args]"
  echo ""
  echo "options:"
  echo "  -h, --help           show this help message"
  echo "  -x, --model_family   {llama,gptneox}"
  echo "                       family name of model"
  echo "  -t N, --threads N    number of threads to use during computation (default: 8)"
  echo "  -n N, --n_predict N  number of tokens to predict (default: 128, -1 = infinity)"
  echo "  args                 parameters passed to the specified model function"
}

function llama {
  PROMPT_TEMPLATE="$prompts_dir/chat-with-llm.txt"
  EXTRA_ARGS+=('-i' '--file' "'$PROMPT_TEMPLATE'" '--reverse-prompt' "'USER:'" '--in-prefix' "' '")
  command="$lib_dir/main-llama -t $threads -n $n_predict ${filteredArguments[*]} ${EXTRA_ARGS[*]}"
  echo "$command"
  eval "$command"
}

function gptneox {
  PROMPT="A chat between a curious human <human> and an artificial intelligence assistant <bot>.\
  The assistant gives helpful, detailed, and polite answers to the human's questions."
  EXTRA_ARGS+=('--instruct' '-p' '"$PROMPT"')
  command="$lib_dir/main-gptneox -t $threads -n $n_predict ${filteredArguments[*]} ${EXTRA_ARGS[*]}"
  echo "$command"
  eval "$command"
}

# Remove model_family/x parameter
filteredArguments=()
while [[ $# -gt 0 ]]; do
  case "$1" in
  -h | --help)
    display_help
    shift
    ;;
  -x | --model_family | --model-family)
    model_family="$2"
    shift 2
    ;;
  -t | --threads)
    threads="$2"
    shift 2
    ;;
  -n | --n_predict | --n-predict)
    n_predict="$2"
    shift 2
    ;;
  *)
    filteredArguments+=("'$1'")
    shift
    ;;
  esac
done

# Perform actions based on the model_family
if [[ "$model_family" == "llama" ]]; then
  llama
elif [[ "$model_family" == "gptneox" ]]; then
  gptneox
else
  echo "llm-chat does not support model_family $model_family for now."
  display_help
fi
