// Copyright 2013 Michel Kraemer
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pw.prok.download;

import groovy.lang.Closure;
import org.apache.http.*;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.AuthCache;
import org.apache.http.client.ClientProtocolException;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.client.utils.DateUtils;
import org.apache.http.config.Registry;
import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.HttpClientConnectionManager;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.impl.auth.BasicScheme;
import org.apache.http.impl.client.BasicAuthCache;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.protocol.HttpDateGenerator;
import org.apache.http.util.EntityUtils;
import org.gradle.api.GradleException;
import org.gradle.api.Project;
import org.gradle.internal.logging.progress.ProgressLogger;
import org.gradle.internal.logging.progress.ProgressLoggerFactory;
import pw.prok.download.internal.InsecureHostnameVerifier;
import pw.prok.download.internal.InsecureTrustManager;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import java.io.*;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.zip.GZIPInputStream;

/**
 * Downloads a file and displays progress
 *
 * @author Michel Kraemer
 */
public class DownloadAction implements DownloadSpec {
    private static final HostnameVerifier INSECURE_HOSTNAME_VERIFIER = new InsecureHostnameVerifier();
    private static final TrustManager[] INSECURE_TRUST_MANAGERS = {new InsecureTrustManager()};

    private final Project project;
    private List<URL> sources = new ArrayList<URL>(1);
    private File dest;
    private boolean quiet = false;
    private boolean overwrite = true;
    private boolean onlyIfNewer = false;
    private boolean compress = true;
    private String username;
    private String password;
    private Map<String, String> headers;
    private boolean acceptAnyCertificate = false;

    private AtomicInteger upToDate = new AtomicInteger(0);
    private AtomicInteger skipped = new AtomicInteger(0);

    private SSLConnectionSocketFactory insecureSSLSocketFactory = null;

    /**
     * Creates a new download action
     *
     * @param project the project to be built
     */
    public DownloadAction(Project project) {
        this.project = project;
    }

    /**
     * Starts downloading
     *
     * @throws IOException if the file could not downloaded
     */
    public void execute() throws IOException {
        if (sources.isEmpty()) {
            throw new IllegalArgumentException("Please provide a download source");
        }
        if (dest == null) {
            throw new IllegalArgumentException("Please provide a download destination");
        }

        if (dest.equals(project.getBuildDir())) {
            //make sure build dir exists
            dest.mkdirs();
        }

        if (sources.size() > 1 && !dest.isDirectory()) {
            if (!dest.exists()) {
                // create directory automatically
                dest.mkdirs();
            } else {
                throw new IllegalArgumentException("If multiple sources are provided "
                        + "the destination has to be a directory.");
            }
        }

        CloseableHttpClient client = createHttpClient();
        ExecutorService executors = Executors.newFixedThreadPool(maxDownloadThreads);
        List<Future<?>> futures = new ArrayList<Future<?>>();

        try {
            for (URL src : sources) {
                Future<?> future = execute(client, executors, src);
                if (future != null) futures.add(future);
            }

            for (Future<?> future : futures) {
                try {
                    future.get();
                } catch (Exception e) {
                    throw new IOException(e);
                }
            }
        } finally {
            executors.shutdown();
        }
    }

    private Future<?> execute(CloseableHttpClient client, ExecutorService executors, URL src) throws IOException {
        final File destFile = destFile(src);
        // in case offline mode is enabled don't try to download if
        // destination already exists
        if (project.getGradle().getStartParameter().isOffline()) {
            if (destFile.exists()) {
                if (!quiet) {
                    project.getLogger().info("Skipping existing file '" +
                            destFile.getName() + "' in offline mode.");
                }
                skipped.incrementAndGet();
                return null;
            }
            throw new IllegalStateException("Unable to download " + src +
                    " in offline mode.");
        }

        if (!overwrite && destFile.exists()) {
            if (!quiet) {
                project.getLogger().info("Destination file already exists. "
                        + "Skipping '" + destFile.getName() + "'");
            }
            upToDate.incrementAndGet();
            return null;
        }

        if (cache > 0 && System.currentTimeMillis() - destFile.lastModified() > cache) {
            if (!quiet) {
                project.getLogger().info("File '" + destFile.getName() + "' not expired cache period.");
            }
            upToDate.incrementAndGet();
            return null;
        }

        return executors.submit(new Downloader(client, src, destFile));
    }

