/**************************************************************************************************
 *
 * Filename           : main_3.3.3_GenerateKeyPair.c
 * Author             : Utimaco GmbH
 * Description        : PKCS#11 main example out of the guide:
 *                      "Learning PKCS#11 in Half a Day - using the Utimaco HSM Simulator"
 * Dependencies       : pkcs11_handson.c
 * Creation Date      : 27.01.2016
 * Version            : 1.3.0
 *
 *************************************************************************************************/
#include "qptool2.h"
#include "stdlib.h"
extern int test_case_lms(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);
extern int test_case_hss(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);
extern int test_case_xmss(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);

#ifdef IF_INCLUDE_LATTICE_IPD
extern int test_case_dilithium(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);
extern int test_case_kyber(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);
#endif

extern int test_case_mldsa(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);
extern int test_case_mlkem(CK_FUNCTION_LIST_PTR pFunctions, CK_SESSION_HANDLE hSession, struct PPARMS *parms);

extern void pkcs_err_text(unsigned int errv);
void help();


//
// BEGIN BOILERPLATE BEGIN
//

/******************************************************************************
 *
 * main_3.3.3
 *
 ******************************************************************************/
int main(int argc, char *argv[])
{
    int                   err = 0;

    char *cfg = getenv("CS_PKCS11_R3_CFG");
    printf("Using: %s\n", cfg);
    if (cfg == NULL) {
        printf("CS_PKCS11_R3_CFG unset. Stopping.\n");
    }

    CK_FUNCTION_LIST_PTR  pFunctions = NULL;
#ifdef OSYS_win
    HMODULE               hModule = NULL;
#else
    void                  *lib_handle = NULL;
#endif   

    // -pin xyzab [default 12345678]
    char                  *userPIN = "12345678";
    CK_ULONG              lenUserPIN = (CK_ULONG)strlen(userPIN);

    // -lib .../lib/cs_PKCS11_R2.dll [no default]
    char                  *lib = "cs_pkcs11_r3.dll";

    // -slot N [default 0]
    CK_ULONG              slotID = 0;

    // -test T [default 0]
    CK_ULONG              testID = 1;
    CK_ULONG              algo = 0;

    CK_SESSION_HANDLE     hSession = 0;
    // CK_OBJECT_HANDLE      hPublicKey = 0;
    // CK_OBJECT_HANDLE      hPrivateKey = 0;

    struct PPARMS parms;
    memset(&parms, 0, sizeof(parms));
    // set default parms by testID
    parms.astoken = 0; // hbs testID 1
    parms.curvedes = 3; // secp256r1 via the dotted OID
    parms.keytype = 1; // security levels, valid are 1/2, 3, and 5
    parms.r3_variant = 0; // dilithium and kyber, if -aes or -90s

    // initialize
    for (int i = 1; i < argc; i++ /* one half of i+=2 */)
    {
        if (strcmp(argv[i], "-sign") == 0) { parms.sign = 1; continue; }
        if (strcmp(argv[i], "-encap") == 0) { parms.encap = 1; continue; }
        if (strcmp(argv[i], "-verify") == 0) { parms.verify = 1; continue; }
        if (strcmp(argv[i], "-decap") == 0) { parms.decap = 1; continue; }
        if (strcmp(argv[i], "-gen") == 0) { parms.gen = 1;  continue; }
        if (strcmp(argv[i], "-all") == 0) {
            parms.gen = 1; parms.sign = 1; parms.verify = 1;
            parms.encap = 1; parms.decap = 1;
            continue;
    }
        if (strcmp(argv[i], "-signverify") == 0) { parms.signverify = 1;  continue; }
        if (strcmp(argv[i], "-encapdecap") == 0) { parms.encdec = 1;  continue; }

        if (strcmp(argv[i], "-h") == 0) { help(); return 0; }
        if (strcmp(argv[i], "-lms") == 0) { algo = LMS; testID = 1; continue; }
#ifdef IF_INCLUDE_LATTICE_IPD
        if (strcmp(argv[i], "-dil") == 0) { algo = DILITHIUM; testID = 2;  continue; }
        if (strcmp(argv[i], "-dilithium") == 0) { algo = DILITHIUM; testID = 2; continue; }
        if (strcmp(argv[i], "-kyber") == 0) { algo = KYBER; testID = 3; continue; }
        if (strcmp(argv[i], "-kyb") == 0) { algo = KYBER; testID = 3; continue; }
        if ((strcmp(argv[i], "-90s") == 0) && (algo == KYBER)) { parms.r3_variant = 3; continue; }
        if ((strcmp(argv[i], "-aes") == 0) && (algo == DILITHIUM)) { parms.r3_variant = 1; continue; }
#endif
        if (strcmp(argv[i], "-mldsa") == 0) { algo = MLDSA; testID = 4; continue; }
        if (strcmp(argv[i], "-mlkem") == 0) { algo = MLKEM; testID = 5; continue; }
        if (strcmp(argv[i], "-xmss") == 0) { algo = XMSS; testID = 6; continue; }
        if (strcmp(argv[i], "-hss") == 0) { algo = LMS; testID = 7; continue; }

        // parms targets
        if (strcmp(argv[i], "-token") == 0) { parms.astoken = 1; continue; }
        if (strcmp(argv[i], "-tkn") == 0) { parms.astoken = 1; continue; }

        if (i + 1 >= argc) {
            printf("Value not supplied for flag %s at end of input.\n", argv[i]);
            goto syntax;
        }

        if (strcmp(argv[i], "-lib") == 0) { lib = argv[++i]; continue; }
        if (strcmp(argv[i], "-l") == 0) { lib = argv[++i]; continue; }

        if (strcmp(argv[i], "-pin") == 0) { userPIN = argv[++i]; continue; }
        if (strcmp(argv[i], "-p") == 0) { userPIN = argv[++i]; continue; }

        if (strcmp(argv[i], "-slot") == 0) { slotID = atol(argv[++i]); continue; }
        if (strcmp(argv[i], "-s") == 0) { slotID = atol(argv[++i]); continue; }

        if (strcmp(argv[i], "-test") == 0) { testID = atol(argv[++i]); continue; }
        if (strcmp(argv[i], "-t") == 0) { testID = atol(argv[++i]); continue; }

        // parms targets
        if (strcmp(argv[i], "-curve") == 0) { parms.curvedes = atol(argv[++i]); continue; }
        if (strcmp(argv[i], "-c") == 0) { parms.curvedes = atol(argv[++i]); continue; }
        if (strcmp(argv[i], "-keytype") == 0) { parms.keytype = atol(argv[++i]); continue; }
        if (strcmp(argv[i], "-loop") == 0) { parms.loopcount = atol(argv[++i]); continue; }
        if (strcmp(argv[i], "-count") == 0) { parms.loopcount = atol(argv[++i]); continue; }

        printf("Unrecognized flag %s at %d\n", argv[i], i);
        goto syntax;
    }

    printf("Slot %ld\n", slotID);
    switch (testID) {
    /*
    case 0: // AES/ECC key wrapping
        printf("Curve switch %ld\n", parms.curvedes);
        break;
    */
    case 1: // hbs test suite - gen/get/sign/verify
        printf("As Token %ld\n", parms.astoken);
        break;
#ifdef IF_INCLUDE_LATTICE_IPD
    case 2: // dilithium test suite - gen/get/sign/verify
    case 3: // kyber test suite - gen/get/encap/decap
        printf("As Token %ld\n", parms.astoken);
        printf("Keytype %ld\n", parms.keytype);
        printf("Varient key type %ld\n", parms.r3_variant);
#endif
    case 4: // mldsa test suite gen/get/encap/decap
    case 5: // mlkem test suite gen/get/sign/verify
    case 6: // xmss test suite
    case 7: // hss
        printf("As Token %ld\n", parms.astoken);
        printf("Keytype %ld\n", parms.keytype);
        break;
    default:
        printf("Unrecognized -test option %ld\n", testID);
        printf("-test [id] or one of the -lms, -xmss, -mlkem, ...\n");
        // printf("  0 - key(EC) wrap(AES)\n");
        printf("  1 | -hbs\n");
#ifdef IF_INCLUDE_LATTICE_IPD
        printf("  2 | -dil\n");
        printf("  3 | -kyb\n");
#endif
        printf("  4 | -mldsa\n");
        printf("  5 | -mlkem\n");
        printf("  6 | -xmss\n");
        printf("  7 | -hss\n");
	printf("\n");
	help();
        return -1;
    }

    printf("Algorithm:  ");
    switch (algo) {
#ifdef IF_INCLUDE_LATTICE_IPD
    case DILITHIUM: printf("Dilithium\n"); break;
    case KYBER: printf("Kyber\n"); break;
#endif
    case XMSS: printf("XMSS\n"); break;
    case LMS: printf("LMS\n"); break;
    case HSS: printf("HSS\n"); break;
    case MLDSA:  printf("MLDSA\n");  break;
    case MLKEM: printf("MLKEM\n"); break;
    default:
        printf("Unrecognized.  Stopping.\n");
	help();
        return -1;
        break;
    }


#ifdef OSYS_win
    err = Initialize(&pFunctions, &hModule, lib);
#else
    err = Initialize(&pFunctions, &lib_handle, lib);
#endif
    if ((err != 0) || (pFunctions == NULL)) goto cleanup;
    printf("Token initialized.\n\n");


    // check for users on slot 0
    err = EnsureUserExistence(pFunctions, userPIN, slotID);
    if (err != 0) goto cleanup;

    // open session 
    err = pFunctions->C_OpenSession(slotID, CKF_SERIAL_SESSION | CKF_RW_SESSION, NULL, NULL, &hSession);
    if (err != CKR_OK)
    {
        printf("[main]: C_OpenSession returned 0x%08x\n", err);
        goto cleanup;
    }
    printf("\nOpened session on slot %lu.\n", slotID);

    // login as user
    err = pFunctions->C_Login(hSession, CKU_USER, (CK_UTF8CHAR_PTR)userPIN, lenUserPIN);
    if (err != CKR_OK)
    {
        printf("[main]: C_Login (USER) returned 0x%08x\n", err);
        goto cleanup;
    }
    printf("-> Normal user logged in.\n");
    //
    // END BOILERPLATE BEGIN
    //

    switch (algo) {
    case XMSS:
        err = test_case_xmss(pFunctions, hSession, &parms);
        break;
    case LMS:
        err = test_case_lms(pFunctions, hSession, &parms);
        break;
    case HSS:
        err = test_case_hss(pFunctions, hSession, &parms);
        break;
#ifdef IF_INCLUDE_LATTICE_IPD
    case DILITHIUM:
        err = test_case_dilithium(pFunctions, hSession, &parms);
        break;
    case KYBER:
        err = test_case_kyber(pFunctions, hSession, &parms);
        break;
#endif
    case MLDSA:
        err = test_case_mldsa(pFunctions, hSession, &parms);
        break;
    case MLKEM:
        err = test_case_mlkem(pFunctions, hSession, &parms);
        break;
    default:
        printf("whut.\n");
	help();
        return 0;
    }

    if (err != 0) {
        pkcs_err_text(err);
    }

    //
    // BEGIN BOILERPLATE END
    //
      // logout 
    err = pFunctions->C_Logout(hSession);
    if (err != CKR_OK)
    {
        printf("[main]: C_Logout (USER) returned 0x%08x\n", err);
        goto cleanup;
    }
    printf("-> Normal user logged out.\n");

    // close session
    err = pFunctions->C_CloseSession(hSession);
    if (err != CKR_OK)
    {
        printf("[main]: C_CloseSession returned 0x%08x\n", err);
        goto cleanup;
    }
    printf("Closed session on slot %lu.\n\n", slotID);

cleanup:

    if (err != 0 && hSession != CK_INVALID_HANDLE) {
        pFunctions->C_Logout(hSession);
        pFunctions->C_CloseSession(hSession);
    }
    if (pFunctions != NULL)
    {
        pFunctions->C_Finalize(NULL);
        printf("\nToken finalized.\n");
    }
#ifdef OSYS_win 
    if (hModule != NULL) FreeLibrary(hModule);
#else
    if (lib_handle != NULL) dlclose(lib_handle);
#endif

syntax:

    return err;
}

void help() {
    printf("\nSYNTAX : <exe> -lib <library path>\\cs_pkcs11_R2.dll <arguments>\n\nArguments: ");
    printf("\n       : -p[in] <pin> (default: 12345678)");
    printf("\n       : -s[lot] <slot> (default: 0)");
    printf("\n       : -lms");
    printf("\n       : -dil -keytype [1|3|5] [-gen]* [-sign]* [-verify]* [-signverify] -count <n>");
    printf("\n       : -kyber -keytype [1|3|5] [-gen]* [-encap]* [-decap]* [-encapdecap] -count <n>");
    printf("\n       : -mldsa -keytype [2|3|5] [-gen]* [-sign]* [-verify]* [-signverify] -count <n>");
    printf("\n       : -mlkem -keytype [2|3|5] [-gen]* [-encap]* [-decap]* [-encapdecap] -count <n>");
    printf("\n");
    // add -hss -t 4 
    // add -xmss -t 5
    printf("\n       : -all - selects the flags marked by * for the selected algorithm. ");
    printf("\n");
}
//
// END BOILERPLATE END
//


