package com.utimaco.bench;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicBoolean;

import com.utimaco.Benchmarks;
import com.utimaco.cs2.mdl.ErrList;
import com.utimaco.cs2.mdl.SdkBuffer;
import com.utimaco.cs2.mdl.SerializationError;
import com.utimaco.cs2.mdl.Utils;
import com.utimaco.cs2.mdl.any.CxiKeyAttributes;
import com.utimaco.cs2.mdl.pqmi.MLDSA_KeyGen;
import com.utimaco.cs2.mdl.pqmi.MLDSA_Sign;
import com.utimaco.cs2.mdl.pqmi.MLDSA_Verify;
import com.utimaco.cs2.mdl.pqmi.PqmiKeyStore;
import com.utimaco.cs2.mdl.pqmi.Pqmi;

import CryptoServerAPI.CryptoServerException;
import CryptoServerCXI.CryptoServerCXI;

public class MLDSABench extends Bench {
	
	/* MLDDSA HELPERS */
	/**
	 * kt should be supplied as security parameter set level, ie 2, 3 or 5
	 * 
	 * @param kt
	 * @param aes
	 * @return
	 */
	public static int getMldsaKeytype(String [] args) {
		int kt = Benchmarks.getArg(args, "keytype", -1);
		if (kt == -1) kt = Benchmarks.getArg(args, "kt", -1);

		int keytype = Pqmi.ML_DSA_KT_FLAG;

		switch (kt) {
		case 44:
		case 2: keytype |= Pqmi.ML_DSA_KT_44; break;
		case 65:
		case 3: keytype |= Pqmi.ML_DSA_KT_65; break;
		case 87:
		case 5: keytype |= Pqmi.ML_DSA_KT_87; break;
		default:
			System.out.printf("Invalid MLDSA key type, please supply -kt 2 3 or 5 (Security Levels)");
			System.out.printf("or -kt 44 65 or 87 (Lattice dimensions, ie 4x4 6x5 and 8x7)");
			System.exit(-1);
		}

		return keytype;
	}
	
	public static void announce(boolean rng, int keytype) {
		String keyt = "";
		int kt = keytype & ~Pqmi.ML_DSA_KT_FLAG;

		switch (kt) {
		case Pqmi.ML_DSA_KT_44: keyt="4X4"; break;
		case Pqmi.ML_DSA_KT_65: keyt="6X5"; break;
		case Pqmi.ML_DSA_KT_87: keyt="8X7"; break;
		default:
			System.out.printf("Invalid MLDSA key type.");
			System.exit(-1);
		}
		announce(String.format("MLDSA-%s-EXT-%s", 
				(rng ? "REAL" : "PSEU"), 
				keyt));
	}
	
