/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.graph.inference;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.graph.inference.Message;
import gov.sandia.cognition.graph.inference.Node;
import gov.sandia.cognition.graph.inference.SumProductInferencingAlgorithm;
import gov.sandia.cognition.util.Pair;
import java.util.Collection;

@PublicationReference(author={"Tu-Thach Quach and Jeremy D. Wendt"}, title="A diffusion model for maximizing influence spread in large networks", type=PublicationType.Conference, publication="Proceedings of the International Conference on Social Informatics", year=2016)
public class SumProductDirectedPropagation<LabelType>
extends SumProductInferencingAlgorithm<LabelType> {
    public SumProductDirectedPropagation(int maxNumIterations, double eps, int numThreads) {
        super(maxNumIterations, eps, numThreads);
    }

    public SumProductDirectedPropagation(int maxNumIterations) {
        this(maxNumIterations, 0.001, 4);
    }

    public SumProductDirectedPropagation() {
        this(20, 0.001, 4);
    }

    @Override
    void initMessages(Pair<Integer, Integer> edgePair) {
        Node node = (Node)this.nodes.get((Integer)edgePair.getSecond());
        node.link((Integer)edgePair.getFirst(), false);
    }

    @Override
    protected void computeTemporaryMessage(int edge) {
        Pair<Integer, Integer> edgePair = this.fn.getEdge(edge);
        Node sourceNode = (Node)this.nodes.get((Integer)edgePair.getFirst());
        Node targetNode = (Node)this.nodes.get((Integer)edgePair.getSecond());
        int sourceNodeId = sourceNode.getId();
        int targetNodeId = targetNode.getId();
        Collection sourceLabels = this.fn.getPossibleLabels(sourceNodeId);
        Collection targetLabels = this.fn.getPossibleLabels(targetNodeId);
        int size = sourceLabels.size() * targetLabels.size();
        double[] values = new double[size];
        double max = -1.7976931348623157E308;
        int ij = 0;
        for (Object targetLabel : targetLabels) {
            int sourceLabelIdx = 0;
            for (Object sourceLabel : sourceLabels) {
                values[ij] = -this.fn.getUnaryCost(sourceNodeId, sourceLabel);
                int n = ij;
                values[n] = values[n] + -this.fn.getPairwiseCost(edge, sourceLabel, targetLabel);
                int n2 = ij;
                values[n2] = values[n2] + sourceNode.getLogMessageSum(sourceLabelIdx, targetNodeId);
                max = Math.max(values[ij], max);
                ++sourceLabelIdx;
                ++ij;
            }
        }
        ij = 0;
        Message targetMessage = targetNode.getMessageFromSource(sourceNodeId);
        for (int i = 0; i < targetLabels.size(); ++i) {
            double value = 0.0;
            for (int j = 0; j < sourceLabels.size(); ++j) {
                value += Math.exp(values[ij] - max);
                ++ij;
            }
            targetMessage.setTempValue(i, value);
        }
    }
}

