#include "fd_ed25519.h"
#include "fd_curve25519.h"

uchar * FD_FN_SENSITIVE
fd_ed25519_public_from_private( uchar         public_key [ static 32 ],
                                uchar const   private_key[ static 32 ],
                                fd_sha512_t * sha ) {

  //  RFC 8032 - Edwards-Curve Digital Signature Algorithm (EdDSA)
  //
  //  5.1.5.  Key Generation
  //
  //  The private key is 32 octets (256 bits, corresponding to b) of
  //  cryptographically secure random data.  See [RFC4086] for a discussion
  //  about randomness.
  //
  //  The 32-byte public key is generated by the following steps.
  //
  //  1.  Hash the 32-byte private key using SHA-512, storing the digest in
  //      a 64-octet large buffer, denoted h.  Only the lower 32 bytes are
  //      used for generating the public key.

  uchar s[ FD_SHA512_HASH_SZ ];
  fd_sha512_fini( fd_sha512_append( fd_sha512_init( sha ), private_key, 32UL ), s );

  //  2.  Prune the buffer: The lowest three bits of the first octet are
  //      cleared, the highest bit of the last octet is cleared, and the
  //      second highest bit of the last octet is set.

  s[ 0] &= (uchar)0xF8;
  s[31] &= (uchar)0x7F;
  s[31] |= (uchar)0x40;

  //  3.  Interpret the buffer as the little-endian integer, forming a
  //      secret scalar s.  Perform a fixed-base scalar multiplication
  //      [s]B.

  fd_ed25519_point_t sB[1];
  fd_ed25519_scalar_mul_base_const_time( sB, s );

  //  4.  The public key A is the encoding of the point [s]B.  First,
  //      encode the y-coordinate (in the range 0 <= y < p) as a little-
  //      endian string of 32 octets.  The most significant bit of the
  //      final octet is always zero.  To form the encoding of the point
  //      [s]B, copy the least significant bit of the x coordinate to the
  //      most significant bit of the final octet.  The result is the
  //      public key.

  fd_ed25519_point_tobytes( public_key, sB );

  /* Sanitize */

  fd_memset_explicit( s, 0, FD_SHA512_HASH_SZ );
  fd_sha512_clear( sha );

  return public_key;
}

uchar * FD_FN_SENSITIVE
fd_ed25519_sign( uchar         sig[ static 64 ],
                 uchar const   msg[], /* msg_sz */
                 ulong         msg_sz,
                 uchar const   public_key[ static 32 ],
                 uchar const   private_key[ static 32 ],
                 fd_sha512_t * sha ) {

  //  RFC 8032 - Edwards-Curve Digital Signature Algorithm (EdDSA)
  //
  //  5.1.6.  Sign
  //
  //  The inputs to the signing procedure is the private key, a 32-octet
  //  string, and a message M of arbitrary size.  For Ed25519ctx and
  //  Ed25519ph, there is additionally a context C of at most 255 octets
  //  and a flag F, 0 for Ed25519ctx and 1 for Ed25519ph.
  //
  //  1.  Hash the private key, 32 octets, using SHA-512.  Let h denote the
  //      resulting digest.  Construct the secret scalar s from the first
  //      half of the digest, and the corresponding public key A, as
  //      described in the previous section.  Let prefix denote the second
  //      half of the hash digest, h[32],...,h[63].

  uchar s[ FD_SHA512_HASH_SZ ];
  fd_sha512_fini( fd_sha512_append( fd_sha512_init( sha ), private_key, 32UL ), s );
  s[ 0] &= (uchar)0xF8;
  s[31] &= (uchar)0x7F;
  s[31] |= (uchar)0x40;
  uchar * h = s + 32;

  /* public_key is an input */

  //  2.  Compute SHA-512(dom2(F, C) || prefix || PH(M)), where M is the
  //      message to be signed.  Interpret the 64-octet digest as a little-
  //      endian integer r.

  uchar r[ FD_SHA512_HASH_SZ ];
  fd_sha512_fini( fd_sha512_append( fd_sha512_append( fd_sha512_init( sha ), h, 32UL ), msg, msg_sz ), r );

  //  3.  Compute the point [r]B.  For efficiency, do this by first
  //      reducing r modulo L, the group order of B.  Let the string R be
  //      the encoding of this point.

  fd_curve25519_scalar_reduce( r, r );           /* reduce r mod L */
  fd_ed25519_point_t R[1];
  fd_ed25519_scalar_mul_base_const_time( R, r ); /* R = [r]B */
  fd_ed25519_point_tobytes( sig, R );

  //  4.  Compute SHA512(dom2(F, C) || R || A || PH(M)), and interpret the
  //      64-octet digest as a little-endian integer k.

  /* note: all inputs to k are public values */
  uchar k[ FD_SHA512_HASH_SZ ];
  fd_sha512_fini( fd_sha512_append( fd_sha512_append( fd_sha512_append( fd_sha512_init( sha ),
                  sig, 32UL ), public_key, 32UL ), msg, msg_sz ), k );

  //  5.  Compute S = (r + k * s) mod L.  For efficiency, again reduce k
  //      modulo L first.
  //
  //  6.  Form the signature of the concatenation of R (32 octets) and the
  //      little-endian encoding of S (32 octets; the three most
  //      significant bits of the final octet are always zero).

  fd_curve25519_scalar_reduce( k, k );
  fd_curve25519_scalar_muladd( ((uchar *)sig)+32, k, s, r );

  /* Sanitize */

  /* note: no need to sanitize k as all inputs to k are public values */
  fd_memset_explicit( s, 0, FD_SHA512_HASH_SZ );
  fd_memset_explicit( r, 0, FD_SHA512_HASH_SZ );
  fd_sha512_clear( sha );

  return sig;
}

