#!/usr/bin/env python3.7

from datetime import datetime, timedelta
import logging
from typing import List

import asn1

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import AttributeOID, NameOID
from cryptography.hazmat.primitives import serialization

from OpenSSL.crypto import load_certificate_request, dump_certificate_request
from OpenSSL.crypto import load_certificate, dump_certificate
from OpenSSL.crypto import FILETYPE_TEXT, FILETYPE_ASN1, FILETYPE_PEM

from . import Iso7816Algo, has_key_usage, has_extended_key_usage

# Get an instance of a logger
logger = logging.getLogger(__name__)

# 1.3.6.1.4.1.63847.1.3.1 is the KeypMe SAN Certificate Name (PEN Number '63847' is 'labapart UG')
KEYPME_CERTIFICATE_NAME_OID_STR = "1.3.6.1.4.1.63847.1.3.1"
KEYPME_CERTIFICATE_NAME_OID = x509.ObjectIdentifier(KEYPME_CERTIFICATE_NAME_OID_STR)

def _convert_dn_to_x509_name(dn):
    id_to_oid = {
        'cn': NameOID.COMMON_NAME,
        'l': NameOID.LOCALITY_NAME,
        'c': NameOID.COUNTRY_NAME,
        'ou': NameOID.ORGANIZATIONAL_UNIT_NAME,
        'o': NameOID.ORGANIZATION_NAME,
        'uid': NameOID.USER_ID,
        'dc': NameOID.DOMAIN_COMPONENT}
    names = []
    for name in dn.split(','):
        values = name.split('=')
        if values[0].lower() not in id_to_oid:
            raise RuntimeError(f"RDN '{values[0]}' is not supported")
        names.append(x509.NameAttribute(id_to_oid[values[0].lower()], values[1]))

    return names

def create_DER_certificate_request(dn, algo=Iso7816Algo.RSA_2048.value, private_key=None, challenge_password=None) -> bytes:
    if not private_key:
        if algo == Iso7816Algo.RSA_2048.value:
            private_key = rsa.generate_private_key(
                public_exponent=65537,
                key_size=2048,
            )
        elif algo == Iso7816Algo.RSA_1024.value:
            private_key = rsa.generate_private_key(
                public_exponent=65537,
                key_size=1024,
            )
        else:
            raise RuntimeError(f"Algo {algo} is not supported.")

    builder = x509.CertificateSigningRequestBuilder()
    builder = builder.subject_name(x509.Name(_convert_dn_to_x509_name(dn)))
    builder = builder.add_extension(
        x509.BasicConstraints(ca=False, path_length=None), critical=True,
    )
    if challenge_password:
        builder = builder.add_attribute(
            AttributeOID.CHALLENGE_PASSWORD, challenge_password.encode('utf-8')
        )
    request = builder.sign(
        private_key, hashes.SHA256()
    )
    return request.public_bytes(serialization.Encoding.DER)

def load_DER_certificate_request(req_der: bytes|None) -> x509.CertificateSigningRequest:
    if req_der is None:
        raise ValueError()
    elif isinstance(req_der, memoryview):
        req_der = req_der.tobytes()

    req = x509.load_der_x509_csr(req_der, default_backend())
    logger.debug(req.subject)
    return req

def DER_certificate_request_to_str(der_cert):
    if (der_cert is None) or (len(der_cert) == 0):
        logger.debug("DER Certificate Request is not valid (CSR is empty)")
        return None

    if isinstance(der_cert, memoryview):
        der_cert = der_cert.tobytes()
    req = load_certificate_request(FILETYPE_ASN1, der_cert)
    return dump_certificate_request(FILETYPE_TEXT, req).decode()

def DER_certificate_to_str(der_cert):
    req = load_certificate(FILETYPE_ASN1, der_cert)
    return dump_certificate(FILETYPE_TEXT, req).decode()

def DER_certificate_to_PEM(der_cert):
    req = load_certificate(FILETYPE_ASN1, der_cert)
    return dump_certificate(FILETYPE_PEM, req).decode()

def create_kerberos_san(realm, principal_name):
    kerberos_san_enc = asn1.Encoder()
    kerberos_san_enc.start()

    kerberos_san_enc.enter(asn1.Numbers.Sequence)
    kerberos_san_enc.enter(0, asn1.Classes.Context)
    kerberos_san_enc.write(realm.encode('utf-8'), 27)
    kerberos_san_enc.leave()

    kerberos_san_enc.enter(1, asn1.Classes.Context)
    kerberos_san_enc.enter(asn1.Numbers.Sequence)

    kerberos_san_enc.enter(0, asn1.Classes.Context)
    kerberos_san_enc.write(1)
    kerberos_san_enc.leave()

    kerberos_san_enc.enter(1, asn1.Classes.Context)
    kerberos_san_enc.enter(asn1.Numbers.Sequence)
    kerberos_san_enc.write(principal_name.encode('utf-8'), 27)
    kerberos_san_enc.leave()
    kerberos_san_enc.leave()

    kerberos_san_enc.leave()
    kerberos_san_enc.leave()

    kerberos_san_enc.leave()

    return kerberos_san_enc.output()

