package writers;

import org.apache.hadoop.fs.Path;
import org.apache.spark.ml.util.MLWriter;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import models.IndependentGaussianModel;

public class IndependentGaussianWriter extends MLWriter {

  private final IndependentGaussianModel instance;

  public IndependentGaussianWriter(IndependentGaussianModel instance) {

    this.instance = instance;
  }


  @Override
  public void saveImpl(String path) {

    IndependentGaussianWriter.Data data = new IndependentGaussianWriter.Data();
    data.setInputColName(instance.getInputColName());
    data.setOutProbColumnName(instance.getOutProbColumnName());
    data.setOutDecisionColumnName(instance.getOutDecisionColumnName());
    data.setMeanValues(instance.getMeanValues());
    data.setStdValues(instance.getStdValues());
    data.setDecisionThreshold(instance.getDecisionThreshold());

    List<IndependentGaussianWriter.Data> listData = new ArrayList<>();
    listData.add(data);
    String dataPath = new Path(path, "data").toString();
    sparkSession().createDataFrame(listData, IndependentGaussianWriter.Data.class).repartition(1)
        .write().parquet(dataPath);
    /*
    FileOutputStream fileOut = null;
    ObjectOutputStream out = null;

    try {

      fileOut = FileUtils.openOutputStream(new File(path));
      out = new ObjectOutputStream(fileOut);

      out.writeObject(instance);

    } catch (IOException e) {

      LOGGER.log(Level.SEVERE, e.toString(), e);
    } finally {

      try {

        if (out != null) {

          out.close();
        }
        if (fileOut != null) {

          fileOut.close();
        }
      } catch (IOException e) {

        LOGGER.log(Level.SEVERE, e.toString(), e);
      }
    }*/
  }

  public static class Data implements Serializable {

    private static final long serialVersionUID = 8249297463275856588L;

    private String inputColName;
    private String outProbColumnName;
    private String outDecisionColumnName;

    private double[] meanValues;
    private double[] stdValues;

    private double decisionThreshold;

    public String getInputColName() {
      return inputColName;
    }

    public void setInputColName(String inputColName) {
      this.inputColName = inputColName;
    }

    public String getOutProbColumnName() {
      return outProbColumnName;
    }

    public void setOutProbColumnName(String outProbColumnName) {

      this.outProbColumnName = outProbColumnName;
    }

    public String getOutDecisionColumnName() {
      return outDecisionColumnName;
    }

    public void setOutDecisionColumnName(String outDecisionColumnName) {

      this.outDecisionColumnName = outDecisionColumnName;
    }

    public double[] getMeanValues() {
      return meanValues;
    }

    public void setMeanValues(double[] meanValues) {
      this.meanValues = meanValues;
    }

    public double[] getStdValues() {
      return stdValues;
    }

    public void setStdValues(double[] stdValues) {
      this.stdValues = stdValues;
    }

    public double getDecisionThreshold() {

      return this.decisionThreshold;
    }

    public void setDecisionThreshold(double decisionThreshold) {

      this.decisionThreshold = decisionThreshold;
    }
  }
}
