package com.utimaco;

/* 
 * 03c autogenerated 2024-01-08T10:52:06.1294572 using 
 * CS_SdkVpl version 0.0.5
 * 
 * Copyright included by reference, please see Utimaco_Demo_License.txt
 */

import java.io.File;
// import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;

import CryptoServerAPI.CryptoServerException;
import CryptoServerAPI.CryptoServerUtil;
import CryptoServerCXI.CryptoServerCXI;

//import com.utimaco.bench.DilithiumBench;
//import com.utimaco.bench.KyberBench;
import com.utimaco.cs2.mdl.*;
// Typed parameters (if any)
import com.utimaco.cs2.mdl.any.CxiKeyAttributes;
//import com.utimaco.cs2.mdl.pqmi.DilithiumGen;
//import com.utimaco.cs2.mdl.pqmi.DilithiumSigVer;
//import com.utimaco.cs2.mdl.pqmi.KyberGen;
//import com.utimaco.cs2.mdl.pqmi.KyberKEMDec;
//import com.utimaco.cs2.mdl.pqmi.KyberKEMEnc;
//import com.utimaco.cs2.mdl.pqmi.KyberKEMResponse;
import com.utimaco.cs2.mdl.pqmi.PqmiKeyStore;
import com.utimaco.cs2.mdl.pqmi.Pqmi;
import com.utimaco.cs2.mdl.pqmi.PqmiErrorList;

public class Benchmarks {
	static boolean dbg = false;
	private static Benchmarks _singleton = new Benchmarks();
	public static ExecutorService xekr = Executors.newFixedThreadPool(16);

	/**
	 * END of TEST SPECIFIC
	 **********************************************************************/


	public static ArrayList<CryptoServerCXI> connect(ArrayList<String> devlist, ArrayList<Login> logins) {
		ArrayList<CryptoServerCXI> ret = new ArrayList<>();
		
		for (int i = 0; i < devlist.size(); i++) {
			String dev = devlist.get(i);
		
			CryptoServerCXI cxi = null;
			try {
				cxi = new CryptoServerCXI(dev, 3000); // connection timeout is 3 seconds

				cxi.setTimeout(1000 * 60 * 15); // command timeout is 15 minute

				cxi.setKeepSessionAlive(true); // in theory a no-op as the test code doesn't have any waitstates
				cxi.setEndSessionOnShutdown(true); // in theory a no-op as the session is manually closed below

				dbg("Device: " + cxi.getDevice());
				String zeros = "00000000";
				for (Login l : logins) {
					if (!l.valid) { continue; }
					dbg("User: " + l.user);
					l.login(cxi);
					String authstate = zeros + Integer.toHexString(cxi.getAuthState());
					dbg(" -> 0x" + authstate.substring(authstate.length()-8));
				}
			} catch (CryptoServerException e) {
				System.out.format("Device err at %s\n%s\n", dev, e.getMessage());
				System.exit(-4326);
			} catch (NumberFormatException e) {
				System.out.format("NumberFormat err at %s\n%s\n", dev, e.getMessage());
				System.exit(-4327);
			} catch (IOException e) {
				System.out.format("IOException at %s\n%s\n", dev, e.getMessage());
				System.exit(-4327);
			}
			ret.add(cxi);
		}
		
		return ret;
	}
	
	public static boolean isDilithium(String[] args) { return hasArg(args, "dil:dilithium"); }
	public static boolean isKyber(String[] args) { return hasArg(args, "kyb:kyber"); }
	
	public static boolean isGen(String[] args) { return hasArg(args, "gen:generate:suite"); }
	public static boolean isSign(String[] args) { return hasArg(args, "sign"); }
	public static boolean isVerify(String[] args) { return hasArg(args, "verify:verf"); }
	public static boolean isKem(String[] args) { return hasArg(args, "kem:agree:suite"); }
	
