#!/usr/bin/env python

import json
import logging
import multiprocessing
import os
import sys

import boto3
import click

aws_access_key_id = os.environ.get("AWS_ACCESS_KEY", None)
aws_secret_access_key = os.environ.get("AWS_ACCESS_SECRET", None)
REGIONS = [
    "us-west-1",
]

def highlight(x):
    if not isinstance(x, str):
        x = json.dumps(x, sort_keys=True, indent=2)
    click.secho(x, fg='green')

DEBUG_LOGGING_MAP = {
    0: logging.CRITICAL,
    1: logging.WARNING,
    2: logging.INFO,
    3: logging.DEBUG
}


@click.group()
@click.option('--verbose', '-v',
              help="Sets the debug noise level, specify multiple times "
                   "for more verbosity.",
              type=click.IntRange(0, 3, clamp=True),
              count=True)
@click.pass_context
def cli(ctx, verbose):
    logger_handler = logging.StreamHandler(sys.stderr)
    logger_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logging.getLogger().addHandler(logger_handler)
    logging.getLogger().setLevel(DEBUG_LOGGING_MAP.get(verbose, logging.DEBUG))


def get_clients():
    regions = REGIONS
    clients = []
    for region in regions:
        client = boto3.client(
            "ec2",
            region_name=region,
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
        )
        client.region = region
        clients.append(client)
    scaling_clients = []
    for region in regions:
        client = boto3.client(
            "autoscaling",
            region_name=region,
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
        )
        client.region = region
        scaling_clients.append(client)
    return zip(clients, scaling_clients)


def _collect_instances(region):
    client = boto3.client(
        "ec2",
        region_name=region,
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
    )
    print("Collecting instances in region", region)
    instances = [x['Instances'][0] for x in client.describe_instances(
        Filters=[
            {
                'Name': 'tag:es_dist_role',
                'Values': [
                    "master"
                ]
            },
            {
                'Name': 'instance-state-name',
                'Values': [
                    'running'
                ]
            },
        ]
    )['Reservations']]
    for instance in instances:
        instance['Region'] = region
    return instances

def _collect_scaling_groups(region):
    client = boto3.client(
        "autoscaling",
        region_name=region,
        aws_access_key_id=aws_access_key_id,
        aws_secret_access_key=aws_secret_access_key,
    )
    print("Collecting scaling groups in region", region)
    resp = client.describe_auto_scaling_groups()
    assert "NextToken" not in resp, "did not program to handle pagination"
    groups = list(filter(
        lambda grp_dict: any(tag["Key"] == "es_dist_role" for tag in grp_dict["Tags"]),
        resp["AutoScalingGroups"]
    ))
    for group in groups:
        group["Region"] = region
    return groups


def get_all_instances():
    # with multiprocessing.Pool(10) as pool:
    #     all_instances = sum(pool.map(_collect_instances, REGIONS), [])
    all_instances = sum(map(_collect_instances, REGIONS), [])

    return all_instances

def get_all_scaling_groups():
    # with multiprocessing.Pool(10) as pool:
    #     all_instances = sum(pool.map(_collect_scaling_groups, REGIONS), [])
    all_instances = sum(map(_collect_scaling_groups, REGIONS), [])

    return all_instances

def get_all_x(name):
    return eval("get_all_%s" % name)()

def get_tag(key, stuff):
    if 'Tags' in stuff:
        try:
            tags = stuff['Tags']
            name_tag = [t for t in tags if t['Key'] == key][0]
            return name_tag['Value']
        except IndexError:
            return None
    return None

def get_name_tag(instance):
    return get_tag("Name", instance)


def get_exp_prefix_tag(instance):
    return get_tag("exp_prefix", instance)

def get_exp_name_tag(instance):
    return get_tag("exp_name", instance)

@cli.command()
def jobs():
    jobs = []
    with multiprocessing.Pool(2) as pool:
        master_instances, groups = pool.map(get_all_x, ["instances", "scaling_groups"])
    group_map = dict(
        (get_exp_name_tag(group), group) for group in groups
    )
    for instance in master_instances:
        name = get_exp_name_tag(instance)
        group = group_map[name]
        jobs.append(
            "{} (#workers: {}/{})".format(
                name, len(group["Instances"]),
                group["DesiredCapacity"],
            )
        )
    for job in sorted(jobs):
        click.secho(job, fg='green')


