/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.data;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableSet;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.mmds.api.Modeler;
import io.cdap.mmds.data.ColumnSplitStats;
import io.cdap.mmds.data.ColumnStats;
import io.cdap.mmds.data.DataSplit;
import io.cdap.mmds.data.DataSplitInfo;
import io.cdap.mmds.data.DataSplitStats;
import io.cdap.mmds.data.DataSplitTable;
import io.cdap.mmds.data.EvaluationMetrics;
import io.cdap.mmds.data.Experiment;
import io.cdap.mmds.data.ExperimentMetaTable;
import io.cdap.mmds.data.ExperimentStats;
import io.cdap.mmds.data.ExperimentsMeta;
import io.cdap.mmds.data.ModelKey;
import io.cdap.mmds.data.ModelMeta;
import io.cdap.mmds.data.ModelStatus;
import io.cdap.mmds.data.ModelTable;
import io.cdap.mmds.data.ModelTrainerInfo;
import io.cdap.mmds.data.ModelsMeta;
import io.cdap.mmds.data.SortInfo;
import io.cdap.mmds.data.SortType;
import io.cdap.mmds.data.SplitKey;
import io.cdap.mmds.data.SplitStatus;
import io.cdap.mmds.modeler.Modelers;
import io.cdap.mmds.proto.BadRequestException;
import io.cdap.mmds.proto.ConflictException;
import io.cdap.mmds.proto.CreateModelRequest;
import io.cdap.mmds.proto.ExperimentNotFoundException;
import io.cdap.mmds.proto.ModelNotFoundException;
import io.cdap.mmds.proto.SplitNotFoundException;
import io.cdap.mmds.proto.TrainModelRequest;
import io.cdap.mmds.spec.Parameters;
import io.cdap.mmds.stats.CategoricalHisto;
import io.cdap.mmds.stats.NumericHisto;
import io.cdap.mmds.stats.NumericStats;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import org.apache.twill.filesystem.Location;

public class ExperimentStore {
    private static final Set<Schema.Type> CATEGORICAL_TYPES = ImmutableSet.of((Object)Schema.Type.BOOLEAN, (Object)Schema.Type.STRING);
    private static final Set<Schema.Type> NUMERIC_TYPES = ImmutableSet.of((Object)Schema.Type.INT, (Object)Schema.Type.LONG, (Object)Schema.Type.FLOAT, (Object)Schema.Type.DOUBLE);
    private final ExperimentMetaTable experiments;
    private final DataSplitTable splits;
    private final ModelTable models;

    public ExperimentStore(ExperimentMetaTable experiments, DataSplitTable splits, ModelTable models) {
        this.experiments = experiments;
        this.splits = splits;
        this.models = models;
    }

    public ExperimentsMeta listExperiments(int offset, int limit, SortInfo sortInfo) {
        return this.experiments.list(offset, limit);
    }

    public ExperimentsMeta listExperiments(int offset, int limit, Predicate<Experiment> predicate, SortInfo sortInfo) {
        return this.experiments.list(offset, limit, predicate, sortInfo);
    }

    public Experiment getExperiment(String experimentName) {
        Experiment experiment = this.experiments.get(experimentName);
        if (experiment == null) {
            throw new ExperimentNotFoundException(experimentName);
        }
        return experiment;
    }