	public static void main_KyberGen(String[] args, ArrayList<CryptoServerCXI> devpool) {
		// mldsa_genkey (String [] args, int instance, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		if (!isGen(args)) { return; }
		
		int threads = hasArg(args, "-once") ? 1 : 16;
		int cperthread = hasArg(args, "-once") ? 1 : 1000;
		
		threads = getArg(args, "count", threads);
		cperthread = getArg(args, "cpthread", cperthread);

		
		// /* AtomicBoolean done = */ KyberBench.mlkem_genkey(args, devpool, threads, cperthread);
	}
	
	public static void main_KyberKem(String [] args, ArrayList<CryptoServerCXI> devpool) {
		if (!isKem(args)) { return; }
		
		int threads = hasArg(args, "-once") ? 1 : 16;
		int cperthread = hasArg(args, "-once") ? 1 : 1000;
		
		threads = getArg(args, "count", threads);
		cperthread = getArg(args, "cpthread", cperthread);

		
		// /* AtomicBoolean done = */ KyberBench.mlkem_agreesecrets(args, devpool, threads, cperthread);
	}
	
	public static void main_DilithiumGen(String[] args, ArrayList<CryptoServerCXI> devpool) {
		// mldsa_genkey (String [] args, int instance, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		if (!isGen(args)) { return; }
		
		int threads = hasArg(args, "-once") ? 1 : 16;
		int cperthread = hasArg(args, "-once") ? 1 : 1000;
		
		threads = getArg(args, "count", threads);
		cperthread = getArg(args, "cpthread", cperthread);
		
		// /* AtomicBoolean done = */ DilithiumBench.mldsa_genkey(args, devpool, threads, cperthread);
	}
	
	public static void main_DilithiumSign(String[] args, ArrayList<CryptoServerCXI> devpool) {
		// mldsa_genkey (String [] args, int instance, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		if (!isSign(args)) { return; }
		
		int threads = hasArg(args, "-once") ? 1 : 16;
		int cperthread = hasArg(args, "-once") ? 1 : 1000;
		
		threads = getArg(args, "count", threads);
		cperthread = getArg(args, "cpthread", cperthread);

		
		// /* AtomicBoolean done = */ DilithiumBench.mldsa_signing(args, devpool, threads, cperthread);
	}
	
	public static void main_DilithiumVerify(String[] args, ArrayList<CryptoServerCXI> devpool) {
		// mldsa_genkey (String [] args, int instance, ArrayList<CryptoServerCXI> cxi, int threads, int countperthread) {
		if (!isVerify(args)) { return; }

		int threads = hasArg(args, "-once") ? 1 : 16;
		int cperthread = hasArg(args, "-once") ? 1 : 1000;
		
		threads = getArg(args, "count", threads);
		cperthread = getArg(args, "cpthread", cperthread);

		
		// /* AtomicBoolean done = */ DilithiumBench.mldsa_verifying(args, devpool, threads, cperthread);
	}
	
	public static void printUsageAndExit() {
		Utils.printHelp("helptext-bm.txt", new HashMap<String, String>());
	}	
	
	public static int[] parseToInt(String [] in) {
		if (in.length == 0) { return null; }
		int [] ret = new int[in.length];
		for (int i = 0; i < in.length; i++) {
			try {
				ret[i] = Integer.parseInt(in[i]);
			} catch (NumberFormatException nfe) {
				System.out.format("Failed to parse integer in string beginning %s\n", in[0]);
				System.exit(-4322);
			}
		}
		for (int i = 0; i < (ret.length - 1); i++) {
			for (int j = i+1; j < ret.length; j++) {
				if (ret[i] > ret[j]) {
					int t = ret[j];
					ret[j] = ret[i];
					ret[i] = t;
				}
			}
		}
			
		return ret;
	}
	
	/* dev=4001-4003@ip,4001-4006@ip2,PCI:0-1.1-5,/dev/cs2.0-1.1-6 */
	
