/*
 * Decompiled with CFR 0.152.
 */
package co.cask.mmds.data;

import co.cask.cdap.api.common.Bytes;
import co.cask.cdap.api.data.schema.Schema;
import co.cask.cdap.api.dataset.DatasetProperties;
import co.cask.cdap.api.dataset.lib.IndexedTable;
import co.cask.cdap.api.dataset.table.Put;
import co.cask.cdap.api.dataset.table.Row;
import co.cask.cdap.api.dataset.table.Scan;
import co.cask.cdap.api.dataset.table.Scanner;
import co.cask.mmds.data.CountTable;
import co.cask.mmds.data.DataSplitStats;
import co.cask.mmds.data.EvaluationMetrics;
import co.cask.mmds.data.Experiment;
import co.cask.mmds.data.ModelKey;
import co.cask.mmds.data.ModelMeta;
import co.cask.mmds.data.ModelStatus;
import co.cask.mmds.data.ModelsMeta;
import co.cask.mmds.data.SortInfo;
import co.cask.mmds.data.SortType;
import co.cask.mmds.proto.CreateModelRequest;
import co.cask.mmds.proto.TrainModelRequest;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import javax.annotation.Nullable;

public class ModelTable
extends CountTable<IndexedTable> {
    private static final Gson GSON = new Gson();
    private static Type MAP_TYPE = new TypeToken<Map<String, String>>(){}.getType();
    private static Type LIST_TYPE = new TypeToken<List<String>>(){}.getType();
    private static Type SET_TYPE = new TypeToken<Set<String>>(){}.getType();
    private static final String SEPARATOR = "/";
    private static final String EXPERIMENT_COL = "experiment";
    private static final String ID_COL = "id";
    private static final String NAME_COL = "name";
    private static final String DESC_COL = "description";
    private static final String ALGO_COL = "algorithm";
    private static final String SPLIT_COL = "split";
    private static final String OUTCOME_COL = "outcome";
    private static final String HYPER_PARAMS_COL = "hyperparameters";
    private static final String FEATURES_COL = "features";
    private static final String CATEGORICAL_FEATURES_COL = "catfeatures";
    private static final String CREATE_TIME_COL = "createtime";
    private static final String TRAIN_TIME_COL = "trainedtime";
    private static final String DEPLOY_TIME_COL = "deploytime";
    private static final String STATUS_COL = "status";
    private static final String DIRECTIVES_COL = "directives";
    private static final String PREDICTIONS_COL = "predictions";
    private static final String PRECISION_COL = "precision";
    private static final String RECALL_COL = "recall";
    private static final String F1_COL = "f1";
    private static final String RMSE_COL = "rmse";
    private static final String R2_COL = "r2";
    private static final String EVARIANCE_COL = "evariance";
    private static final String MAE_COL = "mae";
    public static final DatasetProperties DATASET_PROPERTIES = DatasetProperties.builder().add("columnsToIndex", "name").build();

    public ModelTable(IndexedTable table) {
        super(table);
    }

    public ModelsMeta list(String experiment, int offset, int limit, SortInfo sortInfo) {
        SortType sortType = sortInfo.getSortType();
        ArrayList<ModelMeta> models = new ArrayList<ModelMeta>();
        byte[] startKey = Bytes.toBytes((String)(experiment + SEPARATOR));
        try (Scanner scanner = ((IndexedTable)this.table).scan(startKey, Bytes.stopKeyForPrefix((byte[])startKey));){
            Row row;
            while ((row = scanner.next()) != null) {
                models.add(this.fromRow(row));
            }
        }
        Collections.sort(models, sortType.equals((Object)SortType.DESC) ? new Comparator<ModelMeta>(){

            @Override
            public int compare(ModelMeta o1, ModelMeta o2) {
                return o2.getName().compareTo(o1.getName());
            }
        } : new Comparator<ModelMeta>(){

            @Override
            public int compare(ModelMeta o1, ModelMeta o2) {
                return o1.getName().compareTo(o2.getName());
            }
        });
        return models.isEmpty() ? new ModelsMeta(models.size(), models) : new ModelsMeta(models.size(), models.subList(offset, Math.min(offset + limit, models.size())));
    }

    @Nullable
    public ModelMeta get(ModelKey key) {
        Row row = ((IndexedTable)this.table).get(this.getKey(key));
        return row.isEmpty() ? null : this.fromRow(row);
    }

    public void setStatus(ModelKey key, ModelStatus status) {
        Put put = new Put(this.getKey(key)).add(STATUS_COL, status.name());
        if (status == ModelStatus.DEPLOYED) {
            put.add(DEPLOY_TIME_COL, System.currentTimeMillis());
        } else if (status == ModelStatus.TRAINED) {
            put.add(TRAIN_TIME_COL, System.currentTimeMillis());
        }
        ((IndexedTable)this.table).put(put);
    }

    public void delete(ModelKey key) {
        ((IndexedTable)this.table).delete(this.getKey(key));
        this.decrementRowCount(1, key.getExperiment());
    }

    public int delete(String experiment) {
        int deleted = this.delete(experiment, Integer.MAX_VALUE);
        this.decrementRowCount(deleted, experiment);
        return deleted;
    }

    public int delete(String experiment, int limit) {
        byte[] startKey = Bytes.toBytes((String)(experiment + SEPARATOR));
        Scan scan = new Scan(startKey, Bytes.stopKeyForPrefix((byte[])startKey));
        ArrayList<byte[]> keys = new ArrayList<byte[]>();
        int numKeys = 0;
        try (Scanner scanner = ((IndexedTable)this.table).scan(scan);){
            Row row;
            while ((row = scanner.next()) != null) {
                keys.add(row.getRow());
                if (++numKeys < limit) continue;
                break;
            }
        }
        for (byte[] key : keys) {
            ((IndexedTable)this.table).delete(key);
        }
        return numKeys;
    }

    public String add(Experiment experiment, CreateModelRequest createRequest, long createTs) {
        String id = UUID.randomUUID().toString().replaceAll("-", "");
        Put put = new Put(this.getKey(experiment.getName(), id)).add(EXPERIMENT_COL, experiment.getName()).add(ID_COL, id).add(NAME_COL, createRequest.getName()).add(DESC_COL, createRequest.getDescription()).add(OUTCOME_COL, experiment.getOutcome()).add(CREATE_TIME_COL, createTs).add(STATUS_COL, ModelStatus.PREPARING.name()).add(TRAIN_TIME_COL, -1L).add(DEPLOY_TIME_COL, -1L);
        if (!createRequest.getDirectives().isEmpty()) {
            put.add(DIRECTIVES_COL, GSON.toJson(createRequest.getDirectives()));
        }
        ((IndexedTable)this.table).put(put);
        this.incrementRowCount(experiment.getName());
        return id;
    }

    public void setDirectives(ModelKey key, List<String> directives) {
        if (directives.isEmpty()) {
            return;
        }
        Put put = new Put(this.getKey(key)).add(DIRECTIVES_COL, GSON.toJson(directives));
        ((IndexedTable)this.table).put(put);
    }

    public void setSplit(ModelKey key, DataSplitStats split, String outcome) {
        ModelStatus status;
        switch (split.getStatus()) {
            case SPLITTING: {
                status = ModelStatus.SPLITTING;
                break;
            }
            case FAILED: {
                status = ModelStatus.SPLIT_FAILED;
                break;
            }
            case COMPLETE: {
                status = ModelStatus.DATA_READY;
                break;
            }
            default: {
                throw new IllegalStateException("Unknown split status " + (Object)((Object)split.getStatus()));
            }
        }
        Schema splitSchema = split.getSchema();
        ArrayList<String> featureNames = new ArrayList<String>(splitSchema.getFields().size() - 1);
        for (Schema.Field field : splitSchema.getFields()) {
            String fieldName = field.getName();
            if (fieldName.equals(outcome)) continue;
            featureNames.add(fieldName);
        }
        Put put = new Put(this.getKey(key)).add(SPLIT_COL, split.getId()).add(STATUS_COL, status.name()).add(DIRECTIVES_COL, GSON.toJson(split.getDirectives())).add(FEATURES_COL, GSON.toJson(featureNames));
        ((IndexedTable)this.table).put(put);
    }

    public void unassignSplit(ModelKey key) {
        ((IndexedTable)this.table).delete(this.getKey(key), Bytes.toBytes((String)SPLIT_COL));
    }

    public void setTrainingInfo(ModelKey key, TrainModelRequest trainRequest) {
        Put put = new Put(this.getKey(key)).add(ALGO_COL, trainRequest.getAlgorithm()).add(HYPER_PARAMS_COL, GSON.toJson(trainRequest.getHyperparameters())).add(STATUS_COL, ModelStatus.TRAINING.name());
        if (trainRequest.getPredictionsDataset() != null) {
            put.add(PREDICTIONS_COL, trainRequest.getPredictionsDataset());
        }
        ((IndexedTable)this.table).put(put);
    }

    public void update(ModelKey key, EvaluationMetrics evaluationMetrics, long trainedTime, Set<String> categoricalFeatures) {
        Put put = new Put(this.getKey(key));
        if (evaluationMetrics.getPrecision() != null) {
            put.add(PRECISION_COL, evaluationMetrics.getPrecision().doubleValue());
        }
        if (evaluationMetrics.getRecall() != null) {
            put.add(RECALL_COL, evaluationMetrics.getRecall().doubleValue());
        }
        if (evaluationMetrics.getF1() != null) {
            put.add(F1_COL, evaluationMetrics.getF1().doubleValue());
        }
        if (evaluationMetrics.getRmse() != null) {
            put.add(RMSE_COL, evaluationMetrics.getRmse().doubleValue());
        }
        if (evaluationMetrics.getR2() != null) {
            put.add(R2_COL, evaluationMetrics.getR2().doubleValue());
        }
        if (evaluationMetrics.getEvariance() != null) {
            put.add(EVARIANCE_COL, evaluationMetrics.getEvariance().doubleValue());
        }
        if (evaluationMetrics.getMae() != null) {
            put.add(MAE_COL, evaluationMetrics.getMae().doubleValue());
        }
        put.add(STATUS_COL, ModelStatus.TRAINED.name());
        put.add(TRAIN_TIME_COL, trainedTime);
        put.add(CATEGORICAL_FEATURES_COL, GSON.toJson(categoricalFeatures));
        ((IndexedTable)this.table).put(put);
    }

    private ModelMeta fromRow(Row row) {
        String keyStr = Bytes.toString((byte[])row.getRow());
        int idx = keyStr.indexOf(SEPARATOR);
        String modelId = keyStr.substring(idx + 1);
        Map hyperParameters = (Map)GSON.fromJson(row.getString(HYPER_PARAMS_COL), MAP_TYPE);
        hyperParameters = hyperParameters == null ? new HashMap() : hyperParameters;
        List features = (List)GSON.fromJson(row.getString(FEATURES_COL), LIST_TYPE);
        features = features == null ? new ArrayList() : features;
        Set categoricalFeatures = (Set)GSON.fromJson(row.getString(CATEGORICAL_FEATURES_COL), SET_TYPE);
        categoricalFeatures = categoricalFeatures == null ? new HashSet() : categoricalFeatures;
        String description = row.getString(DESC_COL);
        description = description == null ? "" : description;
        String statusStr = row.getString(STATUS_COL);
        ModelStatus status = statusStr == null ? null : ModelStatus.valueOf(statusStr);
        String directivesStr = row.getString(DIRECTIVES_COL);
        ArrayList<String> directives = directivesStr == null ? new ArrayList<String>() : (List)GSON.fromJson(directivesStr, LIST_TYPE);
        EvaluationMetrics evaluationMetrics = new EvaluationMetrics(row.getDouble(PRECISION_COL), row.getDouble(RECALL_COL), row.getDouble(F1_COL), row.getDouble(RMSE_COL), row.getDouble(R2_COL), row.getDouble(EVARIANCE_COL), row.getDouble(MAE_COL));
        return ((ModelMeta.Builder)((ModelMeta.Builder)((ModelMeta.Builder)((ModelMeta.Builder)((ModelMeta.Builder)((ModelMeta.Builder)((ModelMeta.Builder)ModelMeta.builder(modelId).setName(row.getString(NAME_COL))).setDescription(description)).setOutcome(row.getString(OUTCOME_COL)).setAlgorithm(row.getString(ALGO_COL))).setSplit(row.getString(SPLIT_COL))).setHyperParameters(hyperParameters)).setFeatures(features).setStatus(status).setCategoricalFeatures(categoricalFeatures).setCreateTime(row.getLong(CREATE_TIME_COL, -1L)).setTrainedTime(row.getLong(TRAIN_TIME_COL, -1L)).setDeployTime(row.getLong(DEPLOY_TIME_COL, -1L)).setEvaluationMetrics(evaluationMetrics).setDirectives(directives)).setPredictionsDataset(row.getString(PREDICTIONS_COL))).build();
    }

    private byte[] getKey(ModelKey key) {
        return this.getKey(key.getExperiment(), key.getModel());
    }

    private byte[] getKey(String experiment, String model) {
        return Bytes.toBytes((String)(experiment + SEPARATOR + model));
    }
}

