/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.modeler.train;

import io.cdap.cdap.api.Admin;
import io.cdap.cdap.api.Transactional;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.dataset.DatasetProperties;
import io.cdap.cdap.api.dataset.lib.PartitionDetail;
import io.cdap.cdap.api.dataset.lib.PartitionKey;
import io.cdap.cdap.api.dataset.lib.PartitionOutput;
import io.cdap.cdap.api.dataset.lib.PartitionedFileSet;
import io.cdap.cdap.api.dataset.lib.PartitionedFileSetProperties;
import io.cdap.cdap.api.dataset.lib.Partitioning;
import io.cdap.mmds.Schemas;
import io.cdap.mmds.api.AlgorithmType;
import io.cdap.mmds.data.ModelKey;
import io.cdap.mmds.modeler.train.ModelOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.apache.spark.sql.SaveMode;
import org.apache.twill.filesystem.Location;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ModelOutputWriter {
    private static final Logger LOG = LoggerFactory.getLogger(ModelOutputWriter.class);
    private final Admin admin;
    private final Transactional transactional;
    private final Location baseLocation;
    private final boolean overwrite;

    public ModelOutputWriter(Admin admin, Transactional transactional, Location baseLocation, boolean overwrite) {
        this.admin = admin;
        this.transactional = transactional;
        this.baseLocation = baseLocation;
        this.overwrite = overwrite;
    }

    public void save(ModelKey modelKey, ModelOutput modelOutput, @Nullable String predictionsDataset) throws Exception {
        if (modelOutput.getTargetIndexModel() != null) {
            LOG.info("Saving outcome indices...");
            String path = this.getPath(modelKey, "targetindices");
            modelOutput.getTargetIndexModel().save(path);
            LOG.info("Outcome indices successfully saved.");
        }
        LOG.info("Saving feature generation pipeline...");
        String featureGenPath = this.getPath(modelKey, "featuregen");
        modelOutput.getFeatureGenModel().write().overwrite().save(featureGenPath);
        LOG.info("Feature generation pipeline successfully saved.");
        LOG.info("Saving trained model...");
        String modelPath = this.getPath(modelKey, "model");
        modelOutput.getModel().write().overwrite().save(modelPath);
        LOG.info("Model successfully saved.");
        if (predictionsDataset != null) {
            if (!this.admin.datasetExists(predictionsDataset)) {
                ArrayList<Schema.Field> predictionFields = new ArrayList<Schema.Field>();
                Schema.Type predictionType = modelOutput.getAlgorithmType() == AlgorithmType.REGRESSION ? Schema.Type.DOUBLE : Schema.Type.STRING;
                predictionFields.add(Schema.Field.of((String)"prediction", (Schema)Schema.of((Schema.Type)predictionType)));
                predictionFields.addAll(modelOutput.getSchema().getFields());
                Schema predictionSchema = Schema.recordOf((String)(modelOutput.getSchema().getRecordName() + ".prediction"), predictionFields);
                DatasetProperties datasetProperties = PartitionedFileSetProperties.builder().setPartitioning(Partitioning.builder().addStringField("experiment").addStringField("model").build()).setEnableExploreOnCreate(true).setExploreFormat("text").setExploreFormatProperty("delimiter", ",").setExploreSchema(Schemas.toHiveSchema(predictionSchema)).build();
                this.admin.createDataset(predictionsDataset, PartitionedFileSet.class.getName(), datasetProperties);
            }
            PartitionKey predictionsPartitionKey = PartitionKey.builder().addStringField("model", modelKey.getModel()).addStringField("experiment", modelKey.getExperiment()).build();
            AtomicReference path = new AtomicReference();
            this.transactional.execute(datasetContext -> {
                PartitionedFileSet predictionsFileset = (PartitionedFileSet)datasetContext.getDataset(predictionsDataset);
                PartitionDetail partitionDetail = predictionsFileset.getPartition(predictionsPartitionKey);
                if (partitionDetail == null) {
                    PartitionOutput partitionOutput = predictionsFileset.getPartitionOutput(predictionsPartitionKey);
                    path.set(partitionOutput.getLocation().toURI().getPath());
                    partitionOutput.addPartition();
                } else {
                    path.set(partitionDetail.getLocation().toURI().getPath());
                }
            });
            modelOutput.getPredictions().write().format("csv").mode(SaveMode.Overwrite).save((String)path.get());
            LOG.info("Predictions on training data successfully saved.");
        }
    }

    public void deleteComponents(ModelKey modelKey) throws IOException {
        this.deleteComponent(modelKey, "targetindices");
        this.deleteComponent(modelKey, "featuregen");
        this.deleteComponent(modelKey, "model");
    }

    private String getPath(ModelKey modelKey, String component) throws IOException {
        Location location = this.baseLocation.append(modelKey.getExperiment()).append(modelKey.getModel()).append(component);
        if (location.exists()) {
            if (this.overwrite) {
                location.delete();
            } else {
                throw new IllegalArgumentException(location + " already exists.");
            }
        }
        return location.toURI().getPath();
    }

    private void deleteComponent(ModelKey modelKey, String component) throws IOException {
        Location location = this.baseLocation.append(modelKey.getExperiment()).append(modelKey.getModel()).append(component);
        if (location.exists()) {
            location.delete();
        }
    }
}

