/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelexport.solr.handler;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.handler.SolrDefaultStreamFactory;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.util.ModelGuesser;
import org.deeplearning4j.util.NetworkUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class ModelTupleStream
extends TupleStream
implements Expressible {
    private static final String SERIALIZED_MODEL_FILE_NAME_PARAM = "serializedModelFileName";
    private static final String INPUT_KEYS_PARAM = "inputKeys";
    private static final String OUTPUT_KEYS_PARAM = "outputKeys";
    private final TupleStream tupleStream;
    private final String serializedModelFileName;
    private final String inputKeysParam;
    private final String outputKeysParam;
    private final String[] inputKeys;
    private final String[] outputKeys;
    private final SolrResourceLoader solrResourceLoader;
    private final Model model;

    public ModelTupleStream(StreamExpression streamExpression, StreamFactory streamFactory) throws IOException {
        List streamExpressions = streamFactory.getExpressionOperandsRepresentingTypes(streamExpression, new Class[]{Expressible.class, TupleStream.class});
        if (streamExpressions.size() != 1) {
            throw new IOException("Expected exactly one stream in expression: " + streamExpression);
        }
        this.tupleStream = streamFactory.constructStream((StreamExpression)streamExpressions.get(0));
        this.serializedModelFileName = ModelTupleStream.getOperandValue(streamExpression, streamFactory, SERIALIZED_MODEL_FILE_NAME_PARAM);
        this.inputKeysParam = ModelTupleStream.getOperandValue(streamExpression, streamFactory, INPUT_KEYS_PARAM);
        this.inputKeys = this.inputKeysParam.split(",");
        this.outputKeysParam = ModelTupleStream.getOperandValue(streamExpression, streamFactory, OUTPUT_KEYS_PARAM);
        this.outputKeys = this.outputKeysParam.split(",");
        if (!(streamFactory instanceof SolrDefaultStreamFactory)) {
            throw new IOException(((Object)((Object)this)).getClass().getName() + " requires a " + SolrDefaultStreamFactory.class.getName() + " StreamFactory");
        }
        this.solrResourceLoader = ((SolrDefaultStreamFactory)streamFactory).getSolrResourceLoader();
        this.model = this.restoreModel(this.openInputStream());
    }

    private static String getOperandValue(StreamExpression streamExpression, StreamFactory streamFactory, String operandName) throws IOException {
        StreamExpressionNamedParameter namedParameter = streamFactory.getNamedOperand(streamExpression, operandName);
        String operandValue = null;
        if (namedParameter != null && namedParameter.getParameter() instanceof StreamExpressionValue) {
            operandValue = ((StreamExpressionValue)namedParameter.getParameter()).getValue();
        }
        if (operandValue == null) {
            throw new IOException("Expected '" + operandName + "' in expression: " + streamExpression);
        }
        return operandValue;
    }

    public Map toMap(Map<String, Object> map) {
        return super.toMap(map);
    }

    public void setStreamContext(StreamContext streamContext) {
        this.tupleStream.setStreamContext(streamContext);
    }

    public List<TupleStream> children() {
        return this.tupleStream.children();
    }

    public void open() throws IOException {
        this.tupleStream.open();
    }

    public void close() throws IOException {
        this.tupleStream.close();
    }

    public Tuple read() throws IOException {
        Tuple tuple = this.tupleStream.read();
        if (tuple.EOF) {
            return tuple;
        }
        INDArray inputs = this.getInputsFromTuple(tuple);
        INDArray outputs = NetworkUtils.output((Model)this.model, (INDArray)inputs);
        return this.applyOutputsToTuple(tuple, outputs);
    }

    public StreamComparator getStreamSort() {
        return this.tupleStream.getStreamSort();
    }

    public Explanation toExplanation(StreamFactory streamFactory) throws IOException {
        return new StreamExplanation(this.getStreamNodeId().toString()).withChildren(new Explanation[]{this.tupleStream.toExplanation(streamFactory)}).withExpressionType("stream-decorator").withFunctionName(streamFactory.getFunctionName(((Object)((Object)this)).getClass())).withImplementingClass(((Object)((Object)this)).getClass().getName()).withExpression(this.toExpression(streamFactory, false).toString());
    }

    public StreamExpressionParameter toExpression(StreamFactory streamFactory) throws IOException {
        return this.toExpression(streamFactory, true);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private StreamExpression toExpression(StreamFactory streamFactory, boolean includeStreams) throws IOException {
        String functionName = streamFactory.getFunctionName(((Object)((Object)this)).getClass());
        StreamExpression streamExpression = new StreamExpression(functionName);
        if (includeStreams) {
            if (!(this.tupleStream instanceof Expressible)) throw new IOException("This " + ((Object)((Object)this)).getClass().getName() + " contains a non-Expressible TupleStream " + this.tupleStream.getClass().getName());
            streamExpression.addParameter(((Expressible)this.tupleStream).toExpression(streamFactory));
        } else {
            streamExpression.addParameter("<stream>");
        }
        streamExpression.addParameter((StreamExpressionParameter)new StreamExpressionNamedParameter(SERIALIZED_MODEL_FILE_NAME_PARAM, this.serializedModelFileName));
        streamExpression.addParameter((StreamExpressionParameter)new StreamExpressionNamedParameter(INPUT_KEYS_PARAM, this.inputKeysParam));
        streamExpression.addParameter((StreamExpressionParameter)new StreamExpressionNamedParameter(OUTPUT_KEYS_PARAM, this.outputKeysParam));
        return streamExpression;
    }

    protected InputStream openInputStream() throws IOException {
        return this.solrResourceLoader.openResource(this.serializedModelFileName);
    }

    protected Model restoreModel(InputStream inputStream) throws IOException {
        File instanceDir = this.solrResourceLoader.getInstancePath().toFile();
        try {
            return ModelGuesser.loadModelGuess((InputStream)inputStream, (File)instanceDir);
        }
        catch (Exception e) {
            throw new IOException("Failed to restore model from given file (" + this.serializedModelFileName + ")", e);
        }
    }

    protected INDArray getInputsFromTuple(Tuple tuple) {
        double[] inputs = new double[this.inputKeys.length];
        for (int ii = 0; ii < this.inputKeys.length; ++ii) {
            inputs[ii] = tuple.getDouble((Object)this.inputKeys[ii]);
        }
        return Nd4j.create((double[][])new double[][]{inputs});
    }

    protected Tuple applyOutputsToTuple(Tuple tuple, INDArray output) {
        for (int ii = 0; ii < this.outputKeys.length; ++ii) {
            tuple.put((Object)this.outputKeys[ii], (Object)Float.valueOf(output.getFloat((long)ii)));
        }
        return tuple;
    }
}

