#!/usr/bin/env python3

# PIP requirements for 'pySCEP'
#   - asn1crypto
#   - certbuilder
#   - cryptography
#   - csrbuilder
#   - oscrypto
#   - requests
#   - six

import argparse
import io
import logging

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import AttributeOID, NameOID, ExtendedKeyUsageOID

from oscrypto import keys

import requests

from scep import Client

# Get an instance of a logger
logger = logging.getLogger(__name__)


class ScepException(Exception):
    pass

class ScepCertificateNotFound(ScepException):
    pass

class ScepPasswordNotValid(ScepException):
    pass

class ScepServerUnreachable(ScepException):
    pass

def _test_generate_pkcs10_with_password_challenge(cn, challenge_password):
    # Generate a private key
    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=2048,
        backend=default_backend()
    )

    # Create a builder for the CSR
    csr_builder = x509.CertificateSigningRequestBuilder()

    # Add subject information
    csr_builder = csr_builder.subject_name(x509.Name([
        x509.NameAttribute(NameOID.COMMON_NAME, cn),
    ]))

    # Add challenge password attribute
    csr_builder = csr_builder.add_attribute(AttributeOID.CHALLENGE_PASSWORD, challenge_password.encode('utf-8'))

    # Add X.509 extensions
    csr_builder = csr_builder.add_extension(
        x509.BasicConstraints(ca=False, path_length=None),
        critical=True
    )
    csr_builder = csr_builder.add_extension(
        x509.ExtendedKeyUsage([
            ExtendedKeyUsageOID.SERVER_AUTH,
            ExtendedKeyUsageOID.CLIENT_AUTH
        ]),
        critical=False
    )
    csr_builder = csr_builder.add_extension(
        x509.KeyUsage(
            digital_signature=True,
            key_encipherment=True,
            content_commitment=False,
            data_encipherment=False,
            key_agreement=False,
            key_cert_sign=False,
            crl_sign=False,
            encipher_only=False,
            decipher_only=False
        ),
        critical=True
    )

    # Sign the CSR with the private key
    csr = csr_builder.sign(
        private_key=private_key,
        algorithm=hashes.SHA256(),
        backend=default_backend()
    )

    return csr

def scep_client_test(scep_server_url, identifier, p12_file=None, p12_password=None):
    """Test SCEP certificate. Exception are raised when URL or password not valid."""

    capabilities = ""

    try :
        # Do not use pySCEP to get capability to use a shorter timeout
        res = requests.get(scep_server_url, params={'operation': 'GetCACaps', 'message': '' },timeout=2)
        if res.status_code != 200:
            raise ScepServerUnreachable()

        capabilities = res.text.split('\n')
        logging.debug("SCEP server capabilities for '%s': %s", scep_server_url, capabilities)
    except requests.exceptions.ConnectTimeout as exc:
        logger.error("scep_client_test:server:timeout")
        raise ScepServerUnreachable() from exc
    except Exception as exc:
        logger.error("scep_client_test:server:%s:%s", type(exc), str(exc))
        raise ScepServerUnreachable() from exc

    # Confirm the password is valid
    if p12_file:
        try:
            if isinstance(p12_password, str):
                p12_password = p12_password.encode('utf-8')
            if isinstance(p12_file, memoryview):
                p12_file = p12_file.tobytes()
            elif isinstance(p12_file, io.BufferedReader):
                p12_file = p12_file.read()

            _ = keys.parse_pkcs12(p12_file, password=p12_password)
        except Exception as exc:
            logger.error("scep_client_test:password:%s:%s", type(exc), str(exc))
            raise ScepPasswordNotValid() from exc

    return capabilities

def scep_client_get(scep_server_url, p12_file=None, p12_password=None):
    client = Client.Client(scep_server_url)

    if p12_file:
        if isinstance(p12_password, str):
            p12_password = p12_password.encode('utf-8')

        if isinstance(p12_file, memoryview):
            p12_file = p12_file.tobytes()

        private_key_info, certificate, _ = keys.parse_pkcs12(p12_file, password=p12_password)
        identity = Client.Certificate(certificate=certificate)
        identity_private_key = Client.PrivateKey(private_key=private_key_info)
    else:
        # Generate a Self Signed Certificate
        identity, identity_private_key = Client.SigningRequest.generate_self_signed(cn='PyScep-test',
            key_usage={ 'digital_signature', 'key_encipherment' })

    return client, identity, identity_private_key

def scep_client_enrol(client, identity, identity_private_key, ca_identifier, csr_der):  #pylint: disable=inconsistent-return-statements
    # Convert DER csr to pySCEP csr
    csr = Client.SigningRequest(csr_der)

    res = client.enrol(
        csr=csr,
        identity=identity,
        identity_private_key=identity_private_key,
        identifier=ca_identifier
    )

    if res.status == Client.PKIStatus.FAILURE:
        raise ScepException(f"FAILURE:{res.fail_info}")
    elif res.status == Client.PKIStatus.PENDING:
        raise ScepException(f"PENDING:{res.transaction_id}")
    else:
        try:
            return res.certificate.to_der()
        except Exception as exc:
            for c in res.certificates:
                # Parse the certificate
                crypto_certificate = x509.load_der_x509_certificate(c.to_der(), default_backend())

                # Check the Basic Constraints extension
                try:
                    basic_constraints = crypto_certificate.extensions.get_extension_for_class(x509.BasicConstraints).value
                    is_ca = basic_constraints.ca
                    if is_ca:
                        continue

                    return c.to_der()
                except x509.ExtensionNotFound as exc2:
                    logger.error("Basic Constraints extension not found, cannot determine if it is a CA.")
                    exc = exc2

            raise ScepCertificateNotFound() from exc

if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='scep-client', description='Enroll certificate to SCEP server')
    parser.add_argument('--url', default='http://127.0.0.1/ejbca/publicweb/apply/scep/enduser/pkiclient.exe', help='SCEP server')
    parser.add_argument('--ca-identifier', default='MyFirstRootCA', help='CA identifier')
    parser.add_argument('--cn', required=True, help='CSR Common Name')
    parser.add_argument('--challenge-password', required=True, help='CSR Challenge Password')
    parser.add_argument('--scep-identifier', help='PKCS#12 SCEP identifier')
    parser.add_argument('--scep-identifier-pass', help='PKCS#12 SCEP identifier password')
    args = parser.parse_args()

    test_client, test_identity, test_identity_private_key = scep_client_get(args.url, args.scep_identifier, args.scep_identifier_pass)

    test_csr = _test_generate_pkcs10_with_password_challenge(args.cn, args.challenge_password)
    test_csr_der = test_csr.public_bytes(serialization.Encoding.DER)

    test_der_certificate = scep_client_enrol(test_client, test_identity, test_identity_private_key, args.ca_identifier, test_csr_der)
    print(test_der_certificate)
