Compiling long and complex compute shaders

I have written a compute shader which takes approximately 350 lines (with all #includes) and now I’m trying to load it in Unity editor.
Here is the code:
Common.cginc

#ifndef COMMON_CGINC
#define COMMON_CGINC

struct BorderedTriangle2D
{
    float2 v[3];
    uint borders;
};

void SetBorder(inout uint dest, uint destIndex, bool value)
{
    if (value)
        dest |= (1 << destIndex);
    else
        dest &= ~(1 << destIndex);
}

bool HasBorder(in uint src, uint srcIndex)
{
    return (src & (1 << srcIndex)) != 0;
}

struct Triangle
{
    float3 v[3];
};

float2 IntersectX(float2 from, float2 to, float x)
{
    float ratio = (x - from.x) / (to.x - from.x);
    return float2(x, lerp(from.y, to.y, ratio));
}

float2 IntersectY(float2 from, float2 to, float y)
{
    float ratio = (y - from.y) / (to.y - from.y);
    return float2(lerp(from.x, to.x, ratio), y);
}

float2 Turn(float2 v, float angle)
{
    float sin, cos;
    sincos(angle, sin, cos);
    return float2(v.x * cos - v.y * sin, v.x * sin + v.y * cos);
}

const float PI = 3.14159265f;
#endif

BorderedTriangle_ClipShiftSpatialize.cginc

#ifndef BORDEREDTRIANGLE_CLIPSHIFTSPATIALIZE_CGINC
#define BORDEREDTRIANGLE_CLIPSHIFTSPATIALIZE_CGINC

#include "Common.cginc"

void SplitX(BorderedTriangle2D inp, float x, bool positive, out BorderedTriangle2D res[2], out uint number)
{
    uint cntUpper = 0;
    for (uint i = 0; i < 3; i++)
        if (positive ? inp.v[i].x >= x : inp.v[i].x <= x)
            cntUpper++;
    uint idx;
    float2 left, right, cur;
    switch (cntUpper)
    {
        case 0:
            number = 0;
            break;
        case 3:
            res[0] = inp;
            number = 1;
            break;
        case 1:
            idx = positive ? (inp.v[0].x >= x ? 0 : inp.v[1].x >= x ? 1 : 2) : (inp.v[0].x <= x ? 0 : inp.v[1].x <= x ? 1 : 2);
            left = inp.v[(idx + 2) % 3];
            right = inp.v[(idx + 1) % 3];
            cur = inp.v[idx];
            res[0].v[0] = IntersectX(left, cur, x);
            res[0].v[1] = cur;
            res[0].v[2] = IntersectX(right, cur, x);
            res[0].borders = 4;
            SetBorder(res[0].borders, 0, HasBorder(inp.borders, (idx + 2) % 3));
            SetBorder(res[0].borders, 1, HasBorder(inp.borders, idx));
            number = 1;
            break;
        case 2:
            idx = positive ? (inp.v[0].x < x ? 0 : inp.v[1].x < x ? 1 : 2) : (inp.v[0].x > x ? 0 : inp.v[1].x > x ? 1 : 2);
            left = inp.v[(idx + 2) % 3];
            right = inp.v[(idx + 1) % 3];
            cur = inp.v[idx];
            float2 curRight = IntersectX(cur, right, x);
            res[0].v[0] = left;
            res[0].v[1] = IntersectX(cur, left, x);
            res[0].v[2] = curRight;
            res[0].borders = 2;
            SetBorder(res[0].borders, 0, HasBorder(inp.borders, (idx + 2) % 3));
            res[1].v[0] = left;
            res[1].v[1] = curRight;
            res[1].v[2] = right;
            res[1].borders = 0;
            SetBorder(res[1].borders, 1, HasBorder(inp.borders, idx));
            SetBorder(res[1].borders, 2, HasBorder(inp.borders, (idx + 1) % 3));
            number = 2;
            break;
    }
}

void SplitY(BorderedTriangle2D inp, float y, bool positive, out BorderedTriangle2D res[2], out uint number)
{
    uint cntUpper = 0;
    for (uint i = 0; i < 3; i++)
        if (positive ? inp.v[i].y >= y : inp.v[i].y <= y)
            cntUpper++;
    uint idx;
    float2 left, right, cur;
    switch (cntUpper)
    {
        case 0:
            number = 0;
            break;
        case 3:
            res[0] = inp;
            number = 1;
            break;
        case 1:
            idx = positive ? (inp.v[0].y >= y ? 0 : inp.v[1].y >= y ? 1 : 2) : (inp.v[0].y <= y ? 0 : inp.v[1].y <= y ? 1 : 2);
            left = inp.v[(idx + 2) % 3];
            right = inp.v[(idx + 1) % 3];
            cur = inp.v[idx];
            res[0].v[0] = IntersectY(left, cur, y);
            res[0].v[1] = cur;
            res[0].v[2] = IntersectY(right, cur, y);
            res[0].borders = 4;
            SetBorder(res[0].borders, 0, HasBorder(inp.borders, (idx + 2) % 3));
            SetBorder(res[0].borders, 1, HasBorder(inp.borders, idx));
            number = 1;
            break;
        case 2:
            idx = positive ? (inp.v[0].y < y ? 0 : inp.v[1].y < y ? 1 : 2) : (inp.v[0].y > y ? 0 : inp.v[1].y > y ? 1 : 2);
            left = inp.v[(idx + 2) % 3];
            right = inp.v[(idx + 1) % 3];
            cur = inp.v[idx];
            float2 curRight = IntersectY(cur, right, y);
            res[0].v[0] = left;
            res[0].v[1] = IntersectY(cur, left, y);
            res[0].v[2] = curRight;
            res[0].borders = 2;
            SetBorder(res[0].borders, 0, HasBorder(inp.borders, (idx + 2) % 3));
            res[1].v[0] = left;
            res[1].v[1] = curRight;
            res[1].v[2] = right;
            res[1].borders = 0;
            SetBorder(res[1].borders, 1, HasBorder(inp.borders, idx));
            SetBorder(res[1].borders, 2, HasBorder(inp.borders, (idx + 1) % 3));
            number = 2;
            break;
    }
}

void BorderedTriangle_ClipShiftSpatialize_Shader (BorderedTriangle2D current, uint threadId, float4 clipRect, float4 shift, BorderedTriangle_ClipShiftSpatialize_UserData userData)
{
    for (uint i = 0; i < 4; i++)
    {
        BorderedTriangle2D buf[2];
        uint number;
        if (i % 2 == 0)
            SplitX(current, i == 0 ? clipRect.x : clipRect.z, i == 0, buf, number);
        else
            SplitY(current, i == 1 ? clipRect.y : clipRect.w, i == 1, buf, number);
        uint cNumber = (threadId & (1 << i)) >> i;
        if (cNumber >= number)
            return;
        current = buf[cNumber];
        if (!any(current.v[1] - current.v[0]) || !any(current.v[2] - current.v[1]) || !any(current.v[2] - current.v[0]))
            return;
    }
    if (cross(float3(current.v[1] - current.v[0], 0), float3(current.v[2] - current.v[0], 0)).z > 0)
    {
        float2 t = current.v[1];
        current.v[1] = current.v[2];
        current.v[2] = t;
    }
    Triangle outTriangle;
    outTriangle.v[0] = float3(current.v[0], 0) + shift.xyz;
    outTriangle.v[1] = float3(current.v[1], 0) + shift.xyz;
    outTriangle.v[2] = float3(current.v[2], 0) + shift.xyz;
    BorderedTriangle_ClipShiftSpatialize_Submit(outTriangle, userData);
    outTriangle.v[0] = float3(current.v[0], shift.w) + shift.xyz;
    outTriangle.v[1] = float3(current.v[2], shift.w) + shift.xyz;
    outTriangle.v[2] = float3(current.v[1], shift.w) + shift.xyz;
    BorderedTriangle_ClipShiftSpatialize_Submit(outTriangle, userData);
    for (uint i = 0; i < 3; i++)
        if (HasBorder(current.borders, i))
        {
            float2 cur = current.v[i];
            float2 next = current.v[(i + 1) % 3];
            outTriangle.v[0] = float3(cur, 0) + shift.xyz;
            outTriangle.v[1] = float3(cur, shift.w) + shift.xyz;
            outTriangle.v[2] = float3(next, shift.w) + shift.xyz;
            BorderedTriangle_ClipShiftSpatialize_Submit(outTriangle, userData);
            outTriangle.v[0] = float3(cur, 0) + shift.xyz;
            outTriangle.v[1] = float3(next, shift.w) + shift.xyz;
            outTriangle.v[2] = float3(next, 0) + shift.xyz;
            BorderedTriangle_ClipShiftSpatialize_Submit(outTriangle, userData);

            // TODO Should we always add these?
            outTriangle.v[0] = float3(cur, 0) + shift.xyz;
            outTriangle.v[1] = float3(next, shift.w) + shift.xyz;
            outTriangle.v[2] = float3(cur, shift.w) + shift.xyz;
            BorderedTriangle_ClipShiftSpatialize_Submit(outTriangle, userData);
            outTriangle.v[0] = float3(cur, 0) + shift.xyz;
            outTriangle.v[1] = float3(next, 0) + shift.xyz;
            outTriangle.v[2] = float3(next, shift.w) + shift.xyz;
            BorderedTriangle_ClipShiftSpatialize_Submit(outTriangle, userData);
        }
}

#endif

Vector2_FixedWidth.shader

#ifndef VECTOR2_FIXEDSHAPE_CGINC
#define VECTOR2_FIXEDSHAPE_CGINC

#include "Common.cginc"

void OutputCurveTriangles(float2 prev, float2 cur, float2 next, float lineWidth, int from, int to, Vector2_FixedShape_UserData userData)
{
    float2 back = normalize(prev - cur);
    float2 normal1 = float2(-back.y, back.x);
    if (normal1.y < 0)
        normal1 = -normal1;
    float2 front = normalize(next - cur);
    float2 bisector;
    if (!any(back + front))
    {
        bisector = normalize(back - front);
        bisector = float2(-bisector.y, bisector.x);
    }
    else
        bisector = normalize(back + front);
    float2 normal2 = float2(-front.y, front.x);
    if (dot(normal1, bisector) > 0)
        normal1 = -normal1;
    if (dot(normal2, bisector) > 0)
        normal2 = -normal2;
    float startAngle = atan2(normal1.y, normal1.x);
    float endAngle = atan2(normal2.y, normal2.x);
    BorderedTriangle2D result;
    result.borders = 2;
    result.v[0] = cur;
    if (bisector.y < 0)
    {
        if (endAngle > startAngle)
            endAngle -= PI * 2;
        for (int part = from; part < to; part++)
        {
            float intermediate = lerp(0, endAngle - startAngle, part / 6.);
            float prevIntermed = lerp(0, endAngle - startAngle, (part - 1) / 6.);
            result.v[1] = cur + Turn(normal1, prevIntermed) * lineWidth / 2;
            result.v[2] = cur + Turn(normal1, intermediate) * lineWidth / 2;
            Vector2_FixedShape_Submit(result, userData);
        }
    }
    else
    {
        if (endAngle < startAngle)
            endAngle += PI * 2;
        for (int part = from; part < to; part++)
        {
            float intermediate = lerp(0, endAngle - startAngle, part / 6.);
            float prevIntermed = lerp(0, endAngle - startAngle, (part - 1) / 6.);
            result.v[1] = cur + Turn(normal1, intermediate) * lineWidth / 2;
            result.v[2] = cur + Turn(normal1, prevIntermed) * lineWidth / 2;
            Vector2_FixedShape_Submit(result, userData);
        }
    }
}

void OutputLineTriangles(float2 cur, float2 next, float lineWidth, bool startBound, bool endBound, Vector2_FixedShape_UserData userData)
{
    float2 back = normalize(cur - next);
    float2 normal1 = float2(-back.y, back.x);
    if (normal1.y < 0)
        normal1 = -normal1;
    float2 startTop = cur + normal1 * lineWidth / 2;
    float2 startBottom = cur - normal1 * lineWidth / 2;
    BorderedTriangle2D result;
    result.v[0] = startTop;
    result.v[1] = next - normal1 * lineWidth / 2;
    result.v[2] = startBottom;
    result.borders = 2 | (startBound << 2);
    Vector2_FixedShape_Submit(result, userData);
    result.v[2] = result.v[1];
    result.v[1] = next + normal1 * lineWidth / 2;
    result.borders = 1 | (endBound << 1);
    Vector2_FixedShape_Submit(result, userData);
}

void Vector2_FixedShape_Shader(float2 prevPoint, float2 cPoint, float2 nextPoint, bool hasPrevPoint, bool hasNextPoint, bool hasNextNextPoint, float lineWidth, uint threadId, Vector2_FixedShape_UserData userData)
{
    if (!hasNextPoint)
        return;
    switch (threadId)
    {
        case 0: case 1: case 2:
            if (hasPrevPoint)
                OutputCurveTriangles(prevPoint, cPoint, nextPoint, lineWidth, 1 + threadId * 2, 3 + threadId * 2, userData);
            break;
        case 3:
            OutputLineTriangles(cPoint, nextPoint, lineWidth, !hasPrevPoint, !hasNextNextPoint, userData);
            break;
    }
}
#endif

Vector2_FixedShape_ClipShiftSpatialize.compute

CGPROGRAM
#pragma kernel Vector2_FixedShape_ClipShiftSpatialize

#include "Common.cginc"

struct ParamStructure
{
    float4 Shift;
    float4 ClipRect;
    float LineWidth;
};

struct NumberedVector
{
    int Number;
    float2 Value;
};

StructuredBuffer<NumberedVector> Input;
StructuredBuffer<uint> InputIndirectArgs;
StructuredBuffer<ParamStructure> Parameters;
globallycoherent RWStructuredBuffer<int> OutputIndirectArgs;
AppendStructuredBuffer<Triangle> Output;

struct BorderedTriangle_ClipShiftSpatialize_UserData
{
};

void BorderedTriangle_ClipShiftSpatialize_Submit(Triangle tri, BorderedTriangle_ClipShiftSpatialize_UserData userData)
{
    Output.Append(tri);
    InterlockedAdd(OutputIndirectArgs[0], 3);
}

#include "BorderedTriangle_ClipShiftSpatialize.cginc"

typedef uint Vector2_FixedShape_UserData;

void Vector2_FixedShape_Submit(BorderedTriangle2D tri, Vector2_FixedShape_UserData userData)
{
    BorderedTriangle_ClipShiftSpatialize_UserData btUserData;
    BorderedTriangle_ClipShiftSpatialize_Shader(tri, userData, Parameters[0].ClipRect, Parameters[0].Shift, btUserData);
}

#include "Vector2_FixedShape.cginc"

[numthreads(4,16,1)]
void Vector2_FixedShape_ClipShiftSpatialize(uint3 groupId: SV_GroupID, uint3 threadId: SV_GroupThreadID)
{
    uint nStructures = InputIndirectArgs[0];
    uint index = groupId[0];
    bool hasNextPoint = index < nStructures - 1 && Input[index + 1].Number == Input[index].Number + 1;
    bool hasNextNextPoint = index < nStructures - 2 && Input[index + 2].Number == Input[index].Number + 2;
    bool hasPrevPoint = index > 0 && Input[index - 1].Number == Input[index].Number - 1;
    float2 prevPoint = hasPrevPoint ? Input[index - 1].Value : 0;
    float2 cPoint = Input[index].Value;
    float2 nextPoint = hasNextPoint ? Input[index + 1].Value : 0;
    Vector2_FixedShape_Shader(prevPoint, cPoint, nextPoint, hasPrevPoint, hasNextPoint, hasNextNextPoint, Parameters[0].LineWidth, threadId[0], threadId[1]);
}
ENDCG

The non-optimized shader bytecode is about 2000 instructions with a number of branches.
As you may notice, I am attempting to process shader input in a pipeline fashion:

  • the shader gets an array of Vector2 as input;
  • a line shape through input points is produced as a set of 2D triangles;
  • each triangle is then converted to a number of 3D triangles, forming a spatial line shape which is further displayed.

I have tried to split my code into a number of shaders in order to apply them consecutively. This approach led to a flickering problem when the resulting pipeline of separate shaders is run every frame. I don’t want to use geometry shader for drawing this because the shader input is not going to change each and every next frame.

Can you advise me a way of using long and complex compute shaders in Unity without waiting half-an-hour till the shader compiles? Is my only option to compile the shader for a specific platform in runtime by means of native plugins?

I have written a C++ code that compiles the shader for a single platform (D3D11, cs_5_0) and creates the shader object:
main.cpp

#include <D3DCompiler.h>
#include <stdio.h>
#include <d3d11.h>
#include <chrono>

using namespace std::chrono;

int main()
{
    ID3D11Device *device;
    HRESULT hr;
    hr = D3D11CreateDevice(nullptr, D3D_DRIVER_TYPE_HARDWARE, nullptr, 0, nullptr, 0, D3D11_SDK_VERSION, &device, nullptr, nullptr);
    if (hr != S_OK)
    {
        printf("Failed to create d3d11 device: %u\n", hr);
        return -1;
    }
    ID3DBlob *outBlob, *errorBlob;
    milliseconds start = duration_cast<milliseconds>(system_clock::now().time_since_epoch());
    hr = D3DCompileFromFile(L"Vector2_FixedShape_ClipShiftSpatialize.hlsl", nullptr, D3D_COMPILE_STANDARD_FILE_INCLUDE, "Vector2_FixedShape_ClipShiftSpatialize", "cs_5_0", D3DCOMPILE_OPTIMIZATION_LEVEL3, 0, &outBlob, &errorBlob);
    milliseconds end = duration_cast<milliseconds>(system_clock::now().time_since_epoch());
    if (hr != S_OK)
    {
        printf("Compilation error with code %u\n", hr);
        if (errorBlob)
        {
            fwrite(errorBlob->GetBufferPointer(), 1, errorBlob->GetBufferSize(), stdout);
            errorBlob->Release();
        }
        return -1;
    }
    ID3D11ComputeShader* d3d11Shader = nullptr;
    hr = device->CreateComputeShader(outBlob->GetBufferPointer(), outBlob->GetBufferSize(), nullptr, &d3d11Shader);
    if (hr == S_OK)
       printf("Shader compilation for a single platform took %I64u millis, bytecode length %u\n", (end - start).count(), outBlob->GetBufferSize());
    else
        printf("Shader creation error with code %u\n", hr);
    if (d3d11Shader)
        d3d11Shader->Release();
    if (outBlob)
        outBlob->Release();
    if (errorBlob)
        errorBlob->Release();
    device->Release();
    getc(stdin);
    return 0;
}

This program outputs the following: Shader compilation for a single platform took 4767 millis, bytecode length 55008
The shader is compiled for a single platform and loaded in 5 seconds. At the same time, it takes eternity (more than a night) to import it in Unity.

UPD: It appears that CSSetShader takes much longer time than the loading procedure in C++. I’m continuing investigation; maybe I’m really not supposed to run compute shader programs with length more than 50000 bytes.

UPD[2]: I have optimized shader code a little and reduced its size to 15000. The shader was successfully loaded in Unity after this; looks like my issue is resolved.

can’t fully understand your problem and chosen solution, but I had a similar issue
glad I actually took action to (help myself and) look it up
simply using “#include” inside compute shader and separating the code into multiple compute shaders worked wonders

Leaving this here for future travellers:

In Compute shaders add [fastopt] on the line preceding ANY for loop. It’s obviously not the best idea for production but a simple stripping process can take care of that when it needs to be compiled for production, this will improve compilation speed by a factor of around 100 (7 minute compiles for a single character change will now be 4 seconds), at least it will if you have a lot of complex data that requires loops. My guess is the process used to determine if a loop is better off being unrolled can be extremely aggressive, so shut it up until you NEED to let it make such an intense determination. 100x perf gain, that sort of thing should really be part of Unity’s shader compilation options.

2 Likes