/*
 * Decompiled with CFR 0.152.
 */
package ch.epfl.leb.defcon.ij;

import ch.epfl.leb.defcon.predictors.internal.AbstractPredictor;
import ij.IJ;
import ij.ImagePlus;
import ij.WindowManager;
import ij.gui.Roi;
import ij.measure.ResultsTable;
import ij.plugin.filter.PlugInFilter;
import ij.process.ImageProcessor;
import net.imglib2.type.numeric.RealType;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public class MaxCountFCN
extends AbstractPredictor
implements PlugInFilter {
    private ImagePlus image;
    private ResultsTable rt;
    private Session tfSession;
    private Roi roi;

    public void run(ImageProcessor ip) {
        int stack_size = this.image.getImageStackSize();
        this.roi = WindowManager.getCurrentImage().getRoi();
        Roi reshapedRoi = this.initRoi();
        for (int i = 1; i <= stack_size; ++i) {
            ImageProcessor proc = this.image.getImageStack().getProcessor(i);
            ImagePlus slice = new ImagePlus("DEFCoN", proc);
            this.rt.incrementCounter();
            float prediction = this.predict(slice, reshapedRoi);
            this.rt.addValue("Max local count (7x7)", (double)prediction);
            IJ.showProgress((int)i, (int)stack_size);
        }
        this.rt.show("Maximum local count");
    }

    public int setup(String pathToModel, ImagePlus imp) {
        if (imp.isLocked()) {
            imp.unlock();
        }
        this.image = imp;
        SavedModelBundle smb = SavedModelBundle.load((String)pathToModel, (String[])new String[]{"serve"});
        this.tfSession = smb.session();
        this.rt = new ResultsTable();
        return 5;
    }

    private Roi initRoi() {
        int image_width = this.image.getWidth();
        int image_height = this.image.getHeight();
        if (this.roi == null) {
            this.roi = new Roi(0, 0, image_width, image_height);
        }
        Roi reshapedRoi = new Roi(this.roi.getBounds().x, this.roi.getBounds().y, this.roi.getBounds().width - this.roi.getBounds().width % 4, this.roi.getBounds().height - this.roi.getBounds().height % 4);
        WindowManager.getCurrentImage().setRoi(reshapedRoi);
        return reshapedRoi;
    }

    private <T extends RealType<T>> float predict(ImagePlus imp, Roi reshapedRoi) {
        imp.setRoi(reshapedRoi);
        ImagePlus impRoi = imp.crop();
        Tensor<Float> inputTensor = MaxCountFCN.imageToTensor(impRoi);
        Tensor outputTensor = ((Tensor)this.tfSession.runner().feed("input_tensor", inputTensor).fetch("output_tensor").run().get(0)).expect(Float.class);
        float[][] pred = (float[][])outputTensor.copyTo((Object)new float[1][1]);
        return pred[0][0];
    }
}

