/*
 * Decompiled with CFR 0.152.
 */
package science.aist.machinelearning.algorithm.clustering.kmeans;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;
import science.aist.machinelearning.algorithm.clustering.kmeans.Cluster;
import science.aist.machinelearning.algorithm.clustering.kmeans.ElementToVector;
import science.aist.machinelearning.algorithm.clustering.kmeans.KMeansCluster;
import science.aist.machinelearning.algorithm.clustering.kmeans.VectorDistance;
import science.aist.machinelearning.core.Algorithm;
import science.aist.machinelearning.core.Gene;
import science.aist.machinelearning.core.Problem;
import science.aist.machinelearning.core.Solution;
import science.aist.machinelearning.core.SolutionGene;
import science.aist.machinelearning.core.analytics.Analytics;
import science.aist.machinelearning.core.options.Descriptor;

public class KMeansClustering<T>
implements Algorithm<Cluster<T>, T> {
    private final Logger logger = LoggerFactory.getLogger(this.getClass());
    private final Random random = new Random(42L);
    private Analytics analytics;
    private int numberOfClusters;
    private ElementToVector<T> elementToVector;
    private VectorDistance vectorDistance;
    private double epsilon = 0.01;
    private int maxIterations = 25;

    public Solution<Cluster<T>, T> solve(Problem<T> problem) {
        return this.solve(problem, null);
    }

    public Solution<Cluster<T>, T> solve(Problem<T> problem, Solution<Cluster<T>, T> bestSolution) {
        if (this.analytics != null) {
            this.analytics.startAnalytics();
        }
        if (bestSolution != null) {
            this.logger.warn("Best Solution is ignored in current implementation");
        }
        if (this.analytics != null) {
            this.analytics.logProblem(problem);
        }
        List elements = problem.getProblemGenes().stream().map(Gene::getGene).collect(Collectors.toList());
        Map elementVectorCache = elements.stream().collect(Collectors.toMap(Function.identity(), this.elementToVector::mapElementToVector));
        int size = elements.size();
        ArrayList<ElementInCluster> elementInClusters = new ArrayList<ElementInCluster>();
        ArrayList result = new ArrayList();
        for (int i = 0; i < Math.min(this.numberOfClusters, size); ++i) {
            int elemSize = elements.size();
            Object element = elements.remove(elemSize > 1 ? this.random.nextInt(elemSize - 1) : 0);
            double[] elementVector = elementVectorCache.get(element);
            KMeansCluster cluster = this.createCluster(element, elementVector);
            result.add(cluster);
            elementInClusters.add(new ElementInCluster(element, cluster));
        }
        for (Object element : elements) {
            double[] elementVector = elementVectorCache.get(element);
            result.stream().min(Comparator.comparingDouble(c -> this.vectorDistance.calculateDistance(c.getClusterCenter(), elementVector))).ifPresent(c -> {
                this.addElementToCluster((KMeansCluster<T>)c, element, elementVector);
                elementInClusters.add(new ElementInCluster(element, c));
            });
        }
        int iteration = 1;
        AtomicInteger changes = new AtomicInteger(0);
        double delta = 1.0;
        while (delta > this.epsilon && iteration < this.maxIterations) {
            changes.set(0);
            for (ElementInCluster elementInCluster : elementInClusters) {
                Object element = elementInCluster.element;
                double[] elementVector = elementVectorCache.get(element);
                result.stream().min(Comparator.comparingDouble(c -> this.vectorDistance.calculateDistance(c.getClusterCenter(), elementVector))).filter(c -> c != elementInCluster.cluster).ifPresent(c -> {
                    this.removeElementFromCluster(elementInCluster.cluster, element, elementVector);
                    this.addElementToCluster((KMeansCluster<T>)c, element, elementVector);
                    elementInCluster.cluster = c;
                    changes.getAndIncrement();
                });
            }
            delta = (double)changes.get() / (double)size;
            this.logger.debug("Iteration: {} [MaxIterations: {}]", (Object)(++iteration), (Object)this.maxIterations);
            this.logger.debug("Current epsilon: {} [Epsilon: {}]", (Object)delta, (Object)this.epsilon);
        }
        Solution solution = new Solution();
        solution.setSolutionGenes(result.stream().map(SolutionGene::new).collect(Collectors.toList()));
        if (this.analytics != null) {
            this.analytics.logSolution(solution);
        }
        if (this.analytics != null) {
            this.analytics.finishAnalytics();
        }
        return solution;
    }

    public Analytics getAnalytics() {
        return this.analytics;
    }

    public void setAnalytics(Analytics analytics) {
        this.analytics = analytics;
    }

    public Map<String, Descriptor> getOptions() {
        return Stream.of(Pair.of((Object)"numberOfClusters", (Object)this.getNumberOfClusters()), Pair.of((Object)"elementToVector", this.getElementToVector()), Pair.of((Object)"vectorDistance", (Object)this.getVectorDistance())).map(p -> Pair.of((Object)((String)p.getLeft()), (Object)new Descriptor(p.getRight()))).collect(Collectors.toMap(Pair::getLeft, Pair::getRight));
    }

    public boolean setOptions(Map<String, Descriptor> options) {
        return options.entrySet().stream().allMatch(p -> this.setOption((String)p.getKey(), (Descriptor)p.getValue()));
    }

    public boolean setOption(String name, Descriptor descriptor) {
        try {
            Arrays.stream(this.getClass().getMethods()).filter(m -> m.getName().equals("set" + name.substring(0, 1).toUpperCase() + name.substring(1))).findFirst().orElseThrow(NoSuchMethodException::new).invoke((Object)this, descriptor.getValue());
            return true;
        }
        catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
            this.logger.warn("Could not set specified option", (Throwable)e);
            return false;
        }
    }

    private KMeansCluster<T> createCluster(T firstElement, double[] vectorFirstElement) {
        KMeansCluster cluster = new KMeansCluster();
        int length = vectorFirstElement.length;
        cluster.getElementsModfi().add(firstElement);
        cluster.setClusterCenter(new double[length]);
        cluster.setClusterVectorSum(new double[length]);
        System.arraycopy(vectorFirstElement, 0, cluster.getClusterCenter(), 0, length);
        System.arraycopy(vectorFirstElement, 0, cluster.getClusterVectorSum(), 0, length);
        return cluster;
    }

    void addElementToCluster(KMeansCluster<T> cluster, T element, double[] elementVector) {
        cluster.getElementsModfi().add(element);
        int nrOfElementsInCluster = cluster.getElementsModfi().size();
        for (int i = 0; i < cluster.getClusterVectorSum().length; ++i) {
            double[] dArray = cluster.getClusterVectorSum();
            int n = i;
            dArray[n] = dArray[n] + elementVector[i];
            cluster.getClusterCenter()[i] = cluster.getClusterVectorSum()[i] / (double)nrOfElementsInCluster;
        }
    }

    void removeElementFromCluster(KMeansCluster<T> cluster, T element, double[] elementVector) {
        if (!cluster.getElementsModfi().remove(element)) {
            throw new IllegalStateException("Trying to remove a element from the cluster which is not in the cluster.");
        }
        int nrOfElementsInCluster = cluster.getElementsModfi().size();
        for (int i = 0; i < cluster.getClusterVectorSum().length; ++i) {
            double[] dArray = cluster.getClusterVectorSum();
            int n = i;
            dArray[n] = dArray[n] - elementVector[i];
            cluster.getClusterCenter()[i] = cluster.getClusterVectorSum()[i] / (double)nrOfElementsInCluster;
        }
    }

    public int getNumberOfClusters() {
        return this.numberOfClusters;
    }

    @Required
    public void setNumberOfClusters(int numberOfClusters) {
        this.numberOfClusters = numberOfClusters;
    }

    public ElementToVector<T> getElementToVector() {
        return this.elementToVector;
    }

    @Required
    public void setElementToVector(ElementToVector<T> elementToVector) {
        this.elementToVector = elementToVector;
    }

    public VectorDistance getVectorDistance() {
        return this.vectorDistance;
    }

    @Required
    public void setVectorDistance(VectorDistance vectorDistance) {
        this.vectorDistance = vectorDistance;
    }

    public void setSeed(int seed) {
        this.random.setSeed(seed);
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    private class ElementInCluster {
        T element;
        KMeansCluster<T> cluster;

        ElementInCluster(T element, KMeansCluster<T> cluster) {
            this.element = element;
            this.cluster = cluster;
        }
    }
}

