#include "qptool2.h"
#include "CxiKeyAttributes.h"
#include "MLDSA_KeyGen.h"
#include "MLDSA_Sign.h"
#include "MLDSA_Verify.h"
#include "MLXXX_KeyWrap.h"
#include "qsr2mux.h"
#include "template_util.h"

#include "load_store.h"

static CK_BBOOL bTrue = 1;
static CK_BBOOL bFalse = 0;

#define EXPORTABLE_PRIVATE_KEY 

#define SHAKE256_DIGEST_LEN 64

#define ML_DSA_44 1
#define ML_DSA_65 2
#define ML_DSA_87 3

#ifndef min
#define min(a, b) ((a) < (b) ? (a) : (b))
#endif

#ifdef EXPORTABLE_PRIVATE_KEY 
int mldsa_import_privatekey(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE kekHandle, 
	unsigned int flags, unsigned int keytype,
	unsigned char *p_export, unsigned int l_export,
    unsigned long  l_slabel, unsigned char *p_slabel,
    unsigned long  l_sid, unsigned char *p_sid,
    unsigned long asToken,
    CK_OBJECT_HANDLE_PTR keyUnwrapped
)
{
    int err = 0;

    CK_BBOOL             token = (asToken ? bTrue : bFalse);
    unsigned long t = 0;
	MLXXX_KEYWRAP keywrap;
	unsigned int len = 0;
	unsigned char *p_mech = NULL;

    CK_OBJECT_CLASS     prvkeyClass = CKO_PRIVATE_KEY;
    CK_KEY_TYPE         keyType = CKK_VENDOR_UTI_MLDSA_KT;
    CK_UTF8CHAR_PTR     secLabel = (CK_UTF8CHAR_PTR)p_slabel;
    CK_ULONG            l_secLabel = (CK_ULONG)l_slabel;
    CK_UTF8CHAR_PTR     secId = (CK_UTF8CHAR_PTR)p_sid;
    CK_ULONG            l_secId = (CK_ULONG)l_sid;

    CK_ATTRIBUTE_PTR     prvKeyTemplate = NULL;
    unsigned long prvats = 7;
    if (secLabel != NULL) prvats++;
    if (secId != NULL) prvats++;

    prvKeyTemplate = allocate_template(prvats);
    if (prvKeyTemplate == NULL) {
        return E_NO_MEM;
    }

    // if these are changed, verify the prvats value at initialization
    set_attribute(prvKeyTemplate, t++, CKA_CLASS, &prvkeyClass, sizeof(prvkeyClass));
    set_attribute(prvKeyTemplate, t++, CKA_KEY_TYPE, &keyType, sizeof(keyType));
    set_attribute(prvKeyTemplate, t++, CKA_TOKEN, &token, sizeof(token));
    set_attribute(prvKeyTemplate, t++, CKA_SIGN, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_PRIVATE, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_SENSITIVE, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_EXTRACTABLE, &bTrue, sizeof(bTrue));
    if (secLabel != NULL) set_attribute(prvKeyTemplate, t++, CKA_LABEL, secLabel, l_secLabel);
    if (secId != NULL) set_attribute(prvKeyTemplate, t++, CKA_ID, secId, l_secId);

    CK_MECHANISM         mechanism = { CKM_MECH_MLDSA_UWRAP_AESKWP, 0, 0 };

	memset(&keywrap, 0, sizeof(MLXXX_KEYWRAP));
	keywrap.type = keytype;
	keywrap.flags = flags; // what to store to DB

	// pack allocates
	err = mlxxx_keywrap_pack(&keywrap, &len, &p_mech);
	if (err) goto cleanup;

	mechanism.ulParameterLen = len;
	mechanism.pParameter = p_mech;

    entry("qptool2:mldsa_unwrapkey");

    err = pFunctions->C_UnwrapKey(hSession, &mechanism,
        kekHandle,
        (CK_BYTE_PTR)p_export, (CK_ULONG)l_export,
        get_template(prvKeyTemplate),
        get_template_len(prvKeyTemplate),
        keyUnwrapped);

    exuent("qptool2:mldsa_unwrapkey", err);

cleanup:
	if (p_mech != NULL) { free(p_mech); }
	return err;
}

