#include "qptool2.h"
#include "CxiKeyAttributes.h"
#include "MLKEM_KeyGen.h"
#include "MLKEM_Encap.h"
#include "MLKEM_Decap.h"
#include "MLKEM_Response.h"

#include "qsr2mux.h"
#include "template_util.h"

#include "load_store.h"


static CK_BBOOL bTrue = 1;
static CK_BBOOL bFalse = 0;

#define ML_KEM_512 1
#define ML_KEM_768 2
#define ML_KEM_1024 3


#ifndef min
#define min(a, b) ((a) < (b) ? (a) : (b))
#endif

#define DESTROY_IF(keha) \
    if (keha != CK_INVALID_HANDLE) pFunctions->C_DestroyObject(hSession, keha); \
    keha = CK_INVALID_HANDLE;


static int mlkem_genkey(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long  l_plabel, unsigned char *p_plabel,
    unsigned long  l_pid, unsigned char *p_pid,
    unsigned long  l_slabel, unsigned char *p_slabel,
    unsigned long  l_sid, unsigned char *p_sid,
    unsigned long mlkem_keytype, unsigned long asToken,
    CK_OBJECT_HANDLE_PTR p_PubkeyHandle,
    CK_OBJECT_HANDLE_PTR p_PrvkeyHandle,
    CK_BYTE_PTR *p_raw_public_key,
    CK_ULONG    *l_raw_public_key
)
{
    int err = 0x0;

    CK_MECHANISM         mechanism;
    memset(&mechanism, 0, sizeof(mechanism));
    unsigned int len = 0;
    unsigned char *p_mech = NULL;

    MLKEM_KEYGEN keygen;

    CK_BBOOL             token = (asToken ? 1 : 0);
    unsigned long t = 0;

    CK_KEY_TYPE         keyType = CKK_VENDOR_UTI_MLKEM_KT;
    // CK_KEY_TYPE         secKeyType = CKK_GENERIC_SECRET;
    // CK_OBJECT_CLASS     secKeyClass = CKO_SECRET_KEY;

    CK_OBJECT_CLASS     prvkeyClass = CKO_PRIVATE_KEY;
    CK_UTF8CHAR_PTR     secLabel = (CK_UTF8CHAR_PTR)p_slabel;
    CK_ULONG            l_secLabel = (CK_ULONG)l_slabel;
    CK_UTF8CHAR_PTR     secId = (CK_UTF8CHAR_PTR)p_sid;
    CK_ULONG            l_secId = (CK_ULONG)l_sid;

    CK_OBJECT_CLASS     pubkeyClass = CKO_PUBLIC_KEY;
    CK_UTF8CHAR_PTR     pubLabel = (CK_UTF8CHAR_PTR)p_plabel;
    CK_ULONG            l_pubLabel = (CK_ULONG)l_plabel;
    CK_UTF8CHAR_PTR     pubId = (CK_UTF8CHAR_PTR)p_pid;
    CK_ULONG            l_pubId = (CK_ULONG)l_pid;

    CK_ATTRIBUTE_PTR     pubKeyTemplate = NULL;
    unsigned long pubats = 0;

    CK_ATTRIBUTE_PTR     prvKeyTemplate = NULL;
    unsigned long prvats = 0;


    pubats = 5;
    if (pubLabel != NULL) pubats++;
    if (pubId != NULL) pubats++;
    t = 0;
    pubKeyTemplate = allocate_template(pubats);
    if (pubKeyTemplate == NULL)  return E_NO_MEM;

    // if these are changed, verify the pubats value at initialization
    set_attribute(pubKeyTemplate, t++, CKA_CLASS, &pubkeyClass, sizeof(pubkeyClass));
    set_attribute(pubKeyTemplate, t++, CKA_KEY_TYPE, &keyType, sizeof(keyType));
    set_attribute(pubKeyTemplate, t++, CKA_TOKEN, &token, sizeof(token));
    set_attribute(pubKeyTemplate, t++, CKA_DERIVE, &bTrue, sizeof(bTrue));
    set_attribute(pubKeyTemplate, t++, CKA_EXTRACTABLE, &bTrue, sizeof(bTrue));
    if (pubLabel != NULL)     set_attribute(pubKeyTemplate, t++, CKA_LABEL, pubLabel, l_pubLabel);
    if (pubId != NULL)        set_attribute(pubKeyTemplate, t++, CKA_ID, pubId, l_pubId);

    prvats = 6;
    if (secLabel != NULL) prvats++;
    if (secId != NULL) prvats++;
    t = 0;
    prvKeyTemplate = allocate_template(prvats);
    if (prvKeyTemplate == NULL) {
        return E_NO_MEM;
    }

    // if these are changed, verify the prvats value at initialization
    set_attribute(prvKeyTemplate, t++, CKA_CLASS, &prvkeyClass, sizeof(prvkeyClass));
    set_attribute(prvKeyTemplate, t++, CKA_KEY_TYPE, &keyType, sizeof(keyType));
    set_attribute(prvKeyTemplate, t++, CKA_TOKEN, &token, sizeof(token));
    set_attribute(prvKeyTemplate, t++, CKA_DERIVE, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_PRIVATE, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_SENSITIVE, &bTrue, sizeof(bTrue));
    if (secLabel != NULL) set_attribute(prvKeyTemplate, t++, CKA_LABEL, secLabel, l_secLabel);
    if (secId != NULL) set_attribute(prvKeyTemplate, t++, CKA_ID, secId, l_secId);

    entry("qptool2:mlkem_genkey");

    // Prepare mechanism parameters
    memset(&keygen, 0, sizeof(MLKEM_KEYGEN));
    keygen.type = mlkem_keytype;
    keygen.flags = 1; // 1 is pseudo, 0 is true

    // pack allocates
    err = mlkem_keygen_pack(&keygen, &len, &p_mech);
    if (err) goto cleanup;

    mechanism.mechanism = CKM_MECH_MLKEM_GENKEY;
    mechanism.ulParameterLen = len;
    mechanism.pParameter = p_mech;

    err = pFunctions->C_GenerateKeyPair(hSession,
        &mechanism,
        get_template(pubKeyTemplate),
        get_template_len(pubKeyTemplate),
        get_template(prvKeyTemplate),
        get_template_len(prvKeyTemplate),
        p_PubkeyHandle, p_PrvkeyHandle);
    if (err != 0x0) {
        printf("[genkey]: C_GenerateKey returned 0x%08x\n", err);
        goto cleanup;
    }

    printf("mlkem keygen private key id %ld public key id %ld\n", *p_PrvkeyHandle, *p_PubkeyHandle);

    // now get the raw public key
    err = util_get_attribute_value(pFunctions, hSession, *p_PrvkeyHandle, CKA_UTI_CUSTOM_DATA, p_raw_public_key, l_raw_public_key);
    if (err) goto cleanup;

cleanup:
    if (p_mech != NULL) { free(p_mech); }
    if (pubKeyTemplate != NULL) free_template(pubKeyTemplate);
    if (prvKeyTemplate != NULL) free_template(prvKeyTemplate);
    exuent("qptool2:mlkem_genkey", err);
    return err;
}


