#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2018-present ScyllaDB
#

#
# SPDX-License-Identifier: AGPL-3.0-or-later

import os
import sys
import argparse
import re
import distro
import shutil

from scylla_util import *
from subprocess import run
from pathlib import Path

def get_chrony_unit():
    try:
        unit = systemd_unit('chronyd')
    except:
        unit = systemd_unit('chrony')
    return unit

def get_chrony_conf():
    p = Path('/etc/chrony/chrony.conf')
    if p.exists():
        return p
    p = Path('/etc/chrony.conf')
    if p.exists():
        return p
    raise FileNotFoundError('chrony.conf not found')

def get_ntp_unit():
    try:
        unit = systemd_unit('ntpd')
    except:
        unit = systemd_unit('ntp')
    return unit

if __name__ == '__main__':
    if os.getuid() > 0:
        print('Requires root permission.')
        sys.exit(1)

    parser = argparse.ArgumentParser(description='Optimize NTP client setting for Scylla.')
    parser.add_argument('--subdomain',
                        help='specify subdomain of pool.ntp.org (ex: centos, fedora or amazon)')
    args = parser.parse_args()

    target = None
    if os.path.exists('/lib/systemd/systemd-timesyncd'):
        if systemd_unit('systemd-timesyncd').is_active() == 'active':
            print('ntp is already configured, skip setup')
            sys.exit(0)
        target = 'systemd-timesyncd'
    if shutil.which('chronyd'):
        if get_chrony_unit().is_active() == 'active':
            print('ntp is already configured, skip setup')
            sys.exit(0)
        if not target:
            target = 'chrony'
    if shutil.which('ntpd'):
        if get_ntp_unit().is_active() == 'active':
            print('ntp is already configured, skip setup')
            sys.exit(0)
        if not target:
            target = 'ntp'
    if not target:
        if is_redhat_variant():
            target = 'chrony'
        else:
            target = 'systemd-timesyncd'

    print(f'Use {target} to configure ntp')
    if target == 'systemd-timesyncd':
        if not os.path.exists('/lib/systemd/systemd-timesyncd'):
            # On some distribution, systemd-timesyncd is part of systemd package,
            # but we can ignore it since systemd must be installed by default.
            pkg_install('systemd-timesyncd')
        run('timedatectl set-ntp true', shell=True, check=True)
    elif target == 'chrony':
        if not shutil.which('chronyd'):
            pkg_install('chrony')
        confp = get_chrony_conf()
        with confp.open() as f:
            conf = f.read()
        if args.subdomain:
            conf2 = re.sub(r'pool\s+([0-9]+)\.(\S+)\.pool\.ntp\.org', 'pool \\1.{}.pool.ntp.org'.format(args.subdomain), conf, flags=re.MULTILINE)
            with confp.open(confpath, 'w') as f:
                f.write(conf2)
            conf = conf2
        chronyd = get_chrony_unit()
        chronyd.enable()
        chronyd.restart()
        run('chronyc -a makestep', shell=True, check=True)
    elif target == 'ntp':
        if not shutil.which('ntpd'):
            pkg_install('ntp')
        if not shutil.which('ntpdate'):
            pkg_install('ntpdate')
        confp = Path('/etc/ntp.conf')
        with confp.open() as f:
            conf = f.read()
        if args.subdomain:
            conf2 = re.sub(r'(server|pool)\s+([0-9]+)\.(\S+)\.pool\.ntp\.org', '\\1 \\2.{}.pool.ntp.org'.format(args.subdomain), conf, flags=re.MULTILINE)
            with confp.open(confpath, 'w') as f:
                f.write(conf2)
            conf = conf2
        match = re.search(r'^(server|pool)\s+(\S*)(\s+\S+)?', conf, flags=re.MULTILINE)
        server = match.group(2)
        ntpd = get_ntpd_unit()
        ntpd.stop()
        run(f'ntpdate {server}', shell=True, check=True)
        ntpd.enable()
        ntpd.start()
