/*
 * Decompiled with CFR 0.152.
 */
package ch.epfl.leb.defcon.predictors.internal;

import ch.epfl.leb.defcon.predictors.ImageBitDepthException;
import ch.epfl.leb.defcon.predictors.NoLocalCountMapException;
import ch.epfl.leb.defcon.predictors.Predictor;
import ch.epfl.leb.defcon.predictors.SessionClosedException;
import ch.epfl.leb.defcon.predictors.UninitializedPredictorException;
import ch.epfl.leb.defcon.predictors.internal.AbstractPredictor;
import ij.ImagePlus;
import ij.gui.Roi;
import ij.plugin.filter.Convolver;
import ij.process.ByteProcessor;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;
import ij.process.ShortProcessor;
import java.awt.Rectangle;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.Tensor;

public class DefaultPredictor
extends AbstractPredictor
implements Predictor {
    private static final Logger LOGGER = Logger.getLogger(DefaultPredictor.class.getName());
    private Double count;
    private FloatProcessor densityMap;
    private FloatProcessor localCountMap;

    private ImageProcessor checkDimensions(ImageProcessor ip) {
        Rectangle currRoi = ip.getRoi();
        ip.setRoi(new Roi(currRoi.x, currRoi.y, currRoi.width - currRoi.width % 4, currRoi.height - currRoi.height % 4));
        return ip.crop();
    }

    @Override
    public double getCount() throws UninitializedPredictorException {
        if (this.count == null) {
            String msg = "The Predictor has not yet performed any calcuations.";
            LOGGER.log(Level.WARNING, msg);
            throw new UninitializedPredictorException(msg);
        }
        return this.count;
    }

    @Override
    public FloatProcessor getDensityMap() throws UninitializedPredictorException {
        if (this.densityMap == null) {
            String msg = "The Predictor has not yet performed any calcuations.";
            LOGGER.log(Level.SEVERE, msg);
            throw new UninitializedPredictorException(msg);
        }
        return this.densityMap;
    }

    @Override
    public FloatProcessor getLocalCountMap() throws NoLocalCountMapException {
        if (this.localCountMap == null) {
            String msg = "The Predictor has not yet performed any local count estimates.";
            LOGGER.log(Level.SEVERE, msg);
            throw new NoLocalCountMapException(msg);
        }
        return this.localCountMap;
    }

    @Override
    public double getMaximumLocalCount(int boxSize) throws UninitializedPredictorException {
        if (this.densityMap == null) {
            String msg = "The Predictor has not yet performed any calcuations.";
            LOGGER.log(Level.WARNING, msg);
            throw new UninitializedPredictorException(msg);
        }
        Convolver convolver = new Convolver();
        convolver.setNormalize(false);
        this.localCountMap = (FloatProcessor)this.densityMap.clone();
        float[] kernel = new float[boxSize * boxSize];
        Arrays.fill(kernel, 1.0f);
        convolver.convolveFloat((ImageProcessor)this.localCountMap, kernel, boxSize, boxSize);
        int halfBoxSize = boxSize / 2;
        this.localCountMap.setRoi(new Roi(halfBoxSize, halfBoxSize, this.localCountMap.getWidth() - boxSize + 1, this.localCountMap.getHeight() - boxSize + 1));
        this.localCountMap = this.localCountMap.crop().convertToFloatProcessor();
        this.localCountMap.resetMinAndMax();
        return this.localCountMap.getMax();
    }

    private void predict(ByteProcessor bp) {
        ShortProcessor sp = bp.convertToShortProcessor();
        this.predict(sp);
    }

    @Override
    public void predict(ImageProcessor ip) throws ImageBitDepthException, SessionClosedException {
        if (this.isClosed) {
            String msg = "Cannot call the predict() method:\n the TensorFlow session has been closed.";
            LOGGER.log(Level.WARNING, msg);
            throw new SessionClosedException(msg);
        }
        int bitDepth = ip.getBitDepth();
        switch (bitDepth) {
            case 16: {
                this.predict(this.checkDimensions(ip).convertToShortProcessor());
                break;
            }
            case 8: {
                this.predict(this.checkDimensions(ip).convertToByteProcessor());
                break;
            }
            default: {
                String msg = "The predictor only works on 8 and 16-bit images.";
                LOGGER.log(Level.SEVERE, msg);
                throw new ImageBitDepthException(msg);
            }
        }
    }

    private void predict(ShortProcessor sp) {
        ImagePlus imp = new ImagePlus("", (ImageProcessor)sp);
        int height = imp.getHeight();
        int width = imp.getWidth();
        Tensor<Float> inputTensor = DefaultPredictor.imageToTensor(imp);
        Tensor outputTensor = ((Tensor)this.tfSession.runner().feed("input_tensor", inputTensor).fetch("output_tensor").run().get(0)).expect(Float.class);
        float[][][][] pred = (float[][][][])outputTensor.copyTo((Object)new float[1][height][width][1]);
        this.densityMap = new FloatProcessor(width, height);
        this.count = 0.0;
        for (int x = 0; x < width; ++x) {
            for (int y = 0; y < height; ++y) {
                this.count = this.count + (double)pred[0][y][x][0];
                this.densityMap.setf(x, y, pred[0][y][x][0]);
            }
        }
    }

    @Override
    protected void finalize() throws Throwable {
        try {
            this.tfSession.close();
        }
        finally {
            super.finalize();
        }
    }
}

