/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.text.topic;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.DiscreteSamplingUtil;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.Randomized;
import java.util.Collection;
import java.util.Iterator;
import java.util.Random;

@PublicationReferences(references={@PublicationReference(author={"David M. Blei", "Andrew Y. Ng", "Michael I. Jordan"}, title="Latent Dirichlet Allocation", year=2003, type=PublicationType.Journal, publication="Journal of Machine Learning Research", pages={993, 1022}, url="http://www.cs.princeton.edu/~blei/papers/BleiNgJordan2003.pdf"), @PublicationReference(author={"Gregor Heinrich"}, title="Parameter estimation for text analysis", year=2009, type=PublicationType.TechnicalReport, url="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.149.1327&rep=rep1&type=pdf")})
public class LatentDirichletAllocationVectorGibbsSampler
extends AbstractAnytimeBatchLearner<Collection<? extends Vectorizable>, Result>
implements Randomized {
    public static final int DEFAULT_TOPIC_COUNT = 10;
    public static final double DEFAULT_ALPHA = 5.0;
    public static final double DEFAULT_BETA = 0.5;
    public static final int DEFAULT_MAX_ITERATIONS = 10000;
    public static final int DEFAULT_BURN_IN_ITERATIONS = 2000;
    public static final int DEFAULT_ITERATIONS_PER_SAMPLE = 100;
    protected int topicCount;
    protected double alpha;
    protected double beta;
    protected int burnInIterations;
    protected int iterationsPerSample;
    protected Random random;
    protected transient int documentCount;
    protected transient int termCount;
    protected transient int[][] documentTopicCount;
    protected transient int[] documentTopicSum;
    protected transient int[][] topicTermCount;
    protected transient int[] topicTermSum;
    protected transient int[] occurrenceTopicAssignments;
    protected transient int[] documentTermPairsCounts;
    protected transient int[] documentTerms;
    protected transient int[] documentTermCounts;
    protected transient double[] topicCumulativeProportions;
    protected transient int sampleCount;
    protected transient Result result;

    public LatentDirichletAllocationVectorGibbsSampler() {
        this(10, 5.0, 0.5, 10000, 2000, 100, new Random());
    }

    public LatentDirichletAllocationVectorGibbsSampler(int topicCount, double alpha, double beta, int maxIterations, int burnInIterations, int iterationsPerSample, Random random) {
        super(maxIterations);
        this.setTopicCount(topicCount);
        this.setAlpha(alpha);
        this.setBeta(beta);
        this.setBurnInIterations(burnInIterations);
        this.setIterationsPerSample(iterationsPerSample);
        this.setRandom(random);
    }

    private static int intNorm1(Vector v) {
        int ret = 0;
        for (int i = 0; i < v.getDimensionality(); ++i) {
            ret = (int)((double)ret + Math.floor(v.getElement(i)));
        }
        return ret;
    }

    protected boolean initializeAlgorithm() {
        if (CollectionUtil.isEmpty((Collection)((Collection)this.data))) {
            return false;
        }
        this.documentCount = ((Collection)this.data).size();
        this.termCount = DatasetUtil.getDimensionality((Iterable)((Iterable)this.data));
        this.documentTopicCount = new int[this.documentCount][this.topicCount];
        this.documentTopicSum = new int[this.documentCount];
        this.topicTermCount = new int[this.topicCount][this.termCount];
        this.topicTermSum = new int[this.topicCount];
        this.topicCumulativeProportions = new double[this.topicCount];
        this.sampleCount = 0;
        long totalOccurrences = 0L;
        int documentTermPairsCount = 0;
        for (Vectorizable m : (Collection)this.data) {
            Vector vector = m.convertToVector();
            int documentOccurrences = LatentDirichletAllocationVectorGibbsSampler.intNorm1(m.convertToVector());
            totalOccurrences += (long)documentOccurrences;
            Iterator iterator = vector.iterator();
            while (iterator.hasNext()) {
                VectorEntry v = (VectorEntry)iterator.next();
                int count = (int)v.getValue();
                if (count <= 0) continue;
                ++documentTermPairsCount;
            }
        }
        if (totalOccurrences > Integer.MAX_VALUE) {
            throw new RuntimeException("The number of occurrences cannot exceed the maximum number of slots in an array (Integer.MAX_VALUE)");
        }
        this.occurrenceTopicAssignments = new int[(int)totalOccurrences];
        this.documentTermPairsCounts = new int[this.documentCount];
        this.documentTerms = new int[documentTermPairsCount];
        this.documentTermCounts = new int[documentTermPairsCount];
        int document = 0;
        int documentTermPairsIndex = 0;
        for (Vectorizable m : (Collection)this.data) {
            int termsInDocument = 0;
            Vector vector = m.convertToVector();
            for (VectorEntry v : vector) {
                int term = v.getIndex();
                int count = (int)v.getValue();
                if (count <= 0) continue;
                this.documentTerms[documentTermPairsIndex] = term;
                this.documentTermCounts[documentTermPairsIndex] = count;
                ++termsInDocument;
                ++documentTermPairsIndex;
            }
            this.documentTermPairsCounts[document] = termsInDocument;
            ++document;
        }
        if (documentTermPairsIndex != documentTermPairsCount) {
            throw new RuntimeException("The two loops didn't count the same number of terms (" + documentTermPairsCount + " != " + documentTermPairsIndex + ")");
        }
        int docTermIndex = 0;
        int occurrence = 0;
        for (document = 0; document < this.documentTermPairsCounts.length; ++document) {
            int docUniqueTerms = this.documentTermPairsCounts[document];
            for (int docUniqueTerm = 0; docUniqueTerm < docUniqueTerms; ++docUniqueTerm) {
                int term = this.documentTerms[docTermIndex];
                int count = this.documentTermCounts[docTermIndex];
                for (int i = 0; i < count; ++i) {
                    int topic = this.random.nextInt(this.topicCount);
                    int[] nArray = this.documentTopicCount[document];
                    int n = topic;
                    nArray[n] = nArray[n] + 1;
                    int n2 = document;
                    this.documentTopicSum[n2] = this.documentTopicSum[n2] + 1;
                    int[] nArray2 = this.topicTermCount[topic];
                    int n3 = term;
                    nArray2[n3] = nArray2[n3] + 1;
                    int n4 = topic;
                    this.topicTermSum[n4] = this.topicTermSum[n4] + 1;
                    this.occurrenceTopicAssignments[occurrence] = topic;
                    ++occurrence;
                }
                ++docTermIndex;
            }
        }
        if (occurrence != this.occurrenceTopicAssignments.length) {
            throw new RuntimeException("Didn't iterate to the end of the occurrenceTopicAssignments array.  occurrence is " + occurrence + " instead of " + this.occurrenceTopicAssignments.length);
        }
        if (docTermIndex != this.documentTerms.length) {
            throw new RuntimeException("Didn't iterate to the end of the documentTerms array.  docTermIndex is " + docTermIndex + " instead of " + this.documentTerms.length);
        }
        this.result = new Result(this.topicCount, this.documentCount, this.termCount, (int)totalOccurrences);
        return true;
    }

    protected boolean step() {
        int docTermIndex = 0;
        int occurrence = 0;
        for (int document = 0; document < this.documentTermPairsCounts.length; ++document) {
            int docUniqueTerms = this.documentTermPairsCounts[document];
            for (int docUniqueTerm = 0; docUniqueTerm < docUniqueTerms; ++docUniqueTerm) {
                int term = this.documentTerms[docTermIndex];
                int count = this.documentTermCounts[docTermIndex];
                for (int i = 0; i < count; ++i) {
                    int newTopic;
                    int oldTopic = this.occurrenceTopicAssignments[occurrence];
                    int[] nArray = this.documentTopicCount[document];
                    int n = oldTopic;
                    nArray[n] = nArray[n] - 1;
                    int n2 = document;
                    this.documentTopicSum[n2] = this.documentTopicSum[n2] - 1;
                    int[] nArray2 = this.topicTermCount[oldTopic];
                    int n3 = term;
                    nArray2[n3] = nArray2[n3] - 1;
                    int n4 = oldTopic;
                    this.topicTermSum[n4] = this.topicTermSum[n4] - 1;
                    this.occurrenceTopicAssignments[occurrence] = newTopic = this.sampleTopic(document, term, this.topicCumulativeProportions);
                    int[] nArray3 = this.documentTopicCount[document];
                    int n5 = newTopic;
                    nArray3[n5] = nArray3[n5] + 1;
                    int n6 = document;
                    this.documentTopicSum[n6] = this.documentTopicSum[n6] + 1;
                    int[] nArray4 = this.topicTermCount[newTopic];
                    int n7 = term;
                    nArray4[n7] = nArray4[n7] + 1;
                    int n8 = newTopic;
                    this.topicTermSum[n8] = this.topicTermSum[n8] + 1;
                    ++occurrence;
                }
                ++docTermIndex;
            }
        }
        if (occurrence != this.occurrenceTopicAssignments.length) {
            throw new RuntimeException("Didn't iterate to the end of the occurrenceTopicAssignments array.  occurrence is " + occurrence + " instead of " + this.occurrenceTopicAssignments.length);
        }
        if (docTermIndex != this.documentTerms.length) {
            throw new RuntimeException("Didn't iterate to the end of the documentTerms array.  docTermIndex is " + docTermIndex + " instead of " + this.documentTerms.length);
        }
        if (this.iteration >= this.burnInIterations && (this.iteration - this.burnInIterations) % this.iterationsPerSample == 0) {
            this.readParameters();
        }
        return true;
    }

    protected int sampleTopic(int document, int term, double[] topicCumulativeProportions) {
        double cumulativeProportionSum = 0.0;
        for (int topic = 0; topic < this.topicCount; ++topic) {
            double numerator = ((double)this.topicTermCount[topic][term] + this.beta) * ((double)this.documentTopicCount[document][topic] + this.alpha);
            double denominator = (double)this.topicTermSum[topic] + (double)this.termCount * this.beta;
            double p = numerator / denominator;
            topicCumulativeProportions[topic] = cumulativeProportionSum += p;
        }
        return DiscreteSamplingUtil.sampleIndexFromCumulativeProportions((Random)this.random, (double[])topicCumulativeProportions);
    }

    protected void cleanupAlgorithm() {
        if (this.sampleCount <= 0) {
            this.readParameters();
        } else if (this.sampleCount > 1) {
            for (int topic = 0; topic < this.topicCount; ++topic) {
                int term = 0;
                while (term < this.termCount) {
                    double[] dArray = this.result.topicTermProbabilities[topic];
                    int n = term++;
                    dArray[n] = dArray[n] / (double)this.sampleCount;
                }
            }
            for (int document = 0; document < this.documentCount; ++document) {
                int topic = 0;
                while (topic < this.topicCount) {
                    double[] dArray = this.result.documentTopicProbabilities[document];
                    int n = topic++;
                    dArray[n] = dArray[n] / (double)this.sampleCount;
                }
            }
        }
    }

    protected void readParameters() {
        ++this.sampleCount;
        double termCountTimesBeta = (double)this.termCount * this.beta;
        for (int topic = 0; topic < this.topicCount; ++topic) {
            for (int term = 0; term < this.termCount; ++term) {
                double[] dArray = this.result.topicTermProbabilities[topic];
                int n = term;
                dArray[n] = dArray[n] + ((double)this.topicTermCount[topic][term] + this.beta) / ((double)this.topicTermSum[topic] + termCountTimesBeta);
            }
        }
        double topicCountTimesAlpha = (double)this.topicCount * this.alpha;
        for (int document = 0; document < this.documentCount; ++document) {
            for (int topic = 0; topic < this.topicCount; ++topic) {
                double[] dArray = this.result.documentTopicProbabilities[document];
                int n = topic;
                dArray[n] = dArray[n] + ((double)this.documentTopicCount[document][topic] + this.alpha) / ((double)this.documentTopicSum[document] + topicCountTimesAlpha);
            }
        }
    }

    public Result getResult() {
        return this.result;
    }

    public int getTopicCount() {
        return this.topicCount;
    }

    public void setTopicCount(int topicCount) {
        ArgumentChecker.assertIsPositive((String)"topicCount", (int)topicCount);
        this.topicCount = topicCount;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        ArgumentChecker.assertIsPositive((String)"alpha", (double)alpha);
        this.alpha = alpha;
    }

    public double getBeta() {
        return this.beta;
    }

    public void setBeta(double beta) {
        ArgumentChecker.assertIsPositive((String)"beta", (double)beta);
        this.beta = beta;
    }

    public int getBurnInIterations() {
        return this.burnInIterations;
    }

    public void setBurnInIterations(int burnInIterations) {
        ArgumentChecker.assertIsNonNegative((String)"burnInIterations", (int)burnInIterations);
        this.burnInIterations = burnInIterations;
    }

    public int getIterationsPerSample() {
        return this.iterationsPerSample;
    }

    public void setIterationsPerSample(int iterationsPerSample) {
        ArgumentChecker.assertIsPositive((String)"iterationsPerSample", (int)iterationsPerSample);
        this.iterationsPerSample = iterationsPerSample;
    }

    public Random getRandom() {
        return this.random;
    }

    public void setRandom(Random random) {
        this.random = random;
    }

    public int getDocumentCount() {
        return this.documentCount;
    }

    public int getTermCount() {
        return this.termCount;
    }

    public static class Result
    extends AbstractCloneableSerializable {
        protected double[][] topicTermProbabilities;
        protected double[][] documentTopicProbabilities;
        protected int totalOccurrences;

        public Result(int topicCount, int documentCount, int termCount, int totalOccurrences) {
            this.topicTermProbabilities = new double[topicCount][termCount];
            this.documentTopicProbabilities = new double[documentCount][topicCount];
            this.totalOccurrences = totalOccurrences;
        }

        public int getTopicCount() {
            return this.topicTermProbabilities.length;
        }

        public int getDocumentCount() {
            return this.documentTopicProbabilities.length;
        }

        public int getTermCount() {
            return this.topicTermProbabilities[0].length;
        }

        public int getTotalOccurrences() {
            return this.totalOccurrences;
        }

        public double[][] getDocumentTopicProbabilities() {
            return this.documentTopicProbabilities;
        }

        public void setDocumentTopicProbabilities(double[][] documentTopicProbabilities) {
            this.documentTopicProbabilities = documentTopicProbabilities;
        }

        public double[][] getTopicTermProbabilities() {
            return this.topicTermProbabilities;
        }

        public void setTopicTermProbabilities(double[][] topicTermProbabilities) {
            this.topicTermProbabilities = topicTermProbabilities;
        }
    }
}

