You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

216 lines
7.4 KiB

#pragma kernel MainCalculateLevelDispatchArgsFromConst
#pragma kernel MainCalculateLevelDispatchArgsFromBuffer
#pragma kernel MainPrefixSumOnGroup
#pragma kernel MainPrefixSumOnGroupExclusive
#pragma kernel MainPrefixSumNextInput
#pragma kernel MainPrefixSumResolveParent
#pragma kernel MainPrefixSumResolveParentExclusive
#include "Packages/com.unity.render-pipelines.core/Runtime/Utilities/GPUPrefixSum/GPUPrefixSum.Data.cs.hlsl"
#pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch
// #pragma enable_d3d11_debug_symbols
ByteAddressBuffer _InputBuffer;
RWByteAddressBuffer _OutputBuffer;
ByteAddressBuffer _InputCountBuffer;
ByteAddressBuffer _TotalLevelsBuffer;
RWByteAddressBuffer _OutputTotalLevelsBuffer;
RWBuffer<uint> _OutputDispatchLevelArgsBuffer;
StructuredBuffer<LevelOffsets> _LevelsOffsetsBuffer;
RWStructuredBuffer<LevelOffsets> _OutputLevelsOffsetsBuffer;
float4 _PrefixSumIntArgs;
#define _InputCount asint(_PrefixSumIntArgs.x)
#define _MaxLevels asint(_PrefixSumIntArgs.y)
#define _InputOffset asint(_PrefixSumIntArgs.z)
#define _CurrentLevel asint(_PrefixSumIntArgs.w)
uint DivUpGroup(uint v)
{
return (v + GROUP_SIZE - 1) / GROUP_SIZE;
}
uint AlignUpGroup(uint v)
{
return DivUpGroup(v) * GROUP_SIZE;
}
void MainCalculateLevelOffsetsCommon(uint currentCount)
{
//Assume that the input max elements is already aligned.
uint alignedSupportMaxCount = AlignUpGroup(currentCount);
uint prevSize = 0;
uint totalSize = 0;
uint levelCounts = 0;
LevelOffsets offsets;
bool canReduce = alignedSupportMaxCount > 0;
[loop]
while (canReduce)
{
LevelOffsets offsets;
offsets.count = levelCounts == 0 ? currentCount : alignedSupportMaxCount;
offsets.offset = totalSize;
offsets.parentOffset = prevSize;
uint groupCount = DivUpGroup(alignedSupportMaxCount);
_OutputLevelsOffsetsBuffer[levelCounts] = offsets;
//upper dispatch arguments
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * levelCounts) + ARGS_BUFFER_UPPER +0] = groupCount;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * levelCounts) + ARGS_BUFFER_UPPER +1] = 1;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * levelCounts) + ARGS_BUFFER_UPPER +2] = 1;
//lower dispatch arguments, cancel it if we have no first group.
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * levelCounts) + ARGS_BUFFER_LOWER +0] = groupCount == 1 && levelCounts == 0 ? 0 : groupCount;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * levelCounts) + ARGS_BUFFER_LOWER +1] = 1;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * levelCounts) + ARGS_BUFFER_LOWER +2] = 1;
prevSize = totalSize;
totalSize += alignedSupportMaxCount;
++levelCounts;
if (alignedSupportMaxCount <= GROUP_SIZE)
canReduce = false;
alignedSupportMaxCount = AlignUpGroup(groupCount);
}
//zero out all the rest of the dispatch levels
[loop]
for (int i = levelCounts; i < _MaxLevels; ++i)
{
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * i) + ARGS_BUFFER_UPPER +0] = 0;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * i) + ARGS_BUFFER_UPPER +1] = 0;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * i) + ARGS_BUFFER_UPPER +2] = 0;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * i) + ARGS_BUFFER_LOWER +0] = 0;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * i) + ARGS_BUFFER_LOWER +1] = 0;
_OutputDispatchLevelArgsBuffer[(ARGS_BUFFER_STRIDE * i) + ARGS_BUFFER_LOWER +2] = 0;
}
_OutputTotalLevelsBuffer.Store(0, levelCounts);
}
[numthreads(1,1,1)]
void MainCalculateLevelDispatchArgsFromConst()
{
MainCalculateLevelOffsetsCommon((uint)_InputCount);
}
[numthreads(1,1,1)]
void MainCalculateLevelDispatchArgsFromBuffer()
{
MainCalculateLevelOffsetsCommon(_InputCountBuffer.Load(_InputOffset << 2));
}
groupshared uint gs_LastLevelIndex;
groupshared LevelOffsets gs_CurrentLevelOffsets;
groupshared uint gs_prefixCache[GROUP_SIZE];
void MainPrefixSumOnGroupCommon(int3 dispatchThreadID, int groupThreadIndex, bool isExclusive)
{
if (groupThreadIndex == 0)
{
gs_LastLevelIndex = _TotalLevelsBuffer.Load(0) - 1u;
gs_CurrentLevelOffsets = _LevelsOffsetsBuffer[_CurrentLevel];
}
GroupMemoryBarrierWithGroupSync();
int threadID = dispatchThreadID.x;
uint inputVal = (uint)threadID >= gs_CurrentLevelOffsets.count ? 0u : _InputBuffer.Load(threadID << 2);
gs_prefixCache[groupThreadIndex] = inputVal;
GroupMemoryBarrierWithGroupSync();
//Hillis Steele Scan
for (int i = 1; i < GROUP_SIZE; i <<= 1)
{
uint val = groupThreadIndex >= i ? gs_prefixCache[groupThreadIndex - i] : 0u;
GroupMemoryBarrierWithGroupSync();
gs_prefixCache[groupThreadIndex] += val;
GroupMemoryBarrierWithGroupSync();
}
uint outputVal = gs_prefixCache[groupThreadIndex];
if (isExclusive && gs_LastLevelIndex == 0)
outputVal -= inputVal;
_OutputBuffer.Store((threadID + gs_CurrentLevelOffsets.offset) << 2, outputVal);
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainPrefixSumOnGroup(int3 dispatchThreadID : SV_DispatchThreadID, int groupThreadIndex : SV_GroupIndex)
{
MainPrefixSumOnGroupCommon(dispatchThreadID, groupThreadIndex, false);
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainPrefixSumOnGroupExclusive(int3 dispatchThreadID : SV_DispatchThreadID, int groupThreadIndex : SV_GroupIndex)
{
MainPrefixSumOnGroupCommon(dispatchThreadID, groupThreadIndex, true);
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainPrefixSumNextInput(int3 dispatchThreadID : SV_DispatchThreadID, int groupThreadIndex : SV_GroupIndex)
{
if (groupThreadIndex == 0)
gs_CurrentLevelOffsets = _LevelsOffsetsBuffer[_CurrentLevel];
GroupMemoryBarrierWithGroupSync();
_OutputBuffer.Store(dispatchThreadID.x << 2, _InputBuffer.Load((gs_CurrentLevelOffsets.offset + dispatchThreadID.x * GROUP_SIZE + GROUP_SIZE - 1) << 2));
}
groupshared uint gs_ParentSum;
void MainPrefixSumResolveParentCommon(int3 dispatchThreadID, int groupThreadIndex, int3 groupID, bool isExclusive)
{
if (groupThreadIndex == 0)
{
gs_CurrentLevelOffsets = _LevelsOffsetsBuffer[_CurrentLevel];
gs_ParentSum = groupID.x == 0 ? 0 : _OutputBuffer.Load((gs_CurrentLevelOffsets.offset + groupID.x - 1) << 2);
}
GroupMemoryBarrierWithGroupSync();
//no need to do barriers / etc since groupID will trigger a scalar load. We hope!!
int index = gs_CurrentLevelOffsets.parentOffset + dispatchThreadID.x;
if (isExclusive && _CurrentLevel == 1)
{
uint val = _OutputBuffer.Load(index << 2) - _InputBuffer.Load(index << 2);
_OutputBuffer.Store(index << 2, val + gs_ParentSum);
}
else
{
_OutputBuffer.Store(index << 2, _OutputBuffer.Load(index << 2) + gs_ParentSum);
}
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainPrefixSumResolveParent(int3 dispatchThreadID : SV_DispatchThreadID, int groupThreadIndex : SV_GroupIndex, int3 groupID : SV_GroupID)
{
MainPrefixSumResolveParentCommon(dispatchThreadID, groupThreadIndex, groupID, false);
}
[numthreads(GROUP_SIZE, 1, 1)]
void MainPrefixSumResolveParentExclusive(int3 dispatchThreadID : SV_DispatchThreadID, int groupThreadIndex : SV_GroupIndex, int3 groupID : SV_GroupID)
{
MainPrefixSumResolveParentCommon(dispatchThreadID, groupThreadIndex, groupID, true);
}