#!/usr/bin/python3

import BitVector
import string
import sys

BIT = 1
BYTE = 8 * BIT

def bruteforce_xor(encrypted_msg, key_size, only_printable=False, include_null=False):
    # encrypted_msg should be a bitvector
    
    # try every key in the one-byte space (from 0x00 to 0xff)
    
    # optimization: the key 0x00 is useless
    # We'll start from key 0x01 unless specified otherwise
    key_start = 0x00 if include_null else 0x01
    key_end = 0x10 ** (key_size*2)
    
    for key in range(key_start, key_end):
        xored_msg = ''
        for i in range(0, len(encrypted_msg), BYTE):
            xored_char = chr(int(encrypted_msg[i:i+BYTE]) ^ key)
            # optimization: only printable strings should be taken as results
            if only_printable and xored_char.isprintable() is not True:
                break
            else:
                xored_msg += xored_char
        yield xored_msg, key
    
def words_in_message(msg):
    return len(msg.split())

def get_letter_freq_deviation(message):
    # from wikipedia
    expected_frequency = { 
        'a': 8.167, 'b': 1.492, 'c': 2.782, 'd': 4.253, 'e': 2.702, 
        'f': 2.228, 'g': 2.015, 'h': 6.094, 'i': 6.966, 'j': 0.153,
        'k': 0.772, 'l': 4.025, 'm': 2.406, 'n': 6.749, 'o': 7.507, 
        'p': 1.929, 'q': 0.095, 'r': 5.987, 's': 6.327, 't': 9.056, 
        'u': 2.758, 'v': 0.978, 'w': 2.360, 'x': 0.150, 'y': 1.974, 
        'z': 0.074,
    }
    message_len = len(message)
    deviation = 0
    for letter in expected_frequency:
        observed_freq = message.count(letter)/message_len
        deviation += abs(observed_freq - expected_frequency[letter])
    return deviation

    # print(f"Key:{hex(key)} deviation:{tmp_msg_deviation}")

PROBABLE_MESSAGES_Q = 3

if __name__ == "__main__":
    
    # ./single-byte-xor 1b37373331363f78151b7f2b783431333d78397828372d363c78373e783a393b3736
    # should return
    # Cooking MC's like a pound of bacon

    input_str = sys.argv[1]
    input_bv = BitVector.BitVector(hexstring=input_str)

    probable_messages = []

    for xored_msg, key in bruteforce_xor(input_bv, key_size=1, only_printable=True):
        # optimization: count words per message. wrong decryption should result in too many too few words
        if words_in_message(xored_msg) < 3:
            continue
        
        xored_msg_deviation = get_letter_freq_deviation(xored_msg)
        if  len(probable_messages) < PROBABLE_MESSAGES_Q or \
            xored_msg_deviation < min(probable_messages, key=lambda message: message['deviation']):
            probable_messages.append({'msg': xored_msg, 'key': key, 'deviation': xored_msg_deviation})
        else:
            continue
    
    for message in sorted(probable_messages, key=lambda analysis: analysis['deviation']):
        print(message)


