/*
 * Decompiled with CFR 0.152.
 */
package science.aist.machinelearning.algorithm.nn;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.file.FileSystem;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.ToDoubleFunction;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import science.aist.machinelearning.core.AbstractAlgorithm;
import science.aist.machinelearning.core.Gene;
import science.aist.machinelearning.core.Problem;
import science.aist.machinelearning.core.ProblemGene;
import science.aist.machinelearning.core.Solution;
import science.aist.machinelearning.core.SolutionGene;
import science.aist.machinelearning.core.options.Descriptor;

public class NeuralNetwork<GT, PT>
extends AbstractAlgorithm<GT, PT> {
    private static final long SEED = 65738384L;
    private static final int iterations = 100;
    private static final double learningRate = 0.03;
    private final transient Logger log = LoggerFactory.getLogger(((Object)((Object)this)).getClass());
    private final Map<String, Descriptor> specificOptions;
    private int epochs = 10;
    private boolean trained = false;
    private boolean init = false;
    private MultiLayerNetwork neuralNetwork = null;
    private NeuralNetConfiguration.Builder builder;
    private int[] hiddenLayers;
    private LossFunctions.LossFunction outputLayerLossFunction = LossFunctions.LossFunction.MSE;
    private Activation[] activationForLayers;
    private Consumer<NeuralNetConfiguration.Builder> builderConsumer;
    private Consumer<NeuralNetConfiguration.ListBuilder> listBuilderConsumer;
    private Consumer<MultiLayerNetwork> multiLayerNetworkConsumer;
    private ToDoubleFunction<PT> problemToDoubleTransformer;
    private ToDoubleFunction<GT> solutionToDoubleTransformer;
    private Function<Double, GT> doubleToSolutionTransformer;
    private int updateFrequency;

    public NeuralNetwork(int[] hiddenLayers) {
        this();
        if (hiddenLayers.length == 0) {
            throw new IllegalArgumentException("At least one hidden Layer is needed");
        }
        this.hiddenLayers = hiddenLayers;
        this.builder = new NeuralNetConfiguration.Builder();
        this.activationForLayers = new Activation[hiddenLayers.length + 1];
        for (int i = 0; i < this.activationForLayers.length - 1; ++i) {
            this.activationForLayers[i] = Activation.SIGMOID;
        }
        this.activationForLayers[this.activationForLayers.length - 1] = Activation.IDENTITY;
        this.builder.seed(65738384L).iterations(100).learningRate(0.03);
    }

    private NeuralNetwork() {
        this.specificOptions = new HashMap<String, Descriptor>();
    }

    /*
     * Enabled aggressive exception aggregation
     */
    public static <GT, PT> NeuralNetwork<GT, PT> load(String path, Consumer<MultiLayerNetwork> multiLayerNetworkConsumer) {
        URI uri = URI.create("jar:" + Paths.get(path, new String[0]).toUri());
        try (FileSystem fs = FileSystems.newFileSystem(uri, new HashMap());){
            NeuralNetwork<GT, PT> neuralNetwork;
            try (ObjectInputStream reader = new ObjectInputStream(Files.newInputStream(fs.getPath("neuralNetwork.obj", new String[0]), new OpenOption[0]));){
                NeuralNetwork<GT, PT> nn = new NeuralNetwork<GT, PT>();
                nn.setProblemToDoubleTransformer((ToDoubleFunction)reader.readObject());
                nn.setSolutionToDoubleTransformer((ToDoubleFunction)reader.readObject());
                nn.setDoubleToSolutionTransformer((Function)reader.readObject());
                nn.setEpochs(reader.readInt());
                nn.neuralNetwork = ModelSerializer.restoreMultiLayerNetwork((String)path);
                nn.setTrained();
                multiLayerNetworkConsumer.accept(nn.neuralNetwork);
                neuralNetwork = nn;
            }
            return neuralNetwork;
        }
        catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static <GT, PT> NeuralNetwork<GT, PT> load(InputStream is) {
        return NeuralNetwork.load(is, (MultiLayerNetwork n) -> {});
    }

    public static <GT, PT> NeuralNetwork<GT, PT> load(InputStream is, Consumer<MultiLayerNetwork> multiLayerNetworkConsumer) {
        File uniqueFile = null;
        try {
            uniqueFile = File.createTempFile(UUID.randomUUID().toString(), ".zip");
            FileUtils.copyInputStreamToFile((InputStream)is, (File)uniqueFile);
            NeuralNetwork<GT, PT> neuralNetwork = NeuralNetwork.load(uniqueFile.getPath(), multiLayerNetworkConsumer);
            return neuralNetwork;
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        finally {
            if (uniqueFile != null && !uniqueFile.delete()) {
                LoggerFactory.getLogger(NeuralNetwork.class).warn("Could not delete temp file: " + uniqueFile.getAbsolutePath());
            }
        }
    }

    public static <GT, PT> NeuralNetwork<GT, PT> load(String path) {
        return NeuralNetwork.load(path, (MultiLayerNetwork n) -> {});
    }

    public NeuralNetwork<GT, PT> train(List<Problem<PT>> input, List<Solution<GT, PT>> output) {
        if (input.isEmpty() || output.isEmpty() || input.size() != output.size()) {
            throw new IllegalArgumentException("Input and Output must have the same size and non empty!");
        }
        this.init(input, output);
        org.nd4j.linalg.dataset.DataSet dataSet = this.createDataSet(input, output);
        for (int i = 0; i < this.epochs; ++i) {
            this.neuralNetwork.fit((DataSet)dataSet);
        }
        this.setTrained();
        return this;
    }

    private void init(List<Problem<PT>> input, List<Solution<GT, PT>> output) {
        if (this.isInit()) {
            return;
        }
        int inputSize = input.get(0).getProblemGenes().size();
        int outputSize = output.get(0).getSolutionGenes().size();
        if (this.builderConsumer != null) {
            this.builderConsumer.accept(this.builder);
        }
        NeuralNetConfiguration.ListBuilder lb = this.builder.list();
        for (int i = 0; i < this.hiddenLayers.length; ++i) {
            int nIn = i == 0 ? inputSize : this.hiddenLayers[i - 1];
            lb.layer(i, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(nIn)).nOut(this.hiddenLayers[i])).activation(this.activationForLayers[i])).build());
        }
        lb.layer(this.hiddenLayers.length, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(this.outputLayerLossFunction).activation(this.activationForLayers[this.activationForLayers.length - 1])).nIn(this.hiddenLayers[this.hiddenLayers.length - 1])).nOut(outputSize)).build());
        lb.pretrain(false).backprop(true);
        if (this.listBuilderConsumer != null) {
            this.listBuilderConsumer.accept(lb);
        }
        this.neuralNetwork = new MultiLayerNetwork(lb.build());
        this.neuralNetwork.init();
        if (this.multiLayerNetworkConsumer != null) {
            this.multiLayerNetworkConsumer.accept(this.neuralNetwork);
        }
        this.setInit();
    }

    private org.nd4j.linalg.dataset.DataSet createDataSet(List<Problem<PT>> input, List<Solution<GT, PT>> output) {
        int taineeDataSize = input.size();
        int inputSize = input.get(0).getProblemGenes().size();
        int outputSize = output.get(0).getSolutionGenes().size();
        INDArray inputData = Nd4j.zeros((int)taineeDataSize, (int)inputSize);
        INDArray outputData = Nd4j.zeros((int)taineeDataSize, (int)outputSize);
        for (int i = 0; i < taineeDataSize; ++i) {
            List problemGenes = input.get(i).getProblemGenes();
            int j = 0;
            while (j < inputSize) {
                Object gene = ((ProblemGene)problemGenes.get(j)).getGene();
                double value = this.problemToDoubleTransformer.applyAsDouble(gene);
                inputData.putScalar(new int[]{i, j++}, value);
            }
            List solutionGenes = output.get(i).getSolutionGenes();
            int j2 = 0;
            while (j2 < outputSize) {
                Object gene = ((SolutionGene)solutionGenes.get(j2)).getGene();
                double value = this.solutionToDoubleTransformer.applyAsDouble(gene);
                outputData.putScalar(new int[]{i, j2++}, value);
            }
        }
        return new org.nd4j.linalg.dataset.DataSet(inputData, outputData);
    }

    protected Map<String, Descriptor> getSpecificOptions() {
        return this.specificOptions;
    }

    protected boolean setSpecificOption(String name, Descriptor descriptor) {
        Method[] methods;
        String setterFunctionName = "set" + name.substring(0, 1).toUpperCase() + name.substring(1);
        for (Method method : methods = ((Object)((Object)this)).getClass().getDeclaredMethods()) {
            if (!method.getName().equals(setterFunctionName)) continue;
            try {
                method.invoke((Object)this, descriptor.getValue());
                return true;
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                this.log.error("Failed to call setter method", (Throwable)e);
                return false;
            }
            catch (IllegalArgumentException e) {
                this.log.error("Illegal Type for this option", (Throwable)e);
                return false;
            }
        }
        this.log.error("Specific options does not exist");
        return false;
    }

    public NeuralNetwork<GT, PT> setEpochs(int epochs) {
        this.epochs = epochs;
        this.specificOptions.put("epochs", new Descriptor((Object)epochs));
        return this;
    }

    public NeuralNetwork<GT, PT> setBuilderConsumer(Consumer<NeuralNetConfiguration.Builder> builderConsumer) {
        this.builderConsumer = builderConsumer;
        this.specificOptions.put("builderConsumer", new Descriptor(builderConsumer));
        return this;
    }

    public NeuralNetwork<GT, PT> setListBuilderConsumer(Consumer<NeuralNetConfiguration.ListBuilder> listBuilderConsumer) {
        this.listBuilderConsumer = listBuilderConsumer;
        this.specificOptions.put("listBuilderConsumer", new Descriptor(listBuilderConsumer));
        return this;
    }

    public NeuralNetwork<GT, PT> setMultiLayerNetworkConsumer(Consumer<MultiLayerNetwork> multiLayerNetworkConsumer) {
        this.multiLayerNetworkConsumer = multiLayerNetworkConsumer;
        this.specificOptions.put("multiLayerNetworkConsumer", new Descriptor(multiLayerNetworkConsumer));
        return this;
    }

    public NeuralNetwork<GT, PT> setProblemToDoubleTransformer(ToDoubleFunction<PT> problemTransformer) {
        this.problemToDoubleTransformer = problemTransformer;
        this.specificOptions.put("problemTransformer", new Descriptor(problemTransformer));
        return this;
    }

    public NeuralNetwork<GT, PT> setSolutionToDoubleTransformer(ToDoubleFunction<GT> solutionToDoubleTransformer) {
        this.solutionToDoubleTransformer = solutionToDoubleTransformer;
        this.specificOptions.put("solutionToDoubleTransformer", new Descriptor(solutionToDoubleTransformer));
        return this;
    }

    public NeuralNetwork<GT, PT> setDoubleToSolutionTransformer(Function<Double, GT> doubleToSolutionTransformer) {
        this.doubleToSolutionTransformer = doubleToSolutionTransformer;
        this.specificOptions.put("doubleToSolutionTransformer", new Descriptor(doubleToSolutionTransformer));
        return this;
    }

    public NeuralNetwork<GT, PT> setOutputLayerLossFunction(LossFunctions.LossFunction outputLayerLossFunction) {
        this.outputLayerLossFunction = outputLayerLossFunction;
        this.specificOptions.put("outputLayerLossFunction", new Descriptor((Object)outputLayerLossFunction));
        return this;
    }

    public NeuralNetwork<GT, PT> setIterations(int iterations) {
        this.builder.iterations(iterations);
        this.specificOptions.put("iterations", new Descriptor((Object)iterations));
        return this;
    }

    public NeuralNetwork<GT, PT> setOptimizationAlgo(OptimizationAlgorithm algo) {
        this.builder.setOptimizationAlgo(algo);
        this.specificOptions.put("optimizationAlgo", new Descriptor((Object)algo));
        return this;
    }

    public NeuralNetwork<GT, PT> setLearningRate(double learningRate) {
        this.builder.setLearningRate(learningRate);
        this.specificOptions.put("learningRate", new Descriptor((Object)learningRate));
        return this;
    }

    public NeuralNetwork<GT, PT> setUpdater(Updater updater) {
        this.builder.updater(updater);
        this.specificOptions.put("updater", new Descriptor((Object)updater));
        return this;
    }

    public NeuralNetwork<GT, PT> setWeightInit(WeightInit weightInit) {
        this.builder.setWeightInit(weightInit);
        this.specificOptions.put("weightInit", new Descriptor((Object)weightInit));
        return this;
    }

    public NeuralNetwork<GT, PT> setActivation(Activation[] activations) {
        this.activationForLayers = activations;
        this.specificOptions.put("activation", new Descriptor((Object)activations));
        return this;
    }

    public NeuralNetwork<GT, PT> setActivation(int layerIdx, Activation activation) {
        this.activationForLayers[layerIdx] = activation;
        this.specificOptions.put("activation" + layerIdx, new Descriptor((Object)activation));
        return this;
    }

    public void save(String path) {
        if (!this.isInit()) {
            throw new IllegalStateException("Network is created after initialization, So before saving at least one training has to be executed.");
        }
        try {
            ModelSerializer.writeModel((Model)this.neuralNetwork, (String)path, (boolean)true);
        }
        catch (IOException e) {
            this.log.error(e.getMessage(), (Throwable)e);
            throw new RuntimeException(e);
        }
        URI uri = URI.create("jar:" + Paths.get(path, new String[0]).toUri());
        try (FileSystem fs = FileSystems.newFileSystem(uri, new HashMap());
             ObjectOutputStream writer = new ObjectOutputStream(Files.newOutputStream(fs.getPath("neuralNetwork.obj", new String[0]), StandardOpenOption.CREATE));){
            writer.writeObject(this.problemToDoubleTransformer);
            writer.writeObject(this.solutionToDoubleTransformer);
            writer.writeObject(this.doubleToSolutionTransformer);
            writer.writeInt(this.epochs);
        }
        catch (IOException e) {
            this.log.error(e.getMessage(), (Throwable)e);
            if (!new File(path).delete()) {
                this.log.warn("Could not delete file on error.");
            }
            throw new RuntimeException(e);
        }
    }

    public Solution<GT, PT> solve(Problem<PT> problem) {
        if (!this.isTrained()) {
            throw new IllegalStateException("Network has first to be trained via train() function");
        }
        double[] problemData = problem.getProblemGenes().stream().map(Gene::getGene).mapToDouble(this.problemToDoubleTransformer).toArray();
        INDArray inputArr = Nd4j.create((double[])problemData);
        INDArray out = this.neuralNetwork.output(inputArr, false);
        Solution solution = new Solution();
        ArrayList<SolutionGene> solutionGenes = new ArrayList<SolutionGene>();
        for (int i = 0; i < out.length(); ++i) {
            double val = out.getDouble(i);
            SolutionGene solutionGene = new SolutionGene();
            GT gene = this.doubleToSolutionTransformer.apply(val);
            solutionGene.setGene(gene);
            solutionGenes.add(solutionGene);
        }
        solution.setSolutionGenes(solutionGenes);
        return solution;
    }

    public boolean isTrained() {
        return this.trained;
    }

    private void setTrained() {
        this.setInit();
        this.trained = true;
    }

    private boolean isInit() {
        return this.init;
    }

    private void setInit() {
        this.init = true;
    }

    public Solution<GT, PT> solve(Problem<PT> problem, Solution<GT, PT> bestSolution) {
        return this.solve(problem);
    }
}

