/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.modeler.feature;

import io.cdap.mmds.modeler.feature.FeatureGenerator;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class FeatureGeneratorTrainer
extends FeatureGenerator {
    public FeatureGeneratorTrainer(List<String> features, Set<String> categoricalFeatures) {
        super(features, categoricalFeatures);
    }

    @Override
    protected PipelineModel getFeatureGenModel(Dataset<Row> cleanData) {
        ArrayList<Object> stages = new ArrayList<Object>();
        ArrayList<String> indexedCategoricalFeatures = new ArrayList<String>();
        for (Object feature : this.features) {
            if (!this.isCategorical((String)feature)) continue;
            String cleanName = this.cleanName((String)feature);
            String indexedName = this.indexedName((String)feature);
            stages.add(new StringIndexer().setInputCol(cleanName).setOutputCol(indexedName).setHandleInvalid("skip"));
            indexedCategoricalFeatures.add(indexedName);
        }
        ArrayList<String> finalFeatureFields = new ArrayList<String>();
        for (String featureName : this.features) {
            finalFeatureFields.add(this.isCategorical(featureName) ? this.indexedName(featureName) : this.cleanName(featureName));
            if (this.isCategorical(featureName)) continue;
            finalFeatureFields.add(this.cleanName(featureName));
        }
        stages.add(new VectorAssembler().setInputCols(finalFeatureFields.toArray(new String[finalFeatureFields.size()])).setOutputCol("_features"));
        Pipeline featureGenPipeline = new Pipeline().setStages(stages.toArray(new PipelineStage[stages.size()]));
        return featureGenPipeline.fit(cleanData);
    }
}

