# Source: https://github.com/Fyusion/LLFF
import numpy as np
import os
import sys
import pdb
import json
import imageio
import skimage.transform

from .colmap_wrapper import run_colmap
from . import colmap_read_model as read_model


def load_colmap_data_nerfstudio(basedir):
    colmap_base_dir = os.path.join(basedir, 'colmap')
    camerasfile = os.path.join(colmap_base_dir, 'sparse/0/cameras.bin')
    camdata = read_model.read_cameras_binary(camerasfile)
    list_of_keys = list(camdata.keys())
    cam = camdata[list_of_keys[0]]
    print( 'Cameras', len(cam))

    h, w, f = cam.height, cam.width, cam.params[0]
    hwf = np.array([h, w, f]).reshape([3, 1])
    imagesfile = os.path.join(colmap_base_dir, 'sparse/0/images.bin')
    imdata = read_model.read_images_binary(imagesfile)
    
    w2c_mats = []
    bottom = np.array([0,0,0,1.]).reshape([1,4])
    
    names = [imdata[k].name for k in imdata]
    print( 'Images #', len(names))
    sorted_names = np.argsort(names)
    for k in imdata:
        im = imdata[k]
        R = im.qvec2rotmat()
        t = im.tvec.reshape([3,1])
        m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
        w2c_mats.append(m)
    
    w2c_mats = np.stack(w2c_mats, 0)
    c2w_mats = np.linalg.inv(w2c_mats)
    poses = c2w_mats[:, :3, :4].transpose([1, 2, 0])
    poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis], [1, 1, poses.shape[-1]])], 1)
    
    # load render poses
    render_cam_paths = json.load(open(os.path.join(basedir, 'camera_path.json')))
    h, w = render_cam_paths['render_height'], render_cam_paths['render_width']
    render_poses = np.array([p['camera_to_world'] for p in render_cam_paths['camera_path']]).reshape(-1, 4, 4)
    hwf = np.array([h, w, f]).reshape([3, 1])
    render_poses = render_poses[:, :3, :4].transpose([1, 2, 0])
    render_poses = np.concatenate([render_poses, np.tile(hwf[..., np.newaxis], [1, 1, render_poses.shape[-1]])], 1)
    
    train_num = len(c2w_mats)
    # poses = np.concatenate([poses, render_poses], 2)
    points3dfile = os.path.join(colmap_base_dir, 'sparse/0/points3D.bin')
    pts3d = read_model.read_points3d_binary(points3dfile)
    
    # todo: validate this effect, enabled by default, commented out by zelin
    # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t]
    poses = np.concatenate([poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:4, :], poses[:, 4:5, :]], 1)
    return poses, train_num, pts3d, sorted_names, names


def load_colmap_data(realdir):
    camerasfile = os.path.join(realdir, 'dense/sparse/cameras.bin')
    # camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')
    camdata = read_model.read_cameras_binary(camerasfile)
    
    # cam = camdata[camdata.keys()[0]]
    list_of_keys = list(camdata.keys())
    cam = camdata[list_of_keys[0]]
    print( 'Cameras', len(cam))

    h, w, f = cam.height, cam.width, cam.params[0]
    # w, h, f = factor * w, factor * h, factor * f
    hwf = np.array([h,w,f]).reshape([3,1])
    
    imagesfile = os.path.join(realdir, 'dense/sparse/images.bin')
    # imagesfile = os.path.join(realdir, 'sparse/0/images.bin')
    imdata = read_model.read_images_binary(imagesfile)
    
    w2c_mats = []
    bottom = np.array([0,0,0,1.]).reshape([1,4])
    
    names = [imdata[k].name for k in imdata]
    print( 'Images #', len(names))
    perm = np.argsort(names)
    for k in imdata:
        im = imdata[k]
        R = im.qvec2rotmat()
        t = im.tvec.reshape([3,1])
        m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
        w2c_mats.append(m)
    
    w2c_mats = np.stack(w2c_mats, 0)
    c2w_mats = np.linalg.inv(w2c_mats)
    
    poses = c2w_mats[:, :3, :4].transpose([1,2,0])
    poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis], [1,1,poses.shape[-1]])], 1)
    
    # points3dfile = os.path.join(realdir, 'dense/sparse/points3D.bin')
    points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin')
    pts3d = read_model.read_points3d_binary(points3dfile)
    
    # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t]
    poses = np.concatenate([poses[:, 1:2, :], poses[:, 0:1, :], -poses[:, 2:3, :], poses[:, 3:4, :], poses[:, 4:5, :]], 1)
    return poses, pts3d, perm, names


