Sentis output doesn't match with the onnx output for segmentation

Hello,

I’m trying to run the SegFormer segmentation model with Sentis. The onnx model runs smoothly in python. But Sentis outputs the same label (2) for all pixels. I tried the model with different inputs. But the segmentation result is always the same after argmax over the logits. Here is the script to run the model:

I want to learn if the way I’m running the model is wrong or Sentis doesn’t support the SegFormer model.

Cheers,
YS

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using UnityEngine.InputSystem;
using UnityEngine.UI;


public class SegmentImage : MonoBehaviour
{
    RenderTexture outputTexture;

    public int k_LayersPerFrame = 2;
    [SerializeField] Texture2D inputImage;
    [SerializeField] ModelAsset modelAsset;
    Model runtimeModel;
    Tensor<float> inputTensor;
    Worker worker;
    bool m_Started = false;
    IEnumerator executionSchedule;

    [SerializeField] RawImage inputDisplay;
    void Start()
    {
        Application.targetFrameRate = 60;
        runtimeModel = ModelLoader.Load(modelAsset);
        // runtimeModel = ModelLoader.Load(Application.streamingAssetsPath + "/model.sentis");

        FunctionalGraph processor = new FunctionalGraph();
        var input = processor.AddInput(runtimeModel, 0);
        Tensor<float> imageMean = new Tensor<float>(new TensorShape(1,3,1,1), new float[] {0.485f, 0.456f, 0.406f});
        Tensor<float> imageStd = new Tensor<float>(new TensorShape(1,3,1,1), new float[] {0.229f, 0.224f, 0.225f});
        var shiftedInput = Functional.Sub(input, Functional.Constant(imageMean));
        var normalizedInput = Functional.Div(shiftedInput, Functional.Constant(imageStd));
        var logits = Functional.Forward(runtimeModel, normalizedInput);


        Model pro_RuntimeModel = processor.Compile(logits);
        worker = new Worker(pro_RuntimeModel, BackendType.CPU);

        Texture2D resizedImage = new Texture2D(512, 512);
        inputDisplay.texture = resizedImage;
        Graphics.ConvertTexture(inputImage, resizedImage);
        resizedImage = SceneSegmentor.ConvertToRGB(resizedImage);

        inputTensor = TextureConverter.ToTensor(resizedImage);
        Debug.Log("input tensor " +  inputTensor.ReadbackAndClone()[0, 0, 0, 0]);

    }


    private void Update()
    {
        if (Keyboard.current.spaceKey.wasPressedThisFrame || m_Started)
            Segment(); 
    }

    private void Segment()
    {
        if (!m_Started)
        {
            // ExecuteLayerByLayer starts the scheduling of the model
            executionSchedule = worker.ScheduleIterable(inputTensor);
            m_Started = true;
        }

        int it = 0;
        while (executionSchedule.MoveNext())
        {
            if (++it % k_LayersPerFrame == 0)
                return;
        }

        var outputTensor = worker.PeekOutput() as Tensor<float>;
        var cpuCopyTensor = outputTensor.ReadbackAndClone();
        // cpuCopyTensor is a CPU copy of the output tensor. You can access it and modify it
        var segmentationMap = new Tensor<float>(new TensorShape(1, 4, cpuCopyTensor.shape[2], cpuCopyTensor.shape[3]));
        //Find the argmax of the output tensor in dimension 1


        for(int k = 0; k < cpuCopyTensor.shape[2]; k++)
        {
            for(int l = 0; l < cpuCopyTensor.shape[3]; l++)
            {
                float max = -100;
                int maxIndex = 0;
                for(int m = 0; m < cpuCopyTensor.shape[1]; m++)
                {
                    if(cpuCopyTensor[0, m, k, l] > max)
                    {
                        max = cpuCopyTensor[0, m, k, l];
                        maxIndex = m;
                    }
                }
                Color labelColor = SegmentUtils.colors[maxIndex];   
                segmentationMap[0, 0, k, l] = labelColor.r;
                segmentationMap[0, 1, k, l] = labelColor.g;
                segmentationMap[0, 2, k, l] = labelColor.b;
                segmentationMap[0, 3, k, l] = 1;
            }
        }
        

        // Debug.Log(cpuCopyTensor[0, 0, 0, 0]);
        RenderTexture segmentTexture = TextureConverter.ToTexture(segmentationMap);
        inputDisplay.texture = segmentTexture;
        
        // Set this flag to false to run the network again
        m_Started = false;
        cpuCopyTensor.Dispose();
    }

