Hi everyone,
I’m relatively new to compute shaders but have experience with graphics shaders, so I’m trying to figure out what I might be missing.
I’m working on a project where I want to manipulate the mesh of a UI.Image
component in Unity using a compute shader. The goal is to create a lattice modifier-like system that can tessellate and deform the mesh dynamically based on a lattice. However, I’m currently stuck at the first step.
Here’s what I’m trying to do:
- Extract the
MeshData
from theImage.sprite
. - Send this data to a compute buffer for processing via a compute shader.
- Retrieve the processed mesh data from the buffer.
- Set the modified mesh back to the
UI.Image
component using the sprite’s mesh.
The issue is that the mesh I get back after processing is distorted and crumbled. There are no errors or warnings, so I’m struggling to debug the problem.
If anyone could provide guidance, share sample code, or point me toward relevant resources, I’d greatly appreciate it!
Attaching my code snippet below.
Thank you in advance for your help.
CODE:
/////////////////////////////////////////////
My Compute Shader :
/////////////////////////////////////////////
#pragma kernel Main
float3 GetNormalsFromVertices(in float3 a,in float3 b,in float3 c)
{
return normalize(cross(b - a,c - a));
}
struct appdata
{
float3 vertex;
float2 uv;
float extraData;
};
struct v2f
{
float3 vertex;
float2 uv;
float extraData;
};
struct DrawTriangle
{
float3 normalWS;
v2f vertices[3];
};
StructuredBuffer<appdata> _SourceVertexBuffer;
StructuredBuffer<int> _SourceIndexBuffer;
AppendStructuredBuffer<DrawTriangle> _DrawTrianglesBuffer;
int _numSourceTriangles;
float3 _offsetPos;
float4x4 _localToWorldMatrix;
v2f TransformToWorldSpace(appdata v)
{
v2f o = (v2f)0;
o.vertex = mul(_localToWorldMatrix,float4(v.vertex,1)).xyz + _offsetPos;
//o.vertex += _offsetPos;
o.uv = v.uv;
o.extraData = v.extraData;
return o;
}
[numthreads(128,1,1)]
void Main (uint3 id : SV_DispatchThreadID)
{
if((int)id.x >= _numSourceTriangles)
{
return;
}
int indexStart = id.x * 3;
//because 3 vertices make up a triangle,so to this we will offset to get each vertex values form buffer
v2f inputs[3];
//Retrieve data from Buffer
inputs[0] = TransformToWorldSpace(_SourceVertexBuffer[_SourceIndexBuffer[indexStart]]);
inputs[1] = TransformToWorldSpace(_SourceVertexBuffer[_SourceIndexBuffer[indexStart + 1]]);
inputs[2] = TransformToWorldSpace(_SourceVertexBuffer[_SourceIndexBuffer[indexStart + 2]]);
DrawTriangle i;
i.normalWS = GetNormalsFromVertices(inputs[0].vertex,inputs[1].vertex,inputs[2].vertex);
i.vertices[0] = inputs[0];
i.vertices[1] = inputs[1];
i.vertices[2] = inputs[2];
_DrawTrianglesBuffer.Append(i);
}
/////////////////////////////////////////////
My Script:
/////////////////////////////////////////////
using Unity.VisualScripting;
using UnityEngine;
using UnityEngine.UI;
public class Lattice2DSpriteScript : MonoBehaviour
{
[SerializeField] private ComputeShader _shader;
[SerializeField] private Vector3 _offsetXY;
private bool Init;
private ComputeBuffer _SourceVertexBuffer;
private ComputeBuffer _SourceIndexBuffer;
private ComputeBuffer _DrawTrianglesBuffer;
private int kernelIndex;
private int dispatchSize;
[System.Runtime.InteropServices.StructLayout(System.Runtime.InteropServices.LayoutKind.Sequential)]
private struct appdata
{
public Vector3 vertex;//12 bytes
public Vector2 uv; //8 bytes
private float extraData; //extra padding because it was thrwing error
}
private const int SOURCE_VERTEX_STRIDE = (sizeof(float) * 3) + (sizeof(float) * 2) + sizeof(float);
private const int SOURCE_TRIANGLE_STRIDE = sizeof(int);
private const int DRAW_STRIDE = sizeof(float) * 3 + ((sizeof(float) * 3 + sizeof(float) * 2) * 3); // Normal + 3 vertices * (position + uv)
//sizeof(float) is 4 bytes
private static Mesh ComposeMesh(appdata[] generatedVertices, int[] indices)
{
Mesh mesh = new Mesh();
Vector3[] position = new Vector3[generatedVertices.Length];
Vector2[] uvs = new Vector2[generatedVertices.Length];
for (int i = 0; i < generatedVertices.Length; i++)
{
var v = generatedVertices[i];
position[i] = v.vertex;
uvs[i] = v.uv;
}
mesh.SetVertices(position);
mesh.SetUVs(0, uvs);
mesh.SetIndices(indices, MeshTopology.Triangles, 0, true);
mesh.Optimize();
return mesh;
}
private void OnEnable()
{
if (Init)
{
OnDisable();
}
Init = true;
if (GetComponent<Image>() == null)
{
Debug.LogError("No Image Component Attached");
return;
}
InitializeBuffers();
}
private void InitializeBuffers()
{
var sprite = GetComponent<Image>().sprite;
Vector2[] position = sprite.vertices;
Vector2[] uv = sprite.uv;
int[] index = new int[sprite.triangles.Length];
for (int i = 0; i < sprite.triangles.Length; i++)
{
index[i] = sprite.triangles[i]; // This is done to convert ushort to int for Mesh API
}
appdata[] vertex = new appdata[position.Length];
for (int i = 0; i < vertex.Length; i++)
{
vertex[i] = new appdata()
{
vertex = position[i],
uv = uv[i],
};
}
int numTriangles = index.Length / 3;
_SourceVertexBuffer = new ComputeBuffer(position.Length,SOURCE_VERTEX_STRIDE, ComputeBufferType.Structured, ComputeBufferMode.Immutable);
_SourceVertexBuffer.SetData(vertex);
_SourceIndexBuffer = new ComputeBuffer(index.Length,SOURCE_TRIANGLE_STRIDE, ComputeBufferType.Structured, ComputeBufferMode.Immutable);
_SourceIndexBuffer.SetData(index);
_DrawTrianglesBuffer = new ComputeBuffer(numTriangles * 3, DRAW_STRIDE, ComputeBufferType.Append);
_DrawTrianglesBuffer.SetCounterValue(0);
kernelIndex = _shader.FindKernel("Main");
_shader.SetBuffer(kernelIndex, "_SourceVertexBuffer", _SourceVertexBuffer);
_shader.SetBuffer(kernelIndex, "_SourceIndexBuffer", _SourceIndexBuffer);
_shader.SetBuffer(kernelIndex, "_DrawTrianglesBuffer", _DrawTrianglesBuffer);
_shader.SetInt("_numSourceTriangles", numTriangles);
_shader.GetKernelThreadGroupSizes(kernelIndex, out uint threadGroupSize, out _, out _);
dispatchSize = Mathf.CeilToInt((float)numTriangles / threadGroupSize);
}
private void LateUpdate()
{
_DrawTrianglesBuffer.SetCounterValue(0);
_shader.SetMatrix("_localToWorldMatrix", transform.localToWorldMatrix);
_shader.SetVector("_offsetPos", _offsetXY);
GetComponent<Image>().material.SetBuffer("_DrawTrianglesBuffer",_DrawTrianglesBuffer);
_shader.Dispatch(kernelIndex, dispatchSize, 1, 1);
//Reading back from the buffer
int numTriangles = _DrawTrianglesBuffer.count / 3;
var generatedVertices = new appdata[numTriangles * 3];
_DrawTrianglesBuffer.GetData(generatedVertices);
int[] indices = new int[numTriangles * 3];
for (int i = 0; i < numTriangles; i++)
{
indices[i * 3] = (i * 3);
indices[i * 3 + 1] = (i * 3 + 1);
indices[i * 3 + 2] = (i * 3 + 2);
}
Mesh mesh = ComposeMesh(generatedVertices, indices);
GetComponent<Image>().canvasRenderer.SetMesh(mesh);
Canvas.ForceUpdateCanvases();
}
private void OnDisable()
{
if (Init)
{
_SourceVertexBuffer.Release();
_SourceIndexBuffer.Release();
_DrawTrianglesBuffer.Release();
_SourceVertexBuffer = null;
_SourceIndexBuffer = null;
_DrawTrianglesBuffer = null;
}
Init = false;
}
}
/////////////////////////////////////////////
My Shader:
/////////////////////////////////////////////
Shader "Hidden/Lattice2DShader"
{
Properties
{
_Color("Color",color) = (1,1,1,1)
_MainTex ("Texture", 2D) = "white" {}
[Header(Stencil Operations)][Space(5)]
_StencilComp ("Stencil Comparison", Float) = 8
_Stencil ("Stencil ID", Float) = 0
_StencilOp ("Stencil Operation", Float) = 0
_StencilWriteMask ("Stencil Write Mask", Float) = 255
_StencilReadMask ("Stencil Read Mask", Float) = 255
_ColorMask ("Color Mask", Float) = 15
}
CGINCLUDE
// Upgrade NOTE: excluded shader from DX11; has structs without semantics (struct v2f members positionWS,uv)
#pragma exclude_renderers d3d11
struct appdata
{
float3 vertex : POSITION;
float2 uv : TEXCOORD0;
};
struct v2f {
float4 vertex : SV_POSITION; // Position in clip space
float3 normalWS : TEXCOORD0; // Normal vector in world space
float2 uv : TEXCOORD1; // UVs
};
struct DrawTriangle //struct that will be outputed
{
float3 normalWS;
appdata vertices[3]; //since each triangle needs 3 vertex but only 1 normal because all vertices will share it
};
StructuredBuffer<DrawTriangle> _DrawTrianglesBuffer;
ENDCG
SubShader{
Tags {
"RenderType" = "Transparent"
"Queue" = "Transparent"
"IgnoreProjector" = "True"
"PreviewType" = "Plane"
}
Stencil
{
Ref [_Stencil]
Comp [_StencilComp]
Pass [_StencilOp]
ReadMask [_StencilReadMask]
WriteMask [_StencilWriteMask]
}
ColorMask [_ColorMask]
Cull Off Lighting Off ZWrite Off ZTest Always
Blend SrcAlpha OneMinusSrcAlpha
Pass {
Name "ForwardLit"
CGPROGRAM
#pragma vertex vert
#pragma fragment frag
#pragma target 5.0
#pragma multi_compile __ UNITY_UI_CLIP_RECT
#include "UnityCG.cginc"
#include "UnityUI.cginc"
sampler2D _MainTex;
float4 _MainTex_ST;
v2f vert(uint vertexID: SV_VertexID)
{
v2f o = (v2f)0;
DrawTriangle readTri = _DrawTrianglesBuffer[vertexID / 3];
appdata readVertices = readTri.vertices[vertexID % 3];
o.vertex = UnityObjectToClipPos(readVertices.vertex);
o.normalWS = readTri.normalWS;
o.uv = TRANSFORM_TEX(readVertices.uv, _MainTex);
return o;
}
float4 frag(v2f input) : SV_Target {
float3 albedo = tex2D(_MainTex,input.uv).rgb;
return float4(albedo,1);
}
ENDCG
}
}
}