def save_poses(basedir, poses, pts3d, perm, names):
    pts_arr = []
    vis_arr = []
    for k in pts3d:
        pts_arr.append(pts3d[k].xyz)
        cams = [0] * poses.shape[-1]
        for ind in pts3d[k].image_ids:
            if len(cams) <= ind - 1:
                print('ERROR: the correct camera poses for current points cannot be accessed')
                return
            cams[ind-1] = 1
        vis_arr.append(cams)

    pts_arr = np.array(pts_arr)
    vis_arr = np.array(vis_arr)
    print( 'Points', pts_arr.shape, 'Visibility', vis_arr.shape )
    if len(pts_arr) < 1:
        raise RuntimeError("Points has zero shape!")
    zvals = np.sum(-(pts_arr[:, np.newaxis, :].transpose([2,0,1]) - poses[:3, 3:4, :]) * poses[:3, 2:3, :], 0)
    valid_z = zvals[vis_arr==1]
    print( 'Depth stats', valid_z.min(), valid_z.max(), valid_z.mean() )
    
    save_arr = []
    for i in perm:
        vis = vis_arr[:, i]
        zs = zvals[:, i]
        zs = zs[vis==1]
        close_depth, inf_depth = np.percentile(zs, .1), np.percentile(zs, 99.9)
        # print( i, close_depth, inf_depth )
        save_arr.append(np.concatenate([poses[..., i].ravel(), np.array([close_depth, inf_depth])], 0))
    save_arr = np.array(save_arr)
    
    np.save(os.path.join(basedir, 'poses_bounds.npy'), save_arr)
    np.save(os.path.join(basedir, 'poses_names.npy'), sorted(names))


def minify(basedir, factors=[], resolutions=[]):
    needtoload = False
    for r in factors:
        imgdir = os.path.join(basedir, 'images_{}'.format(r))
        if not os.path.exists(imgdir):
            needtoload = True
    for r in resolutions:
        imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
        if not os.path.exists(imgdir):
            needtoload = True
    if not needtoload:
        return

    from shutil import copy
    from subprocess import check_output

    imgdir = os.path.join(basedir, 'images')
    imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
    imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
    imgdir_orig = imgdir

    wd = os.getcwd()

    for r in factors + resolutions:
        if isinstance(r, int):
            name = 'images_{}'.format(r)
            resizearg = '{}%'.format(int(100./r))
        else:
            name = 'images_{}x{}'.format(r[1], r[0])
            resizearg = '{}x{}'.format(r[1], r[0])
        imgdir = os.path.join(basedir, name)
        if os.path.exists(imgdir):
            continue

        print('Minifying', r, basedir)

        os.makedirs(imgdir)
        check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)

        ext = imgs[0].split('.')[-1]
        args = ' '.join(['magick mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
        print(args)
        os.chdir(imgdir)
        os.system(args)
        # check_output(args, shell=True)
        os.chdir(wd)

        if ext != 'png':
            check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
            print('Removed duplicates')
        print('Done')


def gen_poses(basedir, match_type, factors=None):
    print("Force run COLMAP!")
    run_colmap(basedir, match_type)  # comment this line if you have run the colmap software offline

    # files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']]
    # if os.path.exists(os.path.join(basedir, 'sparse/0')):
    #     files_had = os.listdir(os.path.join(basedir, 'sparse/0'))
    # else:
    #     files_had = []
    # if not all([f in files_had for f in files_needed]):
    #     print( 'Need to run COLMAP' )
    #     run_colmap(basedir, match_type)
    # else:
    #     print('Don\'t need to run COLMAP')

    print( 'Post-colmap')

    poses, pts3d, perm, names = load_colmap_data(basedir)

    densedir = os.path.join(basedir, 'dense')

    save_poses(densedir, poses, pts3d, perm, names)

    if factors is not None:
        print( 'Factors:', factors)
        minify(densedir, factors)

    print( 'Done with imgs2poses' )

    return True

