#!/usr/bin/python3

import monkeyhex
import sys
import pprint
import logging
import itertools
from tqdm import tqdm
import base64

import cryp_lib

logging.basicConfig(level=logging.INFO)

keysize_limits = [1,500]
DEEP = 3

if __name__ == "__main__":
    
    # ./decrypt 6.txt out.txt

    input_ciphertext = b''
    with open(sys.argv[1], "rb") as input_file:
        # decode base64
        input_ciphertext = base64.decodebytes(input_file.read())
    logging.info(f"Read {len(input_ciphertext)} bytes from file")
    
    # find probable key size by calculating hamming distance
    keysizes = cryp_lib.get_keysizes(keysize_limits[0], keysize_limits[1], input_ciphertext)[0:DEEP]
    logging.info(f"Top {DEEP} Probable keysizes (from {keysize_limits[0]} to {keysize_limits[1]}) are:")
    pprint.pprint(keysizes)

    scored = []
    for i, keysize in enumerate(map(lambda x: x['size'], keysizes)):
        logging.info(f"Testing keysize {keysize} ({i+1}/{len(keysizes)})")

        key_matrix = []

        # transpose into blocks corresponding to each key byte
        blocks = cryp_lib.transpose(input_ciphertext, keysize)
        for i, block in enumerate(blocks):
            # buteforce each block like a one-byte xor
            logging.info(f"Bruteforcing block {i+1}/{len(blocks)} (size={len(block)})")
            key_matrix.append(cryp_lib.bruteforce_xor(block)[:DEEP])
        
        key_candidates = []
        for i in range(DEEP):
            key =  b''
            for candidates in key_matrix:
                key += bytes([candidates[i]['key']])
            key_candidates.append(key)
                
        for key in key_candidates:
            logging.info(f"Testing probable key {key}")
            out, _ = cryp_lib.xor_cipher(key, input_ciphertext)
            scored.append({'key': key, 'score': cryp_lib.english_score(out)})

    scored = sorted(scored, key=lambda x: x['score'], reverse=True)[0:DEEP]
    logging.info(f"Top {DEEP} scored keys across {DEEP} keysizes")
    pprint.pprint(scored)

    cleartext, key = cryp_lib.xor_cipher(scored[0].get('key'), input_ciphertext)
    logging.info(f'''
        
        Key used :      {key}
        Key Length :    {len(key)}
        Result (Binary) [Length={len(cleartext)}]:

        {cleartext}
        

        Result (String) [Length={len(cleartext)}]:
                        
        {cleartext.decode()}

    ''')
                
