Getting Distorted Mesh,When Passing UI.Image.Sprite's Mesh Data to a Compute Shader

Hi everyone,
I’m relatively new to compute shaders but have experience with graphics shaders, so I’m trying to figure out what I might be missing.

I’m working on a project where I want to manipulate the mesh of a UI.Image component in Unity using a compute shader. The goal is to create a lattice modifier-like system that can tessellate and deform the mesh dynamically based on a lattice. However, I’m currently stuck at the first step.

Here’s what I’m trying to do:

  1. Extract the MeshData from the Image.sprite.
  2. Send this data to a compute buffer for processing via a compute shader.
  3. Retrieve the processed mesh data from the buffer.
  4. Set the modified mesh back to the UI.Image component using the sprite’s mesh.

The issue is that the mesh I get back after processing is distorted and crumbled. There are no errors or warnings, so I’m struggling to debug the problem.

If anyone could provide guidance, share sample code, or point me toward relevant resources, I’d greatly appreciate it!
Attaching my code snippet below.
Thank you in advance for your help.

CODE:
/////////////////////////////////////////////
My Compute Shader :
/////////////////////////////////////////////
#pragma kernel Main

float3 GetNormalsFromVertices(in float3 a,in float3 b,in float3 c)
{
    return normalize(cross(b - a,c - a));
}
struct appdata 
{
    float3 vertex;
    float2 uv;
    float extraData;
};
struct v2f 
{
    float3 vertex;
    float2 uv;
    float extraData;
};

struct DrawTriangle 
{
    float3 normalWS;
    v2f vertices[3];     
};

StructuredBuffer<appdata> _SourceVertexBuffer;
StructuredBuffer<int> _SourceIndexBuffer;
AppendStructuredBuffer<DrawTriangle> _DrawTrianglesBuffer; 

int _numSourceTriangles;
float3 _offsetPos;
float4x4 _localToWorldMatrix;

v2f TransformToWorldSpace(appdata v)
{
    v2f o = (v2f)0;
    o.vertex = mul(_localToWorldMatrix,float4(v.vertex,1)).xyz + _offsetPos;
    //o.vertex += _offsetPos;
    o.uv = v.uv;
    o.extraData = v.extraData;
    return o;
}

[numthreads(128,1,1)]
void Main (uint3 id : SV_DispatchThreadID)
{
    if((int)id.x >= _numSourceTriangles)
    {
        return;
    }

    int indexStart = id.x * 3; 
    //because 3 vertices make up a triangle,so to this we will offset to get each vertex values form buffer 
    v2f inputs[3];
    //Retrieve data from Buffer 
    inputs[0] = TransformToWorldSpace(_SourceVertexBuffer[_SourceIndexBuffer[indexStart]]); 
    inputs[1] = TransformToWorldSpace(_SourceVertexBuffer[_SourceIndexBuffer[indexStart + 1]]);
    inputs[2] = TransformToWorldSpace(_SourceVertexBuffer[_SourceIndexBuffer[indexStart + 2]]);
    
    DrawTriangle i;
    i.normalWS = GetNormalsFromVertices(inputs[0].vertex,inputs[1].vertex,inputs[2].vertex);    
    i.vertices[0] = inputs[0];
    i.vertices[1] = inputs[1];
    i.vertices[2] = inputs[2];

    _DrawTrianglesBuffer.Append(i);
}
/////////////////////////////////////////////
My Script:
/////////////////////////////////////////////
using Unity.VisualScripting;
using UnityEngine;
using UnityEngine.UI;

public class Lattice2DSpriteScript : MonoBehaviour
{    
    [SerializeField] private ComputeShader _shader;
    [SerializeField] private Vector3 _offsetXY;

    private bool Init;
    private ComputeBuffer _SourceVertexBuffer;
    private ComputeBuffer _SourceIndexBuffer;
    private ComputeBuffer _DrawTrianglesBuffer;    

    private int kernelIndex;
    private int dispatchSize;        

