package transformers.writers;

import org.apache.hadoop.fs.Path;
import org.apache.spark.ml.util.DefaultParamsWriter;
import org.apache.spark.ml.util.MLWriter;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import transformers.PreProcessingTransformer;

public class PreProcessingWriter extends MLWriter {

  private PreProcessingTransformer instance;

  public PreProcessingWriter(PreProcessingTransformer instance) {

    this.instance = instance;
  }

  public PreProcessingTransformer getInstance() {

    return instance;
  }

  public void setInstance(PreProcessingTransformer instance) {

    this.instance = instance;
  }

  @Override
  public void saveImpl(String path) {

    DefaultParamsWriter
        .saveMetadata(instance, path, sc(), DefaultParamsWriter.getMetadataToSave$default$3(),
                      DefaultParamsWriter.getMetadataToSave$default$4());

    Data data = new Data();
    data.setInputCols(instance.getInputCols());
    data.setFeaturesCol(instance.getFeaturesCol());
    List<Data> listData = new ArrayList<>();
    listData.add(data);
    String dataPath = new Path(path, "data").toString();
    sparkSession().createDataFrame(listData, Data.class).repartition(1).write().parquet(dataPath);
  }

  public static class Data implements Serializable {

    private static final long serialVersionUID = -7753295698381203425L;

    String[] inputCols;

    String featuresCol;

    public String[] getInputCols() {
      return inputCols;
    }

    public void setInputCols(String[] inputCols) {

      this.inputCols = inputCols;
    }

    public String getFeaturesCol() {
      return featuresCol;
    }

    public void setFeaturesCol(String featuresCol) {

      this.featuresCol = featuresCol;
    }
  }
}
