#!/bin/bash

# Default values
model_family=""
threads=8
n_predict=128


llm_dir="$(dirname "$(python -c "import ipex_llm;print(ipex_llm.__file__)")")"
lib_dir="$llm_dir/libs"

# Function to display help message
function display_help {
  echo "usage: ./llm-cli.sh -x MODEL_FAMILY [-h] [args]"
  echo ""
  echo "options:"
  echo "  -h, --help           show this help message"
  echo "  -x, --model_family {llama,bloom,gptneox,starcoder,chatglm}"
  echo "                       family name of model"
  echo "  -t N, --threads N    number of threads to use during computation (default: 8)"
  echo "  -n N, --n_predict N  number of tokens to predict (default: 128, -1 = infinity)"
  echo "  args                 parameters passed to the specified model function"
}

function llama {
  command="$lib_dir/main-llama -t $threads -n $n_predict ${filteredArguments[*]}"
  echo "$command"
  eval "$command"
}

function bloom {
  command="$lib_dir/main-bloom -t $threads -n $n_predict ${filteredArguments[*]}"
  echo "$command"
  eval "$command"
}

function gptneox {
  command="$lib_dir/main-gptneox -t $threads -n $n_predict ${filteredArguments[*]}"
  echo "$command"
  eval "$command"
}

function starcoder {
  command="$lib_dir/main-starcoder -t $threads -n $n_predict ${filteredArguments[*]}"
  echo "$command"
  eval "$command"
}

function chatglm {
  if [[ $(lscpu | grep "amx_int8") ]]; then
    command="$lib_dir/main-chatglm_amx -t $threads -n $n_predict ${filteredArguments[*]}"
  else
    command="$lib_dir/main-chatglm_vnni -t $threads -n $n_predict ${filteredArguments[*]}"
  fi
  echo "$command"
  eval "$command"
}

# Remove model_family/x parameter
filteredArguments=()
while [[ $# -gt 0 ]]; do
  case "$1" in
  -h | --help)
    display_help
    filteredArguments+=("'$1'")
    shift
    ;;
  -x | --model_family | --model-family)
    model_family="$2"
    shift 2
    ;;
  -t | --threads)
    threads="$2"
    shift 2
    ;;
  -n | --n_predict | --n-predict)
    n_predict="$2"
    shift 2
    ;;
  *)
    filteredArguments+=("'$1'")
    shift
    ;;
  esac
done

# Perform actions based on the model_family
if [[ "$model_family" == "llama" ]]; then
  llama
elif [[ "$model_family" == "bloom" ]]; then
  bloom
elif [[ "$model_family" == "gptneox" ]]; then
  gptneox
elif [[ "$model_family" == "starcoder" ]]; then
  starcoder
elif [[ "$model_family" == "chatglm" ]]; then
  chatglm
else
  echo "Invalid model_family: $model_family"
  display_help
fi