    private void OnDestroy()
    {
        worker.Dispose();
        inputTensor.Dispose();
        outputTexture.Release();
    }
}

Hi YS,
Do you have a link to the model you are using? And we will look into it. Thank you

Viviane

Here you can download the model:
https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512/resolve/b9175de73a0a34f7843135853d27629aa6987b2f/onnx/model.onnx

Thank you,
YS

Hi YS,

SceneSegmentor and SegmentUtils are not defined on my side, so I cannot run your code.

Is there a reason you are not using Functional.ArgMax() ? is it for debugging purposes?

Did you try checking the values of the input and outputs tensors for the worker, to see if they match what you expect?

Regards,
Viviane

Hi Viviane,
Here I’m posting the updated code for SegmentImage and the SegmentsUtils, without having dependency on SceneSegmentor.

ImageSegment.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using UnityEngine.InputSystem;
using UnityEngine.UI;


public class SegmentImage : MonoBehaviour
{
    RenderTexture outputTexture;

    public int k_LayersPerFrame = 2;
    [SerializeField] Texture2D inputImage;
    [SerializeField] ModelAsset modelAsset;
    Model runtimeModel;
    Tensor<float> inputTensor;
    Worker worker;
    bool m_Started = false;

    [SerializeField] RawImage inputDisplay;
    void Start()
    {
        Application.targetFrameRate = 60;
        runtimeModel = ModelLoader.Load(modelAsset);
        // runtimeModel = ModelLoader.Load(Application.streamingAssetsPath + "/model.sentis");

        FunctionalGraph processor = new FunctionalGraph();
        var input = processor.AddInput(runtimeModel, 0);
        // var scaledInput = Functional.Mul(input, Functional.Constant(1/255.0f));
        // // // Debug.Log(scaledInput.shape);
        Tensor<float> imageMean = new Tensor<float>(new TensorShape(1,3,1,1), new float[] {0.485f, 0.456f, 0.406f});
        Tensor<float> imageStd = new Tensor<float>(new TensorShape(1,3,1,1), new float[] {0.229f, 0.224f, 0.225f});
        var shiftedInput = Functional.Sub(input, Functional.Constant(imageMean));
        var normalizedInput = Functional.Div(shiftedInput, Functional.Constant(imageStd));
        var logits = Functional.Forward(runtimeModel, normalizedInput);
        // var segmentation = Functional.ArgMax(logits[0], 1);

        Model pro_RuntimeModel = processor.Compile(logits);
        worker = new Worker(pro_RuntimeModel, BackendType.CPU);

        Texture2D resizedImage = new Texture2D(512, 512);
        inputDisplay.texture = resizedImage;
        Graphics.ConvertTexture(inputImage, resizedImage);
        resizedImage = SegmentUtils.ConvertToRGB(resizedImage);

        inputTensor = TextureConverter.ToTensor(resizedImage);
        Debug.Log("input tensor " +  inputTensor.ReadbackAndClone()[0, 0, 0, 0]);

    }

    IEnumerator executionSchedule;
    private void Update()
    {
        if (Keyboard.current.spaceKey.wasPressedThisFrame || m_Started)
            Segment(); 
    }

