#!/usr/bin/env python3
import os
import socket
import sys
import getpass
import logging
import argparse

import six

from tacacs_plus.client import TACACSClient
from tacacs_plus.flags import (
    TAC_PLUS_AUTHEN_TYPES, TAC_PLUS_PRIV_LVL_MIN, TAC_PLUS_VIRTUAL_REM_ADDR,
    TAC_PLUS_VIRTUAL_PORT, TAC_PLUS_ACCT_FLAGS, TAC_PLUS_AUTHEN_TYPE_CHAP
)

log = logging.getLogger(__name__)


def parse_args():
    tacacs_key = os.environ.get('TACACS_PLUS_KEY')
    user_password = os.environ.get('TACACS_PLUS_PWD')
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        description="""
        Tacacs+ client with full AAA support:

            * Authentication supports ascii, pap and chap
            * Authorization supports AV pairs and single commands.
            * Accounting support AV pairs and single commands.

        NOTE: shared encryption key can be set via environment variable TACACS_PLUS_KEY or via argument.
        NOTE: user password can be setup via environment variable TACACS_PLUS_PWD or via argument.
        """
    )
    parser.add_argument('-u', '--username', required=True, help="user name")
    parser.add_argument('-H', '--host', required=True, help="tacacs+ server address")
    parser.add_argument('-p', '--port', type=int, default=49, help="tacacs+ server port (default 49)")
    parser.add_argument('-l', '--priv-lvl', type=int, default=TAC_PLUS_PRIV_LVL_MIN, help="user privilege level")
    parser.add_argument('-t', '--authen-type', choices=TAC_PLUS_AUTHEN_TYPES, default='ascii',
                        help="authentication type")
    parser.add_argument('-r', '--rem-addr', default=TAC_PLUS_VIRTUAL_REM_ADDR,
                        help="remote address (logged by tacacs server)")
    parser.add_argument('-P', '--virtual-port', default=TAC_PLUS_VIRTUAL_PORT,
                        help="console port used in connection (logged by tacacs server)")
    parser.add_argument('--timeout', type=int, default=10)
    parser.add_argument('-d', '--debug', action='store_true', help="enable debugging output")
    parser.add_argument('-v', '--verbose', action='store_true', help="print responses")
    parser.add_argument('-6', '--v6', action='store_true', help="use IPv6 addresses")
    parser.add_argument('-s', '--session-id', type=int, required=False, help="set session_id")
    parser.add_argument('-e', '--return-0-if-failed', action='store_true', help="If failed, don't change the exit code.")

    if not tacacs_key:
        parser.add_argument('-k', '--key', required=False, help="tacacs+ shared encryption key")

    command = parser.add_subparsers(dest='action', help="action to perform over the tacacs+ server")
    command.required = True

    # authentication parser
    authentication = command.add_parser('authenticate', help="authenticate against a tacacs+ server")
    if not user_password:
        authentication.add_argument('-p', '--password', required=False, help="user password")
    authentication.add_argument('-i', '--chap-id', required=False, help="CHAP Identity")
    authentication.add_argument('-c', '--chap-challenge', required=False, help="CHAP Challenge")

    authorization = command.add_parser('authorize', help="authorize a command against a tacacs+ server")
    authorization.add_argument('-c', '--cmds', required=True, nargs='+', help="list of cmds to authorize")

    accounting = command.add_parser('account',
                                    help="account commands with accounting flags against a tacacs+ server")
    accounting.add_argument('-c', '--cmds', required=True, nargs='+', help="list of cmds to authorize")
    accounting.add_argument('-f', '--flag', required=True, choices=TAC_PLUS_ACCT_FLAGS,
                            help="accounting flag")

    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
        log.warning("\033[93mTACACS+ --debug will log raw packet data INCLUDING PASSWORDS;"
                 "proceed at your own risk!\033[00m")

    if tacacs_key:
        args.key = tacacs_key

    if user_password:
        args.password = user_password

    args.authen_type = TAC_PLUS_AUTHEN_TYPES[args.authen_type]

    return args


def strings_to_bytes(strings):
    return [x.encode('utf-8') for x in strings]


def verbose_response(response):
    print("status: %s" % response.human_status)
    if response.data:
        print("data: %s" % response.data)
    if response.server_msg:
        print("server_msg: %s" % response.server_msg)
    if response.flags:
        print("flags: %s" % response.flags)
    if response.arguments:
        print("av-pairs:")
        for arg in response.arguments:
            print("  %s" % arg.decode('utf-8'))


def authenticate(cli, args):
    if not vars(args).get('password'):
        args.password = getpass.getpass('password for %s: ' % args.username)

    chap_ppp_id = args.chap_id
    chap_challenge = args.chap_challenge
    if args.authen_type == TAC_PLUS_AUTHEN_TYPE_CHAP:
        if chap_ppp_id == None:
            chap_ppp_id = six.moves.input('chap PPP ID: ')
        if chap_challenge == None:
            chap_challenge = six.moves.input('chap challenge: ')
    auth = cli.authenticate(args.username, args.password, priv_lvl=args.priv_lvl,
                            authen_type=args.authen_type, chap_ppp_id=chap_ppp_id,
                            chap_challenge=chap_challenge, rem_addr=args.rem_addr,
                            port=args.virtual_port)
    if args.verbose:
        verbose_response(auth)

    return 0 if auth.valid else 1


def authorize(cli, args):
    cmds = strings_to_bytes(args.cmds)
    auth = cli.authorize(args.username, arguments=cmds, priv_lvl=args.priv_lvl,
                         authen_type=args.authen_type, rem_addr=args.rem_addr, port=args.virtual_port)

    if args.verbose:
        verbose_response(auth)

    return 0 if auth.valid else 1


def account(cli, args):
    cmds = strings_to_bytes(args.cmds)
    flag = TAC_PLUS_ACCT_FLAGS[args.flag]
    auth = cli.account(args.username, flags=flag, arguments=cmds, priv_lvl=args.priv_lvl,
                       authen_type=args.authen_type, rem_addr=args.rem_addr, port=args.virtual_port)

    if args.verbose:
        verbose_response(auth)

    return 0 if auth.valid else 1


def get_cli(args):
    if not args.key:
        args.key = getpass.getpass("tacacs+ shared key: ")
    family = socket.AF_INET
    if args.v6:
        family = family = socket.AF_INET6
    return TACACSClient(args.host, args.port, args.key, timeout=args.timeout, session_id=args.session_id, family=family)


def main():
    args = parse_args()
    try:
        cli = get_cli(args)
        ret = 0

        if args.action == 'authenticate':
            ret = authenticate(cli, args)
        elif args.action == 'authorize':
            ret = authorize(cli, args)
        elif args.action == 'account':
            ret = account(cli, args)

        if args.return_0_if_failed:
            ret = 0

        sys.exit(ret)
    except Exception as e:
        log.exception("error: %s", e)
        sys.exit(1)


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        log.info("quitting...")
        sys.exit(1)
    else:
        sys.exit(0)
