/*
 *   This program is free software; you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation; either version 2 of the License, or (at
 *   your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program; if not, write to the Free Software
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
 */

/**
 * $Id: 3f2145ae9ddcdc921489fa204c8300ebf3cc0624 $
 * @file eap_fast_crypto.c
 * @brief Cryptographic functions for EAP-FAST.
 *
 * @author Alexander Clouter (alex@digriz.org.uk)
 *
 * @copyright 2016 Alan DeKok (aland@freeradius.org)
 * @copyright 2016 The FreeRADIUS server project
 */
RCSID("$Id: 3f2145ae9ddcdc921489fa204c8300ebf3cc0624 $")
USES_APPLE_DEPRECATED_API	/* OpenSSL API has been deprecated by Apple */

#include <stdio.h>
#include <freeradius-devel/tls/base.h>
#include <freeradius-devel/eap/tls.h>

#include <freeradius-devel/tls/openssl_user_macros.h>
#include <openssl/evp.h>
#include <openssl/aes.h>
#include <openssl/err.h>

#include "eap_fast_crypto.h"

/*  EAP-FAST Pseudo-Random Function (T-PRF): RFC 4851, Section 5.5 */
void T_PRF(unsigned char const *secret, unsigned int secret_len,
	   char const *prf_label,
	   unsigned char const *seed, unsigned int seed_len,
	   unsigned char *out, unsigned int out_len)
{
	size_t prf_size = strlen(prf_label);
	size_t pos;
	uint8_t	*buf;

	if (prf_size > 128) prf_size = 128;
	prf_size++;	/* include trailing zero */

	buf = talloc_size(NULL, SHA1_DIGEST_LENGTH + prf_size + seed_len + 2 + 1);

	memcpy(buf + SHA1_DIGEST_LENGTH, prf_label, prf_size);
	if (seed) memcpy(buf + SHA1_DIGEST_LENGTH + prf_size, seed, seed_len);
	*(uint16_t *)&buf[SHA1_DIGEST_LENGTH + prf_size + seed_len] = htons(out_len);
	buf[SHA1_DIGEST_LENGTH + prf_size + seed_len + 2] = 1;

	// T1 is just the seed
	fr_hmac_sha1(buf, buf + SHA1_DIGEST_LENGTH, prf_size + seed_len + 2 + 1, secret, secret_len);

#define MIN(a,b) (((a)>(b)) ? (b) : (a))
	memcpy(out, buf, MIN(out_len, SHA1_DIGEST_LENGTH));

	pos = SHA1_DIGEST_LENGTH;
	while (pos < out_len) {
		buf[SHA1_DIGEST_LENGTH + prf_size + seed_len + 2]++;

		fr_hmac_sha1(buf, buf, SHA1_DIGEST_LENGTH + prf_size + seed_len + 2 + 1, secret, secret_len);
		memcpy(&out[pos], buf, MIN(out_len - pos, SHA1_DIGEST_LENGTH));

		if (out_len - pos <= SHA1_DIGEST_LENGTH)
			break;

		pos += SHA1_DIGEST_LENGTH;
	}

	memset(buf, 0, SHA1_DIGEST_LENGTH + prf_size + seed_len + 2 + 1);
	talloc_free(buf);
}

// http://stackoverflow.com/a/29838852
static NEVER_RETURNS void handleErrors(void)
{
	unsigned long errCode;

	fprintf(stderr, "An error occurred\n");
	while((errCode = ERR_get_error()))
	{
		char *err = ERR_error_string(errCode, NULL);
		fprintf(stderr, "%s\n", err);
	}
	abort();
}