    private void Segment()
    {
        if (!m_Started)
        {
            // ExecuteLayerByLayer starts the scheduling of the model
            executionSchedule = worker.ScheduleIterable(inputTensor);
            m_Started = true;
        }

        int it = 0;
        while (executionSchedule.MoveNext())
        {
            if (++it % k_LayersPerFrame == 0)
                return;
        }

        var outputTensor = worker.PeekOutput() as Tensor<float>;
        var cpuCopyTensor = outputTensor.ReadbackAndClone();
        // cpuCopyTensor is a CPU copy of the output tensor. You can access it and modify it
        var segmentationMap = new Tensor<float>(new TensorShape(1, 4, cpuCopyTensor.shape[2], cpuCopyTensor.shape[3]));
        //Find the argmax of the output tensor in dimension 1


        for(int k = 0; k < cpuCopyTensor.shape[2]; k++)
        {
            for(int l = 0; l < cpuCopyTensor.shape[3]; l++)
            {
                float max = -100;
                int maxIndex = 0;
                for(int m = 0; m < cpuCopyTensor.shape[1]; m++)
                {
                    if(cpuCopyTensor[0, m, k, l] > max)
                    {
                        max = cpuCopyTensor[0, m, k, l];
                        maxIndex = m;
                    }
                }
                // // if(k == 0 && l > 0 && l < 10)
                // // {
                // //     Debug.Log("Max value: " + max);
                // //     Debug.Log("Max index: " + maxIndex);
                // // }
                // if(maxIndex != 0)
                // {
                //     Debug.Log("Max value: " + max);
                //     Debug.Log("Max index: " + maxIndex);
                // }
                Color labelColor = SegmentUtils.colors[maxIndex];   
                segmentationMap[0, 0, k, l] = labelColor.r;
                segmentationMap[0, 1, k, l] = labelColor.g;
                segmentationMap[0, 2, k, l] = labelColor.b;
                segmentationMap[0, 3, k, l] = 1;
            }
        }
        

        // Debug.Log(cpuCopyTensor[0, 0, 0, 0]);
        RenderTexture segmentTexture = TextureConverter.ToTexture(segmentationMap);
        inputDisplay.texture = segmentTexture;
        
        // Set this flag to false to run the network again
        m_Started = false;
        cpuCopyTensor.Dispose();
    }

    private void OnDestroy()
    {
        worker.Dispose();
        inputTensor.Dispose();
        outputTexture.Release();
    }
}

