package transformers.readers;

import org.apache.hadoop.fs.Path;
import org.apache.spark.ml.util.DefaultParamsReader;
import org.apache.spark.ml.util.DefaultParamsReader$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.sql.Row;

import java.util.List;

import transformers.PreProcessingTransformer;

public class PreProcessingReader extends MLReader<PreProcessingTransformer> {

  private String className = PreProcessingTransformer.class.getName();

  public String getClassName() {

    return className;
  }

  public void setClassName(String className) {

    this.className = className;
  }

  @Override
  public PreProcessingTransformer load(String path) {

    DefaultParamsReader.Metadata
        metadata =
        DefaultParamsReader$.MODULE$.loadMetadata(path, sc(), className);

    String dataPath = new Path(path, "data").toString();
    Row row = sparkSession().read().parquet(dataPath).select("inputCols, featuresCol").head();
    List<String> listFeatureNames = row.getList(0);
    String[] featureNames = new String[listFeatureNames.size()];
    featureNames = listFeatureNames.toArray(featureNames);
    String featuresCol = row.getString(1);

    PreProcessingTransformer
        transformer =
        new PreProcessingTransformer().setInputCols(featureNames).setFeaturesCol(featuresCol);

    DefaultParamsReader$.MODULE$.getAndSetParams(transformer, metadata);

    return transformer;
  }
}