static int mlkem_findkey(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long  keyClass,
    unsigned long  l_plabel, unsigned char *p_plabel,
    unsigned long  l_pid, unsigned char *p_pid,
    unsigned long  l_slabel, unsigned char *p_slabel,
    unsigned long  l_sid, unsigned char *p_sid,
    unsigned long *l_retcount,
    CK_OBJECT_HANDLE_PTR p_KeyHandle
)
{
    int err = 0x0;

    entry("mlkem_findkey");

    CK_MECHANISM         mechanism;
    memset(&mechanism, 0, sizeof(mechanism));
    unsigned long t = 0;
    unsigned long maxObjCount = *l_retcount;

    CK_ATTRIBUTE_PTR     findTemplate = NULL;
    unsigned long ats = 0;
    if (p_plabel != NULL) ats++;
    if (p_pid != NULL) ats++;
    if (p_slabel != NULL) ats++;
    if (p_sid != NULL) ats++;
    if ((ats < 1) || (ats > 2)) {
        return E_TMPL_INVALID;
    }

    findTemplate = allocate_template(ats);
    if (findTemplate == NULL) {
        return E_NO_MEM;
    }
    t = 0;
    if (p_plabel != NULL) set_attribute(findTemplate, t++, CKA_LABEL, p_plabel, l_plabel);
    if (p_pid != NULL) set_attribute(findTemplate, t++, CKA_ID, p_pid, l_pid);
    if (p_slabel != NULL) set_attribute(findTemplate, t++, CKA_LABEL, p_slabel, l_slabel);
    if (p_sid != NULL) set_attribute(findTemplate, t++, CKA_ID, p_sid, l_sid);

    err = pFunctions->C_FindObjectsInit(hSession, get_template(findTemplate), get_template_len(findTemplate));
    if (err != 0x0) goto cleanup;

    err = pFunctions->C_FindObjects(hSession, p_KeyHandle, maxObjCount, l_retcount);
    if (err != 0x0) goto cleanup;

    err = pFunctions->C_FindObjectsFinal(hSession);
    if (err != 0x0) goto cleanup;

cleanup:
    if (findTemplate != NULL) free_template(findTemplate);
    exuent("mlkem_findkey", err);
    return err;
}

