package com.utimaco.aut;


import java.nio.ByteBuffer;

import com.utimaco.SimpleArgs;
import com.utimaco.bench.KeyOpts_XMSS;
import com.utimaco.cs2.mdl.ErrList;
import com.utimaco.cs2.mdl.pqmi.PqmiKeyStore;
import com.utimaco.cs2.mdl.pqmi.HbsPubInfo;
import com.utimaco.cs2.mdl.pqmi.Pqmi;

import CryptoServerAPI.CryptoServerException;
import CryptoServerCXI.CryptoServerCXI;
import CryptoServerCXI.CryptoServerCXI.KeyAttributes;

public class KeystoreTest implements AutTest {
	static int module_id = 0xa6; // 

	static final int sfc_id_keystore = Pqmi.SFC_KEYSTORE; 
	static final int sfc_id_ks_list = Pqmi.SFC_LIST_KEYS;
	
	static final int STORE_OPERATION = Pqmi.KEY_LIST_FROM_STORE;
	static final int STORE_LIST_ALL = Pqmi.KT_FLAG_MASK; // "0xFF"
	// or you can filter on the _KT_FLAG values, ie
	// ML_DSA_KT_FLAG, ML_KEM_KT_FLAG, HBS_LMS_KT_FLAG HBS_HSS_KT_FLAG HBS_XMSS_KT_FLAG

	CryptoServerCXI cxi;
	SimpleArgs args;
	boolean ctors;

	public KeystoreTest(CryptoServerCXI cxi, SimpleArgs cla) {
		this(cxi, cla, false);
	}
	public KeystoreTest(CryptoServerCXI cxi, SimpleArgs cla, boolean asynch) {
		this.cxi = cxi;
		this.args = cla;
		this.ctors = false;
	}

	public boolean go() {
		boolean ret = true;
		boolean ret2 = false;

		module_id = AutTest.getModuleId(cxi, "PQMI");
		
		int [] kt = { Pqmi.KT_FLAG_MASK, 
				Pqmi.ML_DSA_KT_FLAG, Pqmi.ML_KEM_KT_FLAG,
				Pqmi.HBS_HSS_KT_FLAG, Pqmi.HBS_LMS_KT_FLAG, Pqmi.HBS_XMSS_KT_FLAG
		};

		for (int i = 0; i < kt.length; i++) {
			PqmiKeyStore ks = new PqmiKeyStore(cxi);
			ks.flags(STORE_OPERATION); // list of indexes
			ks.type(Pqmi.KT_FLAG_MASK); // everything
			
			AutTest.announce(String.format("KEYSTORE-LIST (0x%02x) > ", ks.type()));
			try {
				ret2 = false;
				ks.exec(cxi, module_id, sfc_id_keystore);
				ret2 = true;
				System.out.println("Index | Type    | Group                            | Name             | Spec  ");
				System.out.println("                > Info...");
				System.out.println("------|---------|----------------------------------|------------------|-------");
				ByteBuffer bb = ByteBuffer.wrap(ks.resp);
				int count = bb.getShort();
				int ndx = 1;
				while (count-- > 0) {
					int type = bb.get();
					if (type < 0) {
						type += 256;
					}
					@SuppressWarnings("unused")
					int zero = bb.get();
					short len = bb.getShort();
					switch (type) {
					case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_4X4_R3:
					case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_6X5_R3:
					case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_8X7_R3:
					case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_4X4_AES_R3:
					case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_6X5_AES_R3:
					case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_8X7_AES_R3:
						infoPrint_Dilithium(ndx, bb, type, len);
						break;
					case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_512: 
					case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_768: 
					case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_1024: 
					case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_512_90S:
					case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_768_90S:
					case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_1024_90S:
						infoPrint_Kyber(ndx, bb, type, len);
						break;
					case Pqmi.ML_DSA_KT_FLAG + Pqmi.ML_DSA_KT_44:
					case Pqmi.ML_DSA_KT_FLAG + Pqmi.ML_DSA_KT_65:
					case Pqmi.ML_DSA_KT_FLAG + Pqmi.ML_DSA_KT_87:
						infoPrint_MLDSA(ndx, bb, type, len);
						break;
					case Pqmi.ML_KEM_KT_FLAG + Pqmi.ML_KEM_KT_512:
					case Pqmi.ML_KEM_KT_FLAG + Pqmi.ML_KEM_KT_768:
					case Pqmi.ML_KEM_KT_FLAG + Pqmi.ML_KEM_KT_1024:
						infoPrint_MLKEM(ndx, bb, type, len);
						break;
					case Pqmi.HBS_LMS_KT_FLAG:
					case Pqmi.HBS_HSS_KT_FLAG:
					case 0x4C4D53: // deprecated do not use
					case 0x485353: // deprecated do not use
						infoPrint_LMS(ndx, bb, type, len);
						break;
					case Pqmi.HBS_XMSS_KT_FLAG:
					case 0x584D5353: // deprecated do not use
						infoPrint_XMSS(ndx, bb, type, len);
						break;
						default:
							System.out.format("Unknown type? %x\n", type);
					}
					ndx++;
				}
			} catch (CryptoServerException e) {
				System.err.format("Error 0x%08X\n", e.ErrorCode);
				if ((e.ErrorCode & (int)0xBFFF0000) < 0xB1000000) 
					 { System.err.print(e.getLocalizedMessage()); }
				else { System.err.println(ErrList.errtext(e.ErrorCode)); }
				
			} catch (Exception e) {
				e.printStackTrace();
			}
			System.out.println(ret2 ? "Complete\n" : "FAILED");
			ret = ret && ret2;
		}
		
		return ret;
	}
	