	public static AtomicBoolean mldsa_verifying (String [] args, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		AtomicBoolean outerdone = new AtomicBoolean(false);
		
		class Runt implements Runnable {
			MLDSA_KeyGen gner = new MLDSA_KeyGen();
			MLDSA_Sign signer = new MLDSA_Sign();
			MLDSA_Verify verfer = new MLDSA_Verify();
			
			AtomicBoolean done = new AtomicBoolean(false);
			double duration = 0.0;
			int counter = 0;
		
			CryptoServerCXI cxi = null;
			Runt (CryptoServerCXI cxi, int instance) throws SerializationError {
				this.cxi = cxi; 
				gner = mldsa_generate_prep(args, cxi, instance);
				System.out.println("Generating key...");
				mldsa_generate(cxi, gner);
				int flags = Pqmi.DILITHIUM_MODE_SIG_RAW;
				
				byte [] msga = getMsg(args);
				if (msga == null) {
					System.out.format("Expected a -msg <input> or -msg <hexenc> or -msg <b64enc> or -artifact <filename>\n");
					return;
				}

				signer.flags(flags);
				verfer.flags(flags);
				signer.ctxt(new byte[0]);
				verfer.ctxt(new byte[0]);
				signer.key(gner.resp);
				verfer.key(gner.resp);
				signer.type(getMldsaKeytype(args));
				verfer.type(getMldsaKeytype(args));
				signer.msg(msga);
				verfer.msg(msga);
				
			}
			
			@Override
			public void run() {
				try {
					signer.exec(cxi, module_id, sfc_id_sign_dil);
					verfer.sig(signer.resp);
				} catch (IOException e1) {
					e1.printStackTrace();
					return;
				} catch (CryptoServerException e1) {
					e1.printStackTrace();
					return;
				}
				
				Instant start = timehack();
				try {
					for (int i = 0; i < countperthread; i++) {
						counter++;
						verfer.exec(cxi, module_id, sfc_id_verf_dil);
					}
				} catch (CryptoServerException e) {
					System.err.println(e);
				} catch (Exception e) {
					e.printStackTrace();
				}
				Instant end = timehack();
				duration = report(args, start, end, countperthread, counter, "mdlsa_verify");
				done.set(true);
			}
			public void go() { Benchmarks.xekr.execute(this); }
			public boolean isDone() { return done.get(); }
			public double getTPS() { 
				if (duration == 0) return 0.0; else return counter / duration; 
			}
		}

		int w_cxi = 0;
		ArrayList<Runt> interns = new ArrayList<>();
		for (int i = 0; i < threads; i++) {
			CryptoServerCXI tcxi = cxi.get(w_cxi++ % cxi.size());
			try {
				Runt r = new Runt(tcxi, i);
				interns.add(r);
				r.go();
			} catch (SerializationError se) {
				
			}
		}

		outerdone.set(false);
		do {
			boolean chk = true;
			try { Thread.sleep(1000); } catch (InterruptedException ie) { };
			for (int i = 0; i < interns.size(); i++) {
				if (!interns.get(i).isDone()) {
					chk = false;
					break;
				}
			}
			outerdone.set(chk);
		} while (outerdone.get() == false);

		double tps = 0.0;
		for (int i = 0; i < interns.size(); i++) {
			tps += interns.get(i).getTPS();
		}

		int kt = getMldsaKeytype(args);
		kt &= ~Pqmi.ML_DSA_KT_FLAG;
		
		System.out.format("MLDSA VERF KT:%d Th:%d C:%d Aggregate TPS %.04f\n\n", kt, threads, countperthread, tps);
		return outerdone;
	}

	public static AtomicBoolean mldsa_signing (String [] args, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		AtomicBoolean outerdone = new AtomicBoolean(false);
		
		class Runt implements Runnable {
			MLDSA_KeyGen gner;
			MLDSA_Sign signer = new MLDSA_Sign();
			
			AtomicBoolean done = new AtomicBoolean(false);
			double duration = 0.0;
			int counter = 0;
		
			CryptoServerCXI cxi = null;
			Runt (CryptoServerCXI cxi, int instance) throws SerializationError {
				this.cxi = cxi; 
				gner = mldsa_generate_prep(args, cxi, instance);
				System.out.println("Generating key...");
				mldsa_generate(cxi, gner);
				int flags = Pqmi.DILITHIUM_MODE_SIG_RAW;
				
				byte [] msga = getMsg(args);
				if (msga == null) {
					System.out.format("Expected a -msg <input> or -msg <hexenc> or -msg <b64enc> or -artifact <filename>\n");
					return;
				}

				signer.flags(flags);
				signer.ctxt(new byte[0]);
				signer.key(gner.resp);
				signer.type(getMldsaKeytype(args));
				signer.msg(msga);
				
			}
			
			@Override
			public void run() {
				byte[] serialized;
				try {
					serialized = signer.serialize();
				} catch (SerializationError e) {
					e.printStackTrace();
					return;
				}

				Instant start = timehack();
				try {
					for (int i = 0; i < countperthread; i++) {
						counter++;
						cxi.exec(module_id, sfc_id_sign_dil, serialized);
					}
				} catch (CryptoServerException e) {
					System.err.println(e);
				} catch (Exception e) {
					e.printStackTrace();
				}
				Instant end = timehack();
				duration = report(args, start, end, countperthread, counter, "mdlsa_signing");
				done.set(true);
			}
			public void go() { Benchmarks.xekr.execute(this); }
			public boolean isDone() { return done.get(); }
			public double getTPS() { 
				if (duration == 0) return 0.0; else return counter / duration; 
			}
		}

		int w_cxi = 0;
		ArrayList<Runt> interns = new ArrayList<>();
		for (int i = 0; i < threads; i++) {
			CryptoServerCXI tcxi = cxi.get(w_cxi++ % cxi.size());
			try {
				Runt r = new Runt(tcxi, i);
				interns.add(r);
				r.go();
			} catch (SerializationError se) {
				
			}
		}

		outerdone.set(false);
		do {
			boolean chk = true;
			try { Thread.sleep(1000); } catch (InterruptedException ie) { };
			for (int i = 0; i < interns.size(); i++) {
				if (!interns.get(i).isDone()) {
					chk = false;
					break;
				}
			}
			outerdone.set(chk);
		} while (outerdone.get() == false);

		double tps = 0.0;
		for (int i = 0; i < interns.size(); i++) {
			tps += interns.get(i).getTPS();
		}

		int kt = getMldsaKeytype(args);
		kt &= ~Pqmi.ML_DSA_KT_FLAG;
		
		System.out.format("MLDSA SIGN KT:%d Th:%d C:%d Aggregate TPS %.04f\n\n", kt, threads, countperthread, tps);
		return outerdone;
	}

