/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.graph.inference;

import gov.sandia.cognition.graph.inference.EnergyFunctionSolver;
import gov.sandia.cognition.graph.inference.NodeNameAwareEnergyFunction;
import gov.sandia.cognition.util.Pair;
import java.util.Collection;
import java.util.Map;

public class CostSpeedupEnergyFunction<LabelType, NodeNameType>
implements NodeNameAwareEnergyFunction<LabelType, NodeNameType> {
    private final NodeNameAwareEnergyFunction<LabelType, NodeNameType> wrapped;
    private final double[][] pairwiseCosts;
    private final double[][] unaryCosts;

    public CostSpeedupEnergyFunction(NodeNameAwareEnergyFunction<LabelType, NodeNameType> wrapme) {
        this.wrapped = wrapme;
        int m = this.wrapped.numEdges();
        this.pairwiseCosts = new double[m][];
        for (int i = 0; i < m; ++i) {
            Pair<Integer, Integer> edge = wrapme.getEdge(i);
            int srcLabelsCnt = wrapme.getPossibleLabels((Integer)edge.getFirst()).size();
            int dstLabelsCnt = wrapme.getPossibleLabels((Integer)edge.getSecond()).size();
            int size = srcLabelsCnt * dstLabelsCnt;
            this.pairwiseCosts[i] = new double[size];
            for (int j = 0; j < size; ++j) {
                this.pairwiseCosts[i][j] = Double.MAX_VALUE;
            }
        }
        int n = this.wrapped.numNodes();
        this.unaryCosts = new double[n][];
        for (int i = 0; i < n; ++i) {
            int size = wrapme.getPossibleLabels(i).size();
            this.unaryCosts[i] = new double[size];
            for (int j = 0; j < size; ++j) {
                this.unaryCosts[i][j] = Double.MAX_VALUE;
            }
        }
    }

    public void clearStoredCosts() {
        int m = this.wrapped.numEdges();
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < this.pairwiseCosts[i].length; ++j) {
                this.pairwiseCosts[i][j] = Double.MAX_VALUE;
            }
        }
        int n = this.wrapped.numNodes();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < this.unaryCosts[i].length; ++j) {
                this.unaryCosts[i][j] = Double.MAX_VALUE;
            }
        }
    }

    @Override
    public void setLabel(NodeNameType node, LabelType label) {
        this.wrapped.setLabel(node, label);
    }

    @Override
    public Map<LabelType, Double> getBeliefs(NodeNameType node, EnergyFunctionSolver<LabelType> bp) {
        return this.wrapped.getBeliefs(node, bp);
    }

    @Override
    public Collection<LabelType> getPossibleLabels(int nodeId) {
        return this.wrapped.getPossibleLabels(nodeId);
    }

    @Override
    public int numEdges() {
        return this.wrapped.numEdges();
    }

    @Override
    public int numNodes() {
        return this.wrapped.numNodes();
    }

    @Override
    public Pair<Integer, Integer> getEdge(int i) {
        return this.wrapped.getEdge(i);
    }

    @Override
    public double getUnaryPotential(int i, LabelType label) {
        return this.wrapped.getUnaryPotential(i, label);
    }

    @Override
    public double getPairwisePotential(int edgeId, LabelType ilabel, LabelType jlabel) {
        return this.wrapped.getPairwisePotential(edgeId, ilabel, jlabel);
    }

    private static <LabelType> int indexOf(LabelType label, Collection<LabelType> labels) {
        int idx = 0;
        for (LabelType l : labels) {
            if (l.equals(label)) {
                return idx;
            }
            ++idx;
        }
        throw new RuntimeException("Unable to find input label (" + label + ") in input");
    }

    @Override
    public double getUnaryCost(int i, LabelType label) {
        Collection labels = this.wrapped.getPossibleLabels(i);
        int idx = CostSpeedupEnergyFunction.indexOf(label, labels);
        if (this.unaryCosts[i][idx] == Double.MAX_VALUE) {
            this.unaryCosts[i][idx] = this.wrapped.getUnaryCost(i, label);
        }
        return this.unaryCosts[i][idx];
    }

    @Override
    public double getPairwiseCost(int edgeId, LabelType ilabel, LabelType jlabel) {
        Pair<Integer, Integer> endpoints = this.wrapped.getEdge(edgeId);
        Collection ilabels = this.wrapped.getPossibleLabels((Integer)endpoints.getFirst());
        Collection jlabels = this.wrapped.getPossibleLabels((Integer)endpoints.getSecond());
        int idx = CostSpeedupEnergyFunction.indexOf(ilabel, ilabels) * jlabels.size() + CostSpeedupEnergyFunction.indexOf(jlabel, jlabels);
        if (this.pairwiseCosts[edgeId][idx] == Double.MAX_VALUE) {
            this.pairwiseCosts[edgeId][idx] = this.wrapped.getPairwiseCost(edgeId, ilabel, jlabel);
        }
        return this.pairwiseCosts[edgeId][idx];
    }
}

