#!/usr/bin/python3

from BitVector import BitVector
import string
import sys
import pprint
import logging
import itertools
logging.basicConfig(level=logging.DEBUG)

BIT = 1
BYTE = 8 * BIT

def encrypt_xor(input_bv, key_bv):
    input_len = len(input_bv)
    key_len = len(key_bv)
    encrypted_bv = BitVector(size=input_len)
    
    for i, _ in enumerate(input_bv):
        # '+' here is concat, not sum
        encrypted_bv[i] = input_bv[i] ^ key_bv[i%key_len]
    return encrypted_bv

if __name__ == "__main__":
    
    # ./repeating-key-xor clear.txt key.txt cipher.txt 

    cleartext_file_path = sys.argv[1]
    key_file_path = sys.argv[2]
    ciphertext_file_path = sys.argv[3]
    
    key_bv = None
    with open(key_file_path, "rt") as key_file:
        key_bv = BitVector(textstring=key_file.read())
        logging.info(f"Key: {key_bv.get_bitvector_in_ascii()}")

    with open(cleartext_file_path, "rb") as input_file, open(ciphertext_file_path, "wt") as output_file:
        for i, input_b in enumerate(input_file.read()):
            key_index = (i%int(len(key_bv)/BYTE)) * BYTE
            # print(f"taking bits from {key_index} to {key_index+BYTE} of {len(key_bv)}")
            ciphertext = BitVector(intVal=input_b, size=BYTE) ^ BitVector(bitlist=key_bv[key_index:key_index+BYTE])
            output_file.write(ciphertext.get_bitvector_in_hex())
