#pragma kernel Kernel #pragma kernel CopyBuffer // Ref: https://poniesandlight.co.uk/reflect/bitonic_merge_sort/ #pragma multi_compile _ STAGE_BMS STAGE_LOCAL_DISPERSE STAGE_BIG_FLIP STAGE_BIG_DISPERSE #pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch // Disable warning for auto unrolling of single iteration loop. #pragma warning(disable : 3557) #define LOCAL_SIZE 1024 #define MAX_DISPERSE_UNROLL_COUNT 16 struct Semantics { uint groupIndex : SV_GroupIndex; uint3 groupID : SV_GroupID; uint3 dispatchThreadID : SV_DispatchThreadID; }; //resources for buffer copy RWByteAddressBuffer _CopySrcBuffer; RWByteAddressBuffer _CopyDstBuffer; uint _CopyEntriesCount; static Semantics s_Semantics; uint _H; uint _Total; // Global scratch space. RWByteAddressBuffer _KeyBuffer; RWByteAddressBuffer _ValueBuffer; // Local scratch space. groupshared uint gs_Keys [LOCAL_SIZE * 2]; groupshared uint gs_Values [LOCAL_SIZE * 2]; void GlobalCompareAndSwap(int2 i) { const uint key0 = _KeyBuffer.Load(i.x << 2); const uint key1 = _KeyBuffer.Load(i.y << 2); const uint val0 = _ValueBuffer.Load(i.x << 2); const uint val1 = _ValueBuffer.Load(i.y << 2); if (key0 < key1) { _KeyBuffer.Store(i.x << 2, key1); _KeyBuffer.Store(i.y << 2, key0); _ValueBuffer.Store(i.x << 2, val1); _ValueBuffer.Store(i.y << 2, val0); } } void LocalCompareAndSwap(int2 i) { const uint key0 = gs_Keys[i.x]; const uint key1 = gs_Keys[i.y]; const uint val0 = gs_Values[i.x]; const uint val1 = gs_Values[i.y]; if (key0 < key1) { gs_Keys[i.x] = key1; gs_Keys[i.y] = key0; gs_Values[i.x] = val1; gs_Values[i.y] = val0; } } void BigFlip(uint h) { if (LOCAL_SIZE * 2 > h) return; uint t_p = s_Semantics.dispatchThreadID.x; uint half_h = h >> 1; uint q = ((2 * t_p) / h) * h; uint x = q + (t_p % half_h); uint y = q + h - (t_p % half_h) - 1; GlobalCompareAndSwap(int2(x, y)); } void BigDisperse(uint h) { if (LOCAL_SIZE * 2 > h) return; uint t_p = s_Semantics.dispatchThreadID.x; uint half_h = h >> 1; uint q = ((2 * t_p) / h) * h; uint x = q + (t_p % (half_h)); uint y = q + (t_p % (half_h)) + half_h; GlobalCompareAndSwap(int2(x, y)); } void LocalFlip(uint h) { const uint t = s_Semantics.groupIndex; GroupMemoryBarrierWithGroupSync(); uint half_h = h >> 1; uint q = h * ( ( 2 * t ) / h ); uint x = q + (t % half_h); uint y = q + (h - 1 - ( t % half_h )); LocalCompareAndSwap(int2(x, y)); } void LocalDisperse(uint h) { const uint t = s_Semantics.groupIndex; [unroll(MAX_DISPERSE_UNROLL_COUNT)] for ( ; h > 1 ; h /= 2 ) { GroupMemoryBarrierWithGroupSync(); uint half_h = h >> 1; uint q = h * ( ( 2 * t ) / h ); uint x = q + (t % half_h); uint y = q + (half_h + ( t % half_h )); LocalCompareAndSwap(int2(x, y)); } } void LocalBMS(uint h) { const uint t = s_Semantics.groupIndex; for ( uint hh = 2; hh <= h; hh <<= 1 ) { LocalFlip(hh); LocalDisperse( hh / 2 ); } } [numthreads(LOCAL_SIZE, 1, 1)] void Kernel(Semantics s) { // Push the semantics to static global. s_Semantics = s; const uint t = s_Semantics.groupIndex; uint offset = 2 * LOCAL_SIZE * s_Semantics.groupID.x; #if defined(STAGE_BMS) || defined(STAGE_LOCAL_DISPERSE) if (t * 2 + 1 < _Total) { gs_Keys [t * 2 + 0] = _KeyBuffer.Load ((offset + t * 2 + 0) << 2); gs_Keys [t * 2 + 1] = _KeyBuffer.Load ((offset + t * 2 + 1) << 2); gs_Values[t * 2 + 0] = _ValueBuffer.Load ((offset + t * 2 + 0) << 2); gs_Values[t * 2 + 1] = _ValueBuffer.Load ((offset + t * 2 + 1) << 2); } else { gs_Keys [t * 2 + 0] = 0; gs_Keys [t * 2 + 1] = 0; gs_Values[t * 2 + 0] = 0; gs_Values[t * 2 + 1] = 0; } #endif #ifdef STAGE_BMS LocalBMS(_H); #elif STAGE_LOCAL_DISPERSE LocalDisperse(_H); #elif STAGE_BIG_FLIP BigFlip(_H); #elif STAGE_BIG_DISPERSE BigDisperse(_H); #endif #if defined(STAGE_BMS) || defined(STAGE_LOCAL_DISPERSE) GroupMemoryBarrierWithGroupSync(); if (t * 2 + 1 < _Total) { _KeyBuffer.Store ((offset + t * 2 + 0) << 2, gs_Keys [t * 2 + 0]); _KeyBuffer.Store ((offset + t * 2 + 1) << 2, gs_Keys [t * 2 + 1]); _ValueBuffer.Store((offset + t * 2 + 0) << 2, gs_Values[t * 2 + 0]); _ValueBuffer.Store((offset + t * 2 + 1) << 2, gs_Values[t * 2 + 1]); } #endif } [numthreads(64, 1, 1)] void CopyBuffer(Semantics s) { if(s.dispatchThreadID.x < _CopyEntriesCount) _CopyDstBuffer.Store(s.dispatchThreadID.x << 2, _CopySrcBuffer.Load(s.dispatchThreadID.x << 2)); }