/*
 * Decompiled with CFR 0.152.
 */
package co.cask.mmds.data;

import co.cask.mmds.NullableMath;
import co.cask.mmds.data.SplitCountVal;
import co.cask.mmds.data.SplitHistogramBin;
import co.cask.mmds.data.SplitVal;
import co.cask.mmds.stats.CategoricalHisto;
import co.cask.mmds.stats.NumericBin;
import co.cask.mmds.stats.NumericHisto;
import com.google.common.collect.Sets;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class ColumnSplitStats {
    private static final DecimalFormat NOTATION_FORMAT = new DecimalFormat("0.00E0");
    private static final DecimalFormat DECIMAL_FORMAT = new DecimalFormat("###.####");
    private final String field;
    private final SplitVal<Long> numTotal;
    private final SplitVal<Long> numNull;
    private final SplitVal<Long> numEmpty;
    private final SplitVal<Long> unique;
    private final SplitVal<Long> numZero;
    private final SplitVal<Long> numPositive;
    private final SplitVal<Long> numNegative;
    private final SplitVal<Double> min;
    private final SplitVal<Double> max;
    private final SplitVal<Double> mean;
    private final SplitVal<Double> stddev;
    private final List<SplitHistogramBin> histo;
    private final double divergence;

    public ColumnSplitStats(String field, SplitVal<Long> numTotal, SplitVal<Long> numNull, SplitVal<Long> numEmpty, SplitVal<Long> unique, SplitVal<Long> numZero, SplitVal<Long> numPositive, SplitVal<Long> numNegative, SplitVal<Double> min, SplitVal<Double> max, SplitVal<Double> mean, SplitVal<Double> stddev, List<SplitHistogramBin> histo) {
        this.field = field;
        this.numTotal = numTotal;
        this.numNull = numNull;
        this.numEmpty = numEmpty;
        this.unique = unique;
        this.numZero = numZero;
        this.numPositive = numPositive;
        this.numNegative = numNegative;
        this.min = min;
        this.max = max;
        this.mean = mean;
        this.stddev = stddev;
        this.histo = histo;
        double div = 0.0;
        double trainNonNull = numTotal.getTrain() - numNull.getTrain() + (long)histo.size();
        double testNonNull = numTotal.getTest() - numNull.getTest() + (long)histo.size();
        for (SplitHistogramBin bin : histo) {
            double trainProbability = (double)(1L + bin.getCount().getTrain()) / trainNonNull;
            double testProbability = (double)(1L + bin.getCount().getTest()) / testNonNull;
            div += testProbability * Math.log(testProbability / trainProbability);
        }
        this.divergence = Math.max(0.0, Math.min(1.0, div));
    }

    public ColumnSplitStats(String field, NumericHisto train, NumericHisto test) {
        this(field, new SplitCountVal(train.getTotalCount(), test.getTotalCount()), new SplitCountVal(train.getNullCount(), test.getNullCount()), null, null, new SplitCountVal(train.getZeroCount(), test.getZeroCount()), new SplitCountVal(train.getPositiveCount(), test.getPositiveCount()), new SplitCountVal(train.getNegativeCount(), test.getNegativeCount()), new SplitVal<Double>(train.getMin(), test.getMin(), NullableMath.min(train.getMin(), test.getMin())), new SplitVal<Double>(train.getMax(), test.getMax(), NullableMath.max(train.getMax(), test.getMax())), new SplitVal<Double>(train.getMean(), test.getMean(), NullableMath.mean(train.getMean(), train.getNonNullCount(), test.getMean(), test.getNonNullCount())), new SplitVal<Double>(train.getStddev(), test.getStddev(), NullableMath.stddev(train.getM2(), train.getMean(), train.getNonNullCount(), test.getM2(), test.getMean(), test.getNonNullCount())), ColumnSplitStats.convert(train, test));
    }

    public ColumnSplitStats(String field, CategoricalHisto train, CategoricalHisto test) {
        this(field, new SplitCountVal(train.getTotalCount(), test.getTotalCount()), new SplitCountVal(train.getNullCount(), test.getNullCount()), new SplitCountVal(train.getEmptyCount(), test.getEmptyCount()), new SplitVal<Long>(Long.valueOf(train.getCounts().size()), Long.valueOf(test.getCounts().size()), Long.valueOf(Sets.union(train.getCounts().keySet(), test.getCounts().keySet()).size())), null, null, null, null, null, null, null, ColumnSplitStats.convert(train, test));
    }

    public List<SplitHistogramBin> getHisto() {
        return this.histo;
    }

    public String getField() {
        return this.field;
    }

    public SplitVal<Long> getNumTotal() {
        return this.numTotal;
    }

    public SplitVal<Long> getNumNull() {
        return this.numNull;
    }

    public SplitVal<Long> getNumEmpty() {
        return this.numEmpty;
    }

    public SplitVal<Long> getUnique() {
        return this.unique;
    }

    public SplitVal<Long> getNumZero() {
        return this.numZero;
    }

    public SplitVal<Long> getNumPositive() {
        return this.numPositive;
    }

    public SplitVal<Long> getNumNegative() {
        return this.numNegative;
    }

    public SplitVal<Double> getMin() {
        return this.min;
    }

    public SplitVal<Double> getMax() {
        return this.max;
    }

    public SplitVal<Double> getMean() {
        return this.mean;
    }

    public SplitVal<Double> getStddev() {
        return this.stddev;
    }

    public double getDivergence() {
        return this.divergence;
    }

    private static List<SplitHistogramBin> convert(NumericHisto train, NumericHisto test) {
        if (train.getBins().size() != test.getBins().size()) {
            throw new IllegalArgumentException("Cannot combine numeric histograms with different bins.");
        }
        ArrayList<SplitHistogramBin> bins = new ArrayList<SplitHistogramBin>(train.getBins().size());
        Iterator<NumericBin> trainBins = train.getBins().iterator();
        Iterator<NumericBin> testBins = test.getBins().iterator();
        while (trainBins.hasNext()) {
            NumericBin bin1 = trainBins.next();
            NumericBin bin2 = testBins.next();
            if (bin1.getLo() != bin2.getLo() || bin1.getHi() != bin2.getHi() || bin1.isHiInclusive() != bin2.isHiInclusive()) {
                throw new IllegalArgumentException("Cannot combine numeric histograms with different bins. Bin1 = " + ColumnSplitStats.format(bin1) + ", Bin2 = " + ColumnSplitStats.format(bin2));
            }
            String binStr = ColumnSplitStats.format(bin1);
            bins.add(new SplitHistogramBin(binStr, new SplitCountVal(bin1.getCount(), bin2.getCount())));
        }
        return bins;
    }

    public static List<SplitHistogramBin> convert(CategoricalHisto train, CategoricalHisto test) {
        String category;
        ArrayList<SplitHistogramBin> bins = new ArrayList<SplitHistogramBin>(train.getCounts().size());
        for (Map.Entry<String, Long> trainEntry : train.getCounts().entrySet()) {
            category = trainEntry.getKey();
            Long trainCount = trainEntry.getValue();
            Long testCount = test.getCounts().get(category);
            bins.add(new SplitHistogramBin(category, new SplitCountVal(trainCount, testCount == null ? 0L : testCount)));
        }
        for (Map.Entry<String, Long> testEntry : test.getCounts().entrySet()) {
            category = testEntry.getKey();
            Long testCount = testEntry.getValue();
            if (train.getCounts().containsKey(category)) continue;
            bins.add(new SplitHistogramBin(category, new SplitCountVal(0L, testCount)));
        }
        bins.sort((h1, h2) -> {
            int cmp = Long.compare(h2.getCount().getTrain(), h1.getCount().getTrain());
            if (cmp != 0) {
                return cmp;
            }
            cmp = Long.compare(h2.getCount().getTest(), h1.getCount().getTest());
            if (cmp != 0) {
                return cmp;
            }
            return h1.getBin().compareTo(h2.getBin());
        });
        return bins;
    }

    private static String format(NumericBin bin) {
        return String.format(bin.isHiInclusive() ? "[%s,%s]" : "[%s,%s)", ColumnSplitStats.format(bin.getLo()), ColumnSplitStats.format(bin.getHi()));
    }

    private static String format(double val) {
        double mag = Math.abs(val);
        return (mag > 1000.0 || mag < 0.001 && mag > 0.0 ? NOTATION_FORMAT : DECIMAL_FORMAT).format(val);
    }

    static {
        NOTATION_FORMAT.setRoundingMode(RoundingMode.HALF_UP);
    }
}