int
fd_ed25519_verify( uchar const   msg[], /* msg_sz */
                   ulong         msg_sz,
                   uchar const   sig[ static 64 ],
                   uchar const   public_key[ static 32 ],
                   fd_sha512_t * sha ) {

  //  RFC 8032 - Edwards-Curve Digital Signature Algorithm (EdDSA)
  //
  //  5.1.7.  Verify
  //
  //  1.  To verify a signature on a message M using public key A, with F
  //      being 0 for Ed25519ctx, 1 for Ed25519ph, and if Ed25519ctx or
  //      Ed25519ph is being used, C being the context, first split the
  //      signature into two 32-octet halves.  Decode the first half as a
  //      point R, and the second half as an integer S, in the range
  //      0 <= s < L.  Decode the public key A as point A'.  If any of the
  //      decodings fail (including S being out of range), the signature is
  //      invalid.

  uchar const * r = sig;
  uchar const * S = sig + 32;

  /* Check scalar s */
  if( FD_UNLIKELY( !fd_curve25519_scalar_validate( S ) )) {
    return FD_ED25519_ERR_SIG;
  }

  /* Decompress public_key and point r, concurrently */
  fd_ed25519_point_t Aprime[1], R[1];
  int res = fd_ed25519_point_frombytes_2x( Aprime, public_key,   R, r );

  /* Check public key and point r:
     1. both public key and point r decompress successfully (RFC)
     2. both public key and point r are small order (verify_strict)

     There's another check that we currently do NOT enforce:
     whether public key and point r are canonical.
     Dalek 2.x (currently used by Agave) does NOT do any check.
     Dalek 4.x checks that the point r is canonical, but accepts
     a non canonical public key.

     Note: I couldn't find any test with non canonical points
     (all tests are non canonical + low order, that are excluded by
     the verify_strict rule). The reason is that to write such a
     test one needs to know the discrete log of a non canonical point.

     The following code checks that r is canonical (we can add it
     in when Agave switches to dalek 4.x).

        uchar compressed[ 32 ];
        fd_ed25519_affine_tobytes( compressed, R );
        if( FD_UNLIKELY( !fd_memeq( compressed, r, 32 ) ) ) {
          return FD_ED25519_ERR_SIG;
        }
    */
  if( FD_UNLIKELY( res ) ) {
    return res == 1 ? FD_ED25519_ERR_PUBKEY : FD_ED25519_ERR_SIG;
  }
  if( FD_UNLIKELY( fd_ed25519_affine_is_small_order(Aprime) ) ) {
    return FD_ED25519_ERR_PUBKEY;
  }
  if( FD_UNLIKELY( fd_ed25519_affine_is_small_order(R) ) ) {
    return FD_ED25519_ERR_SIG;
  }

  //  2.  Compute SHA512(dom2(F, C) || R || A || PH(M)), and interpret the
  //      64-octet digest as a little-endian integer k.

  uchar k[ 64 ];
  fd_sha512_fini( fd_sha512_append( fd_sha512_append( fd_sha512_append( fd_sha512_init( sha ),
                  r, 32UL ), public_key, 32UL ), msg, msg_sz ), k );
  fd_curve25519_scalar_reduce( k, k );

  //  3.  Check the group equation [8][S]B = [8]R + [8][k]A'.  It's
  //      sufficient, but not required, to instead check [S]B = R + [k]A'.

  /* Compute R = [k](-A') + [S]B, with B base point.
     Note: this is not the same as R = [-k]A' + [S]B, because the order
     of A' is 8l (computing -k mod 8l would work). */
  fd_ed25519_point_t Rcmp[1];
  fd_ed25519_point_neg( Aprime, Aprime );
  fd_ed25519_double_scalar_mul_base( Rcmp, k, Aprime, S );

  /* Compare R (computed) and R from signature.
     Note: many implementations do this comparison by compressing Rcmd,
     and compare it against the r buf as it appears in the signature.
     This implicitly prevents non-canonical R.
     However this also hides a field inv to compress Rcmp.
     In our implementation we compare the points (see the comment
     above on "Check public key and point r" for details). */
  if( FD_LIKELY( fd_ed25519_point_eq_z1( Rcmp, R ) ) ) {
    return FD_ED25519_SUCCESS;
  }
  return FD_ED25519_ERR_MSG;
}