// https://wiki.openssl.org/index.php/EVP_Authenticated_Encryption_and_Decryption#Authenticated_Encryption_using_GCM_mode
int eap_fast_encrypt(uint8_t const *plaintext, size_t plaintext_len,
		     uint8_t const *aad, size_t aad_len,
		     uint8_t const *key, uint8_t *iv, unsigned char *ciphertext,
		     uint8_t *tag)
{
	EVP_CIPHER_CTX *ctx;

	int len;

	int ciphertext_len;


	/* Create and initialise the context */
	if (!(ctx = EVP_CIPHER_CTX_new())) handleErrors();

	/* Initialise the encryption operation. */
	if (1 != EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL))
		handleErrors();

	/* Set IV length if default 12 bytes (96 bits) is not appropriate */
	if (1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, 16, NULL))
		handleErrors();

	/* Initialise key and IV */
	if (1 != EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv)) handleErrors();

	/* Provide any AAD data. This can be called zero or more times as
	 * required
	 */
	if (1 != EVP_EncryptUpdate(ctx, NULL, &len, aad, aad_len))
		handleErrors();

	/* Provide the message to be encrypted, and obtain the encrypted output.
	 * EVP_EncryptUpdate can be called multiple times if necessary
	 */
	if (1 != EVP_EncryptUpdate(ctx, ciphertext, &len, plaintext, plaintext_len))
		handleErrors();
	ciphertext_len = len;

	/* Finalise the encryption. Normally ciphertext bytes may be written at
	 * this stage, but this does not occur in GCM mode
	 */
	if (1 != EVP_EncryptFinal_ex(ctx, ciphertext + len, &len)) handleErrors();
	ciphertext_len += len;

	/* Get the tag */
	if (1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, tag))
		handleErrors();

	/* Clean up */
	EVP_CIPHER_CTX_free(ctx);

	return ciphertext_len;
}

int eap_fast_decrypt(uint8_t const *ciphertext, size_t ciphertext_len,
		     uint8_t const *aad, size_t aad_len,
		     uint8_t const *tag, uint8_t const *key, uint8_t const *iv, uint8_t *plaintext)
{
	EVP_CIPHER_CTX *ctx;
	int len;
	int plaintext_len;
	int ret;

	/* Create and initialise the context */
	if (!(ctx = EVP_CIPHER_CTX_new())) handleErrors();

	/* Initialise the decryption operation. */
	if (!EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL))
		handleErrors();

	/* Set IV length. Not necessary if this is 12 bytes (96 bits) */
	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, 16, NULL))
		handleErrors();

	/* Initialise key and IV */
	if (!EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv)) handleErrors();

	/* Provide any AAD data. This can be called zero or more times as
	 * required
	 */
	if (!EVP_DecryptUpdate(ctx, NULL, &len, aad, aad_len))
		handleErrors();

	/* Provide the message to be decrypted, and obtain the plaintext output.
	 * EVP_DecryptUpdate can be called multiple times if necessary
	 */
	if (!EVP_DecryptUpdate(ctx, plaintext, &len, ciphertext, ciphertext_len))
		handleErrors();
	plaintext_len = len;

	{
		unsigned char *tmp;

		memcpy(&tmp, &tag, sizeof(tmp));

		/* Set expected tag value. Works in OpenSSL 1.0.1d and later */
		if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, tmp)) handleErrors();
	}

	/* Finalise the decryption. A positive return value indicates success,
	 * anything else is a failure - the plaintext is not trustworthy.
	 */
	ret = EVP_DecryptFinal_ex(ctx, plaintext + len, &len);

	/* Clean up */
	EVP_CIPHER_CTX_free(ctx);

	if (ret > 0)
	{
		/* Success */
		plaintext_len += len;
		return plaintext_len;
	}
	else
	{
		/* Verify failed */
		return -1;
	}
}