int mldsa_export_privatekey(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE kekHandle, CK_OBJECT_HANDLE prvKeyHandle,
	unsigned int flags, unsigned int keytype,
    unsigned char **p_export, unsigned int *l_export
)
{
    int err = 0;
    unsigned char *p_buffer = NULL;
    unsigned int l_buffer = 0;

	MLXXX_KEYWRAP keywrap;

    CK_MECHANISM         mechanism = { CKM_MECH_MLDSA_WRAP_AESKWP, 0, 0 };
	unsigned int len = 0;
	unsigned char *p_mech = NULL;

	memset(&keywrap, 0, sizeof(MLXXX_KEYWRAP));
	keywrap.type = keytype;
	keywrap.flags = flags; // what to export

	// pack allocates
	err = mlxxx_keywrap_pack(&keywrap, &len, &p_mech);
	if (err) goto cleanup;

	mechanism.ulParameterLen = len;
	mechanism.pParameter = p_mech;

    err = pFunctions->C_WrapKey(hSession, &mechanism,
        kekHandle, prvKeyHandle,
	(CK_BYTE_PTR) NULL, // p_buffer,
	(CK_ULONG_PTR) &l_buffer);
    if (err) goto cleanup;
    
    p_buffer = malloc(l_buffer);
    if (p_buffer == 0) {
		err = -1;
		goto cleanup;
    }	

    err = pFunctions->C_WrapKey(hSession, &mechanism,
        kekHandle, prvKeyHandle,
	(CK_BYTE_PTR) p_buffer,
	(CK_ULONG_PTR) &l_buffer);
	if (err) goto cleanup;;

    *p_export = p_buffer;
    *l_export = l_buffer;

cleanup:
	if (p_mech != NULL) { free(p_mech); }
	return err;
}
#endif