int fd_ed25519_verify_batch_single_msg( uchar const   msg[], /* msg_sz */
                                        ulong const   msg_sz,
                                        uchar const   signatures[ static 64 ], /* 64 * batch_sz */
                                        uchar const   pubkeys[ static 32 ],    /* 32 * batch_sz */
                                        fd_sha512_t * shas[ 1 ],               /* batch_sz */
                                        uchar const   batch_sz ) {
#define MAX 16
  if( FD_UNLIKELY( batch_sz == 0 || batch_sz > MAX ) ) {
    return FD_ED25519_ERR_SIG;
  }

#if 0
  /* Naive */
  for( uchar i=0; i<batch_sz; i++ ) {
    int res = fd_ed25519_verify( msg, msg_sz, &signatures[ i*64 ], &pubkeys[ i*32 ], shas[0] );
    if( FD_UNLIKELY( res != FD_ED25519_SUCCESS ) ) {
      return res;
    }
  }
  return FD_ED25519_SUCCESS;
#else

  fd_ed25519_point_t R     [MAX];
  fd_ed25519_point_t Aprime[MAX];
  uchar              k     [MAX * 32];

  /* The first batch_sz points are the R_j, the last are A'_j.
     Scalars will be stored accordingly. */

  /* First, we validate scalars, decompress public keys and points R_j,
     check low order points, and compute k_j.
     TODO: optimize, this is 20% of the total time. */
  for( int j=0; j<batch_sz; j++ ) {

    uchar const * r = signatures + 64*j;
    uchar const * S = signatures + 32 + 64*j;
    uchar const * public_key = pubkeys + 32*j;

    /* Check scalar s */
    if( FD_UNLIKELY( !fd_curve25519_scalar_validate( S ) )) {
      return FD_ED25519_ERR_SIG;
    }

    /* Decompress public_key and point r, concurrently */
    int res = fd_ed25519_point_frombytes_2x( &Aprime[j], public_key,   &R[j], r );

    /* Check public key and point r */
    if( FD_UNLIKELY( res ) ) {
      return res == 1 ? FD_ED25519_ERR_PUBKEY : FD_ED25519_ERR_SIG;
    }
    if( FD_UNLIKELY( fd_ed25519_affine_is_small_order(&Aprime[j]) ) ) {
      return FD_ED25519_ERR_PUBKEY;
    }
    if( FD_UNLIKELY( fd_ed25519_affine_is_small_order(&R[j]) ) ) {
      return FD_ED25519_ERR_SIG;
    }

    /* Compute scalars k_j */
    uchar _k[ 64 ];
    fd_sha512_fini( fd_sha512_append( fd_sha512_append( fd_sha512_append( fd_sha512_init( shas[j] ),
                    r, 32UL ), public_key, 32UL ), msg, msg_sz ), _k );
    fd_curve25519_scalar_reduce( &k[32*j], _k );
  }

  fd_ed25519_point_t res[1];
  for( uchar j=0; j<batch_sz; j++ ) {
    uchar const * S = signatures + 32 + 64*j;

    fd_ed25519_point_neg( &Aprime[j], &Aprime[j] );
    fd_ed25519_double_scalar_mul_base( res, &k[32*j], &Aprime[j], S );
    if( FD_UNLIKELY( !fd_ed25519_point_eq_z1( res, &R[j] ) ) ) {
      return FD_ED25519_ERR_MSG;
    }

  }
  return FD_ED25519_SUCCESS;
#endif
#undef MAX
}

char const *
fd_ed25519_strerror( int err ) {
  switch( err ) {
  case FD_ED25519_SUCCESS:    return "success";
  case FD_ED25519_ERR_SIG:    return "bad signature";
  case FD_ED25519_ERR_PUBKEY: return "bad public key";
  case FD_ED25519_ERR_MSG:    return "bad message";
  default: break;
  }
  return "unknown";
}
