package ai.fritz.vision.imagesegmentation;

import ai.fritz.vision.FritzVisionImage;
import ai.fritz.vision.base.FritzVisionPredictor;
import android.graphics.Bitmap;
import android.util.Log;
import android.util.Size;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;
import org.tensorflow.lite.Tensor;

/* loaded from: classes.dex */
public class FritzVisionSegmentationPredictor extends FritzVisionPredictor {
    private static final String TAG = FritzVisionSegmentationPredictor.class.getSimpleName();
    private ByteBuffer inputByteBuffer;
    private Size inputSize;
    private int[] intValues;
    private FritzVisionSegmentationPredictorOptions options;
    private ByteBuffer outputByteBuffer;
    private Size outputSize;
    private MaskClass[] segmentClassifications;

    public FritzVisionSegmentationPredictor(SegmentationOnDeviceModel segmentationOnDeviceModel, FritzVisionSegmentationPredictorOptions fritzVisionSegmentationPredictorOptions) {
        super(segmentationOnDeviceModel, fritzVisionSegmentationPredictorOptions);
        initializeValues(segmentationOnDeviceModel, fritzVisionSegmentationPredictorOptions);
    }

    private void initializeValues(SegmentationOnDeviceModel segmentationOnDeviceModel, FritzVisionSegmentationPredictorOptions fritzVisionSegmentationPredictorOptions) {
        this.segmentClassifications = setTargetClassifications(segmentationOnDeviceModel.getClassifications(), fritzVisionSegmentationPredictorOptions.targetClasses);
        Tensor inputTensor = this.interpreter.getInputTensor(0);
        this.inputSize = getSizeFromTensor(inputTensor);
        this.inputByteBuffer = ByteBuffer.allocateDirect(inputTensor.numElements() * inputTensor.dataType().byteSize());
        this.inputByteBuffer.order(ByteOrder.nativeOrder());
        Tensor outputTensor = this.interpreter.getOutputTensor(0);
        Size sizeFromTensor = getSizeFromTensor(outputTensor);
        this.inputSize = sizeFromTensor;
        this.outputSize = sizeFromTensor;
        this.outputByteBuffer = ByteBuffer.allocateDirect(outputTensor.numElements() * 4);
        this.outputByteBuffer.order(ByteOrder.nativeOrder());
        this.intValues = new int[this.inputSize.getHeight() * this.inputSize.getWidth()];
        this.options = fritzVisionSegmentationPredictorOptions;
    }

    private FritzVisionSegmentationResult postprocess(Bitmap bitmap) {
        int[][] iArr = (int[][]) Array.newInstance((Class<?>) int.class, this.outputSize.getHeight(), this.outputSize.getWidth());
        float[][] fArr = (float[][]) Array.newInstance((Class<?>) float.class, this.outputSize.getHeight(), this.outputSize.getWidth());
        int height = this.outputSize.getHeight();
        int width = this.outputSize.getWidth();
        this.outputByteBuffer.rewind();
        for (int i = 0; i < height; i++) {
            int length = i * width * this.segmentClassifications.length;
            for (int i2 = 0; i2 < width; i2++) {
                int i3 = 0;
                float f = 0.0f;
                int length2 = length + (this.segmentClassifications.length * i2);
                for (int i4 = 0; i4 < this.segmentClassifications.length; i4++) {
                    float f2 = this.outputByteBuffer.getFloat((length2 + i4) * 4);
                    if (f2 > f) {
                        i3 = i4;
                        f = f2;
                    }
                }
                iArr[i][i2] = i3;
                fArr[i][i2] = f;
            }
        }
        return new FritzVisionSegmentationResult(this.options, this.segmentClassifications, new Size(bitmap.getWidth(), bitmap.getHeight()), this.outputSize, 0, 0, iArr, fArr);
    }

    private void preprocess(Bitmap bitmap) {
        bitmap.getPixels(this.intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        this.inputByteBuffer.rewind();
        for (int i = 0; i < this.inputSize.getHeight(); i++) {
            for (int i2 = 0; i2 < this.inputSize.getWidth(); i2++) {
                int i3 = this.intValues[(this.inputSize.getWidth() * i) + i2];
                this.inputByteBuffer.putFloat((((i3 >> 16) & 255) / 255.0f) - 0.5f);
                this.inputByteBuffer.putFloat((((i3 >> 8) & 255) / 255.0f) - 0.5f);
                this.inputByteBuffer.putFloat(((i3 & 255) / 255.0f) - 0.5f);
            }
        }
    }

    private MaskClass[] setTargetClassifications(MaskClass[] maskClassArr, List<MaskClass> list) {
        if (list == null) {
            return maskClassArr;
        }
        for (int i = 0; i < maskClassArr.length; i++) {
            if (!list.contains(maskClassArr[i])) {
                maskClassArr[i] = MaskClass.NONE;
            }
        }
        return maskClassArr;
    }

    @Override // ai.fritz.vision.base.FritzVisionPredictor
    public FritzVisionSegmentationResult predict(FritzVisionImage fritzVisionImage) {
        long currentTimeMillis = System.currentTimeMillis();
        Bitmap prepare = fritzVisionImage.prepare(this.inputSize);
        preprocess(prepare);
        this.interpreter.run(this.inputByteBuffer, this.outputByteBuffer);
        FritzVisionSegmentationResult postprocess = postprocess(prepare);
        Log.d(TAG, "Predict Time:" + (System.currentTimeMillis() - currentTimeMillis));
        return postprocess;
    }
}