	public static ArrayList<String> parseDevices(ArrayList<String> append, String dev) {
		ArrayList<String> ret = new ArrayList<>();
		if (append != null) {
			ret.addAll(append);
		}
		String [] nodes = dev.split(",;");
		for (String n : nodes) {
			if (n.contains("@")) {
				String [] befaft = n.split("@");
				if (befaft.length != 2) {
					System.out.format("Failed to parse -dev %s at %s\nToo Many @ in string\n", dev, n);
					System.exit(-4321);
				}
				String ip = befaft[1];
				if(befaft[0].contains("-")) {
					String [] tofro = befaft[0].split("-");
					int [] itofro = parseToInt(tofro);
					for (int i = itofro[0]; i <= itofro[1]; i++) {
						ret.add(String.format("%d@%s", i, ip));
					}
				} else {
					ret.add(n);
				}
			} else 
			if (n.startsWith("PCI:")) {
				if (!n.contains("-")) {
					ret.add(n);
				} else {
					String [] befaft = n.substring("PCI:".length()).split("\\.");
					if (befaft.length != 2) {
						System.out.format("Failed to parse -dev %s at %s\nToo Many . in string\n", dev, n);
						System.exit(-4321);
					}
					int [] icards = parseToInt(befaft[0].split("-"));
					int [] islots = parseToInt(befaft[1].split("-"));
					if (icards.length == 1) {
						for (int s = islots[0]; s <= islots[1]; s++) {
							ret.add(String.format("PCI:%d.%d", icards[0], s));
						}
					} else if (islots.length == 1) {
						for (int c = icards[0]; c <= icards[1]; c++) {
							ret.add(String.format("PCI:%d.%d", c, islots[0]));
						}
					} else {
						for (int c = icards[0]; c <= icards[1]; c++) {
							for (int s = islots[0]; s <= islots[1]; s++) {
								ret.add(String.format("PCI:%d.%d", c, s));
							}
						}
						
					}
				}
			} else 
			if (n.startsWith("/dev/cs2.")) {
				if (!n.contains("-")) {
					ret.add(n);
				} else {
					String [] befaft = n.substring("/dev/cs2.".length()).split(".");
					if (befaft.length != 2) {
						System.out.format("Failed to parse -dev %s at %s\nToo Many . in string\n", dev, n);
						System.exit(-4321);
					}
					int [] icards = parseToInt(befaft[0].split("-"));
					int [] islots = parseToInt(befaft[1].split("-"));
					for (int c = icards[0]; c <= icards[1]; c++) {
						for (int s = islots[0]; s <= islots[1]; s++) {
							ret.add(String.format("/dev/cs2.%d.%d", c, s));
						}
					}
				}
			} else {
				System.out.format("Failed to parse -dev %s at %s\nUnrecognized construction\n", dev, n);
				System.exit(-4321);
			}
		}
		return ret;
	}
	
	
	public static void main(String[] args) {
		dbg = (hasArg(args, "-v") ? true : dbg);
		if (hasArg(args, "-h")) { printUsageAndExit(); }
		if (hasArg(args, "-help")) { printUsageAndExit(); }

		ArrayList<String> devpool = new ArrayList<String>();
		for (int i = 0; i < args.length - 1; i++) {
			if (args[i].compareToIgnoreCase("-dev") == 0) {
				devpool = parseDevices(devpool, args[++i]);
			}
		}
		
		new PqmiErrorList();

		int c_lgns = 0;
		ArrayList<Login> logins = new ArrayList<Login>();
		Login lgn = null;
		do {
			lgn = _singleton.new Login(args, c_lgns++);
			if (lgn.valid) logins.add(lgn);
		} while (lgn.valid);
		
		ArrayList<CryptoServerCXI> devs = connect(devpool, logins);
		if (devs.size() == 0) {
			System.out.format("No devices (-dev <dev> or -dev <dev>,<dev> or -dev <dev1> -dev <dev2>)\n");
			System.exit(-4328);
		}

		if (isDilithium(args)) {
			boolean genonly = hasArg(args, "gen");
			boolean signonly = hasArg(args, "sign");
			boolean verfonly = hasArg(args, "verify");
			if (!genonly && !signonly && !verfonly) { genonly = true; signonly = true; verfonly = true; }
			if (hasArg(args, "-suite")) {
				int [] keytypes = {1, 3, 5};
				String [] arg2 = new String[args.length+1];
				int j2 = 0;
				for (int j = 0; j < args.length; j++) {
					if (args[j].compareToIgnoreCase("-suite") == 0) {
						continue;
					}
					arg2[j2++] = args[j];
				}
				arg2[j2++] = "-keytype";
				for (int i = 0; i < keytypes.length; i++) {
					arg2[j2] = Integer.toString(keytypes[i]);
					if (genonly) main_DilithiumGen(arg2, devs);
					if (signonly) main_DilithiumSign(arg2, devs);
					if (verfonly) main_DilithiumVerify(arg2, devs);
				}
			} else {
				if (genonly) main_DilithiumGen(args, devs);
				if (signonly) main_DilithiumSign(args, devs);
				if (verfonly) main_DilithiumVerify(args, devs);
			}
		}
		
		if (isKyber(args)) {
			boolean genonly = hasArg(args, "gen");
			boolean agreeonly = hasArg(args, "kem");
			if (!genonly && !agreeonly) { genonly = true; agreeonly = true; }
			if (hasArg(args, "-suite")) {
				int [] keytypes = {1, 3, 5};
				String [] arg2 = new String[args.length+2];
				int j2 = 0;
				for (int j = 0; j < args.length; j++) {
					if (args[j].compareToIgnoreCase("-keytype") == 0) {
						j++; 
						continue;
					}
					arg2[j2++] = args[j];
				}
				arg2[j2++] = "-keytype";
				for (int i = 0; i < keytypes.length; i++) {
					arg2[j2] = Integer.toString(keytypes[i]);
					if (genonly) main_KyberGen(arg2, devs);
					if (agreeonly) main_KyberKem(arg2, devs);
				}
			} else {
				if (genonly) main_KyberGen(args, devs);
				if (agreeonly) main_KyberKem(args, devs);
			}
		}
		
		xekr.shutdown();
		
		for (CryptoServerCXI c : devs) {
			c.endSession();
		}
		
	}