    [System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Sequential)]
    private struct appdata
    {
        public Vector3 vertex;//12 bytes
        public Vector2 uv; //8 bytes    
        private float extraData; //extra padding because it was thrwing error 
    }

    private const int SOURCE_VERTEX_STRIDE = (sizeof(float) * 3) + (sizeof(float) * 2) + sizeof(float);  
    private const int SOURCE_TRIANGLE_STRIDE = sizeof(int); 
    private const int DRAW_STRIDE = sizeof(float) * 3 + ((sizeof(float) * 3 + sizeof(float) * 2) * 3); // Normal + 3 vertices * (position + uv)
    //sizeof(float) is 4 bytes

    private static Mesh ComposeMesh(appdata[] generatedVertices, int[] indices)
    {
        Mesh mesh = new Mesh();
        Vector3[] position = new Vector3[generatedVertices.Length];
        Vector2[] uvs = new Vector2[generatedVertices.Length];

        for (int i = 0; i < generatedVertices.Length; i++)
        {
            var v = generatedVertices[i];
            position[i] = v.vertex;
            uvs[i] = v.uv;
        }

        mesh.SetVertices(position);
        mesh.SetUVs(0, uvs);        
        mesh.SetIndices(indices, MeshTopology.Triangles, 0, true);
        mesh.Optimize();
        return mesh;
    }

    private void OnEnable()
    {
        if (Init)
        {
            OnDisable();
        }
        Init = true;

        if (GetComponent<Image>() == null)
        {
            Debug.LogError("No Image Component Attached");
            return;
        }
        
        InitializeBuffers();
    }

    private void InitializeBuffers()
    {
        var sprite = GetComponent<Image>().sprite;
        Vector2[] position = sprite.vertices;
        Vector2[] uv = sprite.uv;
        int[] index = new int[sprite.triangles.Length];
        for (int i = 0; i < sprite.triangles.Length; i++)
        {
            index[i] = sprite.triangles[i]; // This is done to convert ushort to int for Mesh API
        }

        appdata[] vertex = new appdata[position.Length];
        for (int i = 0; i < vertex.Length; i++)
        {
            vertex[i] = new appdata()
            {
                vertex = position[i],
                uv = uv[i],
            };
        }

        int numTriangles = index.Length / 3;

        _SourceVertexBuffer = new ComputeBuffer(position.Length,SOURCE_VERTEX_STRIDE, ComputeBufferType.Structured, ComputeBufferMode.Immutable);
        _SourceVertexBuffer.SetData(vertex);

        _SourceIndexBuffer = new ComputeBuffer(index.Length,SOURCE_TRIANGLE_STRIDE, ComputeBufferType.Structured, ComputeBufferMode.Immutable);
        _SourceIndexBuffer.SetData(index);

        _DrawTrianglesBuffer = new ComputeBuffer(numTriangles * 3, DRAW_STRIDE, ComputeBufferType.Append);
        _DrawTrianglesBuffer.SetCounterValue(0);

        kernelIndex = _shader.FindKernel("Main");
        _shader.SetBuffer(kernelIndex, "_SourceVertexBuffer", _SourceVertexBuffer);
        _shader.SetBuffer(kernelIndex, "_SourceIndexBuffer", _SourceIndexBuffer);
        _shader.SetBuffer(kernelIndex, "_DrawTrianglesBuffer", _DrawTrianglesBuffer);
        _shader.SetInt("_numSourceTriangles", numTriangles);
            
        _shader.GetKernelThreadGroupSizes(kernelIndex, out uint threadGroupSize, out _, out _);
        dispatchSize = Mathf.CeilToInt((float)numTriangles / threadGroupSize);
    }

    private void LateUpdate()
    {
        _DrawTrianglesBuffer.SetCounterValue(0);

        _shader.SetMatrix("_localToWorldMatrix", transform.localToWorldMatrix);
        _shader.SetVector("_offsetPos", _offsetXY);
        GetComponent<Image>().material.SetBuffer("_DrawTrianglesBuffer",_DrawTrianglesBuffer);

        _shader.Dispatch(kernelIndex, dispatchSize, 1, 1);

        //Reading back from the buffer 
        int numTriangles = _DrawTrianglesBuffer.count / 3;
        var generatedVertices = new appdata[numTriangles * 3];
        _DrawTrianglesBuffer.GetData(generatedVertices);

        int[] indices = new int[numTriangles * 3];
        for (int i = 0; i < numTriangles; i++)
        {
            indices[i * 3] = (i * 3);
            indices[i * 3 + 1] = (i * 3 + 1);
            indices[i * 3 + 2] = (i * 3 + 2);
        }
        Mesh mesh = ComposeMesh(generatedVertices, indices);
        GetComponent<Image>().canvasRenderer.SetMesh(mesh);

        Canvas.ForceUpdateCanvases();
    }
    private void OnDisable()
    {
        if (Init)
        {
            _SourceVertexBuffer.Release();
            _SourceIndexBuffer.Release();
            _DrawTrianglesBuffer.Release();   

            _SourceVertexBuffer = null;
            _SourceIndexBuffer = null;
            _DrawTrianglesBuffer = null;
        }
        Init = false;
    }
}
/////////////////////////////////////////////
My Shader:
/////////////////////////////////////////////
Shader "Hidden/Lattice2DShader"
{
    Properties
    {
        _Color("Color",color) = (1,1,1,1)
        _MainTex ("Texture", 2D) = "white" {}


        [Header(Stencil Operations)][Space(5)]
        _StencilComp ("Stencil Comparison", Float) = 8
        _Stencil ("Stencil ID", Float) = 0
        _StencilOp ("Stencil Operation", Float) = 0
        _StencilWriteMask ("Stencil Write Mask", Float) = 255
        _StencilReadMask ("Stencil Read Mask", Float) = 255
        _ColorMask ("Color Mask", Float) = 15
    }
    CGINCLUDE                       
        // Upgrade NOTE: excluded shader from DX11; has structs without semantics (struct v2f members positionWS,uv)
        #pragma exclude_renderers d3d11
        struct appdata 
        {
            float3 vertex : POSITION;
            float2 uv : TEXCOORD0;
        };
        struct v2f {
            float4 vertex : SV_POSITION; // Position in clip space            
            float3 normalWS : TEXCOORD0; // Normal vector in world space
            float2 uv : TEXCOORD1; // UVs
        }; 
        struct DrawTriangle //struct that will be outputed
        {
            float3 normalWS;
            appdata vertices[3]; //since each triangle needs 3 vertex but only 1 normal because all vertices will share it
        };       
        StructuredBuffer<DrawTriangle> _DrawTrianglesBuffer;                       
    ENDCG

        SubShader{
            Tags { 
                "RenderType" = "Transparent"
                "Queue" = "Transparent"
                "IgnoreProjector" = "True"
                "PreviewType" = "Plane" 
            }
        Stencil
        {
            Ref [_Stencil]
            Comp [_StencilComp]
            Pass [_StencilOp]
            ReadMask [_StencilReadMask]
            WriteMask [_StencilWriteMask]
        }

        ColorMask [_ColorMask]
        Cull Off Lighting Off ZWrite Off ZTest Always
        Blend SrcAlpha OneMinusSrcAlpha    
        
        Pass {
            Name "ForwardLit"
            CGPROGRAM   
            #pragma vertex vert
            #pragma fragment frag                          
            #pragma target 5.0    
            #pragma multi_compile __ UNITY_UI_CLIP_RECT                                             
            
            #include "UnityCG.cginc"
            #include "UnityUI.cginc"

            sampler2D _MainTex; 
            float4 _MainTex_ST;

            v2f vert(uint vertexID: SV_VertexID) 
            {            
                v2f o = (v2f)0;                                
                DrawTriangle readTri = _DrawTrianglesBuffer[vertexID / 3];
                appdata readVertices = readTri.vertices[vertexID % 3];
            
                o.vertex = UnityObjectToClipPos(readVertices.vertex);                
                o.normalWS = readTri.normalWS;
                o.uv = TRANSFORM_TEX(readVertices.uv, _MainTex);                                       
                return o;
            }
                                                
            float4 frag(v2f input) : SV_Target {
                float3 albedo = tex2D(_MainTex,input.uv).rgb;                                
                return float4(albedo,1);            
            } 
            ENDCG
        }
    }
}