package com.utimaco.aut;

import java.nio.ByteBuffer;
import java.nio.charset.Charset;

import com.utimaco.SimpleArgs;
import com.utimaco.cs2.mdl.ErrList;
import com.utimaco.cs2.mdl.SdkBuffer;
import com.utimaco.cs2.mdl.SerializationError;
import com.utimaco.cs2.mdl.any.CxiKeyAttributes;
import com.utimaco.cs2.mdl.pqmi.MLKEM_Decap;
import com.utimaco.cs2.mdl.pqmi.MLKEM_Encap;
import com.utimaco.cs2.mdl.pqmi.MLKEM_KeyGen;
import com.utimaco.cs2.mdl.pqmi.MLKEM_Response;
import com.utimaco.cs2.mdl.pqmi.PqmiKeyStore;
import com.utimaco.cs2.mdl.pqmi.Pqmi;

import CryptoServerAPI.CryptoServerException;
import CryptoServerAPI.CryptoServerUtil;
import CryptoServerCXI.CryptoServerCXI;

public class MLKEMTest implements AutTest {
	static int module_id = 0xa6; // 

	static final int sfc_id_gen_kyb = Pqmi.SFC_KYBER_KEYGEN;
	static final int sfc_id_enc_kyb = Pqmi.SFC_KYBER_ENC;
	static final int sfc_id_dec_kyb = Pqmi.SFC_KYBER_DEC;

	static final int sfc_id_keystore = Pqmi.SFC_KEYSTORE; // until you change it.

	CryptoServerCXI cxi;
	SimpleArgs args;
	boolean ctors;
	int rnd_touse;
	int sto_touse;
	boolean asynch;
	boolean genkeyonly;

	public MLKEMTest(CryptoServerCXI cxi, SimpleArgs cla) {
		this(cxi, cla, false);
	}
	public MLKEMTest(CryptoServerCXI cxi, SimpleArgs cla, boolean asynch) {
		this(cxi, cla, asynch, false);
	}
	public MLKEMTest(CryptoServerCXI cxi, SimpleArgs cla, boolean asynch, boolean genkey) {
		this.cxi = cxi;
		this.args = cla;
		this.ctors = false;
		this.rnd_touse = cla.hasArg("drbg:trng") ? Pqmi.MODE_REAL_RND : Pqmi.MODE_PSEUDO_RND;
		this.sto_touse = cla.hasArg("intks") ? 0 : Pqmi.KEY_EXTERNAL;
		this.sto_touse |= cla.hasArg("overwrite") ? Pqmi.KEY_OVERWRITE : 0;
		this.asynch = asynch;
		this.genkeyonly = genkey;
		if (genkeyonly) {
			this.sto_touse = Pqmi.KEY_OVERWRITE;
		}

		module_id = AutTest.getModuleId(cxi, "PQMI");

	}

	public boolean go() {
		boolean ret = true;
		boolean ret2 = false;

		if (!args.hasArg("suite")) {
			int keytype = Pqmi.ML_KEM_KT_FLAG;
			if (args.hasArg("keytype")) {
				int kt = args.getArg("keytype", 0);
				switch (kt) {
				case 2: 
				case 512:
					keytype |= Pqmi.ML_KEM_KT_512; 
					break;
				case 3: 
				case 768:
					keytype |= Pqmi.ML_KEM_KT_768; 
					break;
				case 5: 
				case 1024:
					keytype |= Pqmi.ML_KEM_KT_1024; 
					break;
				}
			}
			if (keytype == Pqmi.ML_KEM_KT_FLAG) {
				System.err.println("-keytype [2|3|5] (security level) or -keytype [512|768|1024] (key size in bits) required for -mlkem");
				System.exit(-1);
			}

			AutTest.announce(String.format("MLKEM-PSEU-EXT-%d-%d", keytype, 0));
			ret = genencdec(false, Pqmi.MODE_PSEUDO_RND, Pqmi.KEY_EXTERNAL, keytype, 0);

		} else {
			int [] keytypes = {2, 3, 5};
			for (int kt = 0; kt < keytypes.length; kt++) {
				int keytype = keytypes[kt];
				switch (keytype) {
				case 2:
					keytype = Pqmi.ML_KEM_KT_512;
					break;
				case 3:
					keytype = Pqmi.ML_KEM_KT_768;
					break;
				case 5:
					keytype = Pqmi.ML_KEM_KT_1024;
					break;
				}
				keytype |= Pqmi.ML_KEM_KT_FLAG;
				
				AutTest.announce(String.format("MLKEM-PSEU-EXT-%d-%d", keytype, kt));
				ret2 = genencdec(false, Pqmi.MODE_PSEUDO_RND, Pqmi.KEY_EXTERNAL, keytype, 0);
				AutTest.announce(ret2 ? "Complete\n" : "Failed\n");
				if (!ret2) { ret = false; continue; }

				if (args.hasArg("full")) {
					AutTest.announce(String.format("MLKEM-DRBG-EXT-%d-%d", keytype, kt));
					ret2 = genencdec(false, Pqmi.MODE_REAL_RND, Pqmi.KEY_EXTERNAL, keytype, 0);
					AutTest.announce(ret2 ? "Complete\n" : "Failed\n");
					if (!ret2) ret = false;

					AutTest.announce(String.format("MLKEM-PSEU-INT-%d-%d", keytype, kt));
					ret2 = genencdec(false, Pqmi.MODE_PSEUDO_RND, Pqmi.KEY_OVERWRITE, keytype, 0);
					AutTest.announce(ret2 ? "Complete\n" : "Failed\n");
					if (!ret2) ret = false;

					AutTest.announce(String.format("MLKEM-DRBG-INT-%d-%d", keytype, kt));
					ret2 = genencdec(false, Pqmi.MODE_REAL_RND, Pqmi.KEY_OVERWRITE, keytype, 0);
					AutTest.announce(ret2 ? "Complete\n" : "Failed\n");
					if (!ret2) ret = false;
				}
			}
		}
		return ret;
	}