	public static void dbg (String d) {
		if (dbg) {
			System.out.println(d);
		}
	}

	@SuppressWarnings("unused")
	public static String [] setArg (String[] args, String arg, String val) {

		// look for arg, if not there, add it
		boolean hasit = false;
		for (int i = 0; i < args.length; i++) {
			if (args[i].compareTo(arg) == 0) {
				++i;
				args[i] = val;
				hasit = true;
				break;
			}
		}
		if (!hasit) {
			ArrayList<String> ret = new ArrayList<>();
			for (int i = 0; i < args.length; i++) {
				ret.add(args[i]);
			}
			ret.add(arg);
			ret.add(val);
			args = new String[ret.size()];
			args = ret.toArray(args);			
		}
		return args;
	}

	public static int getArg (String[] args, String arg, int def) {
		String res = getArg(args, arg, Integer.toString(def));
		int base = 10;
		if (res.toUpperCase().indexOf("0X") == 0) {
			base = 16;
			res = res.substring(2);
		}
		return Integer.parseInt(res, base);
	}

	public static String getArg (String[] args, String arg, String def) {
		boolean getnext = false;
		for (String a : args) {
			if (getnext) {
				return a;
			}
			if (a.compareTo(arg) == 0) {
				getnext = true;
			}
			if (a.compareTo("-" + arg) == 0) {  // 01d
				getnext = true;	       // 01d
			}			       // 01d
		}
		return def;
	}

