package models;

import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.MLReadable;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;

import transformers.udf.IndependentGaussianDecisionUdf;
import transformers.udf.IndependentGaussianUdf;
import readers.IndependentGaussianReader;
import writers.IndependentGaussianWriter;

import static org.apache.spark.sql.functions.callUDF;

public class IndependentGaussianModel extends Model<IndependentGaussianModel>
    implements MLWritable, MLReadable<IndependentGaussianModel> {

  private static final long serialVersionUID = -2615582624462629788L;

  final String inputColName;
  final String outProbColumnName;
  final String outDecisionColumnName;

  final double[] meanValues;
  final double[] stdValues;

  final double decisionThreshold;

  public IndependentGaussianModel(final String inputColName, final String outProbColumnName,
                                  final String outDecisionColumnName,
                                  final double[] meanValues, final double[] stdValues, final
                                  double decisionThreshold) {

    this.inputColName = inputColName;
    this.outDecisionColumnName = outDecisionColumnName;
    this.outProbColumnName = outProbColumnName;

    this.meanValues = meanValues;
    this.stdValues = stdValues;

    this.decisionThreshold = decisionThreshold;
  }

  public String getInputColName() {
    return inputColName;
  }

  public String getOutProbColumnName() {
    return outProbColumnName;
  }

  public String getOutDecisionColumnName() {

    return this.outDecisionColumnName;
  }

  public double[] getMeanValues() {
    return meanValues;
  }

  public double[] getStdValues() {
    return stdValues;
  }

  public double getDecisionThreshold() {
    return decisionThreshold;
  }

  @Override
  public Dataset<Row> transform(Dataset dataset) {

    dataset.sqlContext().udf()
        .register("independentGaussianProb",
                  new IndependentGaussianUdf(meanValues, stdValues),
                  DataTypes.DoubleType);

    dataset.sqlContext().udf().register("independentGaussianDecision", new
        IndependentGaussianDecisionUdf(decisionThreshold), DataTypes.IntegerType);

    dataset =
        dataset
            .withColumn(outProbColumnName,
                        callUDF("independentGaussianProb", dataset.col(inputColName)));

    dataset =
        dataset.withColumn(outDecisionColumnName,
                           callUDF("independentGaussianDecision", dataset.col(outProbColumnName)));

    return dataset;
  }

  @Override
  public StructType transformSchema(StructType structType) {

    return structType;
  }

  @Override
  public IndependentGaussianModel copy(ParamMap paramMap) {

    return null;
  }

  @Override
  public String uid() {

    return "IndependentGaussianModel" + serialVersionUID;
  }

  @Override
  public MLReader<IndependentGaussianModel> read() {

    return new IndependentGaussianReader();
  }

  @Override
  public IndependentGaussianModel load(String path) {

    return read().load(path);
  }

  @Override
  public IndependentGaussianWriter write() {

    return new IndependentGaussianWriter(this);
  }

  @Override
  public void save(String path) throws IOException {

    write().overwrite().saveImpl(path);
  }
}