	// 
	public static AtomicBoolean mldsa_genkey (String [] args, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		AtomicBoolean outerdone = new AtomicBoolean(false);
		class Runt implements Runnable {
			AtomicBoolean done = new AtomicBoolean(false);
			double duration = 0.0;
			int counter = 0;
			byte [] gner = null;
			CryptoServerCXI cxi = null;
			Runt (CryptoServerCXI cxi, int instance) throws SerializationError {
				this.cxi = cxi; 
				this.gner = mldsa_generate_prep(args, cxi, instance).serialize();
			}
			@Override
			public void run() {
				Instant start = timehack();
				try {
					for (int i = 0; i < countperthread; i++) {
						counter++;
						cxi.exec(module_id, sfc_id_gen_dil, gner);
					}
				} catch (CryptoServerException e) {
					System.err.println(e);
				} catch (Exception e) {
					e.printStackTrace();
				}
				Instant end = timehack();
				duration = report(args, start, end, countperthread, counter, "mdlsa_genkey");
				done.set(true);
			}
			public void go() { Benchmarks.xekr.execute(this); }
			public boolean isDone() { return done.get(); }
			public double getTPS() { 
				if (duration == 0) return 0.0; else return counter / duration; 
			}
		}

		int w_cxi = 0;
		ArrayList<Runt> interns = new ArrayList<>();
		for (int i = 0; i < threads; i++) {
			CryptoServerCXI tcxi = cxi.get(w_cxi++ % cxi.size());
			try {
				Runt r = new Runt(tcxi, i);
				interns.add(r);
				r.go();
			} catch (SerializationError se) {
				
			}
		}

		outerdone.set(false);
		do {
			boolean chk = true;
			try { Thread.sleep(1000); } catch (InterruptedException ie) { };
			for (int i = 0; i < interns.size(); i++) {
				if (!interns.get(i).isDone()) {
					chk = false;
					break;
				}
			}
			outerdone.set(chk);
		} while (outerdone.get() == false);

		double tps = 0.0;
		for (int i = 0; i < interns.size(); i++) {
			tps += interns.get(i).getTPS();
		}

		int kt = getMldsaKeytype(args);
		kt &= ~Pqmi.ML_DSA_KT_FLAG;
		
		System.out.format("MLDSA GENK KT:%d Th:%d C:%d Aggregate %d TPS %.04f\n\n", kt, threads, countperthread, threads * countperthread, tps);

		return outerdone;
	}

