You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

207 lines
4.8 KiB

#pragma kernel Kernel
#pragma kernel CopyBuffer
// Ref: https://poniesandlight.co.uk/reflect/bitonic_merge_sort/
#pragma multi_compile _ STAGE_BMS STAGE_LOCAL_DISPERSE STAGE_BIG_FLIP STAGE_BIG_DISPERSE
#pragma only_renderers d3d11 playstation xboxone xboxseries vulkan metal switch
// Disable warning for auto unrolling of single iteration loop.
#pragma warning(disable : 3557)
#define LOCAL_SIZE 1024
#define MAX_DISPERSE_UNROLL_COUNT 16
struct Semantics
{
uint groupIndex : SV_GroupIndex;
uint3 groupID : SV_GroupID;
uint3 dispatchThreadID : SV_DispatchThreadID;
};
//resources for buffer copy
RWByteAddressBuffer _CopySrcBuffer;
RWByteAddressBuffer _CopyDstBuffer;
uint _CopyEntriesCount;
static Semantics s_Semantics;
uint _H;
uint _Total;
// Global scratch space.
RWByteAddressBuffer _KeyBuffer;
RWByteAddressBuffer _ValueBuffer;
// Local scratch space.
groupshared uint gs_Keys [LOCAL_SIZE * 2];
groupshared uint gs_Values [LOCAL_SIZE * 2];
void GlobalCompareAndSwap(int2 i)
{
const uint key0 = _KeyBuffer.Load(i.x << 2);
const uint key1 = _KeyBuffer.Load(i.y << 2);
const uint val0 = _ValueBuffer.Load(i.x << 2);
const uint val1 = _ValueBuffer.Load(i.y << 2);
if (key0 < key1)
{
_KeyBuffer.Store(i.x << 2, key1);
_KeyBuffer.Store(i.y << 2, key0);
_ValueBuffer.Store(i.x << 2, val1);
_ValueBuffer.Store(i.y << 2, val0);
}
}
void LocalCompareAndSwap(int2 i)
{
const uint key0 = gs_Keys[i.x];
const uint key1 = gs_Keys[i.y];
const uint val0 = gs_Values[i.x];
const uint val1 = gs_Values[i.y];
if (key0 < key1)
{
gs_Keys[i.x] = key1;
gs_Keys[i.y] = key0;
gs_Values[i.x] = val1;
gs_Values[i.y] = val0;
}
}
void BigFlip(uint h)
{
if (LOCAL_SIZE * 2 > h)
return;
uint t_p = s_Semantics.dispatchThreadID.x;
uint half_h = h >> 1;
uint q = ((2 * t_p) / h) * h;
uint x = q + (t_p % half_h);
uint y = q + h - (t_p % half_h) - 1;
GlobalCompareAndSwap(int2(x, y));
}
void BigDisperse(uint h)
{
if (LOCAL_SIZE * 2 > h)
return;
uint t_p = s_Semantics.dispatchThreadID.x;
uint half_h = h >> 1;
uint q = ((2 * t_p) / h) * h;
uint x = q + (t_p % (half_h));
uint y = q + (t_p % (half_h)) + half_h;
GlobalCompareAndSwap(int2(x, y));
}
void LocalFlip(uint h)
{
const uint t = s_Semantics.groupIndex;
GroupMemoryBarrierWithGroupSync();
uint half_h = h >> 1;
uint q = h * ( ( 2 * t ) / h );
uint x = q + (t % half_h);
uint y = q + (h - 1 - ( t % half_h ));
LocalCompareAndSwap(int2(x, y));
}
void LocalDisperse(uint h)
{
const uint t = s_Semantics.groupIndex;
[unroll(MAX_DISPERSE_UNROLL_COUNT)]
for ( ; h > 1 ; h /= 2 )
{
GroupMemoryBarrierWithGroupSync();
uint half_h = h >> 1;
uint q = h * ( ( 2 * t ) / h );
uint x = q + (t % half_h);
uint y = q + (half_h + ( t % half_h ));
LocalCompareAndSwap(int2(x, y));
}
}
void LocalBMS(uint h)
{
const uint t = s_Semantics.groupIndex;
for ( uint hh = 2; hh <= h; hh <<= 1 )
{
LocalFlip(hh);
LocalDisperse( hh / 2 );
}
}
[numthreads(LOCAL_SIZE, 1, 1)]
void Kernel(Semantics s)
{
// Push the semantics to static global.
s_Semantics = s;
const uint t = s_Semantics.groupIndex;
uint offset = 2 * LOCAL_SIZE * s_Semantics.groupID.x;
#if defined(STAGE_BMS) || defined(STAGE_LOCAL_DISPERSE)
if (t * 2 + 1 < _Total)
{
gs_Keys [t * 2 + 0] = _KeyBuffer.Load ((offset + t * 2 + 0) << 2);
gs_Keys [t * 2 + 1] = _KeyBuffer.Load ((offset + t * 2 + 1) << 2);
gs_Values[t * 2 + 0] = _ValueBuffer.Load ((offset + t * 2 + 0) << 2);
gs_Values[t * 2 + 1] = _ValueBuffer.Load ((offset + t * 2 + 1) << 2);
}
else
{
gs_Keys [t * 2 + 0] = 0;
gs_Keys [t * 2 + 1] = 0;
gs_Values[t * 2 + 0] = 0;
gs_Values[t * 2 + 1] = 0;
}
#endif
#ifdef STAGE_BMS
LocalBMS(_H);
#elif STAGE_LOCAL_DISPERSE
LocalDisperse(_H);
#elif STAGE_BIG_FLIP
BigFlip(_H);
#elif STAGE_BIG_DISPERSE
BigDisperse(_H);
#endif
#if defined(STAGE_BMS) || defined(STAGE_LOCAL_DISPERSE)
GroupMemoryBarrierWithGroupSync();
if (t * 2 + 1 < _Total)
{
_KeyBuffer.Store ((offset + t * 2 + 0) << 2, gs_Keys [t * 2 + 0]);
_KeyBuffer.Store ((offset + t * 2 + 1) << 2, gs_Keys [t * 2 + 1]);
_ValueBuffer.Store((offset + t * 2 + 0) << 2, gs_Values[t * 2 + 0]);
_ValueBuffer.Store((offset + t * 2 + 1) << 2, gs_Values[t * 2 + 1]);
}
#endif
}
[numthreads(64, 1, 1)]
void CopyBuffer(Semantics s)
{
if(s.dispatchThreadID.x < _CopyEntriesCount)
_CopyDstBuffer.Store(s.dispatchThreadID.x << 2, _CopySrcBuffer.Load(s.dispatchThreadID.x << 2));
}