SegmentUtils.cs

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class SegmentUtils : MonoBehaviour
{

    public static Texture2D ConvertToRGB(Texture2D rgbaTexture)
    {
        // Check if the texture is null
        if (rgbaTexture == null)
        {
            Debug.LogError("Input texture is null.");
            return null;
        }

        // Create a new RGB texture with the same dimensions
        Texture2D rgbTexture = new Texture2D(rgbaTexture.width, rgbaTexture.height, TextureFormat.RGB24, false);

        // Get the pixel data from the original texture
        Color[] rgbaPixels = rgbaTexture.GetPixels();

        // Prepare an array for RGB pixels
        Color[] rgbPixels = new Color[rgbaPixels.Length];

        // Copy RGB values and ignore the alpha channel
        for (int i = 0; i < rgbaPixels.Length; i++)
        {
            rgbPixels[i] = new Color(rgbaPixels[i].r, rgbaPixels[i].g, rgbaPixels[i].b);
        }

        // Apply the new pixel data to the RGB texture
        rgbTexture.SetPixels(rgbPixels);
        rgbTexture.Apply();

        return rgbTexture;
    }

    public static readonly Dictionary<int, string> id2label = new Dictionary<int, string>
    {
        {0, "wall"},
        {1, "building"},
        {2, "sky"},
        {3, "floor"},
        {4, "tree"},
        {5, "ceiling"},
        {6, "road"},
        {7, "bed"},
        {8, "windowpane"},
        {9, "grass"},
        {10, "cabinet"},
        {11, "sidewalk"},
        {12, "person"},
        {13, "earth"},
        {14, "door"},
        {15, "table"},
        {16, "mountain"},
        {17, "plant"},
        {18, "curtain"},
        {19, "chair"},
        {20, "car"},
        {21, "water"},
        {22, "painting"},
        {23, "sofa"},
        {24, "shelf"},
        {25, "house"},
        {26, "sea"},
        {27, "mirror"},
        {28, "rug"},
        {29, "field"},
        {30, "armchair"},
        {31, "seat"},
        {32, "fence"},
        {33, "desk"},
        {34, "rock"},
        {35, "wardrobe"},
        {36, "lamp"},
        {37, "bathtub"},
        {38, "railing"},
        {39, "cushion"},
        {40, "base"},
        {41, "box"},
        {42, "column"},
        {43, "signboard"},
        {44, "chest of drawers"},
        {45, "counter"},
        {46, "sand"},
        {47, "sink"},
        {48, "skyscraper"},
        {49, "fireplace"},
        {50, "refrigerator"},
        {51, "grandstand"},
        {52, "path"},
        {53, "stairs"},
        {54, "runway"},
        {55, "case"},
        {56, "pool table"},
        {57, "pillow"},
        {58, "screen door"},
        {59, "stairway"},
        {60, "river"},
        {61, "bridge"},
        {62, "bookcase"},
        {63, "blind"},
        {64, "coffee table"},
        {65, "toilet"},
        {66, "flower"},
        {67, "book"},
        {68, "hill"},
        {69, "bench"},
        {70, "countertop"},
        {71, "stove"},
        {72, "palm"},
        {73, "kitchen island"},
        {74, "computer"},
        {75, "swivel chair"},
        {76, "boat"},
        {77, "bar"},
        {78, "arcade machine"},
        {79, "hovel"},
        {80, "bus"},
        {81, "towel"},
        {82, "light"},
        {83, "truck"},
        {84, "tower"},
        {85, "chandelier"},
        {86, "awning"},
        {87, "streetlight"},
        {88, "booth"},
        {89, "television receiver"},
        {90, "airplane"},
        {91, "dirt track"},
        {92, "apparel"},
        {93, "pole"},
        {94, "land"},
        {95, "bannister"},
        {96, "escalator"},
        {97, "ottoman"},
        {98, "bottle"},
        {99, "buffet"},
        {100, "poster"},
        {101, "stage"},
        {102, "van"},
        {103, "ship"},
        {104, "fountain"},
        {105, "conveyer belt"},
        {106, "canopy"},
        {107, "washer"},
        {108, "plaything"},
        {109, "swimming pool"},
        {110, "stool"},
        {111, "barrel"},
        {112, "basket"},
        {113, "waterfall"},
        {114, "tent"},
        {115, "bag"},
        {116, "minibike"},
        {117, "cradle"},
        {118, "oven"},
        {119, "ball"},
        {120, "food"},
        {121, "step"},
        {122, "tank"},
        {123, "trade name"},
        {124, "microwave"},
        {125, "pot"},
        {126, "animal"},
        {127, "bicycle"},
        {128, "lake"},
        {129, "dishwasher"},
        {130, "screen"},
        {131, "blanket"},
        {132, "sculpture"},
        {133, "hood"},
        {134, "sconce"},
        {135, "vase"},
        {136, "traffic light"},
        {137, "tray"},
        {138, "ashcan"},
        {139, "fan"},
        {140, "pier"},
        {141, "crt screen"},
        {142, "plate"},
        {143, "monitor"},
        {144, "bulletin board"},
        {145, "shower"},
        {146, "radiator"},
        {147, "glass"},
        {148, "clock"},
        {149, "flag"}
    };
public static readonly Color[] colors = new Color[]
{
    new Color(120 / 255f, 120 / 255f, 120 / 255f),
    new Color(180 / 255f, 120 / 255f, 120 / 255f),
    new Color(6 / 255f, 230 / 255f, 230 / 255f),
    new Color(80 / 255f, 50 / 255f, 50 / 255f),
    new Color(4 / 255f, 200 / 255f, 3 / 255f),
    new Color(120 / 255f, 120 / 255f, 80 / 255f),
    new Color(140 / 255f, 140 / 255f, 140 / 255f),
    new Color(204 / 255f, 5 / 255f, 255 / 255f),
    new Color(230 / 255f, 230 / 255f, 230 / 255f),
    new Color(4 / 255f, 250 / 255f, 7 / 255f),
    new Color(224 / 255f, 5 / 255f, 255 / 255f),
    new Color(235 / 255f, 255 / 255f, 7 / 255f),
    new Color(150 / 255f, 5 / 255f, 61 / 255f),
    new Color(120 / 255f, 120 / 255f, 70 / 255f),
    new Color(8 / 255f, 255 / 255f, 51 / 255f),
    new Color(255 / 255f, 6 / 255f, 82 / 255f),
    new Color(143 / 255f, 255 / 255f, 140 / 255f),
    new Color(204 / 255f, 255 / 255f, 4 / 255f),
    new Color(255 / 255f, 51 / 255f, 7 / 255f),
    new Color(204 / 255f, 70 / 255f, 3 / 255f),
    new Color(0 / 255f, 102 / 255f, 200 / 255f),
    new Color(61 / 255f, 230 / 255f, 250 / 255f),
    new Color(255 / 255f, 6 / 255f, 51 / 255f),
    new Color(11 / 255f, 102 / 255f, 255 / 255f),
    new Color(255 / 255f, 7 / 255f, 71 / 255f),
    new Color(255 / 255f, 9 / 255f, 224 / 255f),
    new Color(9 / 255f, 7 / 255f, 230 / 255f),
    new Color(220 / 255f, 220 / 255f, 220 / 255f),
    new Color(255 / 255f, 9 / 255f, 92 / 255f),
    new Color(112 / 255f, 9 / 255f, 255 / 255f),
    new Color(8 / 255f, 255 / 255f, 214 / 255f),
    new Color(7 / 255f, 255 / 255f, 224 / 255f),
    new Color(255 / 255f, 184 / 255f, 6 / 255f),
    new Color(10 / 255f, 255 / 255f, 71 / 255f),
    new Color(255 / 255f, 41 / 255f, 10 / 255f),
    new Color(7 / 255f, 255 / 255f, 255 / 255f),
    new Color(224 / 255f, 255 / 255f, 8 / 255f),
    new Color(102 / 255f, 8 / 255f, 255 / 255f),
    new Color(255 / 255f, 61 / 255f, 6 / 255f),
    new Color(255 / 255f, 194 / 255f, 7 / 255f),
    new Color(255 / 255f, 122 / 255f, 8 / 255f),
    new Color(0 / 255f, 255 / 255f, 20 / 255f),
    new Color(255 / 255f, 8 / 255f, 41 / 255f),
    new Color(255 / 255f, 5 / 255f, 153 / 255f),
    new Color(6 / 255f, 51 / 255f, 255 / 255f),
    new Color(235 / 255f, 12 / 255f, 255 / 255f),
    new Color(160 / 255f, 150 / 255f, 20 / 255f),
    new Color(0 / 255f, 163 / 255f, 255 / 255f),
    new Color(140 / 255f, 140 / 255f, 140 / 255f),
    new Color(250 / 255f, 10 / 255f, 15 / 255f),
    new Color(20 / 255f, 255 / 255f, 0 / 255f),
    new Color(31 / 255f, 255 / 255f, 0 / 255f),
    new Color(255 / 255f, 31 / 255f, 0 / 255f),
    new Color(255 / 255f, 224 / 255f, 0 / 255f),
    new Color(153 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 0 / 255f, 255 / 255f),
    new Color(255 / 255f, 71 / 255f, 0 / 255f),
    new Color(0 / 255f, 235 / 255f, 255 / 255f),
    new Color(0 / 255f, 173 / 255f, 255 / 255f),
    new Color(31 / 255f, 0 / 255f, 255 / 255f),
    new Color(11 / 255f, 200 / 255f, 200 / 255f),
    new Color(255 / 255f, 82 / 255f, 0 / 255f),
    new Color(0 / 255f, 255 / 255f, 245 / 255f),
    new Color(0 / 255f, 61 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 112 / 255f),
    new Color(0 / 255f, 255 / 255f, 133 / 255f),
    new Color(255 / 255f, 0 / 255f, 0 / 255f),
    new Color(255 / 255f, 163 / 255f, 0 / 255f),
    new Color(255 / 255f, 102 / 255f, 0 / 255f),
    new Color(194 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 143 / 255f, 255 / 255f),
    new Color(51 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 82 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 41 / 255f),
    new Color(0 / 255f, 255 / 255f, 173 / 255f),
    new Color(10 / 255f, 0 / 255f, 255 / 255f),
    new Color(173 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 255 / 255f, 153 / 255f),
    new Color(255 / 255f, 92 / 255f, 0 / 255f),
    new Color(255 / 255f, 0 / 255f, 255 / 255f),
    new Color(255 / 255f, 0 / 255f, 245 / 255f),
    new Color(255 / 255f, 0 / 255f, 102 / 255f),
    new Color(255 / 255f, 173 / 255f, 0 / 255f),
    new Color(255 / 255f, 0 / 255f, 20 / 255f),
    new Color(255 / 255f, 184 / 255f, 184 / 255f),
    new Color(0 / 255f, 31 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 61 / 255f),
    new Color(0 / 255f, 71 / 255f, 255 / 255f),
    new Color(255 / 255f, 0 / 255f, 204 / 255f),
    new Color(0 / 255f, 255 / 255f, 194 / 255f),
    new Color(0 / 255f, 255 / 255f, 82 / 255f),
    new Color(0 / 255f, 10 / 255f, 255 / 255f),
    new Color(0 / 255f, 112 / 255f, 255 / 255f),
    new Color(51 / 255f, 0 / 255f, 255 / 255f),
    new Color(0 / 255f, 194 / 255f, 255 / 255f),
    new Color(0 / 255f, 122 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 163 / 255f),
    new Color(255 / 255f, 153 / 255f, 0 / 255f),
    new Color(0 / 255f, 255 / 255f, 10 / 255f),
    new Color(255 / 255f, 112 / 255f, 0 / 255f),
    new Color(143 / 255f, 255 / 255f, 0 / 255f),
    new Color(82 / 255f, 0 / 255f, 255 / 255f),
    new Color(163 / 255f, 255 / 255f, 0 / 255f),
    new Color(255 / 255f, 235 / 255f, 0 / 255f),
    new Color(8 / 255f, 184 / 255f, 170 / 255f),
    new Color(133 / 255f, 0 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 92 / 255f),
    new Color(184 / 255f, 0 / 255f, 255 / 255f),
    new Color(255 / 255f, 0 / 255f, 31 / 255f),
    new Color(0 / 255f, 184 / 255f, 255 / 255f),
    new Color(0 / 255f, 214 / 255f, 255 / 255f),
    new Color(255 / 255f, 0 / 255f, 112 / 255f),
    new Color(92 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 224 / 255f, 255 / 255f),
    new Color(112 / 255f, 224 / 255f, 255 / 255f),
    new Color(70 / 255f, 184 / 255f, 160 / 255f),
    new Color(163 / 255f, 0 / 255f, 255 / 255f),
    new Color(153 / 255f, 0 / 255f, 255 / 255f),
    new Color(71 / 255f, 255 / 255f, 0 / 255f),
    new Color(255 / 255f, 0 / 255f, 163 / 255f),
    new Color(255 / 255f, 204 / 255f, 0 / 255f),
    new Color(255 / 255f, 0 / 255f, 143 / 255f),
    new Color(0 / 255f, 255 / 255f, 235 / 255f),
    new Color(133 / 255f, 255 / 255f, 0 / 255f),
    new Color(255 / 255f, 0 / 255f, 235 / 255f),
    new Color(245 / 255f, 0 / 255f, 255 / 255f),
    new Color(255 / 255f, 0 / 255f, 122 / 255f),
    new Color(255 / 255f, 245 / 255f, 0 / 255f),
    new Color(10 / 255f, 190 / 255f, 212 / 255f),
    new Color(214 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 204 / 255f, 255 / 255f),
    new Color(20 / 255f, 0 / 255f, 255 / 255f),
    new Color(255 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 153 / 255f, 255 / 255f),
    new Color(0 / 255f, 41 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 204 / 255f),
    new Color(41 / 255f, 0 / 255f, 255 / 255f),
    new Color(41 / 255f, 255 / 255f, 0 / 255f),
    new Color(173 / 255f, 0 / 255f, 255 / 255f),
    new Color(0 / 255f, 245 / 255f, 255 / 255f),
    new Color(71 / 255f, 0 / 255f, 255 / 255f),
    new Color(122 / 255f, 0 / 255f, 255 / 255f),
    new Color(0 / 255f, 255 / 255f, 184 / 255f),
    new Color(0 / 255f, 92 / 255f, 255 / 255f),
    new Color(184 / 255f, 255 / 255f, 0 / 255f),
    new Color(0 / 255f, 133 / 255f, 255 / 255f),
    new Color(255 / 255f, 214 / 255f, 0 / 255f),
    new Color(25 / 255f, 194 / 255f, 194 / 255f),
    new Color(102 / 255f, 255 / 255f, 0 / 255f),
    new Color(92 / 255f, 0 / 255f, 255 / 255f)
};
}

