// SDD64 public/private key generator

// os
#include <windows.h>
#include <wincrypt.h>

// cstdlib
#include <stdio.h>

// defines
#define GETBIT(X,B) (((X)&((UINT64)1<<(B)))?1:0)

/*************************************************************
 *
 * MISC. FUNCTIONS
 *
 *************************************************************/

// calculates where to look in lsvec for the equations affected by the term x_i*x_j
INT IndexFromTerm(INT i, INT j)
{
    #define SUM_RANGE(A,B) (((A)-(B))*((A)+(B)+1)/2)
    if(i>j)
        { INT temp = i; i = j; j = temp; }
    return SUM_RANGE(64,64-i) + (j-i);
}

#define ERRCLEANUP(MSG) { ReportError(MSG); goto cleanup; }

void ReportError(char * fmtstr, ...)
{
    CHAR msg_caller[256];
    CHAR msg_gle[256];
    va_list v1;
    va_start(v1,fmtstr);
    vsprintf(msg_caller,fmtstr,v1);
    FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM|FORMAT_MESSAGE_IGNORE_INSERTS,0,GetLastError(),MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),msg_gle,256,0);
    printf("ERROR: %s\nGetLastError() reports: %s",msg_caller,msg_gle);
    va_end(v1);
}

/*************************************************************
 *
 * RANDOMIZING FUNCTIONS/VARIABLES
 *
 *************************************************************/

HCRYPTPROV hCryptProv=0;

BOOL InitRandom()
{
    BOOL bRet=0;

    if(!CryptAcquireContext(&hCryptProv,0,MS_DEF_PROV,PROV_RSA_FULL,0))
        if(!CryptAcquireContext(&hCryptProv,0,MS_DEF_PROV,PROV_RSA_FULL,CRYPT_NEWKEYSET))
            ERRCLEANUP("CryptAcquireContext()");

    bRet=1;
    cleanup:
    return bRet;
}

BOOL GetRandom(UINT64 *val)
{
    BOOL bRet=0;

    if(!CryptGenRandom(hCryptProv,8,(PBYTE)val)) 
        ERRCLEANUP("CryptGetRandom()");

    bRet=1;
    cleanup:
    return bRet;
}

BOOL ReleaseRandom()
{
    if(hCryptProv) 
        CryptReleaseContext(hCryptProv, 0);
            
    return 1;
}

/*************************************************************
 *
 * OUTPUT FUNCTIONS
 *
 *************************************************************/

BOOL Printf_File(HANDLE hFile, PCHAR fmt, ...)
{
    BOOL bRet=0;
    CHAR buff[1024];
    DWORD dwRet;

    va_list v1;
    va_start(v1, fmt);

    vsprintf(buff, fmt, v1);

    if(!WriteFile(hFile, buff, strlen(buff), &dwRet, 0))
        ERRCLEANUP("WriteFile()");

    bRet=1;
    cleanup:
    va_end(v1);
    return bRet;
}   

BOOL PrintA_File(HANDLE hFile, UINT64 *arr, UINT len)
{
    BOOL bRet=0;

    for(INT i=0; i<len; ++i)
    {
        if(!Printf_File(hFile, "0x%016I64X, ", arr[i]))
            ERRCLEANUP("Printf_File()");

        if(!((i+1)%4))
            if(!Printf_File(hFile, "\n"))
                ERRCLEANUP("Printf_File()");
    }

    cleanup:
    bRet=1;
    return bRet;
}

/*************************************************************
 *
 * INVERTIBLE TRANSFORM FUNCTIONS
 *
 *************************************************************/

BOOL MixRows(UINT64 *p, UINT64 *q, UINT n)
{
    BOOL bRet=0;

    // record steps
    WORD history[65536];
    INT history_i=0;

    // choose combinations of rows to add together
    for(INT i=0; i<65536; ++i)
    {
        UINT64 src=0,dst=0;

        do
        {
            if(!GetRandom(&src) || !GetRandom(&dst))
                ERRCLEANUP("GetRandom()");

            src %= n;
            dst %= n;
        } while(src==dst);

        p[dst] ^= p[src];
        history[history_i++] = (dst<<8) | src;
    }

    // playing steps backwards on q
    while(--history_i >= 0)
        q[history[history_i] >> 8] ^= q[history[history_i] & 0xFF];

    bRet=1;
    cleanup:
    return bRet;
}

BOOL RandomInputTransform(UINT64 *p, UINT64 *q)
{
    BOOL bRet=0;

    // start with identity matrix
    for(INT i=0; i<64; ++i)
        p[i]=q[i]=(UINT64)1<<(64-i-1);

    if(!MixRows(p, q, 64))
        goto cleanup;

    bRet=1;
    cleanup:
    return bRet;
}

BOOL RandomOutputTransform(UINT64 *p, UINT64 *q)
{
    BOOL bRet=0;

    // start with identity matrix just for x_59..x_00
    for(INT i=0; i<64; ++i)
        if(i<4)
            p[i]=q[i]=0;
        else
            p[i]=q[i]=(UINT64)1<<(64-i-1);

    if(!MixRows(p+4, q+4, 60))
        goto cleanup;

    bRet=1;
    cleanup:
    return bRet;
}

/*************************************************************
 *
 * MAIN
 *
 *************************************************************/

