/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.modeler.train;

import com.google.common.collect.ImmutableMap;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.mmds.api.AlgorithmType;
import io.cdap.mmds.api.Modeler;
import io.cdap.mmds.data.EvaluationMetrics;
import io.cdap.mmds.data.ModelTrainerInfo;
import io.cdap.mmds.modeler.Modelers;
import io.cdap.mmds.modeler.feature.FeatureGeneratorTrainer;
import io.cdap.mmds.modeler.train.ModelOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class ModelTrainer {
    private static final Logger LOG = LoggerFactory.getLogger((String)ModelTrainer.class.getName());
    private final String algorithm;
    private final String outcomeField;
    private final Schema.Type outcomeType;
    private final Map<String, String> trainingParams;
    private final List<String> featureNames;
    private final Set<String> categoricalFeatures;
    private final Schema schema;

    public ModelTrainer(ModelTrainerInfo modelTrainerInfo) {
        this.algorithm = modelTrainerInfo.getModel().getAlgorithm();
        this.trainingParams = ImmutableMap.copyOf(modelTrainerInfo.getModel().getHyperparameters());
        this.schema = modelTrainerInfo.getDataSplitStats().getSchema();
        this.outcomeField = modelTrainerInfo.getExperiment().getOutcome();
        this.outcomeType = Schema.Type.valueOf((String)modelTrainerInfo.getExperiment().getOutcomeType().toUpperCase());
        this.featureNames = new ArrayList<String>();
        this.categoricalFeatures = new HashSet<String>();
        for (Schema.Field field : this.schema.getFields()) {
            String fieldName = field.getName();
            if (fieldName.equals(this.outcomeField)) continue;
            this.featureNames.add(fieldName);
            Schema fieldSchema = field.getSchema();
            Schema.Type fieldType = fieldSchema.isNullable() ? fieldSchema.getNonNullable().getType() : fieldSchema.getType();
            if (!this.isCategorical(fieldType)) continue;
            this.categoricalFeatures.add(fieldName);
        }
    }

    private boolean isCategorical(Schema.Type type) {
        return type == Schema.Type.STRING || type == Schema.Type.BOOLEAN;
    }

    public ModelOutput train(Dataset<Row> training, Dataset<Row> test) throws IOException {
        EvaluationMetrics evaluationMetrics;
        Dataset rawTraining = training.na().drop(new String[]{this.outcomeField});
        Dataset rawTest = test.na().drop(new String[]{this.outcomeField});
        LOG.info("Generating features for training and test data.");
        FeatureGeneratorTrainer featureGenerator = new FeatureGeneratorTrainer(this.featureNames, this.categoricalFeatures);
        Dataset trainingFeatures = featureGenerator.generateFeatures((Dataset<Row>)rawTraining, this.outcomeField);
        LOG.info("Training features successfully generated.");
        Dataset testFeatures = featureGenerator.generateFeatures((Dataset<Row>)rawTest, this.outcomeField);
        LOG.info("Test features successfully generated.");
        String finalOutcomeField = this.outcomeField;
        StringIndexerModel targetIndexModel = null;
        boolean isCategoricalOutput = this.isCategorical(this.outcomeType);
        String numericPredictionField = "_prediction";
        if (isCategoricalOutput) {
            String strOutcomeField = this.outcomeField;
            if (this.outcomeType == Schema.Type.BOOLEAN) {
                strOutcomeField = "_c_" + this.outcomeField;
                Column outcomeAsStr = new Column(this.outcomeField).cast(DataTypes.StringType);
                trainingFeatures = trainingFeatures.withColumn(strOutcomeField, outcomeAsStr);
                testFeatures = testFeatures.withColumn(strOutcomeField, outcomeAsStr);
            }
            finalOutcomeField = "_t_" + this.outcomeField;
            StringIndexer targetIndexer = new StringIndexer().setInputCol(strOutcomeField).setOutputCol(finalOutcomeField);
            targetIndexModel = targetIndexer.fit(trainingFeatures);
            trainingFeatures = targetIndexModel.transform(trainingFeatures);
            testFeatures = targetIndexModel.transform(testFeatures);
            numericPredictionField = "_n_" + numericPredictionField;
        }
        Modeler modeler = Modelers.getModeler(this.algorithm);
        Predictor predictor = modeler.createPredictor(this.trainingParams);
        predictor.setLabelCol(finalOutcomeField);
        predictor.setFeaturesCol("_features");
        predictor.setPredictionCol(numericPredictionField);
        LOG.info("Training model...");
        PredictionModel model = predictor.fit(trainingFeatures);
        LOG.info("Model successfully trained.");
        LOG.info("Generating predictions on test data.");
        Dataset predictions = model.transform(testFeatures);
        LOG.info("Predictions successfully generated.");
        if (this.isCategorical(this.outcomeType)) {
            String[] labels = targetIndexModel.labels();
            IndexToString reverseIndex = new IndexToString().setLabels(labels).setInputCol(numericPredictionField).setOutputCol("_prediction");
            predictions = reverseIndex.transform(predictions);
        }
        LOG.info("Calculating evaluation metrics...");
        RDD predictionAndLabels = predictions.select(new Column[]{new Column(numericPredictionField), new Column(finalOutcomeField).cast(DataTypes.DoubleType)}).toJavaRDD().map((Function)new PredictionLabelFunction()).rdd();
        try {
            MulticlassMetrics metrics;
            if (modeler.getAlgorithm().getType() == AlgorithmType.REGRESSION) {
                metrics = new RegressionMetrics(predictionAndLabels, false);
                double rmse = metrics.rootMeanSquaredError();
                double r2 = metrics.r2();
                double mae = metrics.meanAbsoluteError();
                double explainedVariance = metrics.explainedVariance();
                LOG.info("root mean squared error = {}, r2 = {}, mean absolute error = {}, explained variance = {}", new Object[]{rmse, r2, mae, explainedVariance});
                evaluationMetrics = new EvaluationMetrics(rmse, r2, explainedVariance, mae);
            } else {
                metrics = new MulticlassMetrics(predictionAndLabels);
                double precision = metrics.weightedPrecision();
                double recall = metrics.weightedRecall();
                double f1 = metrics.weightedFMeasure();
                LOG.info("precision = {}, recall = {}, f1 = {}", new Object[]{precision, recall, f1});
                evaluationMetrics = new EvaluationMetrics(precision, recall, f1);
            }
        }
        catch (IllegalArgumentException e) {
            throw new RuntimeException("Failed to get evaluation metrics for the model. Please check the logs for warnings or errors related to training problems. If there were training problems, please check that your features are the correct type. String features should represent categories, with multiple records for each category value. For example, an ID should not be used as a feature, as there is a unique value for each record.", e);
        }
        Column[] columns = new Column[this.schema.getFields().size() + 1];
        columns[0] = new Column("_prediction");
        int i = 1;
        for (Schema.Field field : this.schema.getFields()) {
            columns[i] = new Column(field.getName());
            ++i;
        }
        Dataset predictionsClean = predictions.select(columns);
        return ModelOutput.builder().setTargetIndexModel(targetIndexModel).setFeatureGenModel(featureGenerator.getFeatureGenModel()).setModel((MLWritable)model).setEvaluationMetrics(evaluationMetrics).setFeatureNames(this.featureNames).setCategoricalFeatures(this.categoricalFeatures).setPredictions(predictionsClean).setAlgorithmType(modeler.getAlgorithm().getType()).setSchema(this.schema).build();
    }

    private static class PredictionLabelFunction
    implements Function<Row, Tuple2> {
        private PredictionLabelFunction() {
        }

        public Tuple2 call(Row row) throws Exception {
            return new Tuple2(row.get(0), row.get(1));
        }
    }
}