def get_instances_by_pattern(job):
    for instance in get_all_instances():
        name = get_name_tag(instance)
        if job in name:
            yield instance

def get_groups_by_pattern(job):
    for instance in get_all_scaling_groups():
        name = get_name_tag(instance)
        if job in name:
            yield instance

@cli.command()
@click.argument('pattern')
def ssh(pattern):
    for instance in get_instances_by_pattern(pattern):
        name = get_name_tag(instance)
        print("Connecting to %s" % name)
        ip_addr = instance['PublicIpAddress']
        exp_prefix = get_exp_prefix_tag(instance)

        command = " ".join([
            "ssh",
            "-oStrictHostKeyChecking=no",
            "-oConnectTimeout=10",
            "-t",
            "ubuntu@" + ip_addr,
        ])
        print(command)
        os.system(command)
        return
    print("Not found!")

@cli.command()
@click.argument('pattern')
def tail(pattern):
    for instance in get_instances_by_pattern(pattern):
        name = get_name_tag(instance)
        print("Connecting to %s" % name)
        ip_addr = instance['PublicIpAddress']
        exp_prefix = get_exp_prefix_tag(instance)
        command = " ".join([
            "ssh",
            "-oStrictHostKeyChecking=no",
            "-oConnectTimeout=10",
            "-t",
            "ubuntu@" + ip_addr,
            "'tail -f -n 2000 user_data.log && exec bash -l'"
        ])
        print(command)
        os.system(command)
        return

@cli.command()
@click.argument('pattern')
@click.argument('size')
def resize(pattern, size):
    size = int(size)
    groups_to_resize = list(get_groups_by_pattern(pattern))
    if not groups_to_resize:
        print("No match found")
        return

    names = sorted(map(get_exp_name_tag, groups_to_resize))
    print("This will resize the following jobs to {}:".format(size))
    click.secho("\n".join(names), fg="blue")
    click.confirm('Continue?', abort=True)

    for name, group in zip(names, groups_to_resize):
        print("Resizing %s to %s" % (name, size))
        client = boto3.client(
            "autoscaling",
            region_name=group["Region"],
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
        )
        client.update_auto_scaling_group(
            AutoScalingGroupName=group["AutoScalingGroupName"],
            MinSize=size,
            MaxSize=size,
            DesiredCapacity=size,
        )


@cli.command()
@click.argument('pattern')
def kill(pattern):
    to_kill = []
    to_kill_instances = {}
    for instance in get_instances_by_pattern(pattern):
        region = instance['Region']
        name = get_exp_name_tag(instance)
        if region not in to_kill_instances:
            to_kill_instances[region] = []
        to_kill_instances[region].append(instance)
        to_kill.append(name)

    print("This will kill the following jobs:")
    click.secho("\n".join(sorted(to_kill)), fg="red")
    click.confirm('Continue?', abort=True)

    for ec2_client, scaling_client in get_clients():
        print("Terminating instances in region", ec2_client.region)
        instances = to_kill_instances.get(ec2_client.region, [])
        if len(instances) > 0:
            for instance in instances:
                exp_name = get_exp_name_tag(instance)
                print("Cleaning up cleaning group and config for %s" % exp_name)
                _ = scaling_client.delete_auto_scaling_group(
                    AutoScalingGroupName=exp_name,
                    ForceDelete=True,
                )
                _ = scaling_client.delete_launch_configuration(
                    LaunchConfigurationName=exp_name
                )
            ec2_client.terminate_instances(
                InstanceIds=[instance["InstanceId"] for instance in instances]
            )


@cli.command()
@click.argument('pattern')
@click.argument('dest')
@click.option('--dry-run', is_flag=True)
def sync(pattern, dest, dry_run):
    for instance in get_instances_by_pattern(pattern):
        name = get_name_tag(instance)
        highlight(name)
        cmd = (
            "rsync {dry_run} -zav --progress "
            "-e 'ssh -oStrictHostKeyChecking=no -oConnectTimeout=10' "
            "ubuntu@{ip}:~ {dest}"
        ).format(
            dry_run='--dry-run' if dry_run else '',
            ip=instance['PublicIpAddress'],
            dest=os.path.join(dest, get_exp_name_tag(instance))
        )
        highlight(cmd)
        os.system(cmd)


if __name__ == '__main__':
    cli()