int mldsa_wrapkey_gen(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long  l_plabel, unsigned char *p_plabel,
    unsigned long  l_pid, unsigned char *p_pid, 
    unsigned long asToken, 
    CK_OBJECT_HANDLE_PTR p_KEKHandle
)
{
    int err = 0x0;

    CK_MECHANISM         mechanism;
    memset(&mechanism, 0, sizeof(mechanism));
    // unsigned int len = 0;
    unsigned char *p_mech = NULL;

    CK_BBOOL             token = (asToken ? bTrue : bFalse);
    unsigned long t = 0;

    CK_OBJECT_CLASS     keyClass = CKO_SECRET_KEY;
    CK_KEY_TYPE         keyType = CKK_AES;
    CK_UTF8CHAR_PTR     keyLabel = (CK_UTF8CHAR_PTR)p_plabel;
    CK_ULONG            l_keyLabel = (CK_ULONG)l_plabel;
    CK_UTF8CHAR_PTR     keyId = (CK_UTF8CHAR_PTR)p_pid;
    CK_ULONG            l_keyId = (CK_ULONG)l_pid;
    CK_ULONG            keyLen = (CK_ULONG)32;

    CK_ATTRIBUTE_PTR     keyTemplate = NULL;
    unsigned long secats = 8;
    if (keyLabel != NULL) secats++;
    if (keyId != NULL) secats++;

    keyTemplate = allocate_template(secats);
    if (keyTemplate == NULL) {
        return E_NO_MEM;
    }

    t = 0;
    // if these are changed, verify the secats value at initialization still agrees 
    set_attribute(keyTemplate, t++, CKA_CLASS, &keyClass, sizeof(keyClass));
    set_attribute(keyTemplate, t++, CKA_KEY_TYPE, &keyType, sizeof(keyType));
    set_attribute(keyTemplate, t++, CKA_VALUE_LEN, &keyLen, sizeof(keyLen));
    set_attribute(keyTemplate, t++, CKA_TOKEN, &token, sizeof(token));
    set_attribute(keyTemplate, t++, CKA_WRAP, &bTrue, sizeof(bTrue));
    set_attribute(keyTemplate, t++, CKA_UNWRAP, &bTrue, sizeof(bTrue));
    set_attribute(keyTemplate, t++, CKA_EXTRACTABLE, &bTrue, sizeof(bTrue));
    set_attribute(keyTemplate, t++, CKA_SENSITIVE, &bTrue, sizeof(bTrue));
    if (keyLabel != NULL)     set_attribute(keyTemplate, t++, CKA_LABEL, keyLabel, l_keyLabel);
    if (keyId != NULL)        set_attribute(keyTemplate, t++, CKA_ID, keyId, l_keyId);

    entry("qptool2:mldsa_genkey - aes wrapping key");

    // Prepare mechanism parameters
    mechanism.mechanism = CKM_AES_KEY_GEN;
    mechanism.ulParameterLen = 0;
    mechanism.pParameter = NULL;

    err = pFunctions->C_GenerateKey(hSession, &mechanism,
        get_template(keyTemplate),
        get_template_len(keyTemplate),
        p_KEKHandle);
    if (err != 0x0) {
        printf("[genkey]: C_GenerateKey returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    if (p_mech != NULL) { free(p_mech); }
    exuent("qptool2:mldsa_genkey - kek wrapping key", err);
    return err;
}


int mldsa_genkey(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long  l_plabel, unsigned char *p_plabel,
    unsigned long  l_pid, unsigned char *p_pid,
    unsigned long  l_slabel, unsigned char *p_slabel,
    unsigned long  l_sid, unsigned char *p_sid,
    unsigned long mldsa_keytype, unsigned long asToken,
    CK_OBJECT_HANDLE_PTR p_PubkeyHandle,
    CK_OBJECT_HANDLE_PTR p_PrvkeyHandle
)
{
    int err = 0x0;

    CK_MECHANISM         mechanism;
    memset(&mechanism, 0, sizeof(mechanism));
    unsigned int len = 0;
    unsigned char *p_mech = NULL;

    MLDSA_KEYGEN keygen;

    CK_BBOOL             token = (asToken ? bTrue : bFalse);
    unsigned long t = 0;

    CK_OBJECT_CLASS     prvkeyClass = CKO_PRIVATE_KEY;
    CK_OBJECT_CLASS     pubkeyClass = CKO_PUBLIC_KEY;
    CK_KEY_TYPE         keyType = CKK_VENDOR_UTI_MLDSA_KT;
    CK_UTF8CHAR_PTR     secLabel = (CK_UTF8CHAR_PTR)p_slabel;
    CK_ULONG            l_secLabel = (CK_ULONG)l_slabel;
    CK_UTF8CHAR_PTR     secId = (CK_UTF8CHAR_PTR)p_sid;
    CK_ULONG            l_secId = (CK_ULONG)l_sid;
    CK_UTF8CHAR_PTR     pubLabel = (CK_UTF8CHAR_PTR)p_plabel;
    CK_ULONG            l_pubLabel = (CK_ULONG)l_plabel;
    CK_UTF8CHAR_PTR     pubId = (CK_UTF8CHAR_PTR)p_pid;
    CK_ULONG            l_pubId = (CK_ULONG)l_pid;

    CK_ATTRIBUTE_PTR     pubKeyTemplate = NULL;
    unsigned long pubats = 4;
    if (pubLabel != NULL) pubats++;
    if (pubId != NULL) pubats++;

    CK_ATTRIBUTE_PTR     prvKeyTemplate = NULL;
    unsigned long prvats = 7;
    if (secLabel != NULL) prvats++;
    if (secId != NULL) prvats++;

    pubKeyTemplate = allocate_template(pubats);
    if (pubKeyTemplate == NULL) {
        return E_NO_MEM;
    }
    t = 0;
    // if these are changed, verify the pubats value at initialization
    set_attribute(pubKeyTemplate, t++, CKA_CLASS, &pubkeyClass, sizeof(pubkeyClass));
    set_attribute(pubKeyTemplate, t++, CKA_KEY_TYPE, &keyType, sizeof(keyType));
    set_attribute(pubKeyTemplate, t++, CKA_TOKEN, &token, sizeof(token));
    set_attribute(pubKeyTemplate, t++, CKA_VERIFY, &bTrue, sizeof(bTrue));
    if (pubLabel != NULL)     set_attribute(pubKeyTemplate, t++, CKA_LABEL, pubLabel, l_pubLabel);
    if (pubId != NULL)        set_attribute(pubKeyTemplate, t++, CKA_ID, pubId, l_pubId);

    prvKeyTemplate = allocate_template(prvats);
    if (prvKeyTemplate == NULL) {
        return E_NO_MEM;
    }
    t = 0;
    // if these are changed, verify the prvats value at initialization
    set_attribute(prvKeyTemplate, t++, CKA_CLASS, &prvkeyClass, sizeof(prvkeyClass));
    set_attribute(prvKeyTemplate, t++, CKA_KEY_TYPE, &keyType, sizeof(keyType));
    set_attribute(prvKeyTemplate, t++, CKA_TOKEN, &token, sizeof(token));
    set_attribute(prvKeyTemplate, t++, CKA_SIGN, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_PRIVATE, &bTrue, sizeof(bTrue));
    set_attribute(prvKeyTemplate, t++, CKA_SENSITIVE, &bTrue, sizeof(bTrue));
	set_attribute(prvKeyTemplate, t++, CKA_DERIVE, &bTrue, sizeof(bTrue));
#ifdef EXPORTABLE_PRIVATE_KEY
    set_attribute(prvKeyTemplate, t++, CKA_EXTRACTABLE, &bTrue, sizeof(bTrue));
#else
    set_attribute(prvKeyTemplate, t++, CKA_EXTRACTABLE, &bFalse, sizeof(bFalse));
#endif

    if (secLabel != NULL) set_attribute(prvKeyTemplate, t++, CKA_LABEL, secLabel, l_secLabel);
    if (secId != NULL) set_attribute(prvKeyTemplate, t++, CKA_ID, secId, l_secId);

    entry("qptool2:mldsa_genkey");

    // Prepare mechanism parameters
    memset(&keygen, 0, sizeof(MLDSA_KEYGEN));
    // mldsa_keytype should be 0x32, 0x33 or 0x35, corresponding to 44, 65 and 87
    keygen.type = mldsa_keytype;
    keygen.flags = 1; // need to use ML's idea for what is pseudo, what is real. 1 is Pseudo

    // pack allocates
    err = mldsa_keygen_pack(&keygen, &len, &p_mech);
    if (err) goto cleanup;

    mechanism.mechanism = CKM_MECH_MLDSA_GENKEY;
    mechanism.ulParameterLen = len;
    mechanism.pParameter = p_mech;

    err = pFunctions->C_GenerateKeyPair(hSession,
        &mechanism,
        get_template(pubKeyTemplate),
        get_template_len(pubKeyTemplate),
        get_template(prvKeyTemplate),
        get_template_len(prvKeyTemplate),
        p_PubkeyHandle, p_PrvkeyHandle);
    if (err != 0x0) {
        printf("[genkey]: C_GenerateKeyPair returned 0x%08x\n", err);
        goto cleanup;
    }

cleanup:
    if (p_mech != NULL) { free(p_mech); }
    if (pubKeyTemplate != NULL) free_template(pubKeyTemplate);
    if (prvKeyTemplate != NULL) free_template(prvKeyTemplate);
    exuent("qptool2:mldsa_genkey", err);
    return err;
}

int mldsa_findkey(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long  keyClass,
    unsigned long  l_plabel, unsigned char *p_plabel,
    unsigned long  l_pid, unsigned char *p_pid,
    unsigned long  l_slabel, unsigned char *p_slabel,
    unsigned long  l_sid, unsigned char *p_sid,
    unsigned long  *l_retcount,
    CK_OBJECT_HANDLE_PTR p_KeyHandle
)
{
    int err = 0x0;

    entry("mldsa_findkey");

    CK_MECHANISM         mechanism;
    memset(&mechanism, 0, sizeof(mechanism));
    unsigned long t = 0;
    unsigned long maxObjCount = *l_retcount;

    CK_ATTRIBUTE_PTR     findTemplate = NULL;
    unsigned long ats = 0;
    if (p_plabel != NULL) ats++;
    if (p_pid != NULL) ats++;
    if (p_slabel != NULL) ats++;
    if (p_sid != NULL) ats++;
    if ((ats < 1) || (ats > 2)) {
        return E_TMPL_INVALID;
    }

    findTemplate = allocate_template(ats);
    if (findTemplate == NULL) {
        return E_NO_MEM;
    }
    t = 0;
    if (p_plabel != NULL) set_attribute(findTemplate, t++, CKA_LABEL, p_plabel, l_plabel);
    if (p_pid != NULL) set_attribute(findTemplate, t++, CKA_ID, p_pid, l_pid);
    if (p_slabel != NULL) set_attribute(findTemplate, t++, CKA_LABEL, p_slabel, l_slabel);
    if (p_sid != NULL) set_attribute(findTemplate, t++, CKA_ID, p_sid, l_sid);

    err = pFunctions->C_FindObjectsInit(hSession, get_template(findTemplate), get_template_len(findTemplate));
    if (err != 0x0) goto cleanup;
    
    err = pFunctions->C_FindObjects(hSession, p_KeyHandle, maxObjCount, l_retcount);
    if (err != 0x0) goto cleanup;

    err = pFunctions->C_FindObjectsFinal(hSession);
    if (err != 0x0) goto cleanup;
    
cleanup:
    if (findTemplate != NULL) free_template(findTemplate);
    exuent("mldsa_findkey", err);
    return err;
}


int hash_shake256(
	CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
	CK_ULONG    l_msg, CK_BYTE_PTR p_msg,
	CK_BYTE_PTR p_hash
)
{
	int                 err = 0;
	CK_ULONG			l_hash = SHAKE256_DIGEST_LEN;

	CK_MECHANISM        mechanism;
	mechanism.mechanism = CKM_MECH_MLDSA_SHAKE256;
	mechanism.pParameter = NULL;
	mechanism.ulParameterLen = 0;

	err = pFunctions->C_DigestInit(hSession, &mechanism);
	if (err) {
		printf("[hash_shake256]: C_DigestInit returned 0x%08x\n", err);
		goto cleanup;
	}

	err = pFunctions->C_Digest(hSession, p_msg, l_msg, p_hash, &l_hash);
	if (err) {
		printf("[hash_shake256]: C_Digest returned 0x%08x\n", err);
		goto cleanup;
	}

cleanup:
	return err;
}

int mldsa_sign(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE prvKeyHandle,
    CK_ULONG    keytype,
    CK_ULONG    l_msg,
    CK_BYTE_PTR p_msg,
    CK_ULONG   *pulSigLen,
    CK_BYTE_PTR *pSig
) 
{
    int err = 0;
    entry("mldsa_sign");
    CK_BYTE_PTR p_sig = NULL;

    unsigned char p_mech[8];
    store_int4(0, p_mech); // flags. used?
    store_int4(keytype, p_mech + 4);

    CK_MECHANISM mechanism;
    mechanism.mechanism = CKM_MECH_MLDSA_SIGN;
    mechanism.ulParameterLen = sizeof(p_mech);
    mechanism.pParameter = p_mech;

    err = pFunctions->C_SignInit(hSession, &mechanism, prvKeyHandle);
    if (err) {
        printf("[sign]: C_SignInit returned 0x%08x\n", err);
        goto cleanup;
    }

    // get len
    err = pFunctions->C_Sign(hSession, p_msg, l_msg, NULL, pulSigLen);
    if (err) {
        printf("[sign]: C_Sign returned 0x%08x\n", err);
        goto cleanup;
    }

    p_sig = malloc(*pulSigLen);

    err = pFunctions->C_Sign(hSession, p_msg, l_msg, p_sig, pulSigLen);
    if (err) {
        printf("[sign]: C_Sign returned 0x%08x\n", err);
        goto cleanup;
    }

    *pSig = p_sig;

cleanup:
    exuent("mldsa_sign", err);
    return err;
}

int mldsa_sign_external_mu(
	CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
	CK_OBJECT_HANDLE prvKeyHandle,
	CK_ULONG    l_pubkey,
	CK_BYTE_PTR p_pubkey,
	CK_ULONG    keytype,
	CK_ULONG    l_msg,
	CK_BYTE_PTR p_msg,
	CK_ULONG   *pulSigLen,
	CK_BYTE_PTR *pSig
)
{
	int err = 0;
	entry("mldsa_sign_external_mu");
	CK_BYTE_PTR p_sig = NULL;
	CK_BYTE_PTR p_buff = NULL;
	CK_ULONG l_buff;
	CK_BYTE mu[SHAKE256_DIGEST_LEN];

	unsigned char p_mech[8];
	store_int4(0, p_mech); // flags. used?
	store_int4(keytype, p_mech + 4);

	CK_MECHANISM mechanism;
	mechanism.mechanism = CKM_MECH_MLDSA_EXTMU_SIGN;
	mechanism.ulParameterLen = sizeof(p_mech);
	mechanism.pParameter = p_mech;

	// calculate mu
	l_buff = SHAKE256_DIGEST_LEN + 2 + l_msg;

	p_buff = malloc(l_buff);
	if (p_buff == NULL) {
		printf("[mldsa_sign_external_mu]: Couldn't allocate memory.\n");
		err = -1;
		goto cleanup;
	}

	// H(pubkey)
	err = hash_shake256(pFunctions, hSession, l_pubkey, p_pubkey, p_buff);
	if (err != 0) goto cleanup;
	
	// ctx (empty)
	p_buff[SHAKE256_DIGEST_LEN] = 0;
	p_buff[SHAKE256_DIGEST_LEN + 1] = 0;

	// msg
	memcpy(p_buff + SHAKE256_DIGEST_LEN + 2, p_msg, l_msg);

	err = hash_shake256(pFunctions, hSession, l_buff, p_buff, mu);
	if (err != 0) goto cleanup;


	err = pFunctions->C_SignInit(hSession, &mechanism, prvKeyHandle);
	if (err) {
		printf("[mldsa_sign_external_mu]: C_SignInit returned 0x%08x\n", err);
		goto cleanup;
	}

	// get len
	err = pFunctions->C_Sign(hSession, mu, SHAKE256_DIGEST_LEN, NULL, pulSigLen);
	if (err) {
		printf("[mldsa_sign_external_mu]: C_Sign returned 0x%08x\n", err);
		goto cleanup;
	}

	p_sig = malloc(*pulSigLen);

	err = pFunctions->C_Sign(hSession, mu, SHAKE256_DIGEST_LEN, p_sig, pulSigLen);
	if (err) {
		printf("[mldsa_sign_external_mu]: C_Sign returned 0x%08x\n", err);
		goto cleanup;
	}

	*pSig = p_sig;

cleanup:
	if (p_buff) free(p_buff);
	exuent("mldsa_sign_external_mu", err);
	return err;
}


int mldsa_verify(
    CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    CK_OBJECT_HANDLE pubKeyHandle,
    CK_ULONG    keytype,
    CK_ULONG    l_msg,
    CK_BYTE_PTR p_msg,
    CK_ULONG    pulSigLen,
    CK_BYTE_PTR pSig
)
{
    int err = 0x0;

    entry("mldsa_verify");
    
    unsigned char p_mech[8];
    store_int4(0, p_mech); // flags. used?
    store_int4(keytype, p_mech + 4);

    CK_MECHANISM mechanism;
    mechanism.mechanism = CKM_MECH_MLDSA_VERIFY;
    mechanism.ulParameterLen = sizeof(p_mech);
    mechanism.pParameter = p_mech;

    err = pFunctions->C_VerifyInit(hSession, &mechanism, pubKeyHandle);
    if (err) {
        printf("[sign]: C_VerifyInit returned 0x%08x\n", err);
        goto cleanup;
    }

    // get len
    err = pFunctions->C_Verify(hSession, p_msg, l_msg, pSig, pulSigLen);
    if (err) {
        if (err == CKR_SIGNATURE_INVALID) {
            printf("[sign]: C_Verify returned CKR_SIGNATURE_INVALID\n");
        }
        else {
            printf("[sign]: C_Verify returned 0x%08x\n", err);
        }
        goto cleanup;
    }

cleanup:

    exuent("mldsa_verify", err);
    return err;
}

int mldsa_verify_external_mu(
	CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
	CK_OBJECT_HANDLE pubKeyHandle,
	CK_ULONG    l_pubkey,
	CK_BYTE_PTR p_pubkey,
	CK_ULONG    keytype,
	CK_ULONG    l_msg,
	CK_BYTE_PTR p_msg,
	CK_ULONG    pulSigLen,
	CK_BYTE_PTR pSig
)
{
	int err = 0;
	entry("mldsa_verify_external_mu");
	CK_BYTE_PTR p_buff = NULL;
	CK_ULONG l_buff;
	CK_BYTE mu[SHAKE256_DIGEST_LEN];

	unsigned char p_mech[8];
	store_int4(0, p_mech); // flags. used?
	store_int4(keytype, p_mech + 4);

	CK_MECHANISM mechanism;
	mechanism.mechanism = CKM_MECH_MLDSA_EXTMU_VERIFY;
	mechanism.ulParameterLen = sizeof(p_mech);
	mechanism.pParameter = p_mech;

	// calculate mu
	l_buff = SHAKE256_DIGEST_LEN + 2 + l_msg;

	p_buff = malloc(l_buff);
	if (p_buff == NULL) {
		printf("[mldsa_verify_external_mu]: Couldn't allocate memory.\n");
		err = -1;
		goto cleanup;
	}

	// H(pubkey)
	err = hash_shake256(pFunctions, hSession, l_pubkey, p_pubkey, p_buff);
	if (err != 0) goto cleanup;

	// ctx (empty)
	p_buff[SHAKE256_DIGEST_LEN] = 0;
	p_buff[SHAKE256_DIGEST_LEN + 1] = 0;

	// msg
	memcpy(p_buff + SHAKE256_DIGEST_LEN + 2, p_msg, l_msg);

	err = hash_shake256(pFunctions, hSession, l_buff, p_buff, mu);
	if (err != 0) goto cleanup;


	err = pFunctions->C_VerifyInit(hSession, &mechanism, pubKeyHandle);
	if (err) {
		printf("[mldsa_verify_external_mu]: C_VerifyInit returned 0x%08x\n", err);
		goto cleanup;
	}

	// get len
	err = pFunctions->C_Verify(hSession, mu, SHAKE256_DIGEST_LEN, pSig, pulSigLen);
	if (err) {
		printf("[mldsa_verify_external_mu]: C_Verify returned 0x%08x\n", err);
		goto cleanup;
	}

cleanup:
	if (p_buff) free(p_buff);
	exuent("mldsa_verify_external_mu", err);
	return err;
}

int mldsa_vdx_test(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession,
    unsigned long gentest,
    unsigned long signtest,
    unsigned long verifytest,
    unsigned long signverifytest,
    unsigned long mldsa_keytype,
    unsigned long asToken,
    unsigned long count
    )
{
    int                    err = 0;

    CK_BYTE_PTR         someRandomMessage = (unsigned char *)
        "\n'Twas brillig, and the slithy toves\n"
        "    Did gyre and gimble in the wabe:\n"
        "All mimsy were the borogoves,\n"
        "    And the mome raths outgrabe.\n"
        "\n"
        "Beware the Jabberwock, my son!\n"
        "    The jaws that bite, the claws that catch!\n"
        "Beware the Jubjub bird, and shun\n"
        "    The frumious Bandersnatch!\n\n";
    CK_ULONG l_someRandomMessage = (CK_ULONG)strnlen((char *)someRandomMessage, 281);

    CK_BYTE_PTR         pSig = NULL;
    CK_BYTE_PTR         pRekeySig = NULL;
	CK_BYTE_PTR         pMuSig = NULL;
	CK_BYTE_PTR         p_pubkey = NULL;
	CK_ULONG            pulSigLen = 0;
    CK_ULONG            pulRekeySigLen = 0;
	CK_ULONG            pulMuSigLen = 0;
	CK_ULONG            l_pubkey = 0;

    CK_BBOOL             token = (asToken ? bTrue : bFalse);

    CK_OBJECT_CLASS     prvkeyClass = CKO_PRIVATE_KEY;
    CK_OBJECT_CLASS     pubkeyClass = CKO_PUBLIC_KEY;

    CK_OBJECT_HANDLE   pubKeyHandle = CK_INVALID_HANDLE;
    unsigned char pubLabel[] = "mldsa_pub";
    unsigned char pubId[] = { 0x30, 0x30 };

    CK_OBJECT_HANDLE   prvKeyHandle = CK_INVALID_HANDLE;
    unsigned char prvLabel[] = "mldsa_prv";
    unsigned char prvId[] = { 0x31, 0x31 };

#ifdef EXPORTABLE_PRIVATE_KEY
    CK_OBJECT_HANDLE   rekeyHandle = CK_INVALID_HANDLE;
    unsigned char rekeyLabel[] = "rekey label";
    unsigned char rekeyId[] = { 0x33, 0x33 };
    CK_OBJECT_HANDLE   kekHandle = CK_INVALID_HANDLE;
    unsigned char kekLabel[] = "kekhandle";
    unsigned char kekId[] = { 0x32, 0x32 };
    unsigned char *p_exportedKey = NULL;
    unsigned int   l_exportedKey = 0;
	unsigned char *p_exportedSeed = NULL;
	unsigned int   l_exportedSeed = 0;
#endif

    unsigned long test_count = count;
    CK_BBOOL key_generated = 0;

    unsigned long keyfound = 1;
    while (keyfound) {
        unsigned long retcount = 1;
        keyfound = 0;
        pubKeyHandle = CK_INVALID_HANDLE;

        err = mldsa_findkey(pFunctions, hSession,
            pubkeyClass,
            (unsigned long)strnlen((char *)pubLabel, sizeof(pubLabel)), pubLabel,
            (unsigned long)strnlen((char *)pubId, sizeof(pubId)), pubId,
            0, NULL, 0, NULL,
            &retcount, &pubKeyHandle);
        if (err) goto cleanup;
        if (retcount == 0) break;
        keyfound = 1;
        if (pubKeyHandle != CK_INVALID_HANDLE) {
            pFunctions->C_DestroyObject(hSession, pubKeyHandle);
        }
    }
    keyfound = 1;
    while (keyfound) {
        unsigned long retcount = 1;
        keyfound = 0;
        prvKeyHandle = CK_INVALID_HANDLE;

        err = mldsa_findkey(pFunctions, hSession,
            prvkeyClass,
            0, NULL, 0, NULL,
            (unsigned long)strnlen((char *)prvLabel, sizeof(prvLabel)), prvLabel,
            (unsigned long)strnlen((char *)prvId, sizeof(prvId)), prvId,
            &retcount, &prvKeyHandle);
        if (err) goto cleanup;
        if (retcount == 0) break;
        keyfound = 1;
        if (prvKeyHandle != CK_INVALID_HANDLE) {
            pFunctions->C_DestroyObject(hSession, prvKeyHandle);
        }
    }

    switch (mldsa_keytype) {
    case 2: mldsa_keytype = ML_DSA_44; break;
    case 3: mldsa_keytype = ML_DSA_65; break;
    case 5: mldsa_keytype = ML_DSA_87; break;
    default:
        return CKR_KEY_SIZE_RANGE;
    }

    while (test_count-- > 0) {
        // first time through generate a key
        // or if gentest, always generate a new key
        if (gentest || !key_generated) {
            if (pubKeyHandle != CK_INVALID_HANDLE) pFunctions->C_DestroyObject(hSession, pubKeyHandle);
            if (prvKeyHandle != CK_INVALID_HANDLE) pFunctions->C_DestroyObject(hSession, prvKeyHandle);
            pubKeyHandle = CK_INVALID_HANDLE;
            prvKeyHandle = CK_INVALID_HANDLE;

            err = mldsa_genkey(pFunctions, hSession,
                (unsigned long)strnlen((char *)pubLabel, sizeof(pubLabel)), pubLabel, 
                (unsigned long)strnlen((char *)pubId, sizeof(pubId)), pubId,
                (unsigned long)strnlen((char *)prvLabel, sizeof(prvLabel)), prvLabel, 
                (unsigned long)strnlen((char *)prvId, sizeof(prvId)), prvId,
                mldsa_keytype, token, &pubKeyHandle, &prvKeyHandle
            );
            if (err) goto cleanup;

            if (pubKeyHandle == CK_INVALID_HANDLE) goto cleanup;
            if (prvKeyHandle == CK_INVALID_HANDLE) goto cleanup;

            if (!gentest) {
                pubKeyHandle = CK_INVALID_HANDLE;

                CK_ULONG retcount = 1;
                err = mldsa_findkey(pFunctions, hSession,
                    pubkeyClass,
                    (unsigned long)strnlen((char *)pubLabel, sizeof(pubLabel)), pubLabel,
                    (unsigned long)strnlen((char *)pubId, sizeof(pubId)), pubId,
                    0, NULL, 0, NULL,
                    &retcount, &pubKeyHandle);
                if (err) goto cleanup;
            }
#ifdef EXPORTABLE_PRIVATE_KEY
	    // Export wrapped key.
	    err = mldsa_wrapkey_gen(pFunctions, hSession, 
                (unsigned long)strnlen((char *)kekLabel, sizeof(kekLabel)), kekLabel,
                (unsigned long)strnlen((char *)kekId, sizeof(kekId)), kekId,
		        0, &kekHandle);
        if (err) goto cleanup;

		// seed export - key can be recreated from seed
        err = mldsa_export_privatekey(pFunctions, hSession, 
                kekHandle, prvKeyHandle,
				ML_MODE_STORE_SEED,  // export seed
				mldsa_keytype, 
		        &p_exportedSeed, &l_exportedSeed);
	    if (err) goto cleanup;

		cs_xprint("Exported key seed (aes_kwp wrapped)\n", p_exportedSeed, l_exportedSeed);

		// private key export
		/*
		err = mldsa_export_privatekey(pFunctions, hSession,
			kekHandle, prvKeyHandle,
			ML_MODE_STORE_SK,  // export private key
			mldsa_keytype,
			&p_exportedKey, &l_exportedKey);
		if (err) goto cleanup;
        cs_xprint("Exported private key (aes_kwp wrapped)\n", p_exportedKey, l_exportedKey);
		*/

        // reimport wrapped key as a test
        err = mldsa_import_privatekey(pFunctions, hSession,
            kekHandle, 
			ML_MODE_STORE_SEED | ML_MODE_STORE_SK, // store seed & private key 
			mldsa_keytype,
			p_exportedSeed, l_exportedSeed, 
            (unsigned long)strnlen((char *)rekeyLabel, sizeof(rekeyLabel)), rekeyLabel,
            (unsigned long)strnlen((char *)rekeyId, sizeof(rekeyId)), rekeyId,
            0, &rekeyHandle);
        if (err) goto cleanup;

		if (p_exportedKey) free(p_exportedKey);
		if (p_exportedSeed) free(p_exportedSeed);
#endif
        }

		// get pubkey
		err = util_get_attribute_value(pFunctions, hSession, prvKeyHandle, CKA_UTI_CUSTOM_DATA, &p_pubkey, &l_pubkey);
		if (err) goto cleanup;

        key_generated = 1;

        if (signtest || signverifytest) {
            // allocates
            if ((pSig == NULL) || signtest || signverifytest) {
                if (pSig != NULL) free(pSig);
                pSig = NULL;

                printf("Sign using generated key\n");
                err = mldsa_sign(pFunctions, hSession,
                    prvKeyHandle, mldsa_keytype,
                    l_someRandomMessage, someRandomMessage,
                    &pulSigLen, &pSig);
                if (err) goto cleanup;

                printf("Sign using re-imported key\n");
                err = mldsa_sign(pFunctions, hSession,
                    rekeyHandle, mldsa_keytype,
                    l_someRandomMessage, someRandomMessage,
                    &pulRekeySigLen, &pRekeySig);
                if (err) goto cleanup;

				printf("Sign using external mu\n");
				err = mldsa_sign_external_mu(pFunctions, hSession,
					prvKeyHandle, l_pubkey, p_pubkey, mldsa_keytype,
					l_someRandomMessage, someRandomMessage,
					&pulMuSigLen, &pMuSig);
				if (err) goto cleanup;

            }
        }

        if (verifytest || signverifytest) {
            printf("Verify original using generated key\n");
            err = mldsa_verify(pFunctions, hSession,
                pubKeyHandle, mldsa_keytype,
                l_someRandomMessage, someRandomMessage,
                pulSigLen, pSig);
            if (err) goto cleanup;

            printf("Verify rekey sig using generated key\n");
            err = mldsa_verify(pFunctions, hSession,
                pubKeyHandle, mldsa_keytype,
                l_someRandomMessage, someRandomMessage,
                pulRekeySigLen, pRekeySig);
            if (err) goto cleanup;

			printf("Verify external mu sig using mldsa_verify\n");
			err = mldsa_verify(pFunctions, hSession,
				pubKeyHandle, mldsa_keytype,
				l_someRandomMessage, someRandomMessage,
				pulMuSigLen, pMuSig);
			if (err) goto cleanup;

			printf("Verify original sig using mldsa_verify_external_mu\n");
			err = mldsa_verify_external_mu(pFunctions, hSession,
				pubKeyHandle, l_pubkey, p_pubkey, mldsa_keytype,
				l_someRandomMessage, someRandomMessage,
				pulSigLen, pSig);
			if (err) goto cleanup;
		}

    }

cleanup:
    entry("mldsa_vdx_test:Cleanup");
    if (pSig != NULL) free(pSig);
	if (pRekeySig != NULL) free(pRekeySig);
	if (pMuSig != NULL) free(pMuSig);
	if (p_pubkey) free(p_pubkey);

	if (pubKeyHandle != CK_INVALID_HANDLE) { pFunctions->C_DestroyObject(hSession, pubKeyHandle); }
    if (prvKeyHandle != CK_INVALID_HANDLE) { pFunctions->C_DestroyObject(hSession, prvKeyHandle); }
    exuent("mldsa_vdx_test:Cleanup", err);

    return err;
}