	private byte [] getPublicKey(byte[] aliceprivatekey, boolean noheader) {
		ByteBuffer bb = SdkBuffer.bbseek(aliceprivatekey, "PK".getBytes(Charset.forName("US-ASCII")), 0);
		if (bb == null) {
			System.err.println("Internal err - getPublicKey didn't find a public key in the response\n");
			return null;
		}
		byte check = bb.get();
		if (check != (byte)'P') {
			return null;
		}
		check = bb.get();
		if (check != (byte)'K') {
			return null;
		}
		int l_pk = bb.getInt(); // len of PK
		byte [] retbuf = new byte[l_pk];
		bb.get(retbuf);
		if (noheader) return retbuf;
		byte [] retbuf2 = new byte[6+l_pk];
		ByteBuffer bb2 = ByteBuffer.wrap(retbuf2);
		bb2.put((byte)'P');
		bb2.put((byte)'K');
		bb2.putInt(l_pk);
		bb2.put(retbuf);
		return bb2.array();
	}

	/**
	 * 
	 * @param args
	 * @param cxi
	 * @param ctors
	 * @return
	 */
	boolean genencdec (boolean ctors, int rnd, int sto, int keytype, int spec) { 
		boolean ret = true;
		MLKEM_KeyGen tobj_ctor = null;
		boolean use_ctor = ctors;

		int mlkemkeytype = keytype;

		CxiKeyAttributes attributes = new CxiKeyAttributes(); // Field primitive is typed Interface

		// CxiKeyAttributes - typed field attributes
		int flags = 0x0;
		byte [] name = "placeholder".getBytes();
		byte [] group = "testseries".getBytes();

		int algo = CryptoServerCXI.KEY_ALGO_RAW;
		int usage = CryptoServerCXI.KEY_USAGE_KEY_WRAP;
		int export = 0; // ALLOW_BACKUP only, no export wrapped

		if (use_ctor) {
			attributes = new CxiKeyAttributes (
					flags,
					name,
					group,
					spec,
					algo,
					usage,
					export,
					new byte[0]
					);
		} else {
			attributes = new CxiKeyAttributes (cxi);
			attributes.flags (flags);
			attributes.name (name);
			attributes.group (group);
			attributes.spec (spec);
			attributes.algo (algo);
			attributes.usage (usage);
			attributes.export (export);
		}

		int outerflags = rnd | sto | (sto == Pqmi.KEY_EXTERNAL ? 0 : Pqmi.KEY_OVERWRITE);


		System.out.format("Keytype is %d (0x%02x), flags is 0x%08x\n", mlkemkeytype, keytype, outerflags);

		try {
			if (use_ctor) {
				tobj_ctor = new MLKEM_KeyGen (
						outerflags,
						mlkemkeytype,
						attributes,
						new byte[0]
						);
			} else {
				tobj_ctor = new MLKEM_KeyGen (cxi);
				tobj_ctor.flags (outerflags);
				tobj_ctor.type (mlkemkeytype);
				tobj_ctor.attributes (attributes);
			}
		} catch (SerializationError ser) {
			System.out.println("Serialization error on typed parameter in KyberGen.\n");
			return false;
		}

		/* ************************************* */
		/* 1. Alice generates her key pair, Bob generates his key pair, Alice sends her public key to Bob. 
		 * 2. Bob uses Alice's public key to create a response, which includes his shared key
		 * 3. Alice uses Bob's response and her secret key to get her shared key
		 */
		PqmiKeyStore AliceKey = null;
		byte[] aliceprivatekey = null;
		byte[] alicepublickeyTrue = null; // both of these should work, it just exercises slightly
		byte[] alicepublickeyFalse = null; // different code paths.
		
		if (ret) { 
			System.out.print("GEN STAGE:      ");
			ret = false;

			try {
				attributes.name("Alice"); // set the name
				tobj_ctor.attributes(attributes); // serialize the data

				// Generate Alice's key
				boolean res = tobj_ctor.exec(cxi, module_id, sfc_id_gen_kyb);

				System.out.println(res ? "PASSED" : "FAILED"); 

				if (sto == Pqmi.KEY_EXTERNAL) {
					aliceprivatekey = new byte[tobj_ctor.resp.length];
					System.arraycopy(tobj_ctor.resp, 0, aliceprivatekey, 0, tobj_ctor.resp.length);
					alicepublickeyTrue = getPublicKey(aliceprivatekey, true); // different code paths
					alicepublickeyFalse = getPublicKey(aliceprivatekey, false);
					
				} else {
					System.out.print("FIND STAGE:    ");
					// get the public key, and the private key 
					AliceKey = new PqmiKeyStore(Pqmi.KEY_GET_PUBLIC_KEY, mlkemkeytype, attributes);
					res = AliceKey.exec(cxi, module_id, sfc_id_keystore);
					System.out.print(res ? " [1]PASSED" : " [1]FAILED"); 

					alicepublickeyTrue = getPublicKey(AliceKey.resp, true); // true = return <buf>, false return PK<len><buf>
					alicepublickeyFalse = getPublicKey(AliceKey.resp, false);
					
					AliceKey = new PqmiKeyStore(Pqmi.KEY_LOAD_FROM_STORE, mlkemkeytype, attributes);
					res = AliceKey.exec(cxi, module_id, sfc_id_keystore);
					aliceprivatekey = new byte[AliceKey.resp.length];
					System.arraycopy(AliceKey.resp, 0, aliceprivatekey, 0, AliceKey.resp.length);
					System.out.println(res ? " [2]PASSED" : " [2]FAILED"); 

				}

				ret = true;
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
					System.err.print(e.getLocalizedMessage());
				} else {
					System.err.println(ErrList.errtext(e.ErrorCode));
				}
				return false;
			} catch (Exception e) {
				e.printStackTrace();
				return false;
			}
		}
		
