import React, { createContext, useState, useEffect, useContext } from "react";
import {InferenceSession} from "onnxruntime-web";
import {modelData, onnxMaskToImage, calcImageScale, arrayToImageData, samPointsAreSimilar} from "../js/SAMUtils";
import npyjs from "npyjs";
import {debounceAsync} from "../js/utils";

const ort = require("onnxruntime-web");


export const ONNXContext = createContext(null);

const ONNXProvider = props => {
    const [model, setModel] = useState(null); // ONNX model
    const [tensor, setTensor] = useState(null); // Image embedding tensor
    const [maskImg, setMaskImg] = useState(null); // Output binary mask
    const [lastProcessedPoints, setLastProcessedPoints] = useState(null);
    const [maskImageData, setMaskImageData] = useState(null);
    const [ready, setReady] = useState(false);

    // The ONNX model expects the input to be rescaled to 1024.
    // The modelScale state variable keeps track of the scale values.
    const [modelScale, setModelScale] = useState(null);

    const resetSAMState = () => {
        setMaskImg(null);
        setLastProcessedPoints(null);
        setMaskImageData(null);
        setTensor(null);
    }

    useEffect(() => {
        // Initialize the ONNX model
        const initModel = async () => {
            try {
                const URL = "./model/facebook_decoder_quant.onnx";
                const model = await InferenceSession.create(URL);
                setModel(model);
                console.log("Initialized SAM Decoder")
                setReady(true);
            } catch (e) {
                console.log(e);
            }
        };

        if (!model)
            initModel();

    }, [model]);

    // Decode a Numpy file into a tensor.
    const loadNpyTensor = async (tensorFile, dType) => {
        let npLoader = new npyjs();
        const npArray = await npLoader.load(tensorFile);
        const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
        return tensor;
    };

    const loadEmbeddingsFromFile = (filename) => {
        Promise.resolve(loadNpyTensor(filename, "float32")).then(
            (embedding) => {
                setTensor(embedding);
                console.log("Loaded embeddings from file")
            }
        );
    }

    const parseImage = (url) => {
        try {
            const img = new Image();
            img.src = url;
            img.onload = () => {
                const { height, width, samScale } = calcImageScale(img);
                setModelScale({
                    height: height,  // original image height
                    width: width,  // original image width
                    samScale: samScale, // scaling factor for image which has been resized to longest side 1024
                });
            };
        } catch (error) {
            console.log(error);
        }
    }

    const runONNX = debounceAsync(async (points) => {
        if (
            model === null ||
            tensor === null ||
            modelScale === null ||
            points.length === 0 ||
            !ready
        )
            return;

        if (lastProcessedPoints && lastProcessedPoints.length > 0)
            if (samPointsAreSimilar(lastProcessedPoints, points))
                return;

        try {
            // Prepare the model input in the correct format for SAM.
            // The modelData function is from onnxModelAPI.tsx.
            const feeds = modelData({
                points,
                tensor,
                modelScale
            });
            if (feeds === undefined) return;
            setLastProcessedPoints(points);
            setReady(false);
            // Run the SAM ONNX model with the feeds returned from modelData()
            const results = await model.run(feeds);
            const output = results[model.outputNames[0]];
            // The predicted mask returned from the ONNX model is an array which is
            // rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
            setMaskImageData(arrayToImageData(output.data, output.dims[2], output.dims[3]));
            setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3]));
            setReady(true);
        } catch (e) {
            console.log(e);
        }
    }, 250);

    return <ONNXContext.Provider value={{
        model, setModel, tensor, setTensor,
        loadNpyTensor, loadEmbeddingsFromFile, runONNX, maskImg, setMaskImg,
        parseImage, lastProcessedPoints, setLastProcessedPoints,
        ready, maskImageData, resetSAMState
    }} {...props} />;
}

export const useONNXContext = () => useContext(ONNXContext);

export default ONNXProvider;