	static private MLDSA_KeyGen mldsa_generate_prep(String [] args, CryptoServerCXI cxi, int instance) {
		MLDSA_KeyGen tobj_ctor = null;
		CxiKeyAttributes attributes = new CxiKeyAttributes(); // Field primitive is typed Interface
		int flags = 0x0;
		byte [] name = Benchmarks.getArg(args, "kname", "mldsa").getBytes();
		byte [] group = Benchmarks.getArg(args, "kgrp", "testseries").getBytes();
		int spec = instance;
		int algo = CryptoServerCXI.KEY_ALGO_RAW;
		int usage = CryptoServerCXI.KEY_USAGE_SIGN | CryptoServerCXI.KEY_USAGE_VERIFY;
		int export = 0; // ALLOW_BACKUP only, no export wrapped
		boolean kti = (Benchmarks.hasArg(args, "ksinternal"));
		boolean ovw = (Benchmarks.hasArg(args, "ksoverwrite"));
		int rnd = (Benchmarks.hasArg(args, "rng") ? Pqmi.MODE_REAL_RND : Pqmi.MODE_PSEUDO_RND);
		int sto = (kti ? (ovw ? Pqmi.KEY_OVERWRITE : 0) : Pqmi.KEY_EXTERNAL);
		attributes = new CxiKeyAttributes (flags, name, group, spec, algo, usage, export, new byte[0]);
		int outerflags = rnd | sto | (sto == Pqmi.KEY_EXTERNAL ? 0 : ovw ? Pqmi.KEY_OVERWRITE : 0);
		int kt = getMldsaKeytype(args);
		try {
			tobj_ctor = new MLDSA_KeyGen (outerflags, kt, attributes, new byte[0]);
		} catch (SerializationError ser) {
			System.out.println("Serialization error on typed parameter in MLDSA_KeyGen.\n");
			return null;
		}
		return tobj_ctor;
	}
	