static int DestroyObjects(
    CK_FUNCTION_LIST_PTR  pFunctions,
    CK_SESSION_HANDLE     hSession,
    CK_ULONG  keyclass,
    CK_ULONG    l_label,
    CK_BYTE_PTR p_label,
    CK_ULONG    l_id,
    CK_BYTE_PTR p_id
)
{
    int err;
    unsigned long retcount = 1;
    
    CK_OBJECT_HANDLE keyHandle = CK_INVALID_HANDLE;

    CK_OBJECT_CLASS     keyClass = keyclass;

    if (keyclass == CKO_PUBLIC_KEY) {
        err = mlkem_findkey(pFunctions, hSession,
            keyClass, l_label, p_label, l_id, p_id, 0, NULL, 0, NULL,
            &retcount, &keyHandle);
    }
    else {
        err = mlkem_findkey(pFunctions, hSession,
            keyClass, 0, NULL, 0, NULL, l_label, p_label, l_id, p_id,
            &retcount, &keyHandle);
    }
    if (err) return 0;
    if (retcount == 0) return 0;
    if (keyHandle != CK_INVALID_HANDLE) {
        pFunctions->C_DestroyObject(hSession, keyHandle);
    }
    return 1;
}


int mlkem_encap(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_ULONG    keytype,
    CK_OBJECT_HANDLE prvKeyHandle,
    CK_ULONG    l_pubkey,
    CK_BYTE_PTR p_pubkey,
    CK_OBJECT_HANDLE_PTR derivedKey,
    CK_ULONG    *pulCyphertext,
    CK_BYTE_PTR *p_cyphertext
)
{
    int err = 0;
    entry("mlkem_encap");

    CK_OBJECT_HANDLE fmKey = CK_INVALID_HANDLE;  // fm = ephemeral

    CK_MECHANISM mechanism;
    MLKEM_ENCAP mech;

    unsigned int l_packed = 0;
    unsigned char *p_packed = NULL;

    CK_ULONG keyClass = CKO_SECRET_KEY;
    CK_ULONG keyType = CKK_GENERIC_SECRET;
    CK_ULONG byteLength = 32;

    CK_ATTRIBUTE     fmktemplate[] = {
    { CKA_CLASS, &keyClass, sizeof(keyClass) },
    { CKA_KEY_TYPE, &keyType, sizeof(keyType) },
    { CKA_TOKEN, &bFalse, sizeof(bFalse) },
    { CKA_EXTRACTABLE, &bFalse, sizeof(bFalse) },
    { CKA_DERIVE, &bTrue, sizeof(bTrue) },
    { CKA_VALUE_LEN, &byteLength, sizeof(byteLength) }
    };
    unsigned long l_fmkt = 6;

    memset(&mech, 0, sizeof(mech));
    mech.l_publickey = l_pubkey;
    mech.p_publickey = p_pubkey;
    mech.flags = 0;
    mech.type = keytype;

    // Mech - _pack allocates
    err = mlkem_encap_pack(&mech, &l_packed, &p_packed);
    if (err) goto cleanup;

    mechanism.mechanism = CKM_MECH_MLKEM_ENCAP;
    mechanism.ulParameterLen = l_packed;
    mechanism.pParameter = p_packed;

    err = pFunctions->C_DeriveKey(hSession, &mechanism, prvKeyHandle, fmktemplate, l_fmkt, &fmKey);
    if (err) {
        printf("[encap]: C_Derive returned 0x%08x\n", err);
        goto cleanup;
    }

    // at this point, we have the secret in fmKey, and the cyphertext in fmKey's UTI attribute
    *derivedKey = fmKey;
    err = util_get_attribute_value(pFunctions, hSession, fmKey, CKA_UTI_CUSTOM_DATA, p_cyphertext, pulCyphertext);

cleanup:
    if (p_packed != NULL) free(p_packed);

    exuent("mlkem_encap", err);
    return err;
}