	private void spacing(String d, int l) {
		int ll = l;
		while (ll > d.length()) {
			System.out.print(d);
			ll -= d.length();
		}
		System.out.print(d.substring(0, ll));
	}
	private void spacer() {
		System.out.print(" | ");
	}
	private void space(int l) {
		String s = "                                                       ";
		spacing(s, l);
	}
	@SuppressWarnings("unused")
	private void dash(int l) {
		String d = "-------------------------------------------------------";
		spacing(d, l);
	}

	private void infoPrint(int ndx, String type, String group, String name, int spec, String info) {
		int digits = (ndx > 9) ? 2 : 1;
		digits = (ndx > 99) ? 3 : digits;
		int fw = 5;
		space(fw - digits);
		System.out.format("%d", ndx);
		spacer();
		System.out.format("%-7s", type);
		spacer();
		System.out.format("%-32s", group);
		spacer();
		System.out.format("%-16s", name);
		spacer();
		digits = (spec > 9) ? 2 : 1;
		digits = (spec > 99) ? 3 : digits;
		fw = 5;
		space(fw - digits);  
		System.out.format("%d\n                > ", spec);
		System.out.println(info);
	}

	private int load_int4(byte [] in) {
		int out = 0;
		out = (in[0] << 24);
		out += (in[1] << 16);
		out += (in[2] << 8);
		out += in[3];
		return out;
	}

	private String exportString(int export) {
		switch (export) {
		case CryptoServerCXI.KEY_EXPORT_ALLOW_PLAIN:
			return "Allow Plain";
		case CryptoServerCXI.KEY_EXPORT_ALLOW:
			return "Allow";
		case CryptoServerCXI.KEY_EXPORT_DENY_BACKUP:
			return "Deny Backup";
		}
		return "Backup Only";
	}
	private String usageString(int usage) {
		switch (usage) {
		case CryptoServerCXI.KEY_USAGE_SIGN:
			return "Sign";
		case CryptoServerCXI.KEY_USAGE_VERIFY:
			return "Verify";
		case CryptoServerCXI.KEY_USAGE_ENCRYPT:
			return "Encrypt";
		case CryptoServerCXI.KEY_USAGE_DECRYPT:
			return "Decrypt";
		case CryptoServerCXI.KEY_USAGE_DERIVE:
			return "Derive/Key Agreement";
		case CryptoServerCXI.KEY_USAGE_UNWRAP:
			return "Unwrap";
		case CryptoServerCXI.KEY_USAGE_WRAP:
			return "Wrap/Key-Wrap";
		case CryptoServerCXI.KEY_USAGE_ENCRYPT | CryptoServerCXI.KEY_USAGE_DECRYPT:
			return "Enc/Dec";
		case CryptoServerCXI.KEY_USAGE_SIGN | CryptoServerCXI.KEY_USAGE_VERIFY:
			return "Sign/Verify";
		case CryptoServerCXI.KEY_USAGE_WRAP | CryptoServerCXI.KEY_USAGE_UNWRAP:
			return "Wrap/Unwrap";
			
		}
		return "Unset";
	}
	
	private void infoPrint_Kyber(int ndx, ByteBuffer bb, int type, short len) {
		byte [] pinfo = new byte[len];
		bb.get(pinfo);
		ByteBuffer pl = ByteBuffer.wrap(pinfo);
		String type_s = "KYB_RD3 - ";
		switch (type) {
		case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_512: type_s += "512"; break;
		case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_768: type_s += "768"; break;
		case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_1024: type_s += "1024"; break;
		case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_512_90S: type_s += "90S Variant 512"; break;
		case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_768_90S: type_s += "90S Variant 768"; break;
		case Pqmi.KYBER_KT_FLAG + Pqmi.KYBER_1024_90S: type_s += "90S Variant 1024"; break;
		}
		infoPrint_ML(ndx, type_s, pl, type, len, "");
	}
	
