#include "qptool2.h"
#include "hbs_defs.h"
#include "load_store.h"

static CK_BBOOL bTrue = 1;
static CK_BBOOL bFalse = 0;

#define HBS_MAX_HSS_LEVELS                8

unsigned long hss_mechanism_len_by_levels(unsigned long levels) {
    if ((levels < 1) || (levels > HBS_MAX_HSS_LEVELS)) return 0;

    unsigned long i = 1; // space (byte) for RNG type
    i += 1; // space (byte) where levels will be written
    i += (levels * 2); // how many levels, space 2 bytes per each
    i += 2; // space (short) for the aux size
    return i;
}

#ifndef min
#define min(a, b) ((a) < (b) ? (a) : (b))
#endif


// EXAMPLE

int hss_int_gen_key(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, CK_OBJECT_HANDLE_PTR p_keyHandle, unsigned long asToken)
{
    int                    err = 0;

    CK_UTF8CHAR_PTR     keyLabel = (CK_UTF8CHAR_PTR)"lms_key";
    CK_ULONG            l_keyLabel = (CK_ULONG)strlen((char *)keyLabel);
    CK_BBOOL             token = (asToken ? 1 : 0);

    CK_OBJECT_CLASS     keyClass = CKO_SECRET_KEY;
    CK_KEY_TYPE         keyType = CKK_GENERIC_SECRET;
    CK_ULONG             byteLength = 32; // not important, required by C_Generate VDM

    CK_ULONG            levels = 3;
    CK_ULONG            i = 0;
    CK_CHAR             lms_mech[] = { LMS_SHA256_N32_H5, LMS_SHA256_N32_H5, LMS_SHA256_N32_H10 };
    CK_CHAR             lmots_mech[] = { LMOTS_SHA256_N32_W4, LMOTS_SHA256_N32_W4, LMOTS_SHA256_N32_W4 };
    CK_MECHANISM         mechanism;
    CK_BYTE                mechParams[20]; // max mechanism parameter length for 8 levels 1+1+(2*8)+2 = 20 

    unsigned short auxsize = 10916; // per the doc

    CK_ATTRIBUTE         keyTemplate[] =
    {
            { CKA_CLASS,         &keyClass,         sizeof(keyClass) },
            { CKA_KEY_TYPE,     &keyType,         sizeof(keyType) },
            { CKA_TOKEN,         &token,         sizeof(token) },
            { CKA_LABEL,         keyLabel,         l_keyLabel },
            { CKA_DERIVE,         &bTrue,         sizeof(bTrue) },
            { CKA_VALUE_LEN,     &byteLength,     sizeof(byteLength) },
    };

    // Prepare mechanism parameters
    unsigned char *pmp = mechParams;
    // 1 byte - type of random number generator
    *pmp = HBS_RNG_TYPE_PSEUDO;
    ++pmp;
    // 1 byte - LMS/OTS levels
    *pmp = 1;
    ++pmp;

    for (i = 0; i < levels; i++) {
        //        1 byte - LMS type
        *pmp = lms_mech[i];
        ++pmp;
        //        1 byte - OTS type
        *pmp = lmots_mech[i];
        ++pmp;
    }

    //2 bytes - optional, size of auxilary data (1 - 16k)
    *pmp = ((auxsize & 0xFF00) >> 8);
    ++pmp;
    *pmp =  (auxsize & 0x00FF);
    ++pmp;

    /* as above, more concise
    store_int1((unsigned char)HBS_RNG_TYPE_PSEUDO, pmp++);
    store_int1(1, pmp++);
    for (i = 0; i < levels; i++) {
        store_int1(lms_mech[i], pmp++); 
        store_int1(lmots_mech[i], pmp++); 
    }
    store_int2(auxsize, pmp); pmp += 2; // aux size
    */

    mechanism.mechanism = HBS_MECH_HSS_GENKEY;
    mechanism.pParameter = mechParams;
    mechanism.ulParameterLen = hss_mechanism_len_by_levels(levels);
    
    entry("hss_int_gen_key:C_GenerateKey");
    err = pFunctions->C_GenerateKey(hSession, 
        &mechanism,
        keyTemplate,
        sizeof(keyTemplate) / sizeof(CK_ATTRIBUTE),
        p_keyHandle);
    exuent("hss_int_gen_key:C_GenerateKey", err);
    if (err != 0x0) {
        printf("[genkey]: C_GenerateKey returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    return err;
}

// EXAMPLE

int hss_int_get_pubkey(
    CK_FUNCTION_LIST_PTR pFunctions, 
    CK_SESSION_HANDLE hSession, 
    CK_OBJECT_HANDLE baseKey, 
    CK_OBJECT_HANDLE_PTR p_pubHandle)
{
    int                 err = 0;
    CK_OBJECT_CLASS     keyClass = CKO_SECRET_KEY;
    // CK_OBJECT_CLASS     pubclass = CKO_PUBLIC_KEY;
    // CK_ULONG            keytype = CKO_VENDOR_DEFINED;

    // any symetric key type, only temporary
    CK_KEY_TYPE         keyType = CKK_AES;
    CK_ULONG             byteLength = 256 / 8;
    CK_UTF8CHAR_PTR     keyLabel = (CK_UTF8CHAR_PTR)"tmp_lms_key";
    CK_ULONG            l_keyLabel = (CK_ULONG)strlen((char *)keyLabel);
    CK_ATTRIBUTE         keyTemplate[] =
    {
            { CKA_CLASS,         &keyClass,         sizeof(keyClass) },
            { CKA_KEY_TYPE,     &keyType,         sizeof(keyType) },
            { CKA_TOKEN,         &bFalse,         sizeof(bFalse) },
            { CKA_LABEL,         keyLabel,         l_keyLabel },
            { CKA_VALUE_LEN,     &byteLength,     sizeof(byteLength) }
    };

    CK_ULONG             l_keyTemplate = sizeof(keyTemplate) / sizeof(CK_ATTRIBUTE);
    CK_MECHANISM         mechanism;

    mechanism.mechanism = HBS_MECH_HSS_GET_PUBKEY;
    mechanism.pParameter = NULL;
    mechanism.ulParameterLen = 0;

    // Create new key, value is random, but pubdata has public key
    entry("hss_int_get_pubkey: C_DeriveKey");
    err = pFunctions->C_DeriveKey(hSession, &mechanism,
        baseKey,
        keyTemplate,
        l_keyTemplate,
        p_pubHandle);
    exuent("hss_int_get_pubkey: C_DeriveKey", err);
    if (err != 0x0) {
        printf("[get_pub_key]: C_DeriveKey returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    // nothing to clean up here
    return err;
}

// ----------------------------------------------------
// EXAMPLE Single Data Block

int hss_sign_all(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG_PTR            ll_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };
    CK_ULONG                 sig_len = (CK_ULONG)4096;

    if (*ll_sig < sig_len) {
        if (*ll_sig == 0) {
            *ll_sig = sig_len;
            return 0;
        }
        return CKR_BUFFER_TOO_SMALL;
    }
    if (l_msg > 250 * 1024) {
        return 0x111L; // ERROR_BUFFER_OVERFLOW;
    }

    printf("\nLength of message:  %ld\n", l_msg);
    cs_xprint("S: Message (one-shot) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    CK_MECHANISM             mechanism;
    mechanism.mechanism = HBS_MECH_HSS_SIGN;

    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    entry("sign_all:C_SignInit");
    err = pFunctions->C_SignInit(hSession, &mechanism, hKey);
    exuent("sign_all:C_SignInit", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_SignInit returned 0x%08x\n", err);
        goto cleanup;
    }

    entry("sign_all:C_Sign");
    err = pFunctions->C_Sign(hSession, p_msg, l_msg, p_sig, ll_sig); 
    exuent("sign_all:C_Sign", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_Sign returned 0x%08x\n", err);
        goto cleanup;
    }
    
    cs_xprint("S: Signature (one-shot) (truncated)", p_sig, (*ll_sig > 64) ? 64 : *ll_sig);
    printf("\n");

cleanup:
    // nothing to clean up here
    return err;
}

int hss_verify_all(
    CK_FUNCTION_LIST_PTR pFunctions,
        CK_SESSION_HANDLE hSession,
        CK_OBJECT_HANDLE         hKey,
        CK_BYTE_PTR             p_msg,
        CK_ULONG                l_msg,
        CK_BYTE_PTR                p_sig,
        CK_ULONG                l_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };

    CK_MECHANISM             mechanism;
    mechanism.mechanism = HBS_MECH_HSS_VERIFY;
    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    if (l_msg > 250 * 1024) {
        return 0x111L; // ERROR_BUFFER_OVERFLOW;
    }

    cs_xprint("V: Signature (one-shot) (truncated)", p_sig, (l_sig > 64) ? 64 : l_sig);
    printf("Length of message:  %ld\n", l_msg);
    cs_xprint("V: Message (one-shot) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    entry("verify_all:C_VerifyInit");
    err = pFunctions->C_VerifyInit(hSession, &mechanism, hKey);
    exuent("verify_all:C_VerifyInit", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    entry("verify_all:C_Verify");
    err = pFunctions->C_Verify(hSession, p_msg, l_msg, p_sig, l_sig);
    exuent("verify_all:C_Verify", err);
    if (err != CKR_OK) {
        printf("[sign_all]: C_Verify returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("Signature verified.\n");
    printf("\n");

cleanup:
    return err;
}

int hss_sign_chunked(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG_PTR            ll_sig)
{
    int                     err = 0;
    unsigned char             params[] = { 0 };
    unsigned long            sig_len = 4096; 
    CK_ULONG sent = 0, chunk_size = 4096;

    CK_MECHANISM             mechanism;

    if (*ll_sig < sig_len) {
        if (*ll_sig == 0) {
            *ll_sig = sig_len;
            return 0;
        }
        return CKR_BUFFER_TOO_SMALL;
    }

    printf("\nLength of message:  %ld\n", l_msg);
    cs_xprint("V: Message (multipass) (truncated)", p_msg, (l_msg > 16) ? 16 : l_msg);
    printf("\n");

    mechanism.mechanism = HBS_MECH_HSS_SIGN;
    mechanism.pParameter = params;
    mechanism.ulParameterLen = 0;

    entry("sign_chunked:C_SignInit");
    err = pFunctions->C_SignInit(hSession, &mechanism, hKey);
    exuent("sign_chunked:C_SignInit", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_SignInit returned 0x%08x\n", err);
        goto cleanup;
    }

    while (sent < l_msg) {
        CK_ULONG len = min(chunk_size, l_msg - sent);
        entry("sign_chunked:C_SignUpdate");
        err = pFunctions->C_SignUpdate(hSession, p_msg + sent, len);
        exuent("sign_chunked:C_SignUpdate", err);
        if (err != CKR_OK) {
            printf("[sign_chunked]: C_SignUpdate returned 0x%08x\n", err);
            goto cleanup;
        }
        sent += len;
    }

    entry("sign_chunked:C_SignFinal");
    err = pFunctions->C_SignFinal(hSession, p_sig, ll_sig);
    exuent("sign_chunked:C_SignFinal", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_SignFinal returned 0x%08x\n", err);
        goto cleanup;
    }

    cs_xprint("S: Signature (multipass) (truncated)", p_sig, (*ll_sig > 64) ? 64 : *ll_sig);
    printf("\n");

cleanup:
    return err;
}

int hss_verify_chunked(
    CK_FUNCTION_LIST_PTR pFunctions,
    CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE         hKey,
    CK_BYTE_PTR             p_msg,
    CK_ULONG                l_msg,
    CK_BYTE_PTR                p_sig,
    CK_ULONG                l_sig)
{
    int                     err = 0;
    CK_ULONG sent = 0, chunk_size = 4096;

    // SIGNATURE VERIFICATION
    CK_MECHANISM             mechanism;

    mechanism.mechanism = HBS_MECH_HSS_VERIFY;
    mechanism.pParameter = p_sig; // or p_sig? oasis doc has NULL/0, HBS doc has p_sig/l_sig
    mechanism.ulParameterLen = l_sig; // or l_sig?

    cs_xprint("V: Signature (multipass) (truncated)", p_sig, (l_sig > 64) ? 64 : l_sig);
    printf("Length of message:  %ld\n", l_msg);
    cs_xprint("V: Message (multipass) (truncated)", p_msg, (l_msg > 16) ? 16 : l_sig);
    printf("\n");

    entry("verify_chunked:C_VerifyInit");
    err = pFunctions->C_VerifyInit(hSession, &mechanism, hKey);
    exuent("verify_chunked:C_VerifyInit", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    while (sent < l_msg) {
        CK_ULONG len = min(chunk_size, l_msg - sent);
        entry("verify_chunked:C_VerifyUpdate");
        err = pFunctions->C_VerifyUpdate(hSession, p_msg + sent, len);
        exuent("verify_chunked:C_VerifyUpdate", err);
        if (err != CKR_OK) {
            printf("[sign_chunked]: C_VerifyUpdate returned 0x%08x\n", err);
            goto cleanup;
        }
        sent += len;
    }

    entry("verify_chunked:C_VerifyFinal");
    err = pFunctions->C_VerifyFinal(hSession, p_sig, l_sig);
    exuent("verify_chunked:C_VerifyFinal", err);
    if (err != CKR_OK) {
        printf("[sign_chunked]: C_VerifyFinal returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("Signature verified.\n");
    printf("\n");

cleanup:
    return err;
}

int hss_vdx_test(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long gentest,
    unsigned long signtest,
    unsigned long verifytest,
    unsigned long signverifytest,
    unsigned long lms_keytype,
    unsigned long asToken,
    unsigned long count
)
{
    int err = 0x0;
    CK_OBJECT_HANDLE hLMSPrivate;
    CK_OBJECT_HANDLE hLMSPublic;
    CK_BYTE_PTR         p_msg = (unsigned char *)"Frog blast the vent core!";
    CK_ULONG            l_msg = (CK_ULONG)strlen((char *)p_msg);
    CK_BYTE_PTR         p_lmsg = NULL;
    CK_ULONG            l_lmsg = 10000;
    CK_BYTE             p_sig[4096];
    CK_ULONG            l_sig = sizeof(p_sig);

    printf("\n\n--> Test Entry <--\n");

    printf("--> Generate LMS key\n");
    err = hss_int_gen_key(pFunctions, hSession, &hLMSPrivate, asToken);
    if (err != CKR_OK) {
        printf("Failed to generate key.\n");
        goto cleanup;
    }

    printf("--> Get LMS public key\n");
    err = hss_int_get_pubkey(pFunctions, hSession, hLMSPrivate, &hLMSPublic);
    if (err != CKR_OK) {
        printf("Failed to retrieve public key.\n");
        goto cleanup;
    }

    printf("--> Sign message (C_SignInit+C_Sign)\n");
    err = hss_sign_all(pFunctions, hSession, hLMSPrivate, p_msg, l_msg, p_sig, &l_sig);
    if (err != CKR_OK) {
        printf("Failed to generate sig (one-shot).\n");
        goto cleanup;
    }

    printf("--> Verify message (C_SignInit+C_Sign) - complete key\n");
    err = hss_verify_all(pFunctions, hSession, hLMSPrivate, p_msg, l_msg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (one-shot) [using original private key].\n");
    }

    printf("--> Verify message (C_SignInit+C_Sign) - retrieved public key\n");
    err = hss_verify_all(pFunctions, hSession, hLMSPublic, p_msg, l_msg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (one-shot) [using retrieved public key].\n");
    }

    // reset length values
    l_sig = sizeof(p_sig);
    p_lmsg = malloc(l_lmsg);
    if (p_lmsg == 0) {
        return CKR_HOST_MEMORY;
    }
    memset(p_lmsg, 0x42, l_lmsg);

    printf("--> Sign message (chunked)\n");
    err = hss_sign_chunked(pFunctions, hSession, hLMSPrivate, p_lmsg, l_lmsg, p_sig, &l_sig);
    if (err != CKR_OK) {
        printf("Failed to generate sig (multiposs).\n");
        goto cleanup;
    }

    printf("--> Verify message (Chunked) - complete key\n");
    err = hss_verify_chunked(pFunctions, hSession, hLMSPrivate, p_lmsg, l_lmsg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (multiposs) [using original private key].\n");
    }

    printf("--> Verify message (Chunked) - retrieved public key\n");
    err = hss_verify_chunked(pFunctions, hSession, hLMSPublic, p_lmsg, l_lmsg, p_sig, l_sig);
    if (err != CKR_OK) {
        printf("Failed to verify sig (multiposs) [using retrieved public key].\n");
    }

cleanup:

    return err;
}