    public ExperimentStats getExperimentStats(String experimentName) {
        Experiment experiment = this.getExperiment(experimentName);
        HashMap<String, ColumnStats> metricStats = new HashMap<String, ColumnStats>();
        CategoricalHisto algoHisto = new CategoricalHisto();
        CategoricalHisto statusHisto = new CategoricalHisto();
        List<ModelMeta> models = this.listModels(experimentName, 0, Integer.MAX_VALUE, new SortInfo(SortType.ASC)).getModels();
        if (models.isEmpty()) {
            return new ExperimentStats(experiment, metricStats, new ColumnStats(algoHisto), new ColumnStats(statusHisto));
        }
        Iterator<ModelMeta> modelIter = models.iterator();
        ModelMeta modelMeta = modelIter.next();
        algoHisto.update(modelMeta.getAlgorithm());
        statusHisto.update(modelMeta.getStatus() == null ? null : modelMeta.getStatus().toString());
        EvaluationMetrics metrics = modelMeta.getEvaluationMetrics();
        NumericStats rmse = new NumericStats(metrics.getRmse());
        NumericStats r2 = new NumericStats(metrics.getR2());
        NumericStats mae = new NumericStats(metrics.getMae());
        NumericStats evariance = new NumericStats(metrics.getEvariance());
        NumericStats precision = new NumericStats(metrics.getPrecision());
        NumericStats recall = new NumericStats(metrics.getRecall());
        NumericStats f1 = new NumericStats(metrics.getF1());
        while (modelIter.hasNext()) {
            modelMeta = modelIter.next();
            algoHisto.update(modelMeta.getAlgorithm());
            statusHisto.update(modelMeta.getStatus() == null ? null : modelMeta.getStatus().toString());
            metrics = modelMeta.getEvaluationMetrics();
            rmse.update(metrics.getRmse());
            r2.update(metrics.getR2());
            mae.update(metrics.getMae());
            evariance.update(metrics.getEvariance());
            precision.update(metrics.getPrecision());
            recall.update(metrics.getRecall());
            f1.update(metrics.getF1());
        }
        modelIter = models.iterator();
        metrics = modelIter.next().getEvaluationMetrics();
        int numBins = Math.min(10, (int)statusHisto.getTotalCount());
        NumericHisto rmseHisto = null;
        if (rmse.getMin() != null) {
            rmseHisto = new NumericHisto(rmse.getMin(), rmse.getMax(), numBins, metrics.getRmse());
        }
        NumericHisto r2Histo = null;
        if (r2.getMin() != null) {
            r2Histo = new NumericHisto(r2.getMin(), r2.getMax(), numBins, metrics.getR2());
        }
        NumericHisto maeHisto = null;
        if (mae.getMin() != null) {
            maeHisto = new NumericHisto(mae.getMin(), mae.getMax(), numBins, metrics.getMae());
        }
        NumericHisto evarianceHisto = null;
        if (evariance.getMin() != null) {
            evarianceHisto = new NumericHisto(evariance.getMin(), evariance.getMax(), numBins, metrics.getEvariance());
        }
        NumericHisto precisionHisto = null;
        if (precision.getMin() != null) {
            precisionHisto = new NumericHisto(0.0, 1.0, 10, metrics.getPrecision());
        }
        NumericHisto recallHisto = null;
        if (recall.getMin() != null) {
            recallHisto = new NumericHisto(0.0, 1.0, 10, metrics.getRecall());
        }
        NumericHisto f1Histo = null;
        if (f1.getMin() != null) {
            f1Histo = new NumericHisto(0.0, 1.0, 10, metrics.getF1());
        }
        while (modelIter.hasNext()) {
            metrics = modelIter.next().getEvaluationMetrics();
            if (rmseHisto != null) {
                rmseHisto.update(metrics.getRmse());
            }
            if (r2Histo != null) {
                r2Histo.update(metrics.getR2());
            }
            if (maeHisto != null) {
                maeHisto.update(metrics.getMae());
            }
            if (evarianceHisto != null) {
                evarianceHisto.update(metrics.getEvariance());
            }
            if (precisionHisto != null) {
                precisionHisto.update(metrics.getPrecision());
            }
            if (recallHisto != null) {
                recallHisto.update(metrics.getRecall());
            }
            if (f1Histo == null) continue;
            f1Histo.update(metrics.getF1());
        }
        if (rmseHisto != null) {
            metricStats.put("rmse", new ColumnStats(rmseHisto));
        }
        if (r2Histo != null) {
            metricStats.put("r2", new ColumnStats(r2Histo));
        }
        if (maeHisto != null) {
            metricStats.put("mae", new ColumnStats(maeHisto));
        }
        if (evarianceHisto != null) {
            metricStats.put("evariance", new ColumnStats(evarianceHisto));
        }
        if (precisionHisto != null) {
            metricStats.put("precision", new ColumnStats(precisionHisto));
        }
        if (recallHisto != null) {
            metricStats.put("recall", new ColumnStats(recallHisto));
        }
        if (f1Histo != null) {
            metricStats.put("f1", new ColumnStats(f1Histo));
        }
        return new ExperimentStats(experiment, metricStats, new ColumnStats(algoHisto), new ColumnStats(statusHisto));
    }

    public void putExperiment(Experiment experiment) {
        this.experiments.put(experiment);
    }

    public void deleteExperiment(String experimentName) {
        this.getExperiment(experimentName);
        this.models.delete(experimentName);
        this.splits.delete(experimentName);
        this.experiments.delete(experimentName);
    }

    public ModelsMeta listModels(String experimentName, int offset, int limit, SortInfo sortInfo) {
        this.getExperiment(experimentName);
        return this.models.list(experimentName, offset, limit, sortInfo);
    }

    public ModelMeta getModel(ModelKey modelKey) {
        this.getExperiment(modelKey.getExperiment());
        ModelMeta modelMeta = this.models.get(modelKey);
        if (modelMeta == null) {
            throw new ModelNotFoundException(modelKey);
        }
        return modelMeta;
    }

