#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2018-present ScyllaDB
#

#
# SPDX-License-Identifier: AGPL-3.0-or-later

import os
import argparse
import distutils.util
import pwd
import grp
import sys
import stat
import logging
import pyudev
import psutil
from pathlib import Path
from scylla_util import *
from subprocess import run, SubprocessError

LOGGER = logging.getLogger(__name__)

class UdevInfo:
    def __init__(self, device_file):
        self.context = pyudev.Context()
        self.device = pyudev.Devices.from_device_file(self.context, device_file)

    def verify(self):
        if not self.id_fs_uuid:
            LOGGER.error('ID_FS_UUID does not found')
        if self.id_fs_type != 'xfs':
            LOGGER.error('ID_FS_TYPE is not "xfs"')
        if self.id_fs_usage != 'filesystem':
            LOGGER.error('ID_FS_USAGE is not "filesystem"')

    def dump_variables(self):
        LOGGER.error(f'    sys_path: {self.device.sys_path}')
        LOGGER.error(f'    sys_name: {self.device.sys_name}')
        LOGGER.error(f'    sys_number: {self.device.sys_number}')
        LOGGER.error(f'    device_path: {self.device.device_path}')
        LOGGER.error(f'    tags: {list(self.device.tags)}')
        LOGGER.error(f'    subsystem: {self.device.subsystem}')
        LOGGER.error(f'    driver: {self.device.driver}')
        LOGGER.error(f'    device_type: {self.device.device_type}')
        LOGGER.error(f'    device_node: {self.device.device_node}')
        LOGGER.error(f'    device_number: {self.device.device_number}')
        LOGGER.error(f'    device_links: {list(self.device.device_links)}')
        LOGGER.error(f'    is_initialized: {self.device.is_initialized}')
        LOGGER.error(f'    time_since_initialized: {self.device.time_since_initialized}')
        for k, v in self.device.properties.items():
            LOGGER.error(f'    {k}: {v}')

    @property
    def id_fs_uuid(self):
        return self.device.properties.get('ID_FS_UUID')

    @property
    def id_fs_type(self):
        return self.device.properties.get('ID_FS_TYPE')

    @property
    def id_fs_usage(self):
        return self.device.properties.get('ID_FS_USAGE')

    @property
    def uuid_link(self):
        for l in self.device.device_links:
            if l.startswith('/dev/disk/by-uuid/'):
                return l

    @property
    def label_link(self):
        for l in self.device.device_links:
            if l.startswith('/dev/disk/by-label/'):
                return l

    @property
    def partuuid_link(self):
        for l in self.device.device_links:
            if l.startswith('/dev/disk/by-partuuid/'):
                return l

    @property
    def path_link(self):
        for l in self.device.device_links:
            if l.startswith('/dev/disk/by-path/'):
                return l

    @property
    def id_links(self):
        return [l for l in self.device.device_links if l.startswith('/dev/disk/by-id')]


def is_selinux_enabled():
    partitions = psutil.disk_partitions(all=True)
    for p in partitions:
        if p.fstype == 'selinuxfs':
            if os.path.exists(p.mountpoint + '/enforce'):
                return True
    return False

