/*
 * Decompiled with CFR 0.152.
 */
package org.opencb.hpg.bigdata.analysis.variant;

import java.util.ArrayList;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
import org.opencb.hpg.bigdata.analysis.variant.VariantAnalysisExecutor;
import org.opencb.hpg.bigdata.core.config.OskarConfiguration;

public class LinearRegressionAnalysis
extends VariantAnalysisExecutor {
    private String depVarName;
    private String indepVarName;
    private int numIterations = 10;
    private double regularization = 0.3;
    private double elasticNet = 0.8;

    @Override
    public void execute() {
        LinearRegression lr = new LinearRegression().setMaxIter(this.numIterations).setRegParam(this.regularization).setElasticNetParam(this.elasticNet);
        int numFeatures = 10;
        double target = Double.NaN;
        double[] features = new double[numFeatures];
        LabeledPoint lp = new LabeledPoint(target, Vectors.dense((double[])features));
        ArrayList<LabeledPoint> list = new ArrayList<LabeledPoint>();
        list.add(lp);
        JavaSparkContext jsc = new JavaSparkContext();
        SQLContext sqlContext = new SQLContext(jsc);
        JavaRDD data = jsc.parallelize(list);
        data.cache();
        Dataset training = sqlContext.createDataFrame(data.rdd(), LabeledPoint.class);
        LinearRegressionModel lrModel = (LinearRegressionModel)lr.fit(training);
        System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
        LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
        System.out.println("numIterations: " + trainingSummary.totalIterations());
        System.out.println("objectiveHistory: " + Vectors.dense((double[])trainingSummary.objectiveHistory()));
        trainingSummary.residuals().show();
        System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
        System.out.println("r2: " + trainingSummary.r2());
    }

    public LinearRegressionAnalysis(String studyId, String depVarName, String indepVarName, OskarConfiguration configuration) {
        this(studyId, depVarName, indepVarName, 10, 0.3, 0.8, configuration);
    }

    public LinearRegressionAnalysis(String studyId, String depVarName, String indepVarName, int numIterations, double regularization, double elasticNet, OskarConfiguration configuration) {
        super(studyId, configuration);
        this.depVarName = depVarName;
        this.indepVarName = indepVarName;
        this.numIterations = numIterations;
        this.regularization = regularization;
        this.elasticNet = elasticNet;
    }

    public String getDepVarName() {
        return this.depVarName;
    }

    public void setDepVarName(String depVarName) {
        this.depVarName = depVarName;
    }

    public String getIndepVarName() {
        return this.indepVarName;
    }

    public void setIndepVarName(String indepVarName) {
        this.indepVarName = indepVarName;
    }

    public int getNumIterations() {
        return this.numIterations;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setRegularization(double regularization) {
        this.regularization = regularization;
    }

    public double getElasticNet() {
        return this.elasticNet;
    }

    public void setElasticNet(double elasticNet) {
        this.elasticNet = elasticNet;
    }
}

