Great work! The code from Mytino worked for me. In case anyone wants to try the compute shader, here I post the script for calling the InclusiveScan.
- Prepare group buffer in pyramids (in case more than 512512 or even more than 512512*512)
- Scan groups in a V-loop manner
The modified compute shader (add support for the total number to scan).
// https://discussions.unity.com/t/692358/13
#define THREADS_PER_GROUP 512 // Ensure that this equals the "threadsPerGroup" variables in the host scripts using this.
int N;
StructuredBuffer<uint> InputBufR;
RWStructuredBuffer<uint> OutputBufW;
groupshared uint bucket[THREADS_PER_GROUP];
void Scan(uint id, uint gi, uint x)
{
bucket[gi] = x;
[unroll]
for (uint t = 1; t < THREADS_PER_GROUP; t <<= 1) {
GroupMemoryBarrierWithGroupSync();
uint temp = bucket[gi];
if (gi >= t) temp += bucket[gi - t];
GroupMemoryBarrierWithGroupSync();
bucket[gi] = temp;
}
OutputBufW[id] = bucket[gi];
}
// Perform isolated scans within each group.
#pragma kernel ScanInGroupsInclusive
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanInGroupsInclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex)
{
uint x = 0;
if ((int)id < N)
x = InputBufR[id];
Scan(id, gi, x);
}
// Perform isolated scans within each group. Shift the input so as to make the final
// result (obtained after the ScanSums and AddScannedSums calls) exclusive.
#pragma kernel ScanInGroupsExclusive
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanInGroupsExclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex)
{
//uint x = (id == 0) ? 0 : InputBufR[id - 1];
uint idx = (id - 1);
uint x = 0;
if ((int)idx >= 0 && (int)idx < N)
x = InputBufR[idx];
Scan(id, gi, x);
}
// Scan the sums of each of the groups (partial sums) from the preceding ScanInGroupsInclusive/Exclusive call.
#pragma kernel ScanSums
[numthreads(THREADS_PER_GROUP, 1, 1)]
void ScanSums(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex)
{
//uint x = (id == 0) ? 0 : InputBufR[id * THREADS_PER_GROUP - 1];
uint idx = (id * THREADS_PER_GROUP - 1);
uint x = 0;
if ((int)idx >= 0 && (int)idx < N)
x = InputBufR[idx];
Scan(id, gi, x);
}
// Add the scanned sums to the output of the first kernel call, to get the final, complete prefix sum.
#pragma kernel AddScannedSums
[numthreads(THREADS_PER_GROUP, 1, 1)]
void AddScannedSums(uint id : SV_DispatchThreadID, uint gid : SV_GroupID)
{
if ((int)id < N)
OutputBufW[id] += InputBufR[gid];
}
The script to call the compute shader.
struct ScanHelper
{
const int threadsPerGroup = 512; // THREADS_PER_GROUP in ScanOperations.compute
public int size;
public List<ComputeBuffer> group_buffer;
public List<int> work_size;
public void InclusiveScan(int num, ComputeShader scanOperations,
ComputeBuffer inputs, ComputeBuffer outputs)
{
this.RequireBuffer(num);
// 1. Per group scan
int kernelScan = scanOperations.FindKernel("ScanInGroupsInclusive");
scanOperations.SetInt("N", num);
scanOperations.SetBuffer(kernelScan, "InputBufR", inputs);
scanOperations.SetBuffer(kernelScan, "OutputBufW", outputs);
scanOperations.Dispatch(kernelScan, NUM_GROUPS(num, threadsPerGroup), 1, 1);
if (num < threadsPerGroup)
return;
int kernelScanSums = scanOperations.FindKernel("ScanSums");
int kernelAdd = scanOperations.FindKernel("AddScannedSums");
// 2. Scan per group sum
scanOperations.SetInt("N", num);
scanOperations.SetBuffer(kernelScanSums, "InputBufR", outputs);
scanOperations.SetBuffer(kernelScanSums, "OutputBufW", this.group_buffer[0]);
scanOperations.Dispatch(kernelScanSums, NUM_GROUPS(this.work_size[0], threadsPerGroup), 1, 1);
// Continue down the pyramid
for (int l = 0; l < this.group_buffer.Count - 1; ++l)
{
int work_sz = this.work_size[l];
// 2. Scan per group sum
scanOperations.SetInt("N", work_sz);
scanOperations.SetBuffer(kernelScanSums, "InputBufR", this.group_buffer[l]);
scanOperations.SetBuffer(kernelScanSums, "OutputBufW", this.group_buffer[l+1]);
scanOperations.Dispatch(kernelScanSums, NUM_GROUPS(this.work_size[l+1], threadsPerGroup), 1, 1);
}
for (int l = this.group_buffer.Count - 1; l > 0; --l)
{
int work_sz = this.work_size[l - 1];
// 3. Add scanned group sum
scanOperations.SetInt("N", work_sz);
scanOperations.SetBuffer(kernelAdd, "InputBufR", this.group_buffer[l]);
scanOperations.SetBuffer(kernelAdd, "OutputBufW", this.group_buffer[l - 1]);
scanOperations.Dispatch(kernelAdd, NUM_GROUPS(work_sz, threadsPerGroup), 1, 1);
}
// 3. Add scanned group sum
scanOperations.SetInt("N", num);
scanOperations.SetBuffer(kernelAdd, "InputBufR", this.group_buffer[0]);
scanOperations.SetBuffer(kernelAdd, "OutputBufW", outputs);
scanOperations.Dispatch(kernelAdd, this.work_size[0], 1, 1);
}
public void RequireBuffer(int alloc_sz)
{
if (this.size < alloc_sz)
{
this.Release();
this.size = (int)(alloc_sz * 1.5);
this.group_buffer = new List<ComputeBuffer>();
this.work_size = new List<int>();
int work_sz = this.size;
while (work_sz > threadsPerGroup)
{
work_sz = NUM_GROUPS(work_sz, threadsPerGroup);
this.group_buffer.Add(new ComputeBuffer(work_sz, sizeof(uint)));
this.work_size.Add(work_sz);
}
}
}
public void Release()
{
if (group_buffer != null)
{
foreach (ComputeBuffer buffer in group_buffer)
if (buffer != null)
buffer.Dispose();
group_buffer = null;
}
}
}
[SerializeField] ComputeShader scanOperations;
ScanHelper mScanHelper;
mScanHelper.InclusiveScan(N, scanOperations, inputs, outputs);
- Make sure THREADS_PER_GROUP and threadsPerGroup are the same. I’ve tested both 512 and 1024 works.