Sorry for the long code, I couldn’t attach the file.

I tried using Functional.Argmax() but it doesn’t work. Using the following code results in a null output at worker.PeekOutput()

        var segmentation = Functional.ArgMax(logits[0], 1);
        Model pro_RuntimeModel = processor.Compile(segmentation);

I checked the input and it seems to be in the correct range. However, the output tensor has mostly higher values than the onnx output.

Kind regards,
YS

Could you check if the first few values in the input image tensor are exactly what you are expecting? It will help narrow down if the problem is in the model inference or the TextureToTensor call.

For the functional API, did you use PeekOutput(outputName), if so be careful because the output name(s) changes after running the Compile method.

Hi YS,

I looked at your script, and got it to work. Here are some things I changed.

  1. k_LayersPerFrame = 2 seems very low as there are 445 layers in the imported model. So that would take 223 frames to run. To simplify, I used Schedule() to run the whole inference in one frame. It can be distributed on 2 or 3 frames if needed.
  2. Used the RenderTexture outputTexture instead of inputDisplay, for clarity, and added outputMaterial to hold the outputTexture.
  3. TextureConverter.ToTexture() and ToTensor() are deprecated, updated.
  4. Run on BackendType.GPUCompute for performance.
  5. Removed SegmentUtils.ConvertToRGB() as it gave nothing with my inputs.

