#!/usr/bin/env python3
# SPDX-License-Identifier: WTFPL

import sys
import os
import sqlite3
import json
import argparse
import re


def escape_name(name):
	return "'%s'" % name.replace("'", "''")


RE_VALIDATE = re.compile(r'^[a-zA-Z_]\w*$')
def validate_name(name):
	return bool(RE_VALIDATE.match(name))


ForbiddenKeyError = lambda: ValueError('Keys may only contain letters, digits and underscores')
UnsupportedTypeError = lambda: TypeError('Only dictionaries and lists are supported')


class Inserter(object):
	def __init__(self, dbfile, table, create=False, alter=False):
		self.db = sqlite3.connect(dbfile)
		self.table = table
		self.cursor = self.db.cursor()
		self.create = create
		self.alter = alter
		self.columns = []

	def insert_data(self, data):
		if isinstance(data, dict):
			if all(isinstance(val, list) for val in data.values()):
				for line in zip(*data.values()):
					line_dict = dict(zip(data.keys(), line))
					self._insert_dict(line_dict)
			else:
				self._insert_dict(data)
		elif isinstance(data, list):
			if not len(data):
				return
			elif isinstance(data[0], dict):
				for line in data:
					self._insert_dict(line)
			elif isinstance(data[0], list):
				for line in data:
					self._insert_list(line)
			else:
				self._insert_list(data)
		else:
			raise UnsupportedTypeError()

	def _insert_dict(self, d):
		if not all(validate_name(n) for n in d.keys()):
			raise ForbiddenKeyError()

		if self.create:
			self._create_table(d)
			self.create = False
		if self.alter:
			self._alter_table(d)

		marks = ','.join('?' * len(d))
		keys = list(d.keys())
		names = ','.join(keys)
		ordered_values = [d[k] for k in keys]
		self.cursor.execute('INSERT INTO %s(%s) VALUES(%s)' % (self.table, names, marks), ordered_values)

	def _insert_list(self, li):
		marks = ','.join('?' * len(li))
		self.cursor.execute('INSERT INTO %s VALUES(%s)' % (self.table, marks), li)

	def update_data(self, ids, data):
		if isinstance(data, dict):
			self._update_dict(ids, data)
		elif isinstance(data, list):
			if not ids:
				return
			elif isinstance(data[0], dict):
				for line in data:
					self._update_dict(ids, line)
			else:
				raise UnsupportedTypeError()
		else:
			raise UnsupportedTypeError()

	def _update_dict(self, ids, d):
		marks = ','.join('?' * len(d))

		if not all(validate_name(n) for n in d.keys()):
			raise ForbiddenKeyError()

		keys_match = ids
		keys_set = [k for k in d.keys() if k not in keys_match]
		ordered_values = [d[k] for k in keys_set + keys_match]

		req_match = ' and '.join('%s = ?' % name for name in keys_match)
		req_set = ', '.join('%s = ?' % name for name in keys_set)

		self.cursor.execute('UPDATE %s SET %s WHERE %s' % (self.table, req_set, req_match), ordered_values)

	def close(self):
		self.db.close()

	def _create_table(self, d):
		columns = ','.join('%s %s' % (k, self._column_type(d[k])) for k in d)
		self.cursor.execute('CREATE TABLE IF NOT EXISTS %s (%s)' % (self.table, columns))

	def _alter_table(self, d):
		if not self.columns:
			self._fetch_columns()

		for k in d:
			if k not in self.columns:
				self.cursor.execute('ALTER TABLE %s ADD COLUMN %s %s' % (self.table, k, self._column_type(d[k])))
				self.columns.append(k)

	def _fetch_columns(self):
		self.columns = [row[1] for row in self.cursor.execute('PRAGMA table_info(%s)' % self.table)]

	def _column_type(self, obj):
		return {str: 'TEXT', int: 'INTEGER', float: 'REAL', bool: 'BOOLEAN'}.get(type(obj), 'NONE')


def main():
	parser = argparse.ArgumentParser()
	parser.add_argument(
		"-f", "--file", dest="file", metavar="INPUT_JSON_FILE",
		help="Input JSON filename (or '-' for stdin)",
	)
	parser.add_argument(
		"-j", "--json", dest="json", metavar="INPUT_JSON_DATA",
		help="Input JSON data (JSON passed as argument, not a filename!)",
	)
	parser.add_argument('-d', '--database', dest='database', metavar='DATABASE_FILE')
	parser.add_argument('-t', '--table', dest='table', metavar='TABLE')
	parser.add_argument('-u', '--update', dest='update_ids', action='append', metavar='COLUMN_MATCH_FOR_UPDATE')
	parser.add_argument('--create', dest='create', action='store_true')
	parser.add_argument('--alter', dest='alter', action='store_true')

	args = parser.parse_args()
	if not args.database:
		parser.error('Missing -d DATABASE_FILE')
	elif not args.file and not args.json:
		parser.error('Missing -f INPUT_JSON_FILE or -j INPUT_JSON_DATA')
	elif not args.table:
		parser.error('Missing -t TABLE')

	inserter = Inserter(args.database, args.table, create=args.create, alter=args.alter)

	try:
		if args.json and args.file:
			parser.error('-f and -j are mutually exclusive')
		elif args.json:
			js = json.loads(args.json)
		elif args.file == '-':
			js = json.load(sys.stdin)
		else:
			js = json.load(open(args.file, 'rb'))

		with inserter.db:
			if args.update_ids:
				inserter.update_data(args.update_ids, js)
			else:
				inserter.insert_data(js)
	finally:
		inserter.close()


if __name__ == '__main__':
	main()
