#include "qptool2.h"
#include "hbs_defs.h"
#include "load_store.h"

static CK_BBOOL bTrue = 1;
static CK_BBOOL bFalse = 0;

#define HBS_RNG_TYPE_PSEUDO                0
#define HBS_RNG_TYPE_REAL                1

#ifndef min
#define min(a, b) ((a) < (b) ? (a) : (b))
#endif
extern void note2(char *t, int a, int b);

// EXAMPLE

int lms_int_gen_key(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, 
    CK_OBJECT_HANDLE_PTR p_keyHandle, unsigned long asToken)
{
    int                    err = 0;

    CK_UTF8CHAR_PTR     keyLabel = (CK_UTF8CHAR_PTR)"lms_key";
    CK_ULONG            l_keyLabel = (CK_ULONG)strlen((char *)keyLabel);
    CK_BBOOL            token = (asToken ? 1 : 0);

    CK_OBJECT_CLASS     keyClass = CKO_SECRET_KEY;
    CK_KEY_TYPE         keyType = CKK_GENERIC_SECRET;
    CK_ULONG            byteLength = 32; // not important, required by C_Generate VDM

    CK_MECHANISM         mechanism;
    CK_BYTE              mechParams[6]; // 1 level, so 1+1+(2*1)+2 = 6, see below

    unsigned short auxsize = 10916; // per the doc, this is the max.  high == fewer keys, low == longer times

    CK_ATTRIBUTE         keyTemplate[] =
    {
            { CKA_CLASS,         &keyClass,         sizeof(keyClass) },
            { CKA_KEY_TYPE,     &keyType,         sizeof(keyType) },
            { CKA_TOKEN,         &token,         sizeof(token) },
            { CKA_LABEL,         keyLabel,         l_keyLabel },
            { CKA_DERIVE,         &bTrue,         sizeof(bTrue) },
            { CKA_VALUE_LEN,     &byteLength,     sizeof(byteLength) },
    };

    // Prepare mechanism parameters
    unsigned char *pmp = mechParams;
    // 1 byte - type of random number generator
    *pmp = HBS_RNG_TYPE_PSEUDO;
    ++pmp;
    // 1 byte - LMS/OTS levels (number of)
    *pmp = 1;
    ++pmp;
    //        1 byte - LMS type
    *pmp = LMS_SHA256_N32_H10;
    ++pmp; 
    //        1 byte - OTS type
    *pmp = LMOTS_SHA256_N32_W4;
    ++pmp;
    //2 bytes - optional, size of auxilary data (1 - 16k)
    *pmp = ((auxsize & 0xFF00) >> 8);
    ++pmp;
    *pmp =  (auxsize & 0x00FF);
    ++pmp;

    /* mechParams, more concise, same result:
    store_int1((unsigned char)HBS_RNG_TYPE_PSEUDO, pmp); pmp += 1;
    store_int1(1, pmp); pmp += 1; // number of levels
    store_int1((unsigned char)LMS_SHA256_N32_H10, pmp); pmp += 1;
    store_int1((unsigned char)LMOTS_SHA256_N32_W4, pmp); pmp += 1;
    store_int2(auxsize, pmp); pmp += 2; // aux size
    */

    mechanism.mechanism = CKM_HBS_LMS_GENKEY;
    mechanism.pParameter = mechParams;
    mechanism.ulParameterLen = sizeof(mechParams);
    
    entry("lms_int_gen_key:C_GenerateKey");
    err = pFunctions->C_GenerateKey(hSession, 
        &mechanism,
        keyTemplate,
        sizeof(keyTemplate) / sizeof(CK_ATTRIBUTE),
        p_keyHandle);
    exuent("lms_int_gen_key:C_GenerateKey", err);
    if (err != 0x0) {
        printf("[genkey]: C_GenerateKey returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    return err;
}

// EXAMPLE

int lms_int_get_pubkey(
    CK_FUNCTION_LIST_PTR pFunctions, 
    CK_SESSION_HANDLE hSession, 
    CK_OBJECT_HANDLE baseKey, 
    CK_OBJECT_HANDLE_PTR p_pubHandle)
{
    int                 err = 0;
    CK_OBJECT_CLASS     keyClass = CKO_SECRET_KEY;
    //CK_OBJECT_CLASS     pubclass = CKO_PUBLIC_KEY;
    //CK_ULONG            keytype = CKO_VENDOR_DEFINED;

    // any symetric key type, only temporary
    CK_KEY_TYPE         keyType = CKK_AES;
    CK_ULONG             byteLength = 256 / 8;
    CK_UTF8CHAR_PTR     keyLabel = (CK_UTF8CHAR_PTR)"tmp_lms_key";
    CK_ULONG            l_keyLabel = (CK_ULONG)strlen((char *)keyLabel);
    CK_ATTRIBUTE         keyTemplate[] =
    {
            { CKA_CLASS,         &keyClass,         sizeof(keyClass) },
            { CKA_KEY_TYPE,     &keyType,         sizeof(keyType) },
            { CKA_TOKEN,         &bFalse,         sizeof(bFalse) },
            { CKA_LABEL,         keyLabel,         l_keyLabel },
            { CKA_VALUE_LEN,     &byteLength,     sizeof(byteLength) }
    };

    CK_ULONG             l_keyTemplate = sizeof(keyTemplate) / sizeof(CK_ATTRIBUTE);
    CK_MECHANISM         mechanism;

    // There are two concepts on "retrieve the public key".
    // First, "we are getting the public key so that we can verify the thing"
    // Second, "we are getting the public key, so that we can give it to a trading partner to verify the thing"
    // The difference can be a simplification on how the public key is further used, whether
    // it needs to be sent to a trading partner, or is going to be used internally in the HSM 
   
    /* this is needed for the second concept
     
    // Get the public data, which is the information needed to validate the signature internally
    CK_BYTE             p_pubData[21504]; // Maximum size of data 
    CK_ULONG             l_pubData = 21504;
    CK_ATTRIBUTE         pubDataFetch[] =
    {
        { CKA_UTI_CUSTOM_DATA, p_pubData, l_pubData }
    };
    CK_ULONG             l_pubDataFetch = sizeof(pubDataFetch) / sizeof(CK_ATTRIBUTE);
    
    CK_ATTRIBUTE         pubkeyTemplate[] =
    {
        { CKA_CLASS,        &pubclass, sizeof(pubclass) },
        { CKA_KEY_TYPE,        &keytype, sizeof(keytype) },
        { CKA_TOKEN,        &bFalse, sizeof(bFalse) },
        { CKA_VALUE,        p_pubData, 0},
        { CKA_VERIFY,        &bTrue, sizeof(bTrue) },
    };
    CK_ULONG             l_pubkeyTemplate = sizeof(pubkeyTemplate) / sizeof(CK_ATTRIBUTE);
    CK_ULONG ckavalue = 3; // offset into pubkeyTemplate
    */
    
    // This is the first concept
    mechanism.mechanism = CKM_HBS_LMS_GET_PUBKEY;
    mechanism.pParameter = NULL;
    mechanism.ulParameterLen = 0;

    // Create new key, value is random, but pubdata has public key
    entry("lms_int_get_pubkey: C_DeriveKey");
    err = pFunctions->C_DeriveKey(hSession, &mechanism,
        baseKey,
        keyTemplate,
        l_keyTemplate,
        p_pubHandle);
    exuent("lms_int_get_pubkey: C_DeriveKey", err);
    if (err != 0x0) {
        printf("[get_pub_key]: C_DeriveKey returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    // nothing to clean up here
    return err;
}

// ----------------------------------------------------
// EXAMPLE Single Data Block

int lms_sign_all(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG_PTR            ll_sig,
    int counter)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };
    CK_ULONG                 sig_len = (CK_ULONG)4096;

    if (*ll_sig < sig_len) {
        if (*ll_sig == 0) {
            *ll_sig = sig_len;
            return 0;
        }
        return CKR_BUFFER_TOO_SMALL;
    }
    if (l_msg > 250 * 1024) {
        return 0x111L; // ERROR_BUFFER_OVERFLOW;
    }

    printf("\nLength of message:  %ld\n", l_msg);
    cs_xprint("S: Message (one-shot) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    CK_MECHANISM             mechanism;
    mechanism.mechanism = CKM_HBS_LMS_SIGN;

    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    for (int i = 0; i < counter; i++) {
        note2("Sign loop count", i + 1, counter);
        entry("sign_all:C_SignInit");
        err = pFunctions->C_SignInit(hSession, &mechanism, hKey);
        exuent("sign_all:C_SignInit", err);
        if (err != CKR_OK) {
            printf("[sign_all]: C_SignInit returned 0x%08x\n", err);
            goto cleanup;
        }

        entry("sign_all:C_Sign");
        err = pFunctions->C_Sign(hSession, p_msg, l_msg, p_sig, ll_sig);
        exuent("sign_all:C_Sign", err);
        if (err != CKR_OK) {
            printf("[sign_all]: C_Sign returned 0x%08x\n", err);
            goto cleanup;
        }

        cs_xprint("S: Signature (one-shot) (truncated)", p_sig, (*ll_sig > 64) ? 64 : *ll_sig);
        printf("\n");
    }

cleanup:
    // nothing to clean up here
    return err;
}

int lms_verify_all(
    CK_FUNCTION_LIST_PTR pFunctions,
        CK_SESSION_HANDLE hSession,
        CK_OBJECT_HANDLE         hKey,
        CK_BYTE_PTR             p_msg,
        CK_ULONG                l_msg,
        CK_BYTE_PTR                p_sig,
        CK_ULONG                l_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };

    CK_MECHANISM             mechanism;
    mechanism.mechanism = CKM_HBS_LMS_VERIFY;
    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    if (l_msg > 250 * 1024) {
        return 0x111L; // ERROR_BUFFER_OVERFLOW;
    }

    cs_xprint("V: Signature (one-shot) (truncated)", p_sig, (l_sig > 64) ? 64 : l_sig);
    printf("Length of message:  %ld\n", l_msg);
    cs_xprint("V: Message (one-shot) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    entry("verify_all:C_VerifyInit");
    err = pFunctions->C_VerifyInit(hSession, &mechanism, hKey);
    exuent("verify_all:C_VerifyInit", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    entry("verify_all:C_Verify");
    err = pFunctions->C_Verify(hSession, p_msg, l_msg, p_sig, l_sig);
    exuent("verify_all:C_Verify", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_Verify returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("Signature verified.\n");
    printf("\n");

cleanup:
    return err;
}


int lms_sign_chunked(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG_PTR            ll_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };
    unsigned long            sig_len = 4096; 
    CK_ULONG sent = 0, chunk_size = 4096;

    CK_MECHANISM             mechanism;

    if (*ll_sig < sig_len) {
        if (*ll_sig == 0) {
            *ll_sig = sig_len;
            return 0;
        }
        return CKR_BUFFER_TOO_SMALL;
    }

    printf("\nLength of message:  %ld\n", l_msg);
    cs_xprint("V: Message (multipass) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    mechanism.mechanism = CKM_HBS_LMS_SIGN;
    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    entry("sign_chunked:C_SignInit");
    err = pFunctions->C_SignInit(hSession, &mechanism, hKey);
    exuent("sign_chunked:C_SignInit", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_SignInit returned 0x%08x\n", err);
        goto cleanup;
    }

    while (sent < l_msg) {
        CK_ULONG len = min(chunk_size, l_msg - sent);
        entry("sign_chunked:C_SignUpdate");
        err = pFunctions->C_SignUpdate(hSession, p_msg + sent, len);
        exuent("sign_chunked:C_SignUpdate", err);
        if (err != CKR_OK) {
            printf("[sign_chunked]: C_SignUpdate returned 0x%08x\n", err);
            goto cleanup;
        }
        sent += len;
    }

    entry("sign_chunked:C_SignFinal");
    err = pFunctions->C_SignFinal(hSession, p_sig, ll_sig);
    exuent("sign_chunked:C_SignFinal", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_SignFinal returned 0x%08x\n", err);
        goto cleanup;
    }

    cs_xprint("S: Signature (multipass) (truncated)", p_sig, (*ll_sig > 64) ? 64 : *ll_sig);
    printf("\n");

cleanup:
    return err;
}

int lms_verify_chunked(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG                l_sig)
{
    int                     err = 0;
    CK_ULONG sent = 0, chunk_size = 4096;

    // SIGNATURE VERIFICATION
    CK_MECHANISM             mechanism;

    mechanism.mechanism = CKM_HBS_LMS_VERIFY;
    mechanism.pParameter = p_sig; // or p_sig? oasis doc has NULL/0, HBS doc has p_sig/l_sig
    mechanism.ulParameterLen = l_sig; // or l_sig?

    cs_xprint("V: Signature (multipass) (truncated)", p_sig, (l_sig > 64) ? 64 : l_sig);
    printf("Length of message:  %ld\n", l_msg);
    cs_xprint("V: Message (multipass) (truncated)", p_msg, (l_msg > 16) ? 16 : l_sig);
    printf("\n");

    entry("verify_chunked:C_VerifyInit");
    err = pFunctions->C_VerifyInit(hSession, &mechanism, hKey);
    exuent("verify_chunked:C_VerifyInit", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    while (sent < l_msg) {
        CK_ULONG len = min(chunk_size, l_msg - sent);
        entry("verify_chunked:C_VerifyUpdate");
        err = pFunctions->C_VerifyUpdate(hSession, p_msg + sent, len);
        exuent("verify_chunked:C_VerifyUpdate", err);
        if (err != CKR_OK) {
            printf("[sign_chunked]: C_VerifyUpdate returned 0x%08x\n", err);
            goto cleanup;
        }
        sent += len;
    }

    entry("verify_chunked:C_VerifyFinal");
    err = pFunctions->C_VerifyFinal(hSession, p_sig, l_sig);
    exuent("verify_chunked:C_VerifyFinal", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_VerifyFinal returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("Signature verified.\n");
    printf("\n");

cleanup:
    return err;
}

/*
 * Refactored to take into account the splitting out of "LMS" artifact vs "HSS" artifact
 * manipulation.
 *
 * You can use test_case_lms for LMS,
 * You can use test_case_hss with an LMS tree (single height tree), but the result artifacts
 * are _HSS_ artifacts, not _LMS_ artifacts.  See the various specifications (RFC, SP800-208, ...)
 * to understand the differences.
 */
int lms_vdx_test(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long gentest,
    unsigned long signtest,
    unsigned long verifytest,
    unsigned long signverifytest,
    unsigned long lms_keytype,
    unsigned long asToken,
    unsigned long count
)
{
    int err = 0x0;
    int counter = 5;
    CK_OBJECT_HANDLE hLMSPrivate;
    CK_OBJECT_HANDLE hLMSPublic;
    CK_BYTE_PTR         p_msg = (CK_BYTE_PTR)"Frog blast the vent core!";
    CK_ULONG            l_msg = (CK_ULONG)strlen((char *)p_msg);
    CK_BYTE_PTR         p_lmsg = NULL;
    CK_ULONG            l_lmsg = 10000;
    CK_BYTE             p_sig[4096];
    CK_ULONG            l_sig = sizeof(p_sig);

    printf("\n\n--> Test Entry <--\n");

    printf("--> Generate LMS key\n");
    err = lms_int_gen_key(pFunctions, hSession, &hLMSPrivate, asToken);
    if (err != CKR_OK) {
        printf("Failed to generate key.\n");
        goto cleanup;
    }

    printf("--> Get LMS public key\n");
    err = lms_int_get_pubkey(pFunctions, hSession, hLMSPrivate, &hLMSPublic);
    if (err != CKR_OK) {
        printf("Failed to retrieve public key.\n");
        goto cleanup;
    }

    printf("--> Sign message (C_SignInit+C_Sign)\n");
    err = lms_sign_all(pFunctions, hSession, hLMSPrivate, p_msg, l_msg, p_sig, &l_sig, counter);
    if (err != CKR_OK) {
        printf("Failed to generate sig (one-shot).\n");
        goto cleanup;
    }

    printf("--> Verify message (C_SignInit+C_Sign) - complete key\n");
    err = lms_verify_all(pFunctions, hSession, hLMSPrivate, p_msg, l_msg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (one-shot) [using original private key].\n");
    }

    printf("--> Verify message (C_SignInit+C_Sign) - retrieved public key\n");
    err = lms_verify_all(pFunctions, hSession, hLMSPublic, p_msg, l_msg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (one-shot) [using retrieved public key].\n");
    }

    // reset length values
    l_sig = sizeof(p_sig);
    p_lmsg = malloc(l_lmsg);
    if (p_lmsg == 0) {
        return CKR_HOST_MEMORY;
    }
    memset(p_lmsg, 0x42, l_lmsg);

    printf("--> Sign message (chunked)\n");
    err = lms_sign_chunked(pFunctions, hSession, hLMSPrivate, p_lmsg, l_lmsg, p_sig, &l_sig);
    if (err != CKR_OK) {
        printf("Failed to generate sig (multiposs).\n");
        goto cleanup;
    }

    printf("--> Verify message (Chunked) - complete key\n");
    err = lms_verify_chunked(pFunctions, hSession, hLMSPrivate, p_lmsg, l_lmsg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (multiposs) [using original private key].\n");
    }

    printf("--> Verify message (Chunked) - retrieved public key\n");
    err = lms_verify_chunked(pFunctions, hSession, hLMSPublic, p_lmsg, l_lmsg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (multiposs) [using retrieved public key].\n");
    }

cleanup:

    return err;
}