static void crypto_rfc4346_p_hash(uint8_t *out, size_t out_len,
				  EVP_MD const *evp_md,
				  uint8_t const *secret, size_t secret_len,
				  uint8_t const *seed,  size_t seed_len)
{
	EVP_MD_CTX *ctx_a, *ctx_out;
	EVP_PKEY *pkey;

	uint8_t a[EVP_MAX_MD_SIZE];
	size_t size;

	ctx_a = EVP_MD_CTX_new();
	ctx_out = EVP_MD_CTX_new();

	MEM(pkey = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, secret, secret_len));

	EVP_DigestSignInit(ctx_a, NULL, evp_md, NULL, pkey);
	EVP_DigestSignInit(ctx_out, NULL, evp_md, NULL, pkey);

	size = EVP_MD_size(evp_md);

	/* Calculate A(1) */
	EVP_DigestSignUpdate(ctx_a, seed, seed_len);

	/*
	 *	OpenSSL <= 1.1.1 requires a non-null pointer for len
	 */
	EVP_DigestSignFinal(ctx_a, a, &(size_t){ sizeof(a) });

	while (1) {
		/* Calculate next part of output */
		EVP_DigestSignUpdate(ctx_out, a, size);
		EVP_DigestSignUpdate(ctx_out, seed, seed_len);

		/* Check if last part */
		if (out_len < size) {
			EVP_DigestSignFinal(ctx_out, a, &(size_t){ sizeof(a) });
			memcpy(out, a, out_len);
			break;
		}

		/* Place digest in output buffer */
		EVP_DigestSignFinal(ctx_out, out, &(size_t){ EVP_MAX_MD_SIZE });
		EVP_MD_CTX_reset(ctx_out);

		EVP_DigestSignInit(ctx_out, NULL, evp_md, NULL, pkey);
		out += size;
		out_len -= size;

		/* Calculate next A(i) */
		EVP_MD_CTX_reset(ctx_a);
		EVP_DigestSignInit(ctx_a, NULL, evp_md, NULL, pkey);
		EVP_DigestSignUpdate(ctx_a, a, size);
		EVP_DigestSignFinal(ctx_a, a, &(size_t){ EVP_MAX_MD_SIZE });
	}

	EVP_PKEY_free(pkey);

	EVP_MD_CTX_free(ctx_a);
	EVP_MD_CTX_free(ctx_out);
#ifdef __STDC_LIB_EXT1__
	memset_s(a, 0, sizeof(a), sizeof(a));
#else
	memset(a, 0, sizeof(a));
#endif
}


static void eap_crypto_rfc4346_prf(uint8_t *out, size_t out_len, uint8_t *scratch,
				   uint8_t const *secret, size_t secret_len,
				   uint8_t const *seed, size_t seed_len)
{
	unsigned int	i;
	unsigned int	len = (secret_len + 1) / 2;
	uint8_t const	*s1 = secret;
	uint8_t const	*s2 = secret + (secret_len - len);

	crypto_rfc4346_p_hash(out, out_len, EVP_md5(), s1, len, seed, seed_len);
	crypto_rfc4346_p_hash(scratch, out_len, EVP_sha1(), s2, len, seed, seed_len);

	for (i = 0; i < out_len; i++) out[i] ^= scratch[i];
}


/*
 *	Same as before, but for EAP-FAST the order of {server,client}_random is flipped
 */
void eap_fast_tls_gen_challenge(SSL *s, uint8_t *buffer, uint8_t *scratch, size_t size, char const *prf_label)
{
	uint8_t		*p;
	size_t		len, master_key_len;
	uint8_t		seed[128 + (2 * SSL3_RANDOM_SIZE)];
	uint8_t		master_key[SSL_MAX_MASTER_KEY_LENGTH];

	len = strlen(prf_label);
	if (len > 128) len = 128;

	p = seed;
	memcpy(p, prf_label, len);
	p += len;
	(void) SSL_get_server_random(s, p, SSL3_RANDOM_SIZE);
	p += SSL3_RANDOM_SIZE;
	(void) SSL_get_client_random(s, p, SSL3_RANDOM_SIZE);
	p += SSL3_RANDOM_SIZE;

	master_key_len = SSL_SESSION_get_master_key(SSL_get_session(s), master_key, sizeof(master_key));
	eap_crypto_rfc4346_prf(buffer, size, scratch, master_key, master_key_len, seed, p - seed);
}
