#!/bin/bash
#
# ROCm Container Runtime

set -Eeo pipefail

RUNTIME_CONFIG="/etc/rocm-container-runtime.conf"
if [ -f "${RUNTIME_CONFIG}" ]; then
  source ${RUNTIME_CONFIG}
fi
ENV_VISIBLE_DEVICES=${ENV_VISIBLE_DEVICES:="AMD_VISIBLE_DEVICES"}
LOG_FILE=${LOG_FILE:="/var/log/rocm-container-runtime.log"}


log() {
  echo "[$(date +%Y-%m-%dT%H:%M:%S%z)]" $@ >> ${LOG_FILE}
}

patch_runc_config() {
  runc_config="$1"
  if ! jq -e '.process.user.additionalGids' ${runc_config} > /dev/null 2>&1; then
    jq -c '.process.user += {"additionalGids": []}' ${runc_config} \
      | sponge ${runc_config}
  fi
  if ! jq -e '.linux.devices' ${runc_config} > /dev/null 2>&1; then
    jq -c '.linux += {"devices": []}' ${runc_config} \
      | sponge ${runc_config}
  fi
}

inject_group() {
  runc_config="$1"
  group="$2"

  gid="$(getent group ${group} | cut -d: -f3)"
  jq --arg gid "${gid}" \
    -c '.process.user.additionalGids += [$gid | tonumber]' ${runc_config} \
    | sponge ${runc_config}
  log "Injected video group id."
}

inject_device() {
  runc_config="$1"
  dev="$2"
  if [ ! -e ${dev} ]; then
    return
  fi

  major=$((0x$(stat -c %t ${dev})))
  minor=$((0x$(stat -c %T ${dev})))
  file_mode=$((020$(stat -c %a ${dev})))
  gid=$(stat -c %g ${dev})
  jq --arg dev ${dev} --arg major ${major} --arg minor ${minor} --arg file_mode ${file_mode} --arg gid ${gid} \
    -c '.linux.devices += [{
      "path": $dev,
      "type": "c",
      "major": $major | tonumber,
      "minor": $minor | tonumber,
      "fileMode": $file_mode | tonumber,
      "uid": 0,
      "gid": $gid | tonumber}]' ${runc_config} \
    | sponge ${runc_config}
  jq --arg major ${major} --arg minor ${minor} \
    -c '.linux.resources.devices += [{
      "allow": true,
      "type": "c",
      "major": $major | tonumber,
      "minor": $minor | tonumber,
      "access": "rwm"}]' ${runc_config} \
    | sponge ${runc_config}
  log "Injected device ${dev}."
}


log "Running $0."
runc_args=$@
while [ $# -gt 0 ]; do
  case "$1" in
    -b|--bundle)
      bundle_path="$2"
      shift 2
      ;;
    create)
      cmd="$1"
      shift
      ;;
    *)
      shift
      ;;
  esac
done

if [ "${cmd}" = "create" ]; then
  runc_config="${bundle_path}/config.json"
  visible_devices=$(jq .process.env ${runc_config} \
    | grep ${ENV_VISIBLE_DEVICES} \
    | grep -Eo "(-)?[[:digit:]]+(,[[:digit:]]+)*" \
    || true)

  case "${visible_devices}" in
    ""|-*)
      log "${ENV_VISIBLE_DEVICES}=${visible_devices}, skip AMD GPU."
      ;;
    *)
      log "${ENV_VISIBLE_DEVICES}=${visible_devices}, start to add AMD GPU."
      patch_runc_config ${runc_config}
      inject_group ${runc_config} "video"

      log "Searching AMD GPU devices."
      IFS="," read -r -a gpu_requests <<< ${visible_devices}
      gpu_pci_paths=($(find /sys/module/amdgpu/drivers/pci:amdgpu/ \
        -mindepth 1 -maxdepth 1 -name "[0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F]*" \
        | sort))
      log "Found ${gpu_pci_paths[@]}."

      log "Injecting AMD GPU devices."
      for request in ${gpu_requests[@]}; do
        if [ "$request" -lt ${#gpu_pci_paths[@]} ]; then
          devices=($(find ${gpu_pci_paths[$request]}/drm \
            -mindepth 1 -maxdepth 1 -type d -printf "%f\n"))
          log "Found ${devices[@]}."
          for dev in ${devices[@]}; do
            inject_device ${runc_config} "/dev/dri/${dev}"
          done
        fi
      done
      inject_device ${runc_config} "/dev/kfd"

      log "Updated runc config successfully."
      ;;
  esac
fi

log "Running runc ${runc_args}."
runc ${runc_args}