As for the ArgMax, your code looks correct. Certainly it would be much faster by using Functional.ArgMax. I will look into it when I have time.

Viviane

using UnityEngine;
using Unity.Sentis;
using UnityEngine.InputSystem;


public class SegmentImage : MonoBehaviour
{
    RenderTexture outputTexture;

    [SerializeField] Texture2D inputImage;
    [SerializeField] ModelAsset modelAsset;
    Model runtimeModel;
    Tensor<float> inputTensor;
    Worker worker;

    [SerializeField] Material outputMaterial;
    void Start()
    {
        Application.targetFrameRate = 60;
        runtimeModel = ModelLoader.Load(modelAsset);

        FunctionalGraph processor = new FunctionalGraph();
        var input = processor.AddInput(runtimeModel, 0);

        Tensor<float> imageMean = new Tensor<float>(new TensorShape(1,3,1,1), new float[] {0.485f, 0.456f, 0.406f});
        Tensor<float> imageStd = new Tensor<float>(new TensorShape(1,3,1,1), new float[] {0.229f, 0.224f, 0.225f});
        var shiftedInput = Functional.Sub(input, Functional.Constant(imageMean));
        var normalizedInput = Functional.Div(shiftedInput, Functional.Constant(imageStd));
        var logits = Functional.Forward(runtimeModel, normalizedInput);

        Model pro_RuntimeModel = processor.Compile(logits);
        worker = new Worker(pro_RuntimeModel, BackendType.GPUCompute);

        Texture2D resizedImage = new Texture2D(512, 512);
        //inputDisplay.texture = resizedImage;
        Graphics.ConvertTexture(inputImage, resizedImage);
        //resizedImage = SegmentUtils.ConvertToRGB(resizedImage);

        inputTensor = new Tensor<float>(new TensorShape(1, 3, 512, 512));
        TextureConverter.ToTensor(resizedImage, inputTensor, new TextureTransform());
        //Debug.Log("input tensor " +  inputTensor.ReadbackAndClone()[0, 0, 0, 0]);
        outputTexture = new RenderTexture(512, 512, 0, RenderTextureFormat.ARGBFloat);
    }