    private File destFile(URL src) {
        File destFile = dest;
        if (destFile.isDirectory()) {
            //guess name from URL
            String name = src.toString();
            if (name.endsWith("/")) {
                name = name.substring(0, name.length() - 1);
            }
            name = name.substring(name.lastIndexOf('/') + 1);
            destFile = new File(dest, name);
        } else {
            //create destination directory
            File parent = destFile.getParentFile();
            if (parent != null) {
                parent.mkdirs();
            }
        }
        return destFile;
    }

    /**
     * Configure proxy for a given HTTP host
     *
     * @param httpHost the HTTP host
     * @return the proxy or <code>null</code> if not proxy is necessary
     * @throws IOException if the proxy could not be configured
     */
    private HttpHost configureProxy(HttpHost httpHost) throws IOException {
        HttpHost proxy = null;

        String scheme = httpHost.getSchemeName();
        if (!"http".equals(scheme) && !"https".equals(scheme) && !"ftp".equals(scheme)) {
            return null;
        }

        String host = System.getProperty(scheme + ".proxyHost");
        if (host != null) {
            String portStr = System.getProperty(scheme + ".proxyPort");
            if (portStr != null) {
                int port;
                try {
                    port = Integer.parseInt(portStr);
                } catch (NumberFormatException e) {
                    throw new IllegalArgumentException("Illegal proxy port: " + portStr);
                }
                proxy = new HttpHost(host, port);
            } else {
                proxy = new HttpHost(host);
            }
        }

        return proxy;
    }

    /**
     * Creates an HTTP client
     *
     * @return the HTTP client
     */
    private CloseableHttpClient createHttpClient() {
        HttpClientBuilder builder = HttpClientBuilder.create();

        HttpClientConnectionManager connectionManager;

        //accept any certificate if necessary
        if (acceptAnyCertificate) {
            SSLConnectionSocketFactory icsf = getInsecureSSLSocketFactory();
            builder.setSSLSocketFactory(icsf);
            Registry<ConnectionSocketFactory> registry =
                    RegistryBuilder.<ConnectionSocketFactory>create()
                            .register("https", icsf)
                            .build();
            connectionManager = new PoolingHttpClientConnectionManager(registry);
        } else {
            connectionManager = new PoolingHttpClientConnectionManager();
        }
        builder.setConnectionManager(connectionManager);

        return builder.build();
    }

    /**
     * Opens a connection to the given HTTP host and requests a file. Checks
     * the last-modified header on the server if the given timestamp is
     * greater than 0.
     *
     * @param httpHost  the HTTP host to connect to
     * @param file      the file to request
     * @param timestamp the timestamp of the destination file, in milliseconds
     * @param client    the HTTP client to use to perform the request
     * @return the URLConnection or null if the download should be skipped
     * @throws IOException if the connection could not be opened
     */
    private CloseableHttpResponse openConnection(HttpHost httpHost, String file,
                                                 long timestamp, CloseableHttpClient client) throws IOException {
        //perform preemptive authentication
        HttpClientContext context = null;
        if (username != null && password != null) {
            context = HttpClientContext.create();
            addAuthentication(httpHost, username, password, true, context);
        }

        //create request
        HttpGet get = new HttpGet(file);

        //configure proxy
        HttpHost proxy = configureProxy(httpHost);
        if (proxy != null) {
            RequestConfig config = RequestConfig.custom()
                    .setProxy(proxy)
                    .build();
            get.setConfig(config);

            //add authentication information for proxy
            String scheme = httpHost.getSchemeName();
            String proxyUser = System.getProperty(scheme + ".proxyUser");
            String proxyPassword = System.getProperty(scheme + ".proxyPassword");
            if (proxyUser != null && proxyPassword != null) {
                if (context == null) {
                    context = HttpClientContext.create();
                }
                addAuthentication(proxy, proxyUser, proxyPassword, false, context);
            }
        }

        //set If-Modified-Since header
        if (timestamp > 0) {
            DateFormat format = new SimpleDateFormat(HttpDateGenerator.PATTERN_RFC1123, Locale.US);
            format.setTimeZone(HttpDateGenerator.GMT);
            get.setHeader("If-Modified-Since", format.format(timestamp));
        }

        //set headers
        if (headers != null) {
            for (Map.Entry<String, String> headerEntry : headers.entrySet()) {
                get.addHeader(headerEntry.getKey(), headerEntry.getValue());
            }
        }

        //enable compression
        if (compress) {
            get.setHeader("Accept-Encoding", "gzip");
        }

        //execute request
        CloseableHttpResponse response = client.execute(httpHost, get, context);

        //handle response
        int code = response.getStatusLine().getStatusCode();
        if ((code < 200 || code > 299) && code != HttpStatus.SC_NOT_MODIFIED) {
            throw new ClientProtocolException(response.getStatusLine().getReasonPhrase());
        }

        return response;
    }

