/*
 * Decompiled with CFR 0.152.
 */
package ch.epfl.leb.defcon.predictors.internal;

import ch.epfl.leb.defcon.utils.GraphBuilder;
import ij.ImagePlus;
import net.imagej.tensorflow.Tensors;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.display.imagej.ImageJFunctions;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public abstract class AbstractPredictor {
    protected Session tfSession;
    protected boolean isClosed = false;

    public void close() {
        this.tfSession.close();
        this.isClosed = true;
    }

    public void setup(String pathToModel) {
        SavedModelBundle smb = SavedModelBundle.load((String)pathToModel, (String[])new String[]{"serve"});
        this.tfSession = smb.session();
    }

    protected static Tensor<Float> imageToTensor(ImagePlus imp) {
        Img img = ImageJFunctions.convertFloat((ImagePlus)imp);
        Tensor imageTensor = Tensors.tensorFloat((RandomAccessibleInterval)img);
        Graph graph = new Graph();
        Output imageTensorOutput = graph.opBuilder("Const", "tensor_image").setAttr("dtype", imageTensor.dataType()).setAttr("value", imageTensor).build().output(0);
        Output expandedTensorOutput = GraphBuilder.expandDims(graph, "dim-1", GraphBuilder.expandDims(graph, "dim0", imageTensorOutput, GraphBuilder.constant(graph, "make_batch", 0)), GraphBuilder.constant(graph, "make_channel", -1));
        Session s0 = new Session(graph);
        Tensor inputTensor = ((Tensor)s0.runner().fetch(expandedTensorOutput.op().name()).run().get(0)).expect(Float.class);
        return inputTensor;
    }

    protected void finalize() throws Throwable {
        try {
            this.tfSession.close();
        }
        finally {
            super.finalize();
        }
    }
}