	private void infoPrint_MLKEM(int ndx, ByteBuffer bb, int type, short len) {
		byte [] pinfo = new byte[len];
		bb.get(pinfo);
		ByteBuffer pl = ByteBuffer.wrap(pinfo);
		String type_s = "MLKEM - ";
		switch (type) {
		case Pqmi.ML_KEM_KT_FLAG + Pqmi.ML_KEM_KT_512: type_s += "4x4"; break;
		case Pqmi.ML_KEM_KT_FLAG + Pqmi.ML_KEM_KT_768: type_s += "6x5"; break;
		case Pqmi.ML_KEM_KT_FLAG + Pqmi.ML_KEM_KT_1024: type_s += "8x7"; break;
		}
		infoPrint_ML(ndx, type_s, pl, type, len, "");
	}
	
	private void infoPrint_Dilithium(int ndx, ByteBuffer bb, int type, short len) {
		byte [] pinfo = new byte[len];
		bb.get(pinfo);
		ByteBuffer pl = ByteBuffer.wrap(pinfo);
		String type_s = "DIL_RD3 - ";
		switch (type) {
		case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_4X4_R3: type_s += "4x4"; break;
		case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_6X5_R3: type_s += "6x5"; break;
		case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_8X7_R3: type_s += "8x7"; break;
		case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_4X4_AES_R3: type_s += "AES Variant 4x4"; break;
		case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_6X5_AES_R3: type_s += "AES Variant 6x5"; break;
		case Pqmi.DILITHIUM_KT_FLAG + Pqmi.DILITHIUM_8X7_AES_R3: type_s += "AES Variant 8x7"; break;
		}
		infoPrint_ML(ndx, type_s, pl, type, len, "");
	}
	
	private void infoPrint_MLDSA(int ndx, ByteBuffer bb, int type, short len) {
		byte [] pinfo = new byte[len];
		bb.get(pinfo);
		ByteBuffer pl = ByteBuffer.wrap(pinfo);
		String type_s = "MLDSA - ";
		switch (type) {
		case Pqmi.ML_DSA_KT_FLAG + Pqmi.ML_DSA_KT_44: type_s += "4x4"; break;
		case Pqmi.ML_DSA_KT_FLAG + Pqmi.ML_DSA_KT_65: type_s += "6x5"; break;
		case Pqmi.ML_DSA_KT_FLAG + Pqmi.ML_DSA_KT_87: type_s += "8x7"; break;
		}
		infoPrint_ML(ndx, type_s, pl, type, len, "");
	}
	
	private void infoPrint_ML(int ndx, String kind, ByteBuffer pl, int type, short len, String info) {
		// it's a property list
		String group = "";
		String name = "";
		int spec = 0;
		int export = 0;
		int usage = 0;
		while (pl.hasRemaining()) {
			short tag = pl.getShort();
			short flen = pl.getShort();
			byte [] data = new byte[flen];
			pl.get(data);
			switch (tag) {
			case KeyAttributes.PROP_KEY_GROUP:
				group = new String(data);
				break;
			case KeyAttributes.PROP_KEY_NAME:
				name = new String(data);
				break;
			case KeyAttributes.PROP_KEY_SPEC:
				spec = load_int4(data);
				break;
			case KeyAttributes.PROP_KEY_EXPORT:
				export = load_int4(data);
				break;
			case KeyAttributes.PROP_KEY_USAGE:
				usage = load_int4(data);
				break;
			default:
				System.out.format(">%d/%d:%02x%02x%02x%02x<\n", tag, flen, data[0], data[1], data[2], data[3]);
				break;
			}
			
		}
		String type_s = String.format("0x%02x", type);

		info = String.format("[%s] Usage: %s, Export: %s", kind, usageString(usage), exportString(export));
		
		infoPrint(ndx, type_s, group, name, spec, info);
	}
	
	private void infoPrint_XMSS(int ndx, ByteBuffer bb, int type, short len) {
		byte [] pinfo = new byte[len];
		bb.get(pinfo);
		HbsPubInfo hpi = new HbsPubInfo(pinfo);
		String type_s = "XMSS " + (hpi.levels() == 1 ? "-MT":"");
		
		String info = KeyOpts_XMSS.getSchemeFromOid(hpi.lm_array(), hpi.levels()==1);
		
		infoPrint(ndx, type_s, new String(hpi.group()), new String(hpi.name()), hpi.spec(), info);
	}
	
	private void infoPrint_LMS(int ndx, ByteBuffer bb, int type, short len) {
		byte [] pinfo = new byte[len];
		bb.get(pinfo);
		HbsPubInfo hpi = new HbsPubInfo(pinfo);
		
		infoPrint(ndx, "LMS/HSS", new String(hpi.group()), new String(hpi.name()), hpi.spec(), "");
	}

}
