/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.graph.inference;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.graph.inference.Message;
import gov.sandia.cognition.graph.inference.Node;
import gov.sandia.cognition.graph.inference.SumProductInferencingAlgorithm;
import gov.sandia.cognition.util.Pair;
import java.util.Collection;

@PublicationReference(author={"Jonahtan S. Yedidia, William T. Freeman, and Yair Weiss"}, title="Understanding Belief Propagation and its Generalizations", type=PublicationType.TechnicalReport, year=2001, notes={"Institution: Mitsubishi Electric Research Laboratories"})
public class SumProductBeliefPropagation<LabelType>
extends SumProductInferencingAlgorithm<LabelType> {
    public SumProductBeliefPropagation(int maxNumIterations, double eps, int numThreads) {
        super(maxNumIterations, eps, numThreads);
    }

    public SumProductBeliefPropagation(int maxNumIterations) {
        this(maxNumIterations, 0.001, 4);
    }

    public SumProductBeliefPropagation() {
        this(20, 0.001, 4);
    }

    private void computeTemporaryMessage(int edge, boolean reverse) {
        Node targetNode;
        Node sourceNode;
        Pair<Integer, Integer> edgePair = this.fn.getEdge(edge);
        if (reverse) {
            sourceNode = (Node)this.nodes.get((Integer)edgePair.getSecond());
            targetNode = (Node)this.nodes.get((Integer)edgePair.getFirst());
        } else {
            sourceNode = (Node)this.nodes.get((Integer)edgePair.getFirst());
            targetNode = (Node)this.nodes.get((Integer)edgePair.getSecond());
        }
        int sourceNodeId = sourceNode.getId();
        int targetNodeId = targetNode.getId();
        Message targetMessage = targetNode.getMessageFromSource(sourceNodeId);
        Collection sourceLabels = this.fn.getPossibleLabels(sourceNodeId);
        Collection targetLabels = this.fn.getPossibleLabels(targetNodeId);
        int size = sourceLabels.size() * targetLabels.size();
        double[] values = new double[size];
        double max = -1.7976931348623157E308;
        int ij = 0;
        for (Object targetLabel : targetLabels) {
            int sourceLabelIdx = 0;
            for (Object sourceLabel : sourceLabels) {
                values[ij] = -this.fn.getUnaryCost(sourceNodeId, sourceLabel);
                if (!reverse) {
                    int n = ij;
                    values[n] = values[n] + -this.fn.getPairwiseCost(edge, sourceLabel, targetLabel);
                } else {
                    int n = ij;
                    values[n] = values[n] + -this.fn.getPairwiseCost(edge, targetLabel, sourceLabel);
                }
                int n = ij;
                values[n] = values[n] + sourceNode.getLogMessageSum(sourceLabelIdx, targetNodeId);
                max = Math.max(values[ij], max);
                ++sourceLabelIdx;
                ++ij;
            }
        }
        ij = 0;
        for (int i = 0; i < targetLabels.size(); ++i) {
            double value = 0.0;
            for (int j = 0; j < sourceLabels.size(); ++j) {
                value += Math.exp(values[ij] - max);
                ++ij;
            }
            targetMessage.setTempValue(i, value);
        }
    }

    @Override
    protected void computeTemporaryMessage(int edge) {
        this.computeTemporaryMessage(edge, true);
        this.computeTemporaryMessage(edge, false);
    }

    @Override
    void initMessages(Pair<Integer, Integer> edgePair) {
        Node node = (Node)this.nodes.get((Integer)edgePair.getFirst());
        node.link((Integer)edgePair.getSecond(), true);
        node = (Node)this.nodes.get((Integer)edgePair.getSecond());
        node.link((Integer)edgePair.getFirst(), true);
    }
}

