/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.graph.inference;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.graph.inference.EnergyFunction;
import gov.sandia.cognition.graph.inference.EnergyFunctionSolver;
import gov.sandia.cognition.graph.inference.Node;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;

@PublicationReference(author={"Jonahtan S. Yedidia, William T. Freeman, and Yair Weiss"}, title="Understanding Belief Propagation and its Generalizations", type=PublicationType.TechnicalReport, year=2001, notes={"Institution: Mitsubishi Electric Research Laboratories"})
public abstract class SumProductInferencingAlgorithm<LabelType>
implements EnergyFunctionSolver<LabelType> {
    public static final double DEFAULT_EPS = 0.001;
    public static final int DEFAULT_MAX_ITERATIONS = 20;
    public static final int DEFAULT_NUM_THREADS = 4;
    private double eps;
    private int maxNumIterations;
    private int numThreads;
    protected List<Node<LabelType>> nodes;
    protected EnergyFunction<LabelType> fn;
    private ConcurrentLinkedQueue<List<Integer>> edgeGroups;
    private ConcurrentLinkedQueue<List<Node<LabelType>>> nodeGroups;
    private List<List<Integer>> edgeGroupsMaster;
    private List<List<Node<LabelType>>> nodeGroupsMaster;

    public SumProductInferencingAlgorithm(int maxNumIterations, double eps, int numThreads) {
        assert (maxNumIterations > 0);
        this.maxNumIterations = maxNumIterations;
        this.eps = eps;
        this.numThreads = numThreads;
        this.fn = null;
    }

    public SumProductInferencingAlgorithm(int maxNumIterations) {
        this(maxNumIterations, 0.001, 4);
    }

    public SumProductInferencingAlgorithm() {
        this(20, 0.001, 4);
    }

    @Override
    public boolean solve() {
        boolean converged = false;
        this.edgeGroups.clear();
        this.nodeGroups.clear();
        int iterCount = 0;
        ThreadFactory threadFactory = new ThreadFactory(){
            private final String baseName = "BpSolver-";
            private int counter = 0;

            @Override
            public Thread newThread(Runnable r) {
                return new Thread(r, "BpSolver-" + this.counter++);
            }
        };
        ExecutorService executorService = Executors.newFixedThreadPool(this.numThreads, threadFactory);
        ArrayList<SolveThread> threads = new ArrayList<SolveThread>(this.numThreads);
        for (int i = 0; i < this.numThreads; ++i) {
            SolveThread thread = new SolveThread();
            threads.add(thread);
        }
        ArrayList futures = new ArrayList(this.numThreads);
        while (!converged && iterCount < this.maxNumIterations) {
            this.copyFromMasters();
            this.loadAndStartFutures(futures, executorService, threads, SolverSetting.COMPUTE_MESSAGES);
            this.waitForThreadsToComplete(futures);
            this.loadAndStartFutures(futures, executorService, threads, SolverSetting.NORMALIZE_NODES);
            this.waitForThreadsToComplete(futures);
            double delta = 0.0;
            for (int i = 0; i < this.numThreads; ++i) {
                delta = Math.max(delta, ((SolveThread)threads.get(i)).getDelta());
            }
            if (delta < this.eps) {
                converged = true;
            }
            ++iterCount;
        }
        this.copyFromMasters();
        this.loadAndStartFutures(futures, executorService, threads, SolverSetting.COMPUTE_BELIEFS);
        this.waitForThreadsToComplete(futures);
        executorService.shutdown();
        return converged;
    }

    private void loadAndStartFutures(List<Future<?>> futures, ExecutorService executorService, List<SolveThread> threads, SolverSetting setting) {
        for (int i = 0; i < this.numThreads; ++i) {
            threads.get(i).setting = setting;
            futures.add(executorService.submit(threads.get(i)));
        }
    }

    private void waitForThreadsToComplete(List<Future<?>> futures) {
        for (int i = 0; i < this.numThreads; ++i) {
            try {
                futures.get(i).get();
                continue;
            }
            catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        }
        futures.clear();
    }

    protected abstract void computeTemporaryMessage(int var1);

    private void copyFromMasters() {
        if (!this.edgeGroups.isEmpty() || !this.nodeGroups.isEmpty()) {
            throw new RuntimeException("Can't copy if the destinations aren't empty");
        }
        for (List<Integer> list : this.edgeGroupsMaster) {
            this.edgeGroups.add(list);
        }
        for (List<Object> list : this.nodeGroupsMaster) {
            this.nodeGroups.add(list);
        }
    }

    abstract void initMessages(Pair<Integer, Integer> var1);

    @Override
    public void init(EnergyFunction<LabelType> f) {
        this.nodes = new ArrayList<Node<LabelType>>(f.numNodes());
        for (int i = 0; i < f.numNodes(); ++i) {
            Node<LabelType> node = new Node<LabelType>(i, f.getPossibleLabels(i));
            this.nodes.add(node);
        }
        for (int edge = 0; edge < f.numEdges(); ++edge) {
            this.initMessages(f.getEdge(edge));
        }
        for (Node<LabelType> node : this.nodes) {
            node.resetToOne();
        }
        this.fn = f;
        this.edgeGroupsMaster = new ArrayList<List<Integer>>();
        int numPieces = this.numThreads * 10;
        int numPerPiece = f.numEdges() / numPieces;
        int startAt = 0;
        for (int i = 0; i < numPieces - 1; ++i) {
            ArrayList<Integer> l = new ArrayList<Integer>(numPerPiece);
            for (int j = 0; j < numPerPiece; ++j) {
                l.add(j + startAt);
            }
            this.edgeGroupsMaster.add(l);
            startAt += numPerPiece;
        }
        ArrayList<Integer> l = new ArrayList<Integer>(f.numEdges() - startAt);
        for (int i = startAt; i < f.numEdges(); ++i) {
            l.add(i);
        }
        this.edgeGroupsMaster.add(l);
        this.nodeGroupsMaster = new ArrayList<List<Node<LabelType>>>();
        numPerPiece = f.numNodes() / numPieces;
        ArrayList<Node<LabelType>> labels = new ArrayList<Node<LabelType>>(numPerPiece);
        for (Node<LabelType> node : this.nodes) {
            labels.add(node);
            if (labels.size() != numPerPiece) continue;
            this.nodeGroupsMaster.add(labels);
            labels = new ArrayList(numPerPiece);
        }
        if (!labels.isEmpty()) {
            this.nodeGroupsMaster.add(labels);
        }
        this.nodeGroups = new ConcurrentLinkedQueue();
        this.edgeGroups = new ConcurrentLinkedQueue();
    }

    @Override
    public double getBelief(int i, int label) {
        return this.nodes.get(i).getBelief(label);
    }

    private class SolveThread
    implements Runnable {
        private double delta;
        private SolverSetting setting;

        private SolveThread() {
        }

        @Override
        public void run() {
            this.delta = 0.0;
            switch (this.setting) {
                case COMPUTE_MESSAGES: {
                    this.computeMesssages();
                    break;
                }
                case NORMALIZE_NODES: {
                    this.normalizeNodes();
                    break;
                }
                case COMPUTE_BELIEFS: {
                    this.computeBeliefs();
                    break;
                }
                default: {
                    throw new RuntimeException("Unhandled case, setting = " + (Object)((Object)this.setting));
                }
            }
        }

        private void computeMesssages() {
            List edges;
            block0: while ((edges = (List)SumProductInferencingAlgorithm.this.edgeGroups.poll()) != null) {
                Iterator iterator = edges.iterator();
                while (true) {
                    if (!iterator.hasNext()) continue block0;
                    int edge = (Integer)iterator.next();
                    SumProductInferencingAlgorithm.this.computeTemporaryMessage(edge);
                }
                break;
            }
            return;
        }

        private void normalizeNodes() {
            List nodes;
            block0: while ((nodes = (List)SumProductInferencingAlgorithm.this.nodeGroups.poll()) != null) {
                Iterator iterator = nodes.iterator();
                while (true) {
                    if (!iterator.hasNext()) continue block0;
                    Node node = (Node)iterator.next();
                    node.normalizeMessagesForSumProductAlgorithm();
                    this.delta = Math.max(this.delta, node.update());
                }
                break;
            }
            return;
        }

        private void computeBeliefs() {
            List nodes;
            block0: while ((nodes = (List)SumProductInferencingAlgorithm.this.nodeGroups.poll()) != null) {
                Iterator iterator = nodes.iterator();
                while (true) {
                    if (!iterator.hasNext()) continue block0;
                    Node node = (Node)iterator.next();
                    node.computeBeliefsForSumProductAlgorithm(SumProductInferencingAlgorithm.this.fn);
                }
                break;
            }
            return;
        }

        public double getDelta() {
            return this.delta;
        }
    }

    private static enum SolverSetting {
        COMPUTE_MESSAGES,
        NORMALIZE_NODES,
        COMPUTE_BELIEFS;

    }
}