	static boolean mldsa_generate (CryptoServerCXI cxi, MLDSA_KeyGen tobj_ctor) {
		boolean ret = false;
		try {
			ret = tobj_ctor.exec(cxi, module_id, sfc_id_gen_dil);
		} catch (CryptoServerException e) {
			System.err.format("Error 0x%08X\n", e.ErrorCode);
			if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
				System.err.print(e.getLocalizedMessage());
			} else {
				System.err.println(ErrList.errtext(e.ErrorCode));
			}
			return false;
		} catch (Exception e) {
			e.printStackTrace();
			return false;
		}
		return ret;
	}

	boolean mldsa_gensignverify (String [] args, CryptoServerCXI cxi, boolean ctors, int rnd, int sto, int keytype, int spec) { 
		boolean ret = true;

		// This is the "HSM-aware POJO" used to generate a key.
		MLDSA_KeyGen tobj_ctor = null;
		// There are two constructors below, one uses the POJO's full constructor (true) or just calls setters on the empty
		// object above.
		boolean use_ctor = ctors;

		// One of the fields of the generator object is a "packed" CXI Key Attributes template.
		// This is the template, it must be serialized for use, however the MLDSA_KeyGen POJO 
		// setter can take the object, and will call the object's serializer to get the serialized
		// buffer.  NOTE: This happens on the call, subsequent changes to the attributes template
		// have NO affect on anything that serialized it previously.
		CxiKeyAttributes attributes = new CxiKeyAttributes(); // Field primitive is typed Interface

		// CxiKeyAttributes - typed field attributes
		int flags = 0x0;
		// CXI stores its keys, indexed by the MD5 of the key {NAME||GROUP||SPEC} (|| here means 
		// 'concatenated with', so NAME concatenated with GROUP etc).
		byte [] name = "mldsa".getBytes();
		byte [] group = "testseries".getBytes();

		// If you need a full 'CryptoServerCXI.KeyAttributes()' payload, you can stash it in 
		// the CxiKeyAttributes mdata field.  This is treated as raw, user-relevant info, the
		// module itself makes no use of it, other than storing it for retrieval later.
		/*
		 CryptoServerCXI.KeyAttributes katt = new CryptoServerCXI.KeyAttributes();
		 ... populate katt using its setters, then
		 
 		 attributes.mdata(katt.toByteArray());

		 */

		int algo = CryptoServerCXI.KEY_ALGO_RAW;
		int usage = CryptoServerCXI.KEY_USAGE_SIGN | CryptoServerCXI.KEY_USAGE_VERIFY;
		int export = 0; // ALLOW_BACKUP only, no export wrapped

		if (use_ctor) {
			attributes = new CxiKeyAttributes (
					flags,
					name,
					group,
					spec,
					algo,
					usage,
					export, 
					new byte[0]
					);
		} else {
			attributes = new CxiKeyAttributes (cxi);
			attributes.flags (flags);
			attributes.name (name);
			attributes.group (group);
			attributes.spec (spec);
			attributes.algo (algo);
			attributes.usage (usage);
			attributes.export (export);
		}

		// Seed is ignored.  Theoretically seed is used as a seed into 
		// the HSM's random number generator.  In CAVP automated tests, it's used as the actual seed
		// into the hashing algorithm used to elaborate the key -- ie it results in a deterministic
		// public/private key pair.
		// 
		byte [] seed = new byte[0]; // ignored

		// rnd -- whether use the DRBG TRNG or use the PRNG (only)
		//     -- Pqmi.MODE_REAL_RND or Pqmi.MODE_PSEUDO_RND
		// sto -- whether to store the key internal on the HSM or return a key blob
		//     -- if there is already an internal key with this {NAME||GROUP||Spec}, KEY_OVERWRITE
		//     -- will cause it to be overwritten.
		boolean ovw = (Benchmarks.hasArg(args, "ksoverwrite"));
		int outerflags = rnd | sto | (sto == Pqmi.KEY_EXTERNAL ? 0 : ovw ? Pqmi.KEY_OVERWRITE : 0);

		System.out.format("Keytype is %d, flags is 0x%08x\n", keytype, outerflags);

		try {
			if (use_ctor) {
				tobj_ctor = new MLDSA_KeyGen (
						outerflags,     // the rnd/sto flags above
						keytype,		// One of the MLDSA keytypes, ie Pqmi.ML_DSA_KT_65 etc
						attributes,     // The attributes object from above
						seed            // ignored
						);
			} else {
				tobj_ctor = new MLDSA_KeyGen (cxi);
				tobj_ctor.flags (outerflags);
				tobj_ctor.type (keytype);
				tobj_ctor.attributes (attributes);
				tobj_ctor.seed (seed);
			}
		} catch (SerializationError ser) {
			System.out.println("Serialization error on typed parameter in MLDSA_KeyGen.\n");
			return false;
		}

		/* ************************************* */
		if (ret) { 
			System.out.println("GEN STAGE");
			ret = false;
			try {
				System.out.println(tobj_ctor.toString());
				// Each of the "HSM-Aware POJO" classes can be "executed".  When one of the class 'exec' methods
				// are called, the class instance will serialize itself, and then call the HSM using the cxi 
				// instance in the exec call (other exec calls use a stored cxi).  In it's simplest form,
				// with predefined module_id and sub-function code ID, and a stored CXI, the class supports
				// foo.exec(), but we use a passed-in cxi, mid and sfc id here for clarity.
				boolean res = tobj_ctor.exec(cxi, module_id, sfc_id_gen_dil);
				System.out.println(res ? "PASSED" : "FAILED"); 
				// When the call succeeds, any result byte-buffer is placed into the object's .resp field.
				// A suitable "HSM-Aware POJO" class can deserialize these responses into a class 
				// instance, when the system is set up that way.  
				// If you selected an external key, the .resp is the key blob.  For an internal key,
				// it's a key handle.
				System.out.println(SdkBuffer.xtrace("Result: ", tobj_ctor.resp));
				ret = true;
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
					System.err.print(e.getLocalizedMessage());
				} else {
					System.err.println(ErrList.errtext(e.ErrorCode));
				}
				return false;
			} catch (Exception e) {
				e.printStackTrace();
				return false;
			}
		}

		PqmiKeyStore key = null;
		if (sto != Pqmi.KEY_EXTERNAL) {
			// test getkey\
			if (ret) { 
				System.out.println("FIND STAGE");

				try {
					// If you are familiar with PKCS11, this is like C_FindObjectsInit/FindObjects/FindObjectsFinal
					key = new PqmiKeyStore(Pqmi.KEY_LOAD_FROM_STORE, keytype, attributes);
					System.out.println(key.toString());
					boolean res = key.exec(cxi, module_id, sfc_id_keystore);
					System.out.println(res ? "PASSED" : "FAILED"); 
					System.out.println(SdkBuffer.xtrace("Result: ", tobj_ctor.resp));
					ret = true;
				} catch (CryptoServerException e) {
					System.err.format("Error 0x%08X\n", e.ErrorCode);
					if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
						System.err.print(e.getLocalizedMessage());
					} else {
						System.err.println(ErrList.errtext(e.ErrorCode));
					}
					return false;
				} catch (Exception e) {
					e.printStackTrace();
					return false;
				}
			}
		} 


		MLDSA_Sign tobj_signer = new MLDSA_Sign();
		MLDSA_Verify tobj_verfer = new MLDSA_Verify();
		if (ret) {
			System.out.println("SIGN STAGE");
			ret = false;
			tobj_signer.flags(Pqmi.DILITHIUM_MODE_SIG_RAW | Pqmi.MODE_PSEUDO_RND);
			if (sto == Pqmi.KEY_EXTERNAL) { 	
				// if it was an external key blob, it's contained in the Gen object's .resp field.
				tobj_signer.key(tobj_ctor.resp);
			} else {
				// it's an internal key, so we retrieved it above.
				tobj_signer.key(key.resp);
				// pro tip:  tobj_sv.key(tobj_ctor.resp) would work here also -- the value contained in that
				// object's .resp is a key handle as used by CXI internally, so it is still seen as a valid
				// key!
			}
			tobj_signer.msg("Test Message");
			if (Benchmarks.hasArg(args, "-msg")) {
				tobj_signer.msg(Benchmarks.getArg(args, "-msg", tobj_signer.msg()));
			}
			if (Benchmarks.hasArg(args, "-obj")) {
				String fname = Benchmarks.getArg(args, "-obj", (String)null);
				if (fname != null) {
					byte[] data = null;
					try {
						data = Utils.readFileAsBytes(fname);
						if (data.length > 250*1024) {
							System.err.format("No can do: Limit is one 256Kb (with overhead) size with this release.\n");
							return false;
						}
					} catch (IOException e) {
						// TODO Auto-generated catch block
						e.printStackTrace();
					}
					if (data == null) { data = "Test Message".getBytes(); }
					tobj_signer.msg(data);
				}
			}
			tobj_signer.type(tobj_ctor.type());  // just grabbing the keytype out of the Gen object.
			try {
				System.out.println(tobj_signer.toString());
				boolean res = tobj_signer.exec(cxi, module_id, sfc_id_sign_dil);
				System.out.println(res ? "PASSED" : "FAILED"); 
				// The sv object's .resp is the signature.
				System.out.println(SdkBuffer.xtrace("Result: ", tobj_signer.resp));
				ret = true;
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
					System.err.print(e.getLocalizedMessage());
				} else {
					System.err.println(ErrList.errtext(e.ErrorCode));
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}

		if (ret) {
			System.out.println("VERIFY STAGE");
			ret = false;
			// same flags et al as the sign op, but move the sign op's resp into the signature
			tobj_verfer.ctxt(tobj_signer.ctxt());
			tobj_verfer.flags(tobj_signer.flags());
			tobj_verfer.msg(tobj_signer.msg());
			tobj_verfer.state(tobj_signer.state());
			// The key is either a full blob or a key handle, either will work.
			tobj_verfer.key(tobj_signer.key());

			try {
				System.out.println(tobj_verfer.toString());
				boolean res = tobj_verfer.exec(cxi, module_id, sfc_id_verf_dil);
				System.out.println(res ? "PASSED" : "FAILED"); 
				System.out.println(SdkBuffer.xtrace("Result: ", tobj_verfer.resp));
				ret = true;
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) {
					System.err.print(e.getLocalizedMessage());
				} else {
					System.err.println(ErrList.errtext(e.ErrorCode));
				}
			} catch (Exception e) {
				e.printStackTrace();
			}
		}


		return ret;
	}
}
