#!/bin/bash

# this script converts torch's *bin and safetensor *safetensor files to bf16 creating a new checkpoint under a sub-dir bf16
#
# usage:
# cd checkpoint
# bash torch-checkpoint-convert-to-bf16

# set destination dir
target_dir=bf16

echo "creating a new checkpoint under dir $target_dir"
mkdir -p $target_dir

# cp config and other files - adapt if needed - could also do `cp * $target_dir`
cp *json *model $target_dir

# convert *bin
echo "converting *bin torch files"
python -c "import torch, sys; [torch.save({k:v.to(torch.bfloat16) for k,v in torch.load(f).items()}, f'{sys.argv[1]}/{f}') for f in sys.argv[2:]]" $target_dir *bin

# convert *safetensors (from original *bin files)
if compgen -G "*.safetensors" > /dev/null; then
    echo "converting *safetensors files"
    cd $target_dir
    python -c "import re, sys, torch; from safetensors.torch import save_file; [save_file(torch.load(f), re.sub(r'.*?(model.*?)\.bin',r'\1.safetensors',f), metadata={'format': 'pt'}) for f in sys.argv[1:]]" *bin
    if test -e "pytorch_model.bin.index.json"; then
        cp pytorch_model.bin.index.json model.safetensors.index.json
        perl -pi -e 's|pytorch_||; s|\.bin|.safetensors|' model.safetensors.index.json
    fi
    cd - > /dev/null
fi

echo "the dir $target_dir now contains a copy of the original checkpoint with bf16 weights"