    public ModelTrainerInfo trainModel(ModelKey key, TrainModelRequest trainRequest, long trainingTime) {
        Experiment experiment = this.getExperiment(key.getExperiment());
        ModelMeta meta = this.getModel(key);
        ModelStatus currentStatus = meta.getStatus();
        if (currentStatus != ModelStatus.DATA_READY) {
            throw new ConflictException(String.format("Cannot train a model that is in the '%s' state.", new Object[]{currentStatus}));
        }
        Modeler modeler = Modelers.getModeler(trainRequest.getAlgorithm());
        Parameters parameters = modeler.getParams(trainRequest.getHyperparameters());
        TrainModelRequest requestWithDefaults = new TrainModelRequest(trainRequest.getAlgorithm(), trainRequest.getPredictionsDataset(), parameters.toMap());
        this.models.setTrainingInfo(key, requestWithDefaults, trainingTime);
        SplitKey splitKey = new SplitKey(key.getExperiment(), meta.getSplit());
        DataSplitStats splitInfo = this.getSplit(splitKey);
        meta = this.models.get(key);
        return new ModelTrainerInfo(experiment, splitInfo, key.getModel(), meta);
    }

    public void setModelSplit(ModelKey key, String splitId) {
        Experiment experiment = this.getExperiment(key.getExperiment());
        ModelMeta meta = this.getModel(key);
        ModelStatus currentStatus = meta.getStatus();
        if (currentStatus != ModelStatus.PREPARING && currentStatus != ModelStatus.SPLIT_FAILED && currentStatus != ModelStatus.TRAINING_FAILED && currentStatus != ModelStatus.DATA_READY) {
            throw new ConflictException(String.format("Cannot set a split for a model in the '%s' state. The model must be in the '%s', '%s', '%s', or '%s' state.", new Object[]{currentStatus, ModelStatus.PREPARING, ModelStatus.SPLIT_FAILED, ModelStatus.TRAINING_FAILED, ModelStatus.DATA_READY}));
        }
        DataSplitStats splitInfo = this.getSplit(new SplitKey(key.getExperiment(), splitId));
        String currentSplit = meta.getSplit();
        if (currentSplit != null) {
            this.splits.unregisterModel(new SplitKey(key.getExperiment(), currentSplit), key.getModel());
        }
        this.models.setSplit(key, splitInfo, experiment.getOutcome());
        this.splits.registerModel(new SplitKey(key.getExperiment(), splitId), key.getModel());
    }

    public void unassignModelSplit(ModelKey key) {
        this.getExperiment(key.getExperiment());
        ModelMeta meta = this.getModel(key);
        ModelStatus currentStatus = meta.getStatus();
        if (currentStatus != ModelStatus.SPLIT_FAILED && currentStatus != ModelStatus.DATA_READY && currentStatus != ModelStatus.TRAINING_FAILED) {
            throw new ConflictException(String.format("Cannot unassign the split for a model in the '%s' state. The model must be in the '%s', '%s', or '%s' state.", new Object[]{currentStatus, ModelStatus.SPLIT_FAILED, ModelStatus.TRAINING_FAILED, ModelStatus.DATA_READY}));
        }
        DataSplitStats splitInfo = this.getSplit(new SplitKey(key.getExperiment(), meta.getSplit()));
        this.models.unassignSplit(key);
        this.models.setStatus(key, ModelStatus.PREPARING);
        SplitKey splitKey = new SplitKey(key.getExperiment(), meta.getSplit());
        this.splits.unregisterModel(splitKey, key.getModel());
        if (splitInfo.getModels().size() == 1) {
            this.splits.delete(splitKey);
        }
    }

    public String addModel(String experimentName, CreateModelRequest createRequest) {
        SplitKey splitKey;
        Experiment experiment = this.getExperiment(experimentName);
        String splitId = createRequest.getSplit();
        DataSplitStats splitStats = null;
        if (splitId != null && (splitStats = this.splits.get(splitKey = new SplitKey(experimentName, splitId))) == null) {
            throw new SplitNotFoundException(splitKey);
        }
        String modelId = this.models.add(experiment, createRequest, System.currentTimeMillis());
        if (splitStats != null) {
            this.models.setSplit(new ModelKey(experimentName, modelId), splitStats, experiment.getOutcome());
        }
        return modelId;
    }

    public void setModelDirectives(ModelKey key, List<String> directives) {
        ModelMeta modelMeta = this.getModel(key);
        ModelStatus status = modelMeta.getStatus();
        if (status != ModelStatus.PREPARING) {
            throw new ConflictException(String.format("Directives can only be set or modified if the model is in the %s state.", new Object[]{ModelStatus.PREPARING}));
        }
        this.models.setDirectives(key, directives);
    }

    public void updateModelMetrics(ModelKey key, EvaluationMetrics evaluationMetrics, long trainedTime, Set<String> categoricalFeatures) {
        this.models.update(key, evaluationMetrics, trainedTime, categoricalFeatures);
    }

