package trainers;

import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.types.StructType;

import models.IndependentGaussianModel;

public class IndependentGaussianTrainer extends Estimator<IndependentGaussianModel> {

  private static final long serialVersionUID = 7979905178183125955L;

  final String inputCol;
  final double decisionThreshold;

  public IndependentGaussianTrainer(final String inputCol, final double decisionThreshold) {

    this.inputCol = inputCol;
    this.decisionThreshold = decisionThreshold;
  }

  //implement fit function for superclass estimator, which returns a transformer (Model in this case)
  @Override
  public IndependentGaussianModel fit(Dataset dataset) {

    //get means and standard deviations of input columns
    StandardScaler
        scaler =
        new StandardScaler().setInputCol(inputCol).setOutputCol("scaledFeatures")
            .setWithMean(true).setWithStd(true);

    //use this data to normalize and de-mean each column, return normalized dataset
    StandardScalerModel scalerModel = scaler.fit(dataset);

    //
    IndependentGaussianModel
        independentGaussianModel =
        new IndependentGaussianModel(inputCol, "anomalous_prob", "anomalous_label",
                                     scalerModel.mean().toArray(), scalerModel.std().toArray(),
                                     decisionThreshold);

    return independentGaussianModel;
  }

  @Override
  public StructType transformSchema(StructType structType) {

    return structType;
  }

  @Override
  public Estimator copy(ParamMap paramMap) {

    return null;
  }

  @Override
  public String uid() {

    return "IndependentGaussianTrainer" + serialVersionUID;
  }
}
