/*
 * Decompiled with CFR 0.152.
 */
package org.opencb.hpg.bigdata.analysis.variant.statistics;

import java.security.InvalidParameterException;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.opencb.hpg.bigdata.analysis.variant.statistics.CMCResult;

public class CMC {
    private CMCResult result = new CMCResult();

    public CMCResult run(RealVector phenotype, RealMatrix genotype) {
        return this.run(phenotype, genotype, 0.05, 100);
    }

    public CMCResult run(RealVector phenotype, RealMatrix genotype, double maf, int numPermutations) {
        Array2DRowRealMatrix newGenotype;
        this.result.setMaf(maf);
        this.result.setNumPermutations(numPermutations);
        int numVariants = genotype.getColumnDimension();
        this.result.setNumVariants(numVariants);
        int numIndividuals = phenotype.getDimension();
        int rare = 0;
        ArrayRealVector mafArray = new ArrayRealVector(numVariants);
        boolean[] isRare = new boolean[numVariants];
        for (int i = 0; i < numIndividuals; ++i) {
            if (numVariants != genotype.getRow(i).length) {
                throw new InvalidParameterException("Number of variants mismatch!");
            }
            for (int j = 0; j < numVariants; ++j) {
                mafArray.setEntry(j, mafArray.getEntry(j) + genotype.getEntry(i, j));
            }
        }
        for (int j = 0; j < numVariants; ++j) {
            mafArray.setEntry(j, mafArray.getEntry(j) / (double)numIndividuals / 2.0);
            if (mafArray.getEntry(j) < maf) {
                isRare[j] = true;
                ++rare;
                continue;
            }
            isRare[j] = false;
        }
        this.result.setNumRareVariants(rare);
        if (rare <= 1) {
            newGenotype = new Array2DRowRealMatrix(genotype.getData());
        } else {
            double[] collapse = new double[numIndividuals];
            for (int i = 0; i < numIndividuals; ++i) {
                double sum = 0.0;
                double[] row = genotype.getRow(i);
                for (int j = 0; j < row.length; ++j) {
                    if (!isRare[j]) continue;
                    sum += row[j];
                }
                collapse[i] = sum != 0.0 ? 1 : 0;
            }
            newGenotype = new Array2DRowRealMatrix(new double[numIndividuals][numVariants - rare + 1]);
            int col = 0;
            for (int j = 0; j < numVariants; ++j) {
                if (isRare[j]) continue;
                newGenotype.setColumn(col, genotype.getColumn(j));
                ++col;
            }
            newGenotype.setColumn(col, collapse);
        }
        newGenotype = newGenotype.scalarAdd(-1.0);
        int numNewVariants = newGenotype.getColumnDimension();
        double stat = this.computeStatistic(phenotype, (RealMatrix)newGenotype);
        this.result.setStatistic(stat);
        double fStat = stat * (double)(numIndividuals - numNewVariants - 1) / (double)(numNewVariants * (numIndividuals - 2));
        int df1 = numNewVariants;
        int df2 = numIndividuals - numNewVariants - 1;
        FDistribution fDistribution = new FDistribution((double)df1, (double)df2);
        this.result.setAsymPvalue(1.0 - fDistribution.cumulativeProbability(fStat));
        int counter = 0;
        ArrayRealVector newPhenotype = new ArrayRealVector(phenotype.getDimension());
        if (numPermutations > 0) {
            RandomDataGenerator random = new RandomDataGenerator();
            for (int i = 0; i < numPermutations; ++i) {
                int[] sample = random.nextPermutation(numIndividuals, numIndividuals);
                for (int j = 0; j < sample.length; ++j) {
                    newPhenotype.setEntry(j, phenotype.getEntry(sample[j]));
                }
                double newStat = this.computeStatistic((RealVector)newPhenotype, (RealMatrix)newGenotype);
                if (!(newStat > stat)) continue;
                ++counter;
            }
            this.result.setPermPvalue(1.0 * (double)counter / (double)numPermutations);
        }
        return this.result;
    }