INT main(INT ac, PCHAR *av)
{
    // file handling vars
    WIN32_FIND_DATA wfd;
    HANDLE hFFFa=INVALID_HANDLE_VALUE, hFFFb=INVALID_HANDLE_VALUE;
    HANDLE hPub=INVALID_HANDLE_VALUE, hPri=INVALID_HANDLE_VALUE;

    // others
    CHAR buff[1024];
    UINT64 lsvec_temp[2080];

    // for study, args can optionally toggle security
    //
    BOOL bInputTransform=1;
    BOOL bOutputTransform=1;

    for(INT i=1; i<ac; ++i)
    {
        if(!strcmp(av[i],"-L1"))
            bInputTransform=0;
        if(!strcmp(av[i],"-L2"))
            bOutputTransform=0;
    }

    printf("doing input transform? %d\n", bInputTransform);
    printf("doing output transform? %d\n", bOutputTransform);

    // see if public_key.h or private_key.h already exist
    //
    hFFFa=FindFirstFile("public_key.h", &wfd);
    hFFFb=FindFirstFile("private_key.h", &wfd);

    if((hFFFa != INVALID_HANDLE_VALUE) || (hFFFb != INVALID_HANDLE_VALUE))
        ERRCLEANUP("public_key.h or private_key.h already exists, exiting...\n");

    hPub = CreateFile("public_key.h", GENERIC_WRITE, 0, 0, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
    if(hPub == INVALID_HANDLE_VALUE) 
        ERRCLEANUP("creating public_key.h");

    hPri = CreateFile("private_key.h", GENERIC_WRITE, 0, 0, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, 0);
    if(hPub == INVALID_HANDLE_VALUE) 
        ERRCLEANUP("creating private_key.h");

    // 
    //
    InitRandom();

    UINT64 lsvec[2080];

    // generate random equations
    //
    for(INT i=0; i<64; ++i)
        for(INT j=i; j<64; ++j)
        {
            UINT64 equsAffected;
            if(!GetRandom(&equsAffected))
                ERRCLEANUP("GetRandom()");

            if(i>=48 || j>=48) 
                equsAffected &= 0x0FFFE00000000000; 
            else if(i>=32 || j>=32)
                equsAffected &= 0x0FFFFFFFC0000000;
            else if(i>=16 || j>=16)
                equsAffected &= 0x0FFFFFFFFFFF8000;
            else
                equsAffected &= 0x0FFFFFFFFFFFFFFF;

            lsvec[IndexFromTerm(i,j)] = equsAffected;
        }

    Printf_File(hPri, "UINT64 lsvec[2080] = {\n");
    PrintA_File(hPri, lsvec, 2080);
    Printf_File(hPri, "};\n");
 
    // generate the input transform
    // 
    if(bInputTransform)
    {
        Printf_File(hPri, "#define DO_INPUT_TRANSFORM 1\n");

        UINT64 L1[64];
        UINT64 L1_inv[64];
        RandomInputTransform(L1, L1_inv);

        Printf_File(hPri, "UINT64 L1_inv[64] = {\n");
        PrintA_File(hPri, L1_inv, 64);
        Printf_File(hPri, "};\n");

        // apply the input transformation to the MQ system
        //
        memset(lsvec_temp, 0, 2080*sizeof(UINT64));

        // for every term x_i*x_j
        for(INT i=0; i<64; ++i)
            for(INT j=i; j<64; ++j)
            {
                // for every x_a*x_b that can result from the transforming expansion of x_i and x_j
                for(INT a=0; a<64; ++a)
                    for(INT b=0; b<64; ++b)
                    {
                        // if x_a*x_b results, it is added to every equation where x_i*x_j was present
                        if(GETBIT(L1[64-i-1],a) && GETBIT(L1[64-j-1],b))
                        {
                            lsvec_temp[IndexFromTerm(a,b)] ^= lsvec[IndexFromTerm(i,j)];
                        }
                    }
            }

        memcpy(lsvec, lsvec_temp, 2080*sizeof(UINT64));
    }
    else
        Printf_File(hPri, "#undef DO_INPUT_TRANSFORM\n");


    // generate the output transform
    //
    if(bOutputTransform)
    {
        Printf_File(hPri, "#define DO_OUTPUT_TRANSFORM 1\n");

        UINT64 L2[64];
        UINT64 L2_inv[64];
        RandomOutputTransform(L2, L2_inv);

        Printf_File(hPri, "UINT64 L2_inv[64] = {\n");
        PrintA_File(hPri, L2_inv, 64);
        Printf_File(hPri, "};\n");

        // apply it to the MQ system 
        //
        memset(lsvec_temp, 0, 2080*sizeof(UINT64));

        // for each equation y_i
        for(INT i=0; i<64; ++i)
            // for each summand in the transformation for y_i
            for(INT j=0; j<64; ++j)
            {
                // if y_j is a summand
                if(GETBIT(L2[64-i-1],j))
                {
                    // then for every term
                    for(INT k=0; k<2080; ++k)
                    {
                        // if the term is present in y_j
                        if(GETBIT(lsvec[k], j))
                        {
                            // change its presence instead to y_i
                            lsvec_temp[k] ^= (UINT64)1<<i;
                        }
                    }
                }
            }

        memcpy(lsvec, lsvec_temp, 2080*sizeof(UINT64));
    }
    else
        Printf_File(hPri, "#undef DO_OUTPUT_TRANSFORM\n");

    // final system is the public key!
    //
    Printf_File(hPub, "UINT64 lsvec_obf[2080] = {\n");
    PrintA_File(hPub, lsvec, 2080);
    Printf_File(hPub, "};\n");

    cleanup:
    ReleaseRandom();

    if(hFFFa != INVALID_HANDLE_VALUE)
        FindClose(hFFFa);
    if(hFFFb != INVALID_HANDLE_VALUE)
        FindClose(hFFFb);

    if(hPub != INVALID_HANDLE_VALUE)
        CloseHandle(hPub);
    if(hPri != INVALID_HANDLE_VALUE)
        CloseHandle(hPri);

    return 0;
}
