import enum
import logging

import google.protobuf.empty_pb2
import grpc

from . import keypme_pkcs11_bridge_pb2_grpc
from . import keypme_pkcs11_bridge_pb2

class KeypMePkcs11Exception(Exception):
    pass

class KeypMePkcs11InvalidSignatureException(KeypMePkcs11Exception):
    pass

def keypme_pkcs11_client_open(server, root_certificate_chain: bytes|None = None,
                              client_certificate: bytes|None = None, client_key: bytes|None = None,
                              target_name_override = None) -> grpc.Channel:
    if client_key is not None and client_certificate is not None:
        if isinstance(client_certificate, memoryview):
            client_certificate = client_certificate.tobytes()
        if isinstance(client_key, memoryview):
            client_key = client_key.tobytes()

        if root_certificate_chain is not None:
            if isinstance(root_certificate_chain, memoryview):
                root_certificate_chain = root_certificate_chain.tobytes()
        else:
            root_certificate_chain = None

        channel_credentials = grpc.ssl_channel_credentials(
                root_certificates=root_certificate_chain,
                private_key=client_key,
                certificate_chain=client_certificate)

        options = [
        #    ('grpc.keepalive_time_ms', 10000),
        #    ('grpc.keepalive_timeout_ms', 5000),
        #    ('grpc.keepalive_permit_without_calls', True),
        #    ('grpc.http2.max_pings_without_data', 0),
        #    ('grpc.http2.min_time_between_pings_ms', 10000),
        #    ('grpc.http2.min_ping_interval_without_data_ms', 300000)
        ]

        if target_name_override is not None:
            options.append(('grpc.ssl_target_name_override', target_name_override))

        return grpc.secure_channel(server, channel_credentials, options=options)
    else:
        return grpc.insecure_channel(server)

def keypme_pkcs11_client_test_server(channel: grpc.Channel, login_username = None, login_key: bytes = None):
    login_argument = None

    if login_username:
        login_argument = keypme_pkcs11_bridge_pb2.LoginRequest(user_type=keypme_pkcs11_bridge_pb2.LoginRequest.UserType.proprietary)
        login_argument.username = login_username

        if login_key:
            login_argument.key = login_key

    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        info = stub.getInfo(google.protobuf.empty_pb2.Empty())
        logging.info("Token info: %s", info)

        #TODO: Add support to test login username and key
        return info
    except grpc.RpcError as rpc_error:
        logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_crypto_support(channel: grpc.Channel, login_username = None, login_key: bytes = None):
    login_argument = None

    if login_username:
        login_argument = keypme_pkcs11_bridge_pb2.LoginRequest(user_type=keypme_pkcs11_bridge_pb2.LoginRequest.UserType.proprietary)
        login_argument.username = login_username

        if login_key:
            login_argument.key = login_key

    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        crypto_support = stub.getCryptographySupport(google.protobuf.empty_pb2.Empty())
        logging.info("Cryptography support: %s", crypto_support)

        #TODO: Add support to test login username and key
        return crypto_support
    except grpc.RpcError as rpc_error:
        logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_token_create(channel: grpc.Channel, token_label, so_pin, user_pin, login_username = None, login_key: bytes = None, reinit_if_exist: bool = False) -> None:
    login_argument = None

    if login_username:
        login_argument = keypme_pkcs11_bridge_pb2.LoginRequest(user_type=keypme_pkcs11_bridge_pb2.LoginRequest.UserType.proprietary)
        login_argument.username = login_username

        if login_key:
            login_argument.key = login_key

    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        stub.createToken(keypme_pkcs11_bridge_pb2.TokenCreateRequest(login=login_argument,
                                                                     token_label=token_label, pin=user_pin, so_pin=so_pin,
                                                                     reinit_if_exist=reinit_if_exist))
        logging.info(f"Token '{token_label}' created.")
    except grpc.RpcError as rpc_error:
        if rpc_error.code() == grpc.StatusCode.NOT_FOUND:
            logging.error(f"Received NOT_FOUND RPC error: message={rpc_error.details()}")
        else:
            logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

#
# Session
#

def keypme_pkcs11_client_session_open(channel: grpc.Channel, token_label, user_pin) -> int:
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        response = stub.open(keypme_pkcs11_bridge_pb2.SessionOpenRequest(token_label=token_label, pin=user_pin))
        return response.session
    except grpc.RpcError as rpc_error:
        if rpc_error.code() == grpc.StatusCode.NOT_FOUND:
            logging.error(f"Received NOT_FOUND RPC error: message={rpc_error.details()}")
        else:
            logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_session_close(channel: grpc.Channel, session) -> None:
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        stub.close(keypme_pkcs11_bridge_pb2.SessionOpenRequest(session=session))
    except grpc.RpcError as rpc_error:
        if rpc_error.code() == grpc.StatusCode.NOT_FOUND:
            logging.error(f"Received NOT_FOUND RPC error: message={rpc_error.details()}")
        else:
            logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