int mlkem_decap(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_ULONG    keytype,
    CK_OBJECT_HANDLE prvKeyHandle,
    CK_ULONG    l_cyphertext,
    CK_BYTE_PTR p_cyphertext,
    CK_OBJECT_HANDLE_PTR derivedKey
)
{
    int err = 0;
    entry("mlkem_decap");

    CK_OBJECT_HANDLE fmKey = CK_INVALID_HANDLE;  // fm = ephemeral

    CK_MECHANISM mechanism;
    MLKEM_DECAP mech;

    unsigned int l_packed = 0;
    unsigned char *p_packed = NULL;

    CK_ULONG keyClass = CKO_SECRET_KEY;
    CK_ULONG keyType = CKK_GENERIC_SECRET;
    CK_ULONG byteLength = 32;

    CK_ATTRIBUTE     fmktemplate[] = {
    { CKA_CLASS, &keyClass, sizeof(keyClass) },
    { CKA_KEY_TYPE, &keyType, sizeof(keyType) },
    { CKA_TOKEN, &bFalse, sizeof(bFalse) },
    { CKA_EXTRACTABLE, &bFalse, sizeof(bFalse) },
    { CKA_DERIVE, &bTrue, sizeof(bTrue) },
    { CKA_VALUE_LEN, &byteLength, sizeof(byteLength) }
    };
    unsigned long l_fmkt = 6;

    memset(&mech, 0, sizeof(mech));
    // mech private key will be taken from the vdm inputs, not from the mech
    mech.l_cyphertext = l_cyphertext;
    mech.p_cyphertext = p_cyphertext;
    mech.flags = 0;
    mech.type = keytype;

    // Mech - _pack allocates
    err = mlkem_decap_pack(&mech, &l_packed, &p_packed);
    if (err) goto cleanup;

    mechanism.mechanism = CKM_MECH_MLKEM_DECAP;
    mechanism.ulParameterLen = l_packed;
    mechanism.pParameter = p_packed;

    err = pFunctions->C_DeriveKey(hSession,
        &mechanism, prvKeyHandle, fmktemplate, l_fmkt, &fmKey);
    if (err) {
        printf("[encap]: C_Derive returned 0x%08x\n", err);
        goto cleanup;
    }

    // at this point, we have the secret in fmKey
    *derivedKey = fmKey;

cleanup:
    if (p_packed != NULL) free(p_packed);

    exuent("mlkem_decap", err);
    return err;
}