def create_email_san(email):
    encoder = asn1.Encoder()
    encoder.start()
    encoder.write(email, asn1.Numbers.UTF8String)
    return encoder.output()

def create_upn_san(upn):
    encoder = asn1.Encoder()
    encoder.start()
    encoder.write(upn, asn1.Numbers.UTF8String)
    return encoder.output()

def create_certificate_name_san(certificate_name):
    encoder = asn1.Encoder()
    encoder.start()
    encoder.write(certificate_name, asn1.Numbers.PrintableString)
    return encoder.output()

def _add_certificate_key_usage(cert_builder, key_usage: int):
    """
    Add Key Usage and Extended Key Usage to the certificate builder.
    :param cert_builder: x509.CertificateBuilder instance
    :param key_usage: int representing the key usage flags
    :return: x509.CertificateBuilder instance with added extensions
    """
    if key_usage == 0:
        logger.debug("No Key Usage defined, skipping")
        return cert_builder

    return cert_builder.add_extension(
        x509.KeyUsage(
            digital_signature=has_key_usage(key_usage, 'digital_signature'),
            content_commitment=has_key_usage(key_usage, 'non_repudiation'),
            key_encipherment=has_key_usage(key_usage, 'key_encipherment'),
            data_encipherment=has_key_usage(key_usage, 'data_encipherment'),
            key_agreement=has_key_usage(key_usage, 'key_agreement'),
            key_cert_sign=has_key_usage(key_usage, 'key_cert_sign'),
            crl_sign=has_key_usage(key_usage, 'crl_sign'),
            encipher_only=has_key_usage(key_usage, 'encipher_only'),
            decipher_only=has_key_usage(key_usage, 'decipher_only'),
        ),
        critical=False,
    )

def _add_certificate_extended_key_usage(cert_builder, extended_key_usage: int, is_critical: bool):
    if int(extended_key_usage) == 0:
        logger.debug("No Extended Key Usage defined, skipping")
        return cert_builder

    extended_key_usage_list = []

    if has_extended_key_usage(extended_key_usage, 'server_auth'):
        extended_key_usage_list.append(x509.oid.ExtendedKeyUsageOID.SERV_AUTH)
    if has_extended_key_usage(extended_key_usage, 'client_auth'):
        extended_key_usage_list.append(x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH)

    if has_extended_key_usage(extended_key_usage, 'document_signing'):
        document_signing_oid_extended_key_usage = x509.ObjectIdentifier("1.3.6.1.5.5.7.3.36")
        extended_key_usage_list.append(document_signing_oid_extended_key_usage)

    if has_extended_key_usage(extended_key_usage, 'code_signing'):
        extended_key_usage_list.append(x509.oid.ExtendedKeyUsageOID.CODE_SIGNING)
    if has_extended_key_usage(extended_key_usage, 'smartcard_logon'):
        kerberos_oid_extended_key_usage = x509.ObjectIdentifier("1.3.6.1.5.2.3.4")
        smart_card_logon_extended_key_usage = x509.ObjectIdentifier("1.3.6.1.4.1.311.20.2.2")

        extended_key_usage_list.append(kerberos_oid_extended_key_usage)
        extended_key_usage_list.append(smart_card_logon_extended_key_usage)

    if has_extended_key_usage(extended_key_usage, 'email_protection'):
        extended_key_usage_list.append(x509.oid.ExtendedKeyUsageOID.EMAIL_PROTECTION)
    if has_extended_key_usage(extended_key_usage, 'time_stamping'):
        extended_key_usage_list.append(x509.oid.ExtendedKeyUsageOID.TIME_STAMPING)
    if has_extended_key_usage(extended_key_usage, 'ocsp_signing'):
        extended_key_usage_list.append(x509.oid.ExtendedKeyUsageOID.OCSP_SIGNING)

    return cert_builder.add_extension(x509.ExtendedKeyUsage(extended_key_usage_list), critical=is_critical,)