    /**
     * Add authentication information for the given host
     *
     * @param host            the host
     * @param username        the username
     * @param password        the password
     * @param preAuthenticate <code>true</code> if the authentication scheme
     *                        should be set to <code>Basic</code> preemptively (should be
     *                        <code>false</code> if adding authentication for a proxy server)
     * @param context         the context in which the authentication information
     *                        should be saved
     */
    private void addAuthentication(HttpHost host, String username,
                                   String password, boolean preAuthenticate, HttpClientContext context) {
        AuthCache authCache = context.getAuthCache();
        if (authCache == null) {
            authCache = new BasicAuthCache();
            context.setAuthCache(authCache);
        }

        CredentialsProvider credsProvider = context.getCredentialsProvider();
        if (credsProvider == null) {
            credsProvider = new BasicCredentialsProvider();
            context.setCredentialsProvider(credsProvider);
        }

        credsProvider.setCredentials(new AuthScope(host),
                new UsernamePasswordCredentials(username, password));

        if (preAuthenticate) {
            authCache.put(host, new BasicScheme());
        }
    }

    /**
     * Converts a number of bytes to a human-readable string
     *
     * @param bytes the bytes
     * @return the human-readable string
     */
    private String toLengthText(long bytes) {
        if (bytes < 1024) {
            return bytes + " B";
        } else if (bytes < 1024 * 1024) {
            return (bytes / 1024) + " KB";
        } else if (bytes < 1024 * 1024 * 1024) {
            return String.format("%.2f MB", bytes / (1024.0 * 1024.0));
        } else {
            return String.format("%.2f GB", bytes / (1024.0 * 1024.0 * 1024.0));
        }
    }

    /**
     * Parse the Last-Modified header of a {@link HttpResponse}
     *
     * @param response the {@link HttpResponse}
     * @return the last-modified value or 0 if it is unknown
     */
    private long parseLastModified(HttpResponse response) {
        Header header = response.getLastHeader("Last-Modified");
        if (header == null) {
            return 0;
        }
        String value = header.getValue();
        if (value == null || value.isEmpty()) {
            return 0;
        }
        Date date = DateUtils.parseDate(value);
        if (date == null) {
            return 0;
        }
        return date.getTime();
    }

    /**
     * Checks if the content of the given {@link HttpEntity} is compressed
     *
     * @param entity the entity to check
     * @return true if it is compressed, false otherwise
     */
    private boolean isContentCompressed(HttpEntity entity) {
        Header header = entity.getContentEncoding();
        if (header == null) {
            return false;
        }
        String value = header.getValue();
        if (value == null || value.isEmpty()) {
            return false;
        }
        return value.equalsIgnoreCase("gzip");
    }

    /**
     * @return true if the download destination is up to date
     */
    boolean isUpToDate() {
        return upToDate.get() == sources.size();
    }

    /**
     * @return true if execution of this task has been skipped
     */
    boolean isSkipped() {
        return skipped.get() == sources.size();
    }

    List<File> getOutputFiles() {
        List<File> files = new ArrayList<File>(sources.size());
        for (URL src : sources) {
            files.add(destFile(src));
        }
        return files;
    }

    @Override
    public void src(Object src) throws MalformedURLException {
        if (src instanceof Closure) {
            //lazily evaluate closure
            Closure<?> closure = (Closure<?>) src;
            src = closure.call();
        }

        if (src instanceof CharSequence) {
            sources.add(new URL(src.toString()));
        } else if (src instanceof URL) {
            sources.add((URL) src);
        } else if (src instanceof Collection) {
            Collection<?> sc = (Collection<?>) src;
            for (Object sco : sc) {
                src(sco);
            }
        } else if (src != null && src.getClass().isArray()) {
            int len = Array.getLength(src);
            for (int i = 0; i < len; ++i) {
                Object sco = Array.get(src, i);
                src(sco);
            }
        } else {
            throw new IllegalArgumentException("Download source must " +
                    "either be a URL, a CharSequence, a Collection or an array.");
        }
    }

