#!/usr/bin/env python3
# SPDX-License-Identifier: WTFPL
# Tool to parse a KeepassX password database and print its content
# Warning: doesn't parse KeepassX2/KeepassXC format (.kdbx)

import struct
from struct import unpack
from getpass import getpass
import hashlib
import io
from argparse import ArgumentParser
from fnmatch import fnmatch

from Crypto.Cipher import AES

# parser lib

def from_cstring(cstr):
	if b'\x00' in cstr:
		cstr = cstr[:cstr.find(b'\x00')]
	return cstr.decode('latin-1')


def sha256(data):
	hasher = hashlib.new('sha256')
	hasher.update(data)
	return hasher.digest()


class BaseCipher(io.RawIOBase):
	def __init__(self, cipher, file):
		super(BaseCipher, self).__init__()
		self.cipher = cipher
		self.file = file
		self.blocksize = cipher.block_size

	def read(self, size):
		size -= size % self.blocksize
		block = self.cipher.decrypt(self.file.read(size))
		if not block:
			return block
		# padding PKCS-7, but we don't check we're at end of buffer
		pad_n = block[-1]
		if 0 < pad_n <= self.blocksize and block[-pad_n:] == (chr(pad_n) * pad_n):
			block = block[:-pad_n]
		return block
# 		return self.cipher.read(size - (size % 16))

	def readinto(self, b):
		dat = self.read(len(b))
		b[:len(dat)] = dat
		return len(dat)

	def readable(self):
		return True


class CipherFile(io.BufferedReader):
	def __init__(self, cipher, file):
		super(CipherFile, self).__init__(BaseCipher(cipher, file))


class CompositeKey:
	def __init__(self, transform_seed, rounds, keys):
		self.transform_seed = transform_seed
		self.rounds = rounds
		self.keys = keys

	def value(self):
		return self._transform(self._key(), self.transform_seed, self.rounds)

	def _key(self):
		return sha256(b''.join(self.keys))

	def _transform(self, key, seed, rounds):
		return sha256(
			self._transform_raw(key[:16], seed, rounds)
			+
			self._transform_raw(key[16:], seed, rounds))

	def _transform_raw(self, chunk, seed, rounds=1):
		cipher = AES.new(seed, AES.MODE_ECB)
		for i in range(rounds):
			chunk = cipher.encrypt(chunk)
		return chunk


class DbReader:
	# TODO handle writing
	# TODO handle keypass2
	def __init__(self, filename):
		self.filename = filename
		self.f = None

	def read_struct(self, fmt):
		return unpack(fmt, self.f.read(struct.calcsize(fmt)))

	def read_uint32(self):
		return self.read_struct('<L')[0]

	def read_int32(self):
		return self.read_struct('<l')[0]

	def read_uint16(self):
		return self.read_struct('<H')[0]

	def read(self, db_password):
		self.f = open(self.filename, 'rb')

		# TODO handle different versions, print an error message
		assert self.read_uint32() == 0x9aa2d903
		assert self.read_uint32() == 0xb54bfb65

		# TODO handle different algorithms
		#enc_algo = {2: 'rijndael', 8: 'twofish'}[unpack('<L', f.read(4))]
		enc_flags = self.read_uint32()

		assert self.read_uint32() == 0x00030002

		master_seed = self.f.read(16)
		iv = self.f.read(16)

		n_groups = self.read_uint32()
		n_entries = self.read_uint32()

		content_hash_header = self.f.read(32)

		transform_seed = self.f.read(32)
		transform_rounds = self.read_uint32()

		key = CompositeKey(transform_seed, transform_rounds, [db_password])
		cipher_key = sha256(master_seed + key.value())

		self.verify(self._build_cipher(cipher_key, iv), content_hash_header)

		self.f = self._build_cipher(cipher_key, iv)

		self.groups = []
		for i in range(n_groups):
			self.groups.append(self.read_group())

		self.entries = []
		for i in range(n_entries):
			entry = self.read_entry()
			if entry['title'] == 'Meta-Info' and entry.get('username') == 'SYSTEM':
				continue
			self.entries.append(entry)

	def _build_cipher(self, cipher_key, iv):
		cipher = AES.new(cipher_key, AES.MODE_CBC, iv)
		return CipherFile(cipher, self.f)

	def verify(self, cipher, expected_hash):
		content_pos = self.f.tell()
		hasher = hashlib.new('sha256')

		while True:
			buf = cipher.read(1024)
			if not buf:
				break
			hasher.update(buf)

		self.f.seek(content_pos)

	def password_to_key(self, password):
		return sha256(password)

	def read_group(self):
		while True:
			type, data = self.read_item()
			if type == 0:
				pass
			elif type == 9:
				pass
			elif type == 0xffff:
				break

	def read_entry(self):
		fields = {}
		while True:
			type, data = self.read_item()
			if type == 0:
				pass
			elif type == 1:
				fields['uuid'] = from_cstring(data)
			elif type == 2:
				fields['groupid'] = data
			elif type == 4:
				fields['title'] = from_cstring(data)
			elif type == 5:
				fields['url'] = from_cstring(data)
			elif type == 6:
				fields['username'] = from_cstring(data)
			elif type == 7:
				fields['password'] = from_cstring(data)
			elif type == 9:
				pass
			elif type == 0xffff:
				break
		return fields

	def read_item(self):
		type = self.read_uint16()
		size = self.read_int32()
		assert 0 <= size < 16 * 1024 * 1024
		data = self.f.read(size)
		return (type, data)


# main

def entries_matching(db, pattern):
	for entry in db.entries:
		if fnmatch(entry['title'], pattern):
			yield entry


def list_entries(db, pattern='*'):
	for entry in entries_matching(db, pattern):
		print('title: %s' % entry['title'])


def print_entries(db, pattern, show_password=False):
	for entry in entries_matching(db, pattern):
		print_entry(entry, show_password)


def print_nonempty(entry, name, label):
	if entry.get(name):
		print('%s: %s' % (label, entry.get(name)))


def print_entry(entry, show_password=False):
	print('title: %s' % entry['title'])
	print_nonempty(entry, 'url', 'url')
	print_nonempty(entry, 'username', 'username')
	if show_password:
		print('password: %s' % entry.get('password'))
	print()


def main():
	parser = ArgumentParser()
	parser.add_argument('title_pattern', default=None, nargs='?')
	parser.add_argument('-f', metavar='FILE', dest='file', default='passwords.kdb')
	parser.add_argument('--show-password', dest='show_password', default=False, action='store_true')
	options = parser.parse_args()

	db = DbReader(options.file)
	db_password = getpass().encode('utf-8')
	db.read(db_password)

	if options.title_pattern:
		print_entries(db, options.title_pattern, options.show_password)
	else:
		list_entries(db)

if __name__ == '__main__':
	main()