	@SuppressWarnings("unused")
	public static byte[] getArg (String[] args, String arg, byte[] bytes) {
		String inv = getArg(args, arg, new String(bytes));
		return inv.getBytes();
	}

	public static boolean hasArg (String[] args, String arg) {
		if (arg.contains(":")) {
			String [] ins = arg.split(":");
			for (String a: ins) {
				if (hasArg(args, a)) return true;
			}
		}
		
		for (String a : args) {
			if (a.compareTo(arg) == 0) {
				return true;
			}
			if (a.compareTo("-" + arg) == 0) {  // 01d
				return true;		  // 01d
			}			       // 01d
		}
		return false;
	}

	private class Login {
		String user = null;
		String spec = null;
		String pin = null;

		File file = null;
		boolean valid = false;
		// valid indicates that this /may/ represent a valid credential, not that it /is/ a valid login

		/*		
		user		Name of the user who wants to autenticate to the CryptoServer.
		keyFile		<-- mapped from -cred, or -s or if mistakenly -p with 3 parameters
			Key file user: Path to key file containing user's private key.
			Password user: null.
		password	<-- mapped from -cred, or -p
			Key file user: Password of the key file if using an encrypted file, null otherwise.
			Password user: Password of the user.
		 */
		Login (String [] args, int i) {
			int c = 0;
			boolean getNext = false;
			for (String a : args) {
				if (getNext) {
					getNext = false;
					String[] nkp = a.split(",");
					if (nkp.length == 2) {
						/* either 
						 * -s name,key (unencrypted)
						 * -s name,:cs2...
						 * -p name,pass
						 */
						user = nkp[0];
						spec = nkp[1];  
						pin = nkp[1]; // yes -- 1.  see later check

						file = new File(spec);
						valid = (spec.matches("^:cs2:.{3,4}:.*$")) || file.exists();

					} else if (nkp.length == 3) {
						// * -s name,key,pass (encrypted)
						user = nkp[0];
						spec = nkp[1];
						pin = nkp[2];
						file = new File(spec);
						valid = file.exists();
					}
					/* at this point, valid is true if the file exists, or if the spec is :cs2:...
					 * valid is false if the value isn't a file and the value isn't a PINPad identifier
					 * 
					 * In this case, we assume HMACpwd, if we've made it this far
					 */
					valid = true; // assume it's HMACpwd
					/*
					 * I may find some logic later for pre-determining if this is an hmac
					 * but I can't test it (using the HSM) because it would log everyone out if it was bad.
					 *
					 * For now, we rely on the cxi.logon(...) to bork if it is not a valid logon.
					 */

					if (!valid) { System.out.println("Auth param format invalid for " + a); }
					c++;
				}
				/* deprecated */
				if (a.compareTo("-s") == 0) {
					if (c != i) { c++; continue; }
					getNext = true;
				}
				/* deprecated */
				if (a.compareTo("-p") == 0) {
					if (c != i) { c++; continue; }
					getNext = true;
				}

				if (a.compareTo("-cred") == 0) {
					if (c != i) { c++; continue; }
					getNext = true;
				}
			}
		}

		boolean login (CryptoServerCXI cxi) {
			boolean ret = false;

			try {
				/* change 01g replaced logonpass/logonsign with the simple logon */
				byte [] pinbytes = new byte[0];
				if (pin != null) {
					pinbytes = pin.getBytes();
				}
				byte [] ovrt = new byte[pinbytes.length];
				cxi.logon(user, spec, pinbytes);
				System.arraycopy(ovrt,0,pinbytes,0,pinbytes.length);
				// also have pin to worry about.
				// pin = null;
				// but we don't know when the GC will reap it.
				return true;
			} catch (IOException e) {
				System.out.println(user + " login failed:  " + e.getMessage());
			} catch (CryptoServerException e) {
				System.out.println(user + " login failed:  0x" + Integer.toHexString(e.ErrorCode));
			}
			return ret;
		}
	}

}