#
# RSA
#

def keypme_pkcs11_client_rsa_generate(channel: grpc.Channel, session, key_alias, rsa_bit_length) -> keypme_pkcs11_bridge_pb2.RSAPublicKeyResponse:
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        key_alias_message = keypme_pkcs11_bridge_pb2.KeyAlias(name=key_alias)
        public_key_response = stub.rsaGenerateKey(keypme_pkcs11_bridge_pb2.RSAGenerateKeyRequest(session=session, alias=key_alias_message, bit_length=rsa_bit_length))
        return public_key_response
    except grpc.RpcError as rpc_error:
        logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_rsa_sign(channel: grpc.Channel, session, key_alias, signature_mechanism: keypme_pkcs11_bridge_pb2.RSAPkcs1SignatureMechanism, data: bytes) -> bytes:
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        key_alias_message = keypme_pkcs11_bridge_pb2.KeyAlias(name=key_alias)
        response = stub.rsaPkcs1Sign(keypme_pkcs11_bridge_pb2.RSAPkcs1SignRequest(session=session, alias=key_alias_message, mechanism=signature_mechanism, data=data))
        return response.data
    except grpc.RpcError as rpc_error:
        logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_rsa_verify(channel: grpc.Channel, session, key_alias, signature_mechanism: keypme_pkcs11_bridge_pb2.RSAPkcs1SignatureMechanism, data: bytes, signature: bytes):
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        key_alias_message = keypme_pkcs11_bridge_pb2.KeyAlias(name=key_alias)
        stub.rsaPkcs1Verify(keypme_pkcs11_bridge_pb2.RSAPkcs1VerifyRequest(session=session, alias=key_alias_message, mechanism=signature_mechanism, data=data, signature=signature))
    except grpc.RpcError as rpc_error:
        if rpc_error.details() == "CKR_SIGNATURE_INVALID":
            raise KeypMePkcs11InvalidSignatureException()
        else:
            logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
            raise

#
# ML-DSA
#

class MLDsaSecurityLevel(enum.Enum):
    ML_DSA_SECURITY_LEVEL_2_4x4 = 2
    ML_DSA_SECURITY_LEVEL_3_6x5 = 3
    ML_DSA_SECURITY_LEVEL_5_8x7 = 5

def keypme_pkcs11_client_ml_dsa_generate(channel: grpc.Channel, session, key_alias, security_level: MLDsaSecurityLevel) -> keypme_pkcs11_bridge_pb2.DataResponse:
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        key_alias_message = keypme_pkcs11_bridge_pb2.KeyAlias(name=key_alias)
        public_key_response = stub.mlDsaGenerateKey(keypme_pkcs11_bridge_pb2.MLDSAGenerateKeyRequest(session=session, alias=key_alias_message, security_level=security_level.value))
        return public_key_response
    except grpc.RpcError as rpc_error:
        logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_ml_dsa_sign(channel: grpc.Channel, session, key_alias, signature_mechanism: keypme_pkcs11_bridge_pb2.MLDSASignatureMechanism, data: bytes) -> keypme_pkcs11_bridge_pb2.DataResponse:
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        key_alias_message = keypme_pkcs11_bridge_pb2.KeyAlias(name=key_alias)
        signature_response = stub.mlDsaSign(keypme_pkcs11_bridge_pb2.MLDSASignRequest(session=session, alias=key_alias_message, mechanism=signature_mechanism, data=data))
        return signature_response.data
    except grpc.RpcError as rpc_error:
        logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
        raise

def keypme_pkcs11_client_ml_dsa_verify(channel: grpc.Channel, session, key_alias, signature_mechanism: keypme_pkcs11_bridge_pb2.MLDSASignatureMechanism, data: bytes, signature: bytes):
    stub = keypme_pkcs11_bridge_pb2_grpc.Pkcs11Stub(channel)
    try:
        key_alias_message = keypme_pkcs11_bridge_pb2.KeyAlias(name=key_alias)
        stub.mlDsaVerify(keypme_pkcs11_bridge_pb2.MLDSAVerifyRequest(session=session, alias=key_alias_message, mechanism=signature_mechanism, data=data, signature=signature))
    except grpc.RpcError as rpc_error:
        if rpc_error.details() == "CKR_SIGNATURE_INVALID":
            raise KeypMePkcs11InvalidSignatureException()
        else:
            logging.error(f"Received unknown RPC error: code={rpc_error.code()} message={rpc_error.details()}")
            raise
