package transformers;

import com.google.common.base.Preconditions;

import org.apache.spark.ml.Model;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCols;
import org.apache.spark.ml.util.MLReadable;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;

import transformers.readers.PreProcessingReader;
import transformers.udf.FeaturesPreProcessingTransformerUdf;
import transformers.writers.PreProcessingWriter;

public class PreProcessingTransformer extends Model<PreProcessingTransformer>
    implements MLWritable, MLReadable<PreProcessingTransformer>, HasInputCols {

  private static final long serialVersionUID = 596263430109672895L;

  private StringArrayParam inputCols;

  private VectorAssembler vectorTransformer;

  @Override
  public Dataset transform(Dataset txs) {

    txs.sqlContext().udf().register("absShiftLog10", new FeaturesPreProcessingTransformerUdf(100.0),
                                    DataTypes.DoubleType);

    for (String columnName : this.get(inputCols).get()) {

      txs =
          txs.withColumn(columnName + "_preprocessed", org.apache.spark.sql.functions
              .callUDF("absShiftLog10", txs.col(columnName)));
    }

    txs = vectorTransformer.transform(txs);

    for (String columnName : this.get(inputCols).get()) {

      txs = txs.drop(columnName + "_preprocessed");
    }

    return txs;
  }

  @Override
  public StructType transformSchema(StructType structType) {

    return structType;
  }

  @Override
  public PreProcessingTransformer copy(ParamMap paramMap) {

    return defaultCopy(paramMap);
  }

  @Override
  public String uid() {

    return "TxPreProcessingTransformer_" + Long.toString(serialVersionUID);
  }

  @Override
  public PreProcessingWriter write() {

    return new PreProcessingWriter(this);
  }

  @Override
  public void save(String path) throws IOException {

    write().saveImpl(path);
  }

  @Override
  public MLReader<PreProcessingTransformer> read() {

    return new PreProcessingReader();
  }

  @Override
  public PreProcessingTransformer load(String path) {

    return read().load(path);
  }

  @Override
  public void org$apache$spark$ml$param$shared$HasInputCols$_setter_$inputCols_$eq(
      StringArrayParam stringArrayParam) {

    this.inputCols = stringArrayParam;
  }

  @Override
  public StringArrayParam inputCols() {

    return new StringArrayParam(this, "inputCols", "Name of columns to be pre-processed");
  }

  @Override
  public String[] getInputCols() {

    return this.get(inputCols).get();
  }

  public PreProcessingTransformer setInputCols(String[] value) {

    inputCols = inputCols();
    return (PreProcessingTransformer) set(inputCols, value);
  }

  public String getFeaturesCol() {

    return vectorTransformer.getOutputCol();
  }

  public PreProcessingTransformer setFeaturesCol(String value) {

    Preconditions.checkState(getInputCols() != null && getInputCols().length > 0);

    String[] columnNames = getInputCols();
    String[] preProcessedColumnNames = new String[columnNames.length];
    for (int index = 0; index < columnNames.length; index++) {

      preProcessedColumnNames[index] = columnNames[index] + "_preprocessed";
    }

    vectorTransformer =
        new VectorAssembler().setInputCols(preProcessedColumnNames).setOutputCol(value);

    return (PreProcessingTransformer) this;
  }
}
