#include "qptool2.h"
#include "hbs_defs.h"
#include "xmss_defs.h"
#include "load_store.h"

static CK_BBOOL bTrue = 1;
static CK_BBOOL bFalse = 0;

#ifndef min
#define min(a, b) ((a) < (b) ? (a) : (b))
#endif
extern void note2(char *t, int a, int b);

// EXAMPLE

int xmss_int_gen_key(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, CK_OBJECT_HANDLE_PTR p_keyHandle, unsigned long asToken)
{
    int                    err = 0;

    CK_UTF8CHAR_PTR   keyLabel = (CK_UTF8CHAR_PTR)"xmss_key";
    CK_ULONG          l_keyLabel = (CK_ULONG)strlen((char *)keyLabel);
    CK_BBOOL          token = (asToken ? 1 : 0);

    CK_OBJECT_CLASS   keyClass = CKO_SECRET_KEY;
    CK_KEY_TYPE       keyType = CKK_GENERIC_SECRET;
    CK_ULONG          byteLength = 32; // not important, required by C_Generate VDM

	unsigned short auxsize = 8196; // 16k max,  high == fewer keys, low == longer times

    CK_MECHANISM      mechanism;
    CK_ULONG          l_mechParam = 5;
    CK_BYTE           mechParams[5]; // see below.  For XMSS, fixed 5 bytes

    /* ************************************************************************
     * Warning:  These keys take a LONG time to generate.  For this test code, 
     * the SHA2_10_256 oid is recommended.  Worst case scenario, key generation
     * times measured in days, and the HSM connection will time out.
     * 
     * When this happens (the timeout), your application will return messages to
     * that effect.  >>> THE KEY IS STILL BEING GENERATED <<< 
     * The HSM will be unresponsive.  Just let it be, occasionally do a GetState
     * to see if it is done.  The key (assuming no errors) will have been generated
     * and will be available to sign things.
     *
     * There is an Asynchronous call technique due in a future update.
     * ************************************************************************ */
    CK_CHAR_PTR         oid_s = (CK_CHAR_PTR)"XMSS-SHA2_10_256";
    CK_BYTE             oid = 0; // will be set later based on the oid_s
    CK_ULONG            mt_oid = CK_FALSE; // will be set later based on the oid_s

    
    CK_ATTRIBUTE         keyTemplate[] =
    {
            { CKA_CLASS,         &keyClass,         sizeof(keyClass) },
            { CKA_KEY_TYPE,     &keyType,         sizeof(keyType) },
            { CKA_TOKEN,         &token,         sizeof(token) },
            { CKA_LABEL,         keyLabel,         l_keyLabel },
            { CKA_DERIVE,         &bTrue,         sizeof(bTrue) },
            { CKA_VALUE_LEN,     &byteLength,     sizeof(byteLength) },
    };

    // Prepare mechanism parameters
    unsigned char *pmp = mechParams;
    // B1|B2|B3
    // B1 - 1 byte, RNG type
    // B2 - 1 byte, MT t/f
    // B3 - 1 byte, OID
	// B4 - 2 bytes, auxsize
	// B5
    err = xmss_str_to_oid(&mt_oid, &oid, oid_s);
    if (err) {
        printf("Invalid oid:  %s - not found\n", oid_s);
        goto cleanup;
    }

    store_int1(HBS_RNG_TYPE_PSEUDO, pmp++);
    store_int1(mt_oid, pmp++);
    store_int1(oid, pmp++);
	
	//2 bytes - optional, size of auxilary data
	*pmp = ((auxsize & 0xFF00) >> 8);
	++pmp;
	*pmp = (auxsize & 0x00FF);
	++pmp;


    mechanism.mechanism = HBS_MECH_XMSS_GENKEY;
    mechanism.pParameter = mechParams;
    mechanism.ulParameterLen = l_mechParam;
    
    entry("xmss_int_gen_key:C_GenerateKey");
    err = pFunctions->C_GenerateKey(hSession, 
        &mechanism,
        keyTemplate,
        sizeof(keyTemplate) / sizeof(CK_ATTRIBUTE),
        p_keyHandle);
    exuent("xmss_int_gen_key:C_GenerateKey", err);
    if (err != 0x0) {
        printf("[genkey]: C_GenerateKey returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    return err;
}



// EXAMPLE

int xmss_int_get_pubkey(
    CK_FUNCTION_LIST_PTR pFunctions, 
    CK_SESSION_HANDLE hSession, 
    CK_OBJECT_HANDLE baseKey, 
    CK_OBJECT_HANDLE_PTR p_pubHandle)
{
    int                 err = 0;
    CK_OBJECT_CLASS     keyClass = CKO_SECRET_KEY;

    // any symetric key type, only temporary
    CK_KEY_TYPE         keyType = CKK_AES;
    CK_ULONG             byteLength = 256 / 8;
    CK_UTF8CHAR_PTR     keyLabel = (CK_UTF8CHAR_PTR)"tmp_xmss_key";
    CK_ULONG            l_keyLabel = (CK_ULONG)strlen((char *)keyLabel);
    CK_ATTRIBUTE         keyTemplate[] =
    {
            { CKA_CLASS,         &keyClass,         sizeof(keyClass) },
            { CKA_KEY_TYPE,     &keyType,         sizeof(keyType) },
            { CKA_TOKEN,         &bFalse,         sizeof(bFalse) },
            { CKA_LABEL,         keyLabel,         l_keyLabel },
            { CKA_VALUE_LEN,     &byteLength,     sizeof(byteLength) }
    };

    CK_ULONG             l_keyTemplate = sizeof(keyTemplate) / sizeof(CK_ATTRIBUTE);
    CK_MECHANISM         mechanism;

    // There are two concepts on "retrieve the public key".
    // First, "we are getting the public key so that we can verify the thing"
    // Second, "we are getting the public key, so that we can give it to a trading partner to verify the thing"
   
    // things needed by the second type 
    CK_BYTE             p_pubData[21504]; // Maximum size of data 
    CK_ULONG             l_pubData = 21504;

    CK_ATTRIBUTE         pubDataFetch[] =
    {
        { CKA_UTI_CUSTOM_DATA, p_pubData, l_pubData }
    };

    CK_ULONG             l_pubDataFetch = sizeof(pubDataFetch) / sizeof(CK_ATTRIBUTE);

    // This is the first representation
    // Get the public data, which is the information needed to validate the signature internally
    
    mechanism.mechanism = HBS_MECH_XMSS_GET_PUBKEY;
    mechanism.pParameter = NULL;
    mechanism.ulParameterLen = 0;

    // Create new key, value is random, but pubdata has public key
    entry("xmss_int_get_pubkey: C_DeriveKey");
    err = pFunctions->C_DeriveKey(hSession, &mechanism,
        baseKey,
        keyTemplate,
        l_keyTemplate,
        p_pubHandle);
    exuent("xmss_int_get_pubkey: C_DeriveKey", err);
    if (err != 0x0) {
        printf("[get_pub_key]: C_DeriveKey returned 0x%08x\n", err);
        goto cleanup;
    }

    // get pubData = public key
    // This is the second representation
    entry("xmss_int_get_pubkey: C_GetAttributeValue");
    err = pFunctions->C_GetAttributeValue(hSession,
        *p_pubHandle,
        pubDataFetch,
        l_pubDataFetch);
    exuent("xmss_int_get_pubkey: C_GetAttributeValue", err);
    if (err) {
        printf("[get_pub_key]: C_GetAttributeValue returned 0x%08x\n", err);
        goto cleanup;
    }

    l_pubData = pubDataFetch[0].ulValueLen;
    
    if (l_pubData == CK_UNAVAILABLE_INFORMATION)
    {
        err = CKR_ATTRIBUTE_VALUE_INVALID;
        printf("[get_pub_key]: Couldn't extract attribute.\n");
        goto cleanup;
    }

    cs_xprint("Verifier (pubkey) information", pubDataFetch[0].pValue, 32);
    printf("\n");

cleanup:
    // nothing to clean up here
    return err;
}

// ----------------------------------------------------
// EXAMPLE Single Data Block

int xmss_sign_all(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG_PTR            ll_sig,
    int counter)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };
    CK_ULONG                 sig_len = (CK_ULONG)27688; // longest OID signature;

    if (*ll_sig < sig_len) {
        if (*ll_sig == 0) {
            *ll_sig = sig_len;
            return 0;
        }
        return CKR_BUFFER_TOO_SMALL;
    }
    if (l_msg > 250 * 1024) {
        return 0x111L; // ERROR_BUFFER_OVERFLOW;
    }

    printf("\nLength of message:  %ld\n", l_msg);
    cs_xprint("S: Message (one-shot) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    CK_MECHANISM             mechanism;
    mechanism.mechanism = HBS_MECH_XMSS_SIGN;

    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    for (int i = 0; i < counter; i++) {
        note2("Sign loop count", i + 1, counter);
        entry("sign_all:C_SignInit");
        err = pFunctions->C_SignInit(hSession, &mechanism, hKey);
        exuent("sign_all:C_SignInit", err);
        if (err != CKR_OK) {
            printf("[sign_all]: C_SignInit returned 0x%08x\n", err);
            goto cleanup;
        }

        entry("sign_all:C_Sign");
        err = pFunctions->C_Sign(hSession, p_msg, l_msg, p_sig, ll_sig);
        exuent("sign_all:C_Sign", err);
        if (err != CKR_OK) {
            printf("[sign_all]: C_Sign returned 0x%08x\n", err);
            goto cleanup;
        }

        cs_xprint("S: Signature (one-shot) (truncated)", p_sig, (*ll_sig > 64) ? 64 : *ll_sig);
        printf("\n");
    }

cleanup:
    // nothing to clean up here
    return err;
}

int xmss_verify_all(
    CK_FUNCTION_LIST_PTR pFunctions,
        CK_SESSION_HANDLE hSession,
        CK_OBJECT_HANDLE         hKey,
        CK_BYTE_PTR             p_msg,
        CK_ULONG                l_msg,
        CK_BYTE_PTR                p_sig,
        CK_ULONG                l_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };

    CK_MECHANISM             mechanism;
    mechanism.mechanism = HBS_MECH_XMSS_VERIFY;
    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    if (l_msg > 250 * 1024) {
        return 0x111L; // ERROR_BUFFER_OVERFLOW;
    }

    cs_xprint("V: Signature (one-shot) (truncated)", p_sig, (l_sig > 64) ? 64 : l_sig);
    printf("Length of message:  %ld\n", l_msg);
    cs_xprint("V: Message (one-shot) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    entry("verify_all:C_VerifyInit");
    err = pFunctions->C_VerifyInit(hSession, &mechanism, hKey);
    exuent("verify_all:C_VerifyInit", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    entry("verify_all:C_Verify");
    err = pFunctions->C_Verify(hSession, p_msg, l_msg, p_sig, l_sig);
    exuent("verify_all:C_Verify", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_Verify returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("Signature verified.\n");
    printf("\n");

cleanup:
    return err;
}


int xmss_sign_chunked(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG_PTR            ll_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };
    unsigned long            sig_len = 27688; // longest OID signature
    CK_ULONG sent = 0, chunk_size = 4096;

    CK_MECHANISM             mechanism;

    if (*ll_sig < sig_len) {
        if (*ll_sig == 0) {
            *ll_sig = sig_len;
            return 0;
        }
        return CKR_BUFFER_TOO_SMALL;
    }

    printf("\nLength of message:  %ld\n", l_msg);
    cs_xprint("V: Message (multipass) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    mechanism.mechanism = HBS_MECH_XMSS_SIGN;
    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    entry("sign_chunked:C_SignInit");
    err = pFunctions->C_SignInit(hSession, &mechanism, hKey);
    exuent("sign_chunked:C_SignInit", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_SignInit returned 0x%08x\n", err);
        goto cleanup;
    }

    while (sent < l_msg) {
        CK_ULONG len = min(chunk_size, l_msg - sent);
        entry("sign_chunked:C_SignUpdate");
        err = pFunctions->C_SignUpdate(hSession, p_msg + sent, len);
        exuent("sign_chunked:C_SignUpdate", err);
        if (err != CKR_OK) {
            printf("[sign_chunked]: C_SignUpdate returned 0x%08x\n", err);
            goto cleanup;
        }
        sent += len;
    }

    entry("sign_chunked:C_SignFinal");
    err = pFunctions->C_SignFinal(hSession, p_sig, ll_sig);
    exuent("sign_chunked:C_SignFinal", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_SignFinal returned 0x%08x\n", err);
        goto cleanup;
    }

    cs_xprint("S: Signature (multipass) (truncated)", p_sig, (*ll_sig > 64) ? 64 : *ll_sig);
    printf("\n");

cleanup:
    return err;
}

int xmss_verify_chunked(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG                l_sig)
{
    int                     err = 0;
    CK_ULONG sent = 0, chunk_size = 4096;

    // SIGNATURE VERIFICATION
    CK_MECHANISM             mechanism;

    mechanism.mechanism = HBS_MECH_XMSS_VERIFY;
    mechanism.pParameter = p_sig; 
    mechanism.ulParameterLen = l_sig; 

    cs_xprint("V: Signature (multipass) (truncated)", p_sig, (l_sig > 64) ? 64 : l_sig);
    printf("Length of message:  %ld\n", l_msg);
    cs_xprint("V: Message (multipass) (truncated)", p_msg, (l_msg > 16) ? 16 : l_sig);
    printf("\n");

    entry("verify_chunked:C_VerifyInit");
    err = pFunctions->C_VerifyInit(hSession, &mechanism, hKey);
    exuent("verify_chunked:C_VerifyInit", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    while (sent < l_msg) {
        CK_ULONG len = min(chunk_size, l_msg - sent);
        entry("verify_chunked:C_VerifyUpdate");
        err = pFunctions->C_VerifyUpdate(hSession, p_msg + sent, len);
        exuent("verify_chunked:C_VerifyUpdate", err);
        if (err != CKR_OK) {
            printf("[sign_chunked]: C_VerifyUpdate returned 0x%08x\n", err);
            goto cleanup;
        }
        sent += len;
    }

    entry("verify_chunked:C_VerifyFinal");
    err = pFunctions->C_VerifyFinal(hSession, p_sig, l_sig);
    exuent("verify_chunked:C_VerifyFinal", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_VerifyFinal returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("Signature verified.\n");
    printf("\n");

cleanup:
    return err;
}

int xmss_vdx_test(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long gentest,
    unsigned long signtest,
    unsigned long verifytest,
    unsigned long signverifytest,
    unsigned long lms_keytype,
    unsigned long asToken,
    unsigned long count
)
{
    int err = 0x0;
    int counter = 5;
    CK_OBJECT_HANDLE hXMSSPrivate;
    CK_OBJECT_HANDLE hXMSSPublic;
    CK_CHAR_PTR  p_msg = (CK_CHAR_PTR)"Frog blast the vent core!";
    CK_ULONG     l_msg = (CK_ULONG)strlen((const char *)p_msg);
    CK_BYTE_PTR  p_lmsg = NULL;
    CK_ULONG     l_lmsg = 10000;
    CK_BYTE      p_sig[27688]; // max sig size for largets oid
    CK_ULONG     l_sig = sizeof(p_sig);

    printf("\n\n--> Test Entry <--\n");

    printf("--> Generate XMSS key\n");
    err = xmss_int_gen_key(pFunctions, hSession, &hXMSSPrivate, asToken);
    if (err != CKR_OK) {
        printf("Failed to generate key.\n");
        goto cleanup;
    }

    printf("--> Get XMSS public key\n");
    err = xmss_int_get_pubkey(pFunctions, hSession, hXMSSPrivate, &hXMSSPublic);
    if (err != CKR_OK) {
        printf("Failed to retrieve public key.\n");
        goto cleanup;
    }

    printf("--> Sign message (C_SignInit+C_Sign)\n");
    err = xmss_sign_all(pFunctions, hSession, hXMSSPrivate, p_msg, l_msg, p_sig, &l_sig, counter);
    if (err != CKR_OK) {
        printf("Failed to generate sig (one-shot).\n");
        goto cleanup;
    }

    printf("--> Verify message (C_SignInit+C_Sign) - complete key\n");
    err = xmss_verify_all(pFunctions, hSession, hXMSSPrivate, p_msg, l_msg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (one-shot) [using original private key].\n");
    }

    // reset length values
    l_sig = sizeof(p_sig);
    p_lmsg = malloc(l_lmsg);
    if (p_lmsg == 0) {
        return CKR_HOST_MEMORY;
    }
    memset(p_lmsg, 0x42, l_lmsg);

    printf("--> Sign message (chunked)\n");
    err = xmss_sign_chunked(pFunctions, hSession, hXMSSPrivate, p_lmsg, l_lmsg, p_sig, &l_sig);
    if (err != CKR_OK) {
        printf("Failed to generate sig (multiposs).\n");
        goto cleanup;
    }

    printf("--> Verify message (Chunked) - complete key\n");
    err = xmss_verify_chunked(pFunctions, hSession, hXMSSPrivate, p_lmsg, l_lmsg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (multiposs) [using original private key].\n");
    }

cleanup:

    return err;
}