if __name__ == '__main__':
    if os.getuid() > 0:
        print('Requires root permission.')
        sys.exit(1)
    parser = argparse.ArgumentParser(description='Configure RAID volume for Scylla.')
    parser.add_argument('--disks', required=True,
                        help='specify disks for RAID')
    parser.add_argument('--raiddev',
                        help='MD device name for RAID')
    parser.add_argument('--enable-on-nextboot', '--update-fstab', action='store_true', default=False,
                        help='mount RAID on next boot')
    parser.add_argument('--root', default='/var/lib/scylla',
                        help='specify the root of the tree')
    parser.add_argument('--volume-role', default='all',
                        help='specify how will this device be used (data, commitlog, or all)')
    parser.add_argument('--force-raid', action='store_true', default=False,
                        help='force constructing RAID when only one disk is specified')
    parser.add_argument('--raid-level', default='0',
                        help='specify RAID level')
    parser.add_argument('--online-discard', default="True",
                        help='Enable XFS online discard (trim SSD cells after file deletion)')

    args = parser.parse_args()

    # Allow args.online_discard to be used as a boolean value
    args.online_discard = distutils.util.strtobool(args.online_discard)

    root = args.root.rstrip('/')
    if args.volume_role == 'all':
        mount_at=root
    elif args.volume_role == 'data':
        mount_at='{}/data'.format(root)
    elif args.volume_role == 'commitlog':
        mount_at='{}/commitlog'.format(root)
    else:
        print('Invalid role specified ({})'.format(args.volume_role))
        parser.print_help()
        sys.exit(1)

    disks = args.disks.split(',')
    for disk in disks:
        if not os.path.exists(disk):
            print('{} is not found'.format(disk))
            sys.exit(1)
        if not stat.S_ISBLK(os.stat(disk).st_mode):
            print('{} is not block device'.format(disk))
            sys.exit(1)
        if not is_unused_disk(disk):
            print('{} is busy'.format(disk))
            sys.exit(1)

    if len(disks) == 1 and not args.force_raid:
        raid = False
        fsdev = disks[0]
    else:
        raid = True
        if args.raiddev is None:
            raiddevs_to_try = [f'/dev/md{i}' for i in range(10)]
        else:
            raiddevs_to_try = [args.raiddev, ]
        for fsdev in raiddevs_to_try:
            raiddevname = os.path.basename(fsdev)
            array_state = Path(f'/sys/block/{raiddevname}/md/array_state')
            # mdX is not allocated
            if not array_state.exists():
                break
            with array_state.open() as f:
                # allocated, but no devices, not running
                if f.read().strip() == 'clear':
                    break
            print(f'{fsdev} is already using')
        else:
            if args.raiddev is None:
                print("Can't find unused /dev/mdX")
            sys.exit(1)
        print(f'{fsdev} will be used to setup a RAID')

    if os.path.ismount(mount_at):
        print('{} is already mounted'.format(mount_at))
        sys.exit(1)

    mntunit_bn = out('systemd-escape -p --suffix=mount {}'.format(mount_at))
    mntunit = Path('/etc/systemd/system/{}'.format(mntunit_bn))
    if mntunit.exists():
        print('mount unit {} already exists'.format(mntunit))
        sys.exit(1)

    if not shutil.which('mkfs.xfs'):
        pkg_install('xfsprogs')
    if not shutil.which('mdadm'):
        pkg_install('mdadm')
    if args.raid_level != '0':
        try:
            md_service = systemd_unit('mdmonitor.service')
        except SystemdException:
            md_service = systemd_unit('mdadm.service')

    print('Creating {type} for scylla using {nr_disk} disk(s): {disks}'.format(type=f'RAID{args.raid_level}' if raid else 'XFS volume', nr_disk=len(disks), disks=args.disks))
    procs=[]
    for disk in disks:
        d = disk.replace('/dev/', '')
        discard_path = '/sys/block/{}/queue/discard_granularity'.format(d)
        if os.path.exists(discard_path):
            with open(discard_path) as f:
                discard = f.read().strip()
            if discard != '0':
                proc = subprocess.Popen(['blkdiscard', disk])
                procs.append(proc)
    for proc in procs:
        proc.wait()
    for disk in disks:
        run(f'wipefs -a {disk}', shell=True, check=True)
    if raid:
        run('udevadm settle', shell=True, check=True)
        run('mdadm --create --verbose --force --run {raid} --level={level} -c1024 --raid-devices={nr_disk} {disks}'.format(raid=fsdev, level=args.raid_level, nr_disk=len(disks), disks=args.disks.replace(',', ' ')), shell=True, check=True)
        run(f'wipefs -a {fsdev}', shell=True, check=True)
        run('udevadm settle', shell=True, check=True)

    major_minor = os.stat(fsdev).st_rdev
    major, minor = major_minor // 256, major_minor % 256
    sector_size = int(open(f'/sys/dev/block/{major}:{minor}/queue/logical_block_size').read())
    # We want smaller block sizes to allow smaller commitlog writes without
    # stalling. The minimum block size for crc enabled filesystems is 1024,
    # see https://git.kernel.org/pub/scm/fs/xfs/xfsprogs-dev.git/tree/mkfs/xfs_mkfs.c .
    # and it also cannot be smaller than the sector size.
    block_size = max(1024, sector_size)
    run('udevadm settle', shell=True, check=True)
    run(f'mkfs.xfs -b size={block_size} {fsdev} -K', shell=True, check=True)
    run('udevadm settle', shell=True, check=True)

    if is_debian_variant():
        confpath = '/etc/mdadm/mdadm.conf'
    else:
        confpath = '/etc/mdadm.conf'

    if raid:
        res = out('mdadm --detail --scan')
        with open(confpath, 'w') as f:
            f.write(res)
            f.write('\nMAILADDR root')

    os.makedirs(mount_at, exist_ok=True)

    udev_info = UdevInfo(fsdev)
    mount_dev = None
    if udev_info.uuid_link:
        mount_dev = udev_info.uuid_link
    else:
        if udev_info.label_link:
            mount_dev = udev_info.label_link
            dev_type = 'label'
        elif udev_info.partuuid_link:
            mount_dev = udev_info.partuuid_link
            dev_type = 'partuuid'
        elif udev_info.path_link:
            mount_dev = udev_info.path_link
            dev_type = 'path'
        elif udev_info.id_links:
            mount_dev = udev_info.id_links[0]
            dev_type = 'id'
        else:
            mount_dev = fsdev
            dev_type = 'realpath'
        LOGGER.error(f'Failed to detect uuid, using {dev_type}: {mount_dev}')

    after = ''
    wants = ''
    if raid and args.raid_level != '0':
        after = wants = 'md_service'
    opt_discard = ''
    if args.online_discard:
        opt_discard = ',discard'
    unit_data = f'''
[Unit]
Description=Scylla data directory
Before=local-fs.target scylla-server.service
After={after}
Wants={wants}
DefaultDependencies=no

[Mount]
What={mount_dev}
Where={mount_at}
Type=xfs
Options=noatime{opt_discard}

[Install]
WantedBy=local-fs.target
'''[1:-1]
    with open(f'/etc/systemd/system/{mntunit_bn}', 'w') as f:
        f.write(unit_data)
    mounts_conf = '/etc/systemd/system/scylla-server.service.d/mounts.conf'
    if not os.path.exists(mounts_conf):
        os.makedirs('/etc/systemd/system/scylla-server.service.d/', exist_ok=True)
        with open(mounts_conf, 'w') as f:
            f.write(f'[Unit]\nRequiresMountsFor={mount_at}\n')
    else:
        with open(mounts_conf, 'a') as f:
            f.write(f'RequiresMountsFor={mount_at}\n')

    systemd_unit.reload()
    if args.raid_level != '0':
        md_service.start()
    try:
        mount = systemd_unit(mntunit_bn)
        mount.start()
    except SubprocessError as e:
        if mount_dev != fsdev:
            if not os.path.islink(mount_dev):
                LOGGER.error('{mount_dev} is not found')
            if not os.path.exists(mount_dev):
                LOGGER.error('{mount_dev} is broken link')
        if not os.path.exists(fsdev):
            LOGGER.error('{fsdev} is not found')
        if not stat.S_ISBLK(os.stat(fsdev).st_mode):
            LOGGER.error('{fsdev} is not block device')
        LOGGER.error(f'Error detected, dumping udev env parameters on {fsdev}')
        udev_info.verify()
        udev_info.dump_variables()
        raise e

    if args.enable_on_nextboot:
        mount.enable()
    uid = pwd.getpwnam('scylla').pw_uid
    gid = grp.getgrnam('scylla').gr_gid
    os.chown(root, uid, gid)

    for d in ['coredump', 'data', 'commitlog', 'hints', 'view_hints', 'saved_caches']:
        dpath = '{}/{}'.format(root, d)
        os.makedirs(dpath, exist_ok=True)
        os.chown(dpath, uid, gid)

    if is_debian_variant():
        if not shutil.which('update-initramfs'):
            pkg_install('initramfs-tools')
        run('update-initramfs -u', shell=True, check=True)

    if not udev_info.uuid_link:
        LOGGER.error(f'Error detected, dumping udev env parameters on {fsdev}')
        udev_info.verify()
        udev_info.dump_variables()

    if is_redhat_variant():
        offline_skip_relabel = False
        has_semanage = True
        if not shutil.which('matchpathcon'):
            offline_skip_relabel = True
            pkg_install('libselinux-utils', offline_exit=False)
        if not shutil.which('restorecon'):
            offline_skip_relabel = True
            pkg_install('policycoreutils', offline_exit=False)
        if not shutil.which('semanage'):
            if is_offline():
                has_semanage = False
            else:
                pkg_install('policycoreutils-python-utils')
        if is_offline() and offline_skip_relabel:
            print('Unable to find SELinux tools, skip relabeling.')
            sys.exit(0)

        selinux_context = out('matchpathcon -n /var/lib/systemd/coredump')
        selinux_type = selinux_context.split(':')[2]
        if has_semanage:
            run(f'semanage fcontext -a -t {selinux_type} "{root}/coredump(/.*)?"', shell=True, check=True)
        else:
            # without semanage, we need to update file_contexts directly,
            # and compile it to binary format (.bin file)
            try:
                with open('/etc/selinux/targeted/contexts/files/file_contexts.local', 'a') as f:
                    spacer = ''
                    if f.tell() != 0:
                        spacer = '\n'
                    f.write(f'{spacer}{root}/coredump(/.*)?   {selinux_context}\n')
            except FileNotFoundError as e:
                print('Unable to find SELinux policy files, skip relabeling.')
                sys.exit(0)
            run('sefcontext_compile /etc/selinux/targeted/contexts/files/file_contexts.local', shell=True, check=True)
        if is_selinux_enabled():
            run(f'restorecon -F -v -R {root}', shell=True, check=True)
        else:
            Path('/.autorelabel').touch(exist_ok=True)
