#!/usr/bin/env python3
from argparse import ArgumentParser
from asyncio.subprocess import DEVNULL, PIPE
from asyncio import (get_event_loop, gather, BoundedSemaphore, create_subprocess_exec,
                     IncompleteReadError)
from enum import Enum, auto
from logging import getLogger, basicConfig, INFO
from os import cpu_count
from pathlib import Path
from sys import exit


logger = getLogger(__name__)


class Mode(Enum):
    identical = auto()
    nonconflicting = auto()


def parse_args(args=None):
    parser = ArgumentParser(description='Extract common files.')

    modes = parser.add_mutually_exclusive_group(required=True)
    for m in Mode:
        modes.add_argument(f'--{m.name}', dest='mode', action='store_const', const=m,
                           help=f'extract {m.name} common files')

    parser.add_argument('-n', '--dry-run', action='store_true', help='Dry run (do nothing)')
    parser.add_argument('common_dir', metavar='COMMON_DIR',
                        help='common files directory to move to')
    parser.add_argument('targets', nargs='+', metavar='INPUT_DIR', help='directory to move from')

    return parser.parse_args(args)


class bounded_exec:
    semaphore = BoundedSemaphore(value=cpu_count() + 1)

    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs

    async def __aenter__(self):
        await self.semaphore.acquire()
        p = self.process = await create_subprocess_exec(*self.args, **self.kwargs)
        return p

    async def __aexit__(self, exc_type, exc, tb):
        p = self.process
        if p.returncode is None:
            try:
                p.kill()
            except:
                pass
            await p.wait()
        self.semaphore.release()


class DiffError(RuntimeError):
    pass


async def diff(file1, file2):
    async with bounded_exec('diff', '-q', str(file1), str(file2),
                            stdin=DEVNULL, stdout=DEVNULL, stderr=DEVNULL) as p:
        ret = await p.wait()

    if ret != 0:
        raise DiffError()


async def identical(files):
    if len(files) < 2:
        return True
    first, *rest = files
    fut = gather(*[diff(first, f) for f in rest])
    try:
        await fut
    except DiffError as e:
        return False
    else:
        return True
    finally:
        fut.cancel()

def commonify_file(common_file, files):
    common_file.parent.mkdir(parents=True, exist_ok=True)
    first, *rest = files
    first.rename(common_file)
    for f in rest:
        f.unlink()


async def commonify_identical(common_file, files, *, dry_run=False):
    if await identical(files):
        if not dry_run:
            commonify_file(common_file, files)
        return 1
    else:
        logger.info('Divergent file: %s', common_file)
        return 0


async def find_files(path):
    async with bounded_exec('find', str(path), '-type', 'f', '-print0',
                            stdin=DEVNULL, stdout=PIPE, stderr=DEVNULL) as p:
        while True:
            try:
                line = await p.stdout.readuntil(b'\x00')
            except IncompleteReadError as e:
                line = e.partial
                if not line:
                    break
                raise
            yield Path(line[:-1].decode())


async def scan(target, files_targetfiles):
    n = 0
    async for f in find_files(target):
        files_targetfiles.setdefault(f.relative_to(target), []).append(f)
        n += 1
    logger.info('%s: Found %d files', target, n)


def arg_dir(s):
    p = Path(s)
    if not p.is_dir():
        raise ValueError(f'{s!r} is not a directory')
    return p


async def main(settings):
    ignore_len = settings.mode == Mode.nonconflicting
    common_dir = arg_dir(settings.common_dir)
    targets = [arg_dir(t) for t in settings.targets]

    files_targetfiles = {}
    await gather(*[scan(t, files_targetfiles) for t in targets])

    results = await gather(*[
        commonify_identical(common_dir / f, tf, dry_run=settings.dry_run)
        for f, tf in files_targetfiles.items()
        if ignore_len or len(tf) == len(targets)
    ])

    logger.info('%s: %d %s files',
                common_dir, sum(results), settings.mode.name)
    return 0


if __name__ == '__main__':
    basicConfig(level=INFO)
    settings = parse_args()
    loop = get_event_loop()
    exit(loop.run_until_complete(main(settings)))