int mlkem_validate(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE sharedSecretA,
    CK_OBJECT_HANDLE BsharedSecret
)
{
    int err = 0;
    CK_OBJECT_HANDLE aDerivedKey = CK_INVALID_HANDLE;
    CK_OBJECT_HANDLE bDerivedKey = CK_INVALID_HANDLE;
    unsigned char *message = (unsigned char *)"If you see this, it appears to work.\n";
    unsigned long l_message = (unsigned long)strnlen((char *)message, 38);
    unsigned char *cyphergram = NULL;
    unsigned long l_cyphergram = 0;
    unsigned char *plain = NULL;
    unsigned long l_plain = 0;

    CK_BYTE iv[16] = "\x1\x2\x3\x4\xa\xb\xc\xd\x1\x2\x3\x4\xa\xb\xc\xd";

    CK_MECHANISM mechanism;
    CK_MECHANISM aesmech;

    CK_ULONG keyClass = CKO_SECRET_KEY;
    CK_ULONG keyType = CKK_AES;
    CK_ULONG byteLength = 32;

    CK_ATTRIBUTE     fmktemplate[] = {
    { CKA_CLASS,       &keyClass,   sizeof(keyClass) },
    { CKA_KEY_TYPE,    &keyType,    sizeof(keyType) },
    { CKA_TOKEN,       &bFalse,     sizeof(bFalse) },
    { CKA_EXTRACTABLE, &bFalse,     sizeof(bFalse) },
    { CKA_ENCRYPT,     &bTrue,      sizeof(bTrue) },
    { CKA_DECRYPT,     &bTrue,      sizeof(bTrue) },
    { CKA_VALUE_LEN,   &byteLength, sizeof(byteLength) }
    };
    unsigned long l_fmkt = 7;

    mechanism.mechanism = CKM_SHA256_KEY_DERIVATION;
    mechanism.ulParameterLen = 0;
    mechanism.pParameter = NULL;

    // The mech param is an IV.
    aesmech.mechanism = CKM_AES_CBC_PAD;
    aesmech.ulParameterLen = 16;
    aesmech.pParameter = iv;


    // derive aDerivedKey into an AES key
    err = pFunctions->C_DeriveKey(hSession, &mechanism, sharedSecretA, fmktemplate, l_fmkt, &aDerivedKey);
    if (err) goto cleanup;

    // derive bDerivedKey into an AES key
    err = pFunctions->C_DeriveKey(hSession, &mechanism, BsharedSecret, fmktemplate, l_fmkt, &bDerivedKey);
    if (err) goto cleanup;

    // encrypt with a's AES key
    err = pFunctions->C_EncryptInit(hSession, &aesmech, aDerivedKey);
    if (err) goto cleanup;

    err = pFunctions->C_Encrypt(hSession, message, l_message, cyphergram, &l_cyphergram);
    if (err) goto cleanup;

    cyphergram = malloc(l_cyphergram);

    err = pFunctions->C_Encrypt(hSession, message, l_message, cyphergram, &l_cyphergram);
    if (err) goto cleanup;

    // decrypt with b's AES key
    err = pFunctions->C_DecryptInit(hSession, &aesmech, bDerivedKey);
    if (err) goto cleanup;

    err = pFunctions->C_Decrypt(hSession, cyphergram, l_cyphergram, plain, &l_plain);
    if (err) goto cleanup;

    plain = malloc(l_plain + 1); // for the c-string terminal null

    err = pFunctions->C_Decrypt(hSession, cyphergram, l_cyphergram, plain, &l_plain);
    if (err) goto cleanup;

    plain[l_plain] = 0x0;
    if (memcmp(message, plain, l_plain) == 0) {
        printf("%s", plain);
    }

cleanup:
    if (aDerivedKey != CK_INVALID_HANDLE) pFunctions->C_DestroyObject(hSession, aDerivedKey);
    if (bDerivedKey != CK_INVALID_HANDLE) pFunctions->C_DestroyObject(hSession, bDerivedKey);
    if (cyphergram != NULL) free(cyphergram);
    if (plain != NULL) free(plain);
    return err;
}