		if (genkeyonly) {
			return true;
		}

		// Alice sends her public key to Bob, and Bob runs this command:
		MLKEM_Encap tobj_kemenc = new MLKEM_Encap(rnd, mlkemkeytype, alicepublickeyTrue);
		MLKEM_Response tobj_encResponse = null;

		if (ret) {
			System.out.print("ENCAP STAGE:    ");
			ret = false;

			// Pass 1 - Bob with Alices's public key
			try {
				boolean res = tobj_kemenc.exec(cxi, module_id, sfc_id_enc_kyb);
				System.out.println(res ? "PASSED" : "FAILED"); 
				tobj_encResponse = new MLKEM_Response(tobj_kemenc.resp);
				// can also work with alicepublickeyTrue
				// the difference is 'true' is just a raw public key,
				// 'false' is KP<len><rawkey>
				tobj_kemenc.publickey(alicepublickeyFalse);
				// tobj_kemenc.secret is Bob's generated secret, and .cyphertext must be sent
				// back to Alice so she can use her private key to get the generated secret.
				
				ret = true;
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
					System.err.print(e.getLocalizedMessage());
				} else {
					System.err.println(ErrList.errtext(e.ErrorCode));
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		ret = (tobj_encResponse != null);

		MLKEM_Decap tobj_kemdec = null;
		MLKEM_Response tobj_decResponse = null;

		if (ret) {
			System.out.print("DECAP STAGE:    ");

			try {
				// Bob has his copy of the shared secret -- in tobj_encResponse.secret();
				// He sends the tobj_encResponse.cyphertext() back to Alice.
				// Alice runs this command, with her private key and Bob's cyphertext
				tobj_kemdec = new MLKEM_Decap(rnd, mlkemkeytype, aliceprivatekey, tobj_encResponse.cyphertext());

				ret = false;

				// Pass 2 - Alice with Alice's private key and Bob's cyphertext
				boolean res = tobj_kemdec.exec(cxi, module_id, sfc_id_dec_kyb);

				System.out.println(res ? "PASSED" : "FAILED"); 
				tobj_decResponse = new MLKEM_Response(tobj_kemdec.resp);

				ret = true;
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
					System.err.print(e.getLocalizedMessage());
				} else {
					System.err.println(ErrList.errtext(e.ErrorCode));
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		if (ret) {
			// for now assumes plain text secrets
			// KB and KH not yet enabled.
			System.out.print("COMPARE STAGE:  ");

			byte [] a = tobj_encResponse.secret();
			byte [] b = tobj_decResponse.secret();
			CryptoServerUtil.xtrace("ENC SEC", a);
			CryptoServerUtil.xtrace("DEC SEC", b);
			if (a.length != b.length) {
				System.out.println("Lengths differ - failed\n");
				return false;
			}
			for (int i = 0; i < a.length; i++) {
				if (a[i] != b[i]) {
					System.out.println("Secrets differ - failed\n");
					return false;
				}
			}
			System.out.println("Compared correctly");
		}


		return ret;
	}


}