    @Override
    public void dest(Object dest) {
        if (dest instanceof Closure) {
            //lazily evaluate closure
            Closure<?> closure = (Closure<?>) dest;
            dest = closure.call();
        }

        if (dest instanceof CharSequence) {
            this.dest = project.file(dest.toString());
        } else if (dest instanceof File) {
            this.dest = (File) dest;
        } else {
            throw new IllegalArgumentException("Download destination must " +
                    "either be a File or a CharSequence");
        }
    }

    @Override
    public void quiet(boolean quiet) {
        this.quiet = quiet;
    }

    @Override
    public void overwrite(boolean overwrite) {
        this.overwrite = overwrite;
    }

    @Override
    public void onlyIfNewer(boolean onlyIfNewer) {
        this.onlyIfNewer = onlyIfNewer;
        this.cache = 0;
    }

    @Override
    public void compress(boolean compress) {
        this.compress = compress;
    }

    @Override
    public void username(String username) {
        this.username = username;
    }

    @Override
    public void password(String password) {
        this.password = password;
    }

    @Override
    public void headers(Map<String, String> headers) {
        if (this.headers == null) {
            this.headers = new LinkedHashMap<String, String>();
        } else {
            this.headers.clear();
        }
        if (headers != null) {
            this.headers.putAll(headers);
        }
    }

    @Override
    public void header(String name, String value) {
        if (headers == null) {
            headers = new LinkedHashMap<String, String>();
        }
        headers.put(name, value);
    }

    @Override
    public void acceptAnyCertificate(boolean accept) {
        this.acceptAnyCertificate = accept;
    }

    @Override
    public Object getSrc() {
        if (sources.size() == 1) {
            return sources.get(0);
        }
        return sources;
    }

    @Override
    public File getDest() {
        return dest;
    }

    @Override
    public boolean isQuiet() {
        return quiet;
    }

    @Override
    public boolean isOverwrite() {
        return overwrite;
    }

    @Override
    public boolean isOnlyIfNewer() {
        return onlyIfNewer;
    }

    @Override
    public boolean isCompress() {
        return compress;
    }

    @Override
    public String getUsername() {
        return username;
    }

    @Override
    public String getPassword() {
        return password;
    }

    @Override
    public Map<String, String> getHeaders() {
        return headers;
    }

    @Override
    public String getHeader(String name) {
        if (headers == null) {
            return null;
        }
        return headers.get(name);
    }

    @Override
    public boolean isAcceptAnyCertificate() {
        return acceptAnyCertificate;
    }

    private SSLConnectionSocketFactory getInsecureSSLSocketFactory() {
        if (insecureSSLSocketFactory == null) {
            SSLContext sc;
            try {
                sc = SSLContext.getInstance("SSL");
                sc.init(null, INSECURE_TRUST_MANAGERS, new SecureRandom());
                insecureSSLSocketFactory = new SSLConnectionSocketFactory(
                        sc, INSECURE_HOSTNAME_VERIFIER);
            } catch (NoSuchAlgorithmException e) {
                throw new RuntimeException(e);
            } catch (KeyManagementException e) {
                throw new RuntimeException(e);
            }
        }
        return insecureSSLSocketFactory;
    }

    private long cache = 0;

    @Override
    public void cache(long amount, TimeUnit unit) {
        cache = unit.toMillis(amount);
        onlyIfNewer = cache > 0;
    }

    @Override
    public long getCache() {
        return cache;
    }

    private int maxDownloadThreads = 1;

    @Override
    public void maxDownloadThreads(int maxDownloadThreads) {
        this.maxDownloadThreads = maxDownloadThreads;
    }

    @Override
    public int getMaxDownloadThreads() {
        return maxDownloadThreads;
    }

    private class Downloader implements Callable<Void> {
        private final CloseableHttpClient client;
        private final URL source;
        private final File dest;
        private ProgressLogger progressLogger;
        private AtomicLong total = new AtomicLong(0);
        private AtomicLong current = new AtomicLong(0);

        Downloader(CloseableHttpClient client, URL source, File dest) {
            this.client = client;
            this.source = source;
            this.dest = dest;
        }

