#!/usr/bin/python3

import BitVector
import string
import sys
import pprint
import logging
logging.basicConfig(level=logging.INFO)

BIT = 1
BYTE = 8 * BIT

def bruteforce_xor(encrypted_bv, key_size, only_printable=False, include_null=False):

    if len(encrypted_bv) < BYTE:
        raise Exception(f'bruteforce_xor received an encrypted_bv of size {len(encrypted_bv)} bits ({len(encrypted_bv)/BYTE} bytes) with content {encrypted_bv}')

    # encrypted_bv should be a bitvector
    
    # try every key in the one-byte space (from 0x00 to 0xff)
    
    # optimization: the key 0x00 is useless
    # We'll start from key 0x01 unless specified otherwise
    key_start = 0x00 if include_null else 0x01
    key_end = 0x10 ** (key_size*2)
    
    for key in range(key_start, key_end):
        xored_msg = ''
        for i in range(0, len(encrypted_bv), BYTE):
            xored_char = chr(int(encrypted_bv[i:i+BYTE]) ^ key)
            # optimization: only printable strings should be taken as results
            if only_printable and xored_char.isprintable() is not True:
                logging.debug("Discarted by printable filter")
                break
            else:
                xored_msg += xored_char
        yield xored_msg, key
    
def words_in_message(msg):
    return len(msg.split())

def get_letter_freq_deviation(message):
    # from wikipedia
    expected_frequency = { 
        'a': 8.167, 'b': 1.492, 'c': 2.782, 'd': 4.253, 'e': 2.702, 
        'f': 2.228, 'g': 2.015, 'h': 6.094, 'i': 6.966, 'j': 0.153,
        'k': 0.772, 'l': 4.025, 'm': 2.406, 'n': 6.749, 'o': 7.507, 
        'p': 1.929, 'q': 0.095, 'r': 5.987, 's': 6.327, 't': 9.056, 
        'u': 2.758, 'v': 0.978, 'w': 2.360, 'x': 0.150, 'y': 1.974, 
        'z': 0.074,
    }
    message_len = len(message)
    deviation = 0
    for letter in expected_frequency:
        observed_freq = message.count(letter)/message_len
        deviation += abs(observed_freq - expected_frequency[letter])
    return deviation

    # print(f"Key:{hex(key)} deviation:{tmp_msg_deviation}")

PROBABLE_MESSAGES_Q = 3

if __name__ == "__main__":
    
    # ./detect-single-byte-xor 4.txt
    # should return
    # Now that the party is jumping

    file_path = sys.argv[1]
    
    probable_messages = []

    with open(file_path, "rt") as input_file:
        for linue_index, input_line in enumerate(input_file.readlines()):
            
            print(f"Analysing line {linue_index}")
            print(f"Original content {input_line}")

            line_bv = BitVector.BitVector(hexstring=input_line.strip()) 

            for xored_msg, key in bruteforce_xor(line_bv, key_size=1, only_printable=True):
                # optimization: count words per message. wrong decryption should result in too many too few words
                if words_in_message(xored_msg) < 3:
                    logging.debug("Discarted by word count filter")
                    continue
                
                xored_msg_deviation = get_letter_freq_deviation(xored_msg)
                if  len(probable_messages) < PROBABLE_MESSAGES_Q or \
                    xored_msg_deviation < min(probable_messages, key=lambda message: message['deviation'])['deviation']:
                    probable_messages.append({'cyphertext': input_line, 'line': linue_index, 'msg': xored_msg, 'key': key, 'deviation': xored_msg_deviation})
                else:
                    logging.debug("Discarted by deviation filter")
                    continue
    
            probable_messages = sorted(probable_messages, key=lambda analysis: analysis['deviation'])[0:PROBABLE_MESSAGES_Q]
            print("="*50)


        pprint.pprint(probable_messages)


