package com.turbospaces.common;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.PKIXParameters;
import java.security.cert.TrustAnchor;
import java.security.cert.X509Certificate;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.collect.ImmutableList;
import com.google.common.net.HostAndPort;
import com.turbospaces.ups.PlainServiceInfo;

public class SSL {
    private static final Logger LOGGER = LoggerFactory.getLogger( SSL.class );
    protected final Set<TrustManager> trustManagers = new LinkedHashSet<TrustManager>();

    public void loadKeystore(File file, String password) throws Exception {
        KeyStore keyStore = SSL.loadTrustStore( file, password );
        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm() );
        trustManagerFactory.init( keyStore );
        X509TrustManager trustManager = (X509TrustManager) trustManagerFactory.getTrustManagers()[0];
        trustManagers.add( trustManager );
    }
    public SSL addUntrustedCertificates(PlainServiceInfo... infos) {
        return addUntrustedCertificates( Arrays.asList( infos ).stream().map( new Function<PlainServiceInfo, HostAndPort>() {
            @Override
            public HostAndPort apply(PlainServiceInfo info) {
                if ( info.getPort() > 0 ) {
                    return HostAndPort.fromParts( info.getHost(), info.getPort() );
                }
                return HostAndPort.fromParts( info.getHost(), 443 );
            }
        } ).collect( Collectors.toList() ) );
    }
    public SSL addUntrustedCertificates(HostAndPort... hostAndPorts) {
        return addUntrustedCertificates( ImmutableList.copyOf( hostAndPorts ) );
    }
    public SSL addUntrustedCertificates(Collection<HostAndPort> hostAndPorts) {
        try {
            Set<X509Certificate> untrustedCertificates = new HashSet<>();
            for ( HostAndPort hap : hostAndPorts ) {
                List<X509Certificate> asList = Arrays.asList( collectChain( hap ) );
                untrustedCertificates.addAll( asList );
            }
            addUntrustedCertificatesFromChain( untrustedCertificates );
        }
        catch ( Exception err ) {
            ExceptionUtils.wrapAndThrow( err );
        }
        return this;
    }
    public SSLContext build() {
        for ( ;; ) {
            try {
                SSLContext sslcontext = SSLContext.getInstance( "TLS" );
                sslcontext.init( null, trustManagers.toArray( new TrustManager[trustManagers.size()] ), null );
                return sslcontext;
            }
            catch ( Exception err ) {
                ExceptionUtils.wrapAndThrow( err );
            }
        }
    }
    public Map.Entry<SSLSocketFactory, X509TrustManager> buildSSLFactory() {
        for ( ;; ) {
            try {
                SSLContext sslcontext = SSLContext.getInstance( "TLS" );
                TrustManager[] tmanagers = trustManagers.toArray( new TrustManager[trustManagers.size()] );
                sslcontext.init( null, tmanagers, null );

                Iterator<TrustManager> it = trustManagers.iterator();
                X509TrustManager trustManager = (X509TrustManager) it.next();
                if ( it.hasNext() ) {
                    throw new IllegalStateException( "Unexpected default trust managers:" + Arrays.toString( tmanagers ) );
                }
                return new AbstractMap.SimpleEntry<SSLSocketFactory, X509TrustManager>( sslcontext.getSocketFactory(), trustManager );
            }
            catch ( Exception err ) {
                ExceptionUtils.wrapAndThrow( err );
            }
        }
    }
    public void addUntrustedCertificatesFromChain(Collection<X509Certificate> chain) throws Exception {
        TrustManagerFactory tmfactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm() );

        if ( !chain.isEmpty() ) {
            KeyStore ks = KeyStore.getInstance( KeyStore.getDefaultType() );
            ks.load( null, null );

            int count = 0;
            for ( X509Certificate cert : chain ) {
                LOGGER.trace( "adding trusted material {}", cert );
                String alias = String.valueOf( count++ );
                ks.setCertificateEntry( alias, cert );
            }

            tmfactory.init( ks );
            for ( TrustManager tm : tmfactory.getTrustManagers() ) {
                trustManagers.add( tm );
            }
        }
    }
    public void addUntrustedCertificatesFromFiles(Collection<File> files) throws Exception {
        TrustManagerFactory tmfactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm() );

        for ( File f : files ) {
            LOGGER.trace( "adding trusted material from {}", f );
            KeyStore ks = KeyStore.getInstance( KeyStore.getDefaultType() );

            String payload = new String( Files.readAllBytes( f.toPath() ), StandardCharsets.UTF_8 );
            LOGGER.trace( payload );
            ks.load( new ByteArrayInputStream( payload.getBytes( StandardCharsets.UTF_8 ) ), null );

            tmfactory.init( ks );
            for ( TrustManager tm : tmfactory.getTrustManagers() ) {
                trustManagers.add( tm );
            }
        }
    }
    public static File dumpTrustStore(KeyStore trustStore, String password) throws Exception {
        File f = File.createTempFile( "truststore", null );
        try (FileOutputStream out = new FileOutputStream( f )) {
            f.deleteOnExit();
            trustStore.store( out, password.toCharArray() );
        }
        return f;
    }
    public static void addCertificates(KeyStore trustStore, Collection<Certificate> chain) throws Exception {
        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm() );
        trustManagerFactory.init( (KeyStore) null );
        X509TrustManager defaultTrustManager = (X509TrustManager) trustManagerFactory.getTrustManagers()[0];
        X509Certificate[] cacerts = defaultTrustManager.getAcceptedIssuers();

        trustStore.load( null );
        int counter = 0;
        for ( X509Certificate cert : cacerts ) {
            trustStore.setCertificateEntry( String.valueOf( counter++ ), cert );
        }
        counter = 0;
        for ( Certificate cert : chain ) {
            String alias = "alias-" + counter++;
            LOGGER.debug( "adding additional cert={} alias = {}", cert, alias );
            trustStore.setCertificateEntry( alias, cert );
        }
    }
    public static KeyStore loadTrustStore(File file, String password) throws Exception {
        try (FileInputStream is = new FileInputStream( file )) {
            KeyStore keystore = KeyStore.getInstance( KeyStore.getDefaultType() );
            keystore.load( is, password.toCharArray() );
            PKIXParameters params = new PKIXParameters( keystore );
            Iterator<TrustAnchor> it = params.getTrustAnchors().iterator();
            while ( it.hasNext() ) {
                TrustAnchor ta = it.next();
                X509Certificate cert = ta.getTrustedCert();
                LOGGER.trace( cert.toString() );
            }
            return keystore;
        }
    }
    public static X509Certificate[] collectChain(HostAndPort hostAndPort) throws Exception {
        LOGGER.debug( "collecting certificate chain to {}", hostAndPort );

        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm() );
        trustManagerFactory.init( (KeyStore) null );
        X509TrustManager defaultTrustManager = (X509TrustManager) trustManagerFactory.getTrustManagers()[0];

        SSLContext context = SSLContext.getInstance( "TLS" );
        List<X509Certificate> chain = new LinkedList<>();
        X509TrustManager collector = new X509TrustManager() {
            @Override
            public synchronized X509Certificate[] getAcceptedIssuers() {
                return defaultTrustManager.getAcceptedIssuers();
            }
            @Override
            public synchronized void checkClientTrusted(X509Certificate[] certChain, String authType) throws CertificateException {
                try {
                    for ( X509Certificate cert : certChain ) {
                        chain.add( cert );
                    }
                    defaultTrustManager.checkClientTrusted( certChain, authType );
                }
                catch ( CertificateException err ) {
                    LOGGER.trace( err.getMessage(), err );
                }
            }
            @Override
            public synchronized void checkServerTrusted(X509Certificate[] certChain, String authType) throws CertificateException {
                try {
                    for ( X509Certificate cert : certChain ) {
                        chain.add( cert );
                    }
                    defaultTrustManager.checkServerTrusted( certChain, authType );
                }
                catch ( CertificateException err ) {
                    LOGGER.trace( err.getMessage(), err );
                }
            }
        };
        context.init( null, new TrustManager[] { collector }, null );
        SSLSocketFactory factory = context.getSocketFactory();

        try {
            try (SSLSocket socket = (SSLSocket) factory.createSocket( hostAndPort.getHost(), hostAndPort.getPortOrDefault( 443 ) )) {
                socket.startHandshake();
            }
        }
        catch ( Exception e ) {
            throw new RuntimeException( e );
        }

        return chain.toArray( new X509Certificate[chain.size()] );
    }
}