    public void deleteModel(ModelKey modelKey) {
        ModelMeta modelMeta = this.models.get(modelKey);
        if (modelMeta == null) {
            throw new ModelNotFoundException(modelKey);
        }
        this.models.delete(modelKey);
        if (modelMeta.getSplit() != null) {
            this.splits.unregisterModel(new SplitKey(modelKey.getExperiment(), modelMeta.getSplit()), modelKey.getModel());
        }
    }

    public void deployModel(ModelKey key) {
        ModelMeta modelMeta = this.getModel(key);
        if (modelMeta.getDeploytime() > 0L) {
            return;
        }
        this.models.setStatus(key, ModelStatus.DEPLOYED);
    }

    public void modelFailed(ModelKey key) {
        ModelMeta modelMeta = this.getModel(key);
        ModelStatus currentStatus = modelMeta.getStatus();
        if (currentStatus != ModelStatus.TRAINING) {
            throw new IllegalStateException(String.format("Cannot transition model to '%s' from '%s'", new Object[]{currentStatus, ModelStatus.TRAINING_FAILED}));
        }
        this.models.setStatus(key, ModelStatus.TRAINING_FAILED);
    }

    public List<DataSplitStats> listSplits(String experimentName) {
        this.getExperiment(experimentName);
        return this.splits.list(experimentName);
    }

    public DataSplitInfo addSplit(String experimentName, DataSplit splitInfo, long startTimeMillis) {
        Experiment experiment = this.getExperiment(experimentName);
        Schema.Type experimentOutcomeType = Schema.Type.valueOf((String)experiment.getOutcomeType().toUpperCase());
        Schema splitSchema = splitInfo.getSchema();
        Schema.Field outcomeField = splitSchema.getField(experiment.getOutcome());
        if (outcomeField == null) {
            throw new BadRequestException(String.format("Invalid split schema. The split must contain the experiment outcome '%s'.", experiment.getOutcome()));
        }
        Schema splitOutcomeSchema = outcomeField.getSchema();
        if (splitOutcomeSchema.isNullable()) {
            splitOutcomeSchema = splitOutcomeSchema.getNonNullable();
        }
        Schema.Type splitOutcomeType = splitOutcomeSchema.getType();
        if (CATEGORICAL_TYPES.contains(experimentOutcomeType) && !CATEGORICAL_TYPES.contains(splitOutcomeType)) {
            throw new BadRequestException(String.format("Invalid split schema. Outcome field '%s' is of categorical type '%s' in the experiment , but is of non-categorical type '%s' in the split.", experiment.getOutcome(), experimentOutcomeType, splitOutcomeType));
        }
        if (NUMERIC_TYPES.contains(experimentOutcomeType) && !NUMERIC_TYPES.contains(splitOutcomeType)) {
            throw new BadRequestException(String.format("Invalid split schema. Outcome field '%s' is of numeric type '%s' in the experiment, but is of non-numeric type '%s' in the split.", experiment.getOutcome(), experimentOutcomeType, splitOutcomeType));
        }
        String splitId = this.splits.addSplit(experimentName, splitInfo, startTimeMillis);
        Location splitLocation = this.splits.getLocation(new SplitKey(experimentName, splitId));
        return new DataSplitInfo(splitId, experiment, splitInfo, splitLocation);
    }

    public DataSplitStats getSplit(SplitKey key) {
        this.getExperiment(key.getExperiment());
        DataSplitStats stats = this.splits.get(key);
        if (stats == null) {
            throw new SplitNotFoundException(key);
        }
        return stats;
    }

    public void finishSplit(SplitKey splitKey, String trainingPath, String testPath, List<ColumnSplitStats> stats, long endTime) {
        this.splits.updateStats(splitKey, trainingPath, testPath, stats, endTime);
        DataSplitStats splitStats = this.getSplit(splitKey);
        for (String modelId : splitStats.getModels()) {
            this.models.setStatus(new ModelKey(splitKey.getExperiment(), modelId), ModelStatus.DATA_READY);
        }
    }

    public void splitFailed(SplitKey key, long failTime) {
        this.getExperiment(key.getExperiment());
        DataSplitStats splitStats = this.getSplit(key);
        if (splitStats.getStatus() != SplitStatus.SPLITTING) {
            throw new IllegalStateException("Cannot transition split to failed state unless it is in the splitting state.");
        }
        this.splits.splitFailed(key, failTime);
        for (String model : splitStats.getModels()) {
            this.models.setStatus(new ModelKey(key.getExperiment(), model), ModelStatus.SPLIT_FAILED);
        }
    }

    public void deleteSplit(SplitKey key) {
        DataSplitStats stats = this.getSplit(key);
        if (!stats.getModels().isEmpty()) {
            throw new ConflictException(String.format("Cannot delete split '%s' since it is used by model(s) '%s'.", key.getSplit(), Joiner.on((char)',').join(stats.getModels())));
        }
        this.splits.delete(key);
    }
}