int mlkem_vdx_test(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long gentest,
    unsigned long encaptest,
    unsigned long decaptest,
    unsigned long encapdecaptest,
    unsigned long mlkem_keytype,
    unsigned long asToken,
    unsigned long count
)
{
    int                    err = 0;

    CK_BBOOL             token = (asToken ? CK_TRUE : CK_FALSE);

    CK_OBJECT_CLASS     prvkeyClass = CKO_PRIVATE_KEY;
    CK_OBJECT_CLASS     pubkeyClass = CKO_PUBLIC_KEY;
    // CK_KEY_TYPE     keyType = CKK_VENDOR_UTI_KYBER_KT;

    CK_OBJECT_HANDLE   aPubKeyHandle = CK_INVALID_HANDLE;
    CK_OBJECT_HANDLE   aPrvKeyHandle = CK_INVALID_HANDLE;
    CK_OBJECT_HANDLE   bPubKeyHandle = CK_INVALID_HANDLE;
    CK_OBJECT_HANDLE   bPrvKeyHandle = CK_INVALID_HANDLE;
    CK_OBJECT_HANDLE   aDerivedKey = CK_INVALID_HANDLE;
    CK_OBJECT_HANDLE   bDerivedKey = CK_INVALID_HANDLE;

    CK_MECHANISM         mechanism;
    memset(&mechanism, 0, sizeof(mechanism));

    CK_BYTE_PTR  p_cyph = NULL;
    CK_ULONG     l_cyph = 0;

    unsigned long test_count = count;
    CK_BBOOL key_generated = 0;

    unsigned char aPubLabel[] = "mlkem_pub_a";
    unsigned char aPrvLabel[] = "mlkem_prv_a";
    unsigned char aPubId[] = { 0x30, 'a' };
    unsigned char aPrvId[] = { 0x31, 'a' };
    CK_BYTE_PTR p_alicePublicKey = NULL;
    CK_ULONG    l_alicePublicKey = 0;

    unsigned char bPubLabel[] = "mlkem_pub_b";
    unsigned char bPrvLabel[] = "mlkem_prv_b";
    unsigned char bPubId[] = { 0x30, 'b' };
    unsigned char bPrvId[] = { 0x31, 'b' };
    CK_BYTE_PTR p_bobPublicKey = NULL;
    CK_ULONG    l_bobPublicKey = 0;

    unsigned long keyfound = 1;
    printf("--> Pre-clean <--\n");
    while (keyfound) {
        keyfound = DestroyObjects(pFunctions, hSession, pubkeyClass,
            (unsigned long)strnlen((const char *)aPubLabel, sizeof(aPubLabel)), aPubLabel,
            (unsigned long)strnlen((const char *)aPubId, sizeof(aPubId)), aPubId);
    }
    keyfound = 1;
    while (keyfound) {
        keyfound = DestroyObjects(pFunctions, hSession, pubkeyClass,
            (unsigned long)strnlen((const char *)bPubLabel, sizeof(bPubLabel)), bPubLabel,
            (unsigned long)strnlen((const char *)bPubId, sizeof(bPubId)), bPubId);
    }
    keyfound = 1;
    while (keyfound) {
        keyfound = DestroyObjects(pFunctions, hSession, prvkeyClass,
            (unsigned long)strnlen((const char *)aPrvLabel, sizeof(aPrvLabel)), aPrvLabel,
            (unsigned long)strnlen((const char *)aPrvId, sizeof(aPrvId)), aPrvId);
    }
    keyfound = 1;
    while (keyfound) {
        keyfound = DestroyObjects(pFunctions, hSession, prvkeyClass,
            (unsigned long)strnlen((const char *)bPrvLabel, sizeof(bPrvLabel)), bPrvLabel,
            (unsigned long)strnlen((const char *)bPrvId, sizeof(bPrvId)), bPrvId);
    }

    switch (mlkem_keytype) {
    case 2: mlkem_keytype = 1; break;
    case 3: mlkem_keytype = 2; break;
    case 5: mlkem_keytype = 3; break;
    }

    while (test_count-- > 0) {
        // first time through generate a key
        // or if gentest, always generate a new key
        if (gentest || !key_generated) {
            CK_ULONG retcount = 1;
            DESTROY_IF(aPubKeyHandle); DESTROY_IF(aPrvKeyHandle); DESTROY_IF(bPubKeyHandle); DESTROY_IF(bPrvKeyHandle);

            err = mlkem_genkey(pFunctions, hSession,
                (unsigned long)strnlen((const char *)aPubLabel, sizeof(aPubLabel)), aPubLabel, 
                (unsigned long)strnlen((const char *)aPubId, sizeof(aPubId)), aPubId,
                (unsigned long)strnlen((const char *)aPrvLabel, sizeof(aPrvLabel)), aPrvLabel, 
                (unsigned long)strnlen((const char *)aPrvId, sizeof(aPrvId)), aPrvId,
                mlkem_keytype, token, &aPubKeyHandle, &aPrvKeyHandle, &p_alicePublicKey, &l_alicePublicKey
            );
            if (err) goto cleanup;

            err = mlkem_genkey(pFunctions, hSession,
                (unsigned long)strnlen((const char *)bPubLabel, sizeof(bPubLabel)), bPubLabel, 
                (unsigned long)strnlen((const char *)bPubId, sizeof(bPubId)), bPubId,
                (unsigned long)strnlen((const char *)bPrvLabel, sizeof(bPrvLabel)), bPrvLabel, 
                (unsigned long)strnlen((const char *)bPrvId, sizeof(bPrvId)), bPrvId,
                mlkem_keytype, token, &bPubKeyHandle, &bPrvKeyHandle, &p_bobPublicKey, &l_bobPublicKey
            );
            if (err) goto cleanup;

            if (aPubKeyHandle == CK_INVALID_HANDLE) goto cleanup;
            if (aPrvKeyHandle == CK_INVALID_HANDLE) goto cleanup;
            if (bPubKeyHandle == CK_INVALID_HANDLE) goto cleanup;
            if (bPrvKeyHandle == CK_INVALID_HANDLE) goto cleanup;

            aPubKeyHandle = CK_INVALID_HANDLE;
            bPubKeyHandle = CK_INVALID_HANDLE;

            err = mlkem_findkey(pFunctions, hSession,
                pubkeyClass,
                (unsigned long)strnlen((const char *)aPubLabel, sizeof(aPubLabel)), aPubLabel,
                (unsigned long)strnlen((const char *)aPubId, sizeof(aPubId)), aPubId,
                0, NULL, 0, NULL,
                &retcount, &aPubKeyHandle);
            if (err) goto cleanup;

            retcount = 1;
            err = mlkem_findkey(pFunctions, hSession,
                pubkeyClass,
                (unsigned long)strnlen((const char *)bPubLabel, sizeof(bPubLabel)), bPubLabel,
                (unsigned long)strnlen((const char *)bPubId, sizeof(bPubId)), bPubId,
                0, NULL, 0, NULL,
                &retcount, &bPubKeyHandle);
            if (err) goto cleanup;
        }

        key_generated = 1;

        if (encaptest || encapdecaptest) {
            // allocates
            if ((p_cyph == NULL) || encaptest || encapdecaptest) {
                if (p_cyph == NULL) free(p_cyph);
                p_cyph = NULL;

                err = mlkem_encap(pFunctions, hSession,
                    mlkem_keytype, aPrvKeyHandle, l_bobPublicKey, p_bobPublicKey, &aDerivedKey, &l_cyph, &p_cyph);
                if (err) goto cleanup;
            }
        }

        if (decaptest || encapdecaptest) {
            err = mlkem_decap(pFunctions, hSession,
                mlkem_keytype, bPrvKeyHandle, l_cyph, p_cyph, &bDerivedKey);
            if (err) goto cleanup;
        }

#ifdef _DEBUG
        {
            err = mlkem_validate(pFunctions, hSession, aDerivedKey, bDerivedKey);
        }
#endif

    }

cleanup:
    entry("mlkem_vdx_test:Cleanup");
    DESTROY_IF(aPubKeyHandle);
    DESTROY_IF(aPrvKeyHandle);
    DESTROY_IF(bPubKeyHandle);
    DESTROY_IF(bPrvKeyHandle);
    if (p_alicePublicKey != NULL) free(p_alicePublicKey);
    if (p_bobPublicKey != NULL) free(p_bobPublicKey);

    exuent("mlkem_vdx_test:Cleanup", err);

    return err;
}