        @Override
        public Void call() throws Exception {
            //create progress logger
            if (!quiet) {
                //we are about to access an internal class. Use reflection here to provide
                //as much compatibility to different Gradle versions as possible
                try {
                    Method getServices = project.getClass().getMethod("getServices");
                    Object serviceFactory = getServices.invoke(project);
                    Method get = serviceFactory.getClass().getMethod("get", Class.class);
                    Object progressLoggerFactory = get.invoke(serviceFactory,
                            ProgressLoggerFactory.class);
                    Method newOperation = progressLoggerFactory.getClass().getMethod(
                            "newOperation", Class.class);
                    progressLogger = (ProgressLogger) newOperation.invoke(
                            progressLoggerFactory, getClass());
                    String desc = "Download " + source.toString();
                    Method setDescription = progressLogger.getClass().getMethod(
                            "setDescription", String.class);
                    setDescription.setAccessible(true);
                    setDescription.invoke(progressLogger, desc);
                    Method setLoggingHeader = progressLogger.getClass().getMethod(
                            "setLoggingHeader", String.class);
                    setLoggingHeader.setAccessible(true);
                    setLoggingHeader.invoke(progressLogger, desc);
                } catch (Exception e) {
                    //unable to get progress logger
                    project.getLogger().error("Unable to get progress logger. Download "
                            + "progress will not be displayed.");
                }
            }

            CloseableHttpResponse response = null;
            try {
                final long timestamp = onlyIfNewer && dest.exists() ? dest.lastModified() : 0;

                //create HTTP host from URL
                HttpHost httpHost = new HttpHost(source.getHost(), source.getPort(), source.getProtocol());

                //open URL connection
                response = openConnection(httpHost, source.getFile(), timestamp, client);
                if (response == null) {
                    return null;
                }

                //check if file on server was modified
                final long lastModified = parseLastModified(response);
                final int code = response.getStatusLine().getStatusCode();
                if (code == HttpStatus.SC_NOT_MODIFIED ||
                        (lastModified != 0 && timestamp >= lastModified)) {
                    if (!quiet) {
                        project.getLogger().info("Not modified. Skipping '" + source + "'");
                    }
                    upToDate.incrementAndGet();
                    return null;
                }
                performDownload(response, dest);
            } finally {
                if (response != null) response.close();
            }
            return null;
        }

        /**
         * Save an HTTP response to a file
         *
         * @param response the response to save
         * @param destFile the destination file
         * @throws IOException if the response could not be downloaded
         */
        private void performDownload(HttpResponse response, File destFile)
                throws IOException {
            File tmp = File.createTempFile(".downloader", destFile.getName(), destFile.getParentFile());
            tmp.deleteOnExit();
            HttpEntity entity = response.getEntity();
            if (entity == null) {
                return;
            }

            //get content length
            long contentLength = entity.getContentLength();
            if (contentLength >= 0)
                total.addAndGet(contentLength);

            //open stream and start downloading
            InputStream is = new BufferedInputStream(entity.getContent());
            if (isContentCompressed(entity)) {
                is = new GZIPInputStream(is);
            }
            try {
                startProgress();
                OutputStream os = new BufferedOutputStream(new FileOutputStream(tmp));

                try {
                    byte[] buf = new byte[1024 * 10];
                    int read;
                    while ((read = is.read(buf)) >= 0) {
                        os.write(buf, 0, read);
                        current.addAndGet(read);
                        logProgress();
                    }

                    os.close();
                    if (destFile.exists() && !destFile.delete()) {
                        throw new GradleException("Unable to erase dest file " + destFile);
                    }
                    if (!tmp.renameTo(destFile)) {
                        throw new GradleException("Unable to move " + tmp + " to " + destFile);
                    }
                } finally {
                    os.close();
                    if (!tmp.delete()) {
                        project.getLogger().warn("Unable to delete temporary file " + tmp);
                    }
                }
            } finally {
                is.close();
                EntityUtils.consumeQuietly(entity);
                completeProgress();
            }

            long newTimestamp = parseLastModified(response);
            if (onlyIfNewer && newTimestamp > 0) {
                destFile.setLastModified(newTimestamp);
            }
        }

        private void startProgress() {
            if (progressLogger != null)
                progressLogger.started();
        }

        private void completeProgress() {
            if (progressLogger != null)
                progressLogger.completed();
        }

        private volatile long lastProgressPublishing = 0;

        private void logProgress() {
            if (progressLogger == null) {
                return;
            }
            final long currentTime = System.currentTimeMillis();
            final long current = this.current.get();
            final long total = this.total.get();
            if (current > 0 && currentTime - lastProgressPublishing > 1000) {
                lastProgressPublishing = currentTime;
                StringBuilder builder = new StringBuilder();
                builder.append(toLengthText(current));
                if (total > 0)
                    builder.append(" / ").append(toLengthText(total));
                builder.append(" downloaded");
                progressLogger.progress(builder.toString());
            }
        }
    }
}