def sign_certificate_request(ca_certificate_path: str, ca_key_path: str,
                             req: x509.CertificateSigningRequest, realm: str|None, principal_name: str|None,
                             key_usage: int, critical_extended_key_usage: int, non_critical_extended_key_usage: int,
                             certificates_id_name = None, email = None, crl_distribution_url = None,
                             dns_names: List[str]|None = None) -> bytes:
    if (ca_certificate_path is None) or (ca_key_path is None):
        raise RuntimeError("Please set environment variable 'KEYPME_SERVER_PKI_ROOT'")

    #
    # Load CA Certificate & Key
    #
    with open(ca_certificate_path, 'rb') as f:
        ca_cert = x509.load_pem_x509_certificate(f.read(), default_backend())
        logger.debug("Signing CA: %s", ca_cert.subject)

    with open(ca_key_path, 'rb') as f:
        #ca_private_key = serialization.load_pem_private_key(f.read(), password=b'olivier')
        ca_private_key = serialization.load_pem_private_key(f.read(), password=None)

    san_extensions = []

    if realm and principal_name:
        kerberos_oid = x509.ObjectIdentifier("1.3.6.1.5.2.2")
        ms_upn_oid = x509.ObjectIdentifier("1.3.6.1.4.1.311.20.2.3")

        kerberos_san = x509.OtherName(kerberos_oid, create_kerberos_san(realm, principal_name))

        upn = f"{principal_name}@{realm}"
        ms_upn_san = x509.OtherName(ms_upn_oid, create_upn_san(upn))

        san_extensions += [kerberos_san, ms_upn_san]
    elif realm is not None or principal_name is not None:
        raise RuntimeError(f"A realm and principal name must be provided (got '{realm}' and '{principal_name}') - or none must be provided")

    if email:
        # Add email address to SAN through RFC822Name
        email_san = x509.RFC822Name(email)
        san_extensions.append(email_san)

        # also add email address to SAN through otherName
        other_email_san = x509.OtherName(NameOID.EMAIL_ADDRESS, create_email_san(email.encode('utf-8')))
        san_extensions.append(other_email_san)

    if certificates_id_name:
        certificate_name_san = x509.OtherName(KEYPME_CERTIFICATE_NAME_OID, create_certificate_name_san(certificates_id_name))
        san_extensions.append(certificate_name_san)

    if dns_names:
        # Add DNS names to SAN
        for dns_name in dns_names:
            san_extensions.append(x509.DNSName(dns_name))

    #TODO: Add work email in extension
    cert_builder = x509.CertificateBuilder().subject_name(
        req.subject
    ).issuer_name(
        ca_cert.subject
    ).public_key(
        req.public_key()
    ).serial_number(
        x509.random_serial_number() # Load serial number from PKI or from DB
    ).not_valid_before(
        datetime.utcnow()
    ).not_valid_after(
        # Our certificate will be valid for 365 days
        datetime.utcnow() + timedelta(days=365)
    ).add_extension(
        x509.SubjectAlternativeName(san_extensions),
        critical=False,
    )

    # Add Key Usage
    cert_builder = _add_certificate_key_usage(cert_builder, key_usage)
    cert_builder = _add_certificate_extended_key_usage(cert_builder, critical_extended_key_usage, True)
    cert_builder = _add_certificate_extended_key_usage(cert_builder, non_critical_extended_key_usage, False)

    if crl_distribution_url:
        cert_builder = cert_builder.add_extension(
            x509.CRLDistributionPoints([
                x509.DistributionPoint(
                    full_name=[
                        x509.UniformResourceIdentifier(
                            crl_distribution_url
                        )
                    ],
                    relative_name=None,
                    reasons=None,  # Note: Windows does not support reasons in crlDistriutionUrl
                    crl_issuer=None,
                )
            ]),
            critical=False,
        )

    # Sign our certificate with our private key
    cert = cert_builder.sign(ca_private_key, hashes.SHA256())
    return cert.public_bytes(serialization.Encoding.DER)

def certificate_get_name(der_cert: bytes) -> str|None:
    certificate = x509.load_der_x509_certificate(der_cert)

    try:
        san_extension = certificate.extensions.get_extension_for_class(x509.SubjectAlternativeName)
    except x509.extensions.ExtensionNotFound:
        logger.warning("Certificate does not have SAN extension (1)")
        return None
    san = san_extension.value
    if san is None:
        logger.warning("Certificate does not have SAN extension (2)")
        return None

    other_names = san.get_values_for_type(x509.OtherName)
    for other_name in other_names:
        if other_name.type_id == KEYPME_CERTIFICATE_NAME_OID:
            decoder = asn1.Decoder()
            decoder.start(other_name.value)
            _, value = decoder.read()

            return value

    return None