    private void Update()
    {
        if (Keyboard.current.spaceKey.wasPressedThisFrame)
            Segment();
    }

    private void Segment()
    {
        worker.Schedule(inputTensor);

        var outputTensor = worker.PeekOutput() as Tensor<float>;
        var cpuCopyTensor = outputTensor.ReadbackAndClone();
        // cpuCopyTensor is a CPU copy of the output tensor. You can access it and modify it
        var segmentationMap = new Tensor<float>(new TensorShape(1, 4, cpuCopyTensor.shape[2], cpuCopyTensor.shape[3]));
        //Find the argmax of the output tensor in dimension 1

        for(int k = 0; k < cpuCopyTensor.shape[2]; k++)
        {
            for(int l = 0; l < cpuCopyTensor.shape[3]; l++)
            {
                float max = -100;
                int maxIndex = 0;
                for(int m = 0; m < cpuCopyTensor.shape[1]; m++)
                {
                    if(cpuCopyTensor[0, m, k, l] > max)
                    {
                        max = cpuCopyTensor[0, m, k, l];
                        maxIndex = m;
                    }
                }

                Color labelColor = SegmentUtils.colors[maxIndex];
                segmentationMap[0, 0, k, l] = labelColor.r;
                segmentationMap[0, 1, k, l] = labelColor.g;
                segmentationMap[0, 2, k, l] = labelColor.b;
                segmentationMap[0, 3, k, l] = 1;
            }
        }

        TextureConverter.RenderToTexture(segmentationMap, outputTexture, new TextureTransform());
        outputMaterial.mainTexture = outputTexture;

        // Set this flag to false to run the network again
        cpuCopyTensor.Dispose();
    }

    private void OnDestroy()
    {
        worker.Dispose();
        inputTensor.Dispose();
        outputTexture.Release();
    }
}

Thank you, Viviane! I solved the ArgMax issue.

Cheers,
YS

1 Like