    private double computeStatistic(RealVector phenotype, RealMatrix genotype) {
        RealMatrix invCov;
        double stat = 0.0;
        int numIndividuals = genotype.getRowDimension();
        int nA = 0;
        for (int i = 0; i < phenotype.getDimension(); ++i) {
            nA = (int)((double)nA + phenotype.getEntry(i));
        }
        this.result.setNumCases(nA);
        int nU = numIndividuals - nA;
        this.result.setNumControls(nU);
        int xRow = 0;
        int yRow = 0;
        Array2DRowRealMatrix xX = new Array2DRowRealMatrix(nA, genotype.getColumnDimension());
        Array2DRowRealMatrix yY = new Array2DRowRealMatrix(nU, genotype.getColumnDimension());
        for (int row = 0; row < phenotype.getDimension(); ++row) {
            if (phenotype.getEntry(row) == 1.0) {
                xX.setRow(xRow++, genotype.getRow(row));
                continue;
            }
            yY.setRow(yRow++, genotype.getRow(row));
        }
        double[] xXMean = this.colMeans((RealMatrix)xX);
        double[] yYMean = this.colMeans((RealMatrix)yY);
        RealMatrix dX = this.substractVector((RealMatrix)xX, xXMean);
        RealMatrix dY = this.substractVector((RealMatrix)yY, yYMean);
        RealMatrix cov = dX.transpose().multiply(dX).add(dY.transpose().multiply(dY)).scalarMultiply(1.0 / (double)(numIndividuals - 2));
        if (cov.getRowDimension() == 1) {
            for (int row = 0; row < cov.getRowDimension(); ++row) {
                for (int col = 0; col < cov.getColumnDimension(); ++col) {
                    if (!(cov.getEntry(row, col) < 1.0E-8)) continue;
                    cov.setEntry(row, col, 1.0E-8);
                }
            }
            invCov = new LUDecomposition(cov).getSolver().getInverse();
        } else {
            EigenDecomposition eigenCOV = new EigenDecomposition(cov);
            double[] eigVals = eigenCOV.getRealEigenvalues();
            double[] invEigVals = new double[eigVals.length];
            for (int i = 0; i < eigVals.length; ++i) {
                invEigVals[i] = Math.abs(eigVals[i]) <= 1.0E-8 ? 0.0 : 1.0 / eigVals[i];
            }
            RealMatrix eigenVectors = eigenCOV.getV();
            DecompositionSolver solver = new LUDecomposition(eigenVectors).getSolver();
            Array2DRowRealMatrix constants = new Array2DRowRealMatrix(eigenVectors.getRowDimension(), eigenVectors.getColumnDimension());
            for (int i = 0; i < constants.getRowDimension(); ++i) {
                constants.setEntry(i, i, 1.0);
            }
            RealMatrix eV = solver.solve((RealMatrix)constants);
            Array2DRowRealMatrix diag = new Array2DRowRealMatrix(cov.getRowDimension(), cov.getColumnDimension());
            for (int i = 0; i < invEigVals.length; ++i) {
                diag.setEntry(i, i, invEigVals[i]);
            }
            invCov = eV.transpose().multiply((RealMatrix)diag).multiply(eV);
        }
        ArrayRealVector xXMeanVector = new ArrayRealVector(xXMean);
        ArrayRealVector yYMeanVector = new ArrayRealVector(yYMean);
        Array2DRowRealMatrix diff = new Array2DRowRealMatrix(xXMeanVector.getDimension(), 1);
        diff.setColumn(0, xXMeanVector.subtract((RealVector)yYMeanVector).toArray());
        RealMatrix statMatrix = diff.transpose().multiply(invCov).multiply((RealMatrix)diff).scalarMultiply(1.0 * (double)nA * (double)nU / (double)numIndividuals);
        return statMatrix.getEntry(0, 0);
    }

    private double[] colMeans(RealMatrix matrix) {
        double[] mean = new double[matrix.getColumnDimension()];
        for (int i = 0; i < matrix.getColumnDimension(); ++i) {
            double sum = 0.0;
            double[] column = matrix.getColumn(i);
            for (int j = 0; j < column.length; ++j) {
                sum += column[j];
            }
            mean[i] = sum / (double)column.length;
        }
        return mean;
    }

    private RealMatrix substractVector(RealMatrix matrix, double[] vector) {
        Array2DRowRealMatrix res = new Array2DRowRealMatrix(matrix.getRowDimension(), matrix.getColumnDimension());
        for (int row = 0; row < res.getRowDimension(); ++row) {
            for (int col = 0; col < res.getColumnDimension(); ++col) {
                res.setEntry(row, col, matrix.getEntry(row, col) - vector[col]);
            }
        }
        return res;
    }
}

