You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
247 lines
8.2 KiB
247 lines
8.2 KiB
/**********************************************************************
|
|
Copyright (c) 2019 Advanced Micro Devices, Inc. All rights reserved.
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
The above copyright notice and this permission notice shall be included in
|
|
all copies or substantial portions of the Software.
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
|
THE SOFTWARE.
|
|
********************************************************************/
|
|
|
|
// WebGPU has extremely strict uniformity requirements that are incompatible with the current implementation of this shader.
|
|
#pragma exclude_renderers webgpu
|
|
|
|
#define SORT_VALUES
|
|
|
|
uint g_constants_num_keys;
|
|
uint g_constants_num_blocks;
|
|
uint g_constants_bit_shift;
|
|
|
|
uint g_input_keys_offset;
|
|
uint g_group_histograms_offset;
|
|
uint g_output_keys_offset;
|
|
uint g_input_values_offset;
|
|
uint g_output_values_offset;
|
|
|
|
RWStructuredBuffer<uint> g_buffer : register(u0);
|
|
|
|
|
|
#define NUM_BITS_PER_PASS 4
|
|
#define NUM_BINS (1 << NUM_BITS_PER_PASS)
|
|
#define GROUP_SIZE 256u
|
|
#define KEYS_PER_THREAD 4u
|
|
//#define HISTOGRAM_TYPE short
|
|
#define INT_MAX 0x7fffffff
|
|
|
|
// Scratch memory for group operations
|
|
groupshared int lds_scratch[GROUP_SIZE];
|
|
// Temporary storage for the keys
|
|
groupshared int lds_keys[GROUP_SIZE];
|
|
// Cache for scanned histogram for faster indexing
|
|
groupshared int lds_scanned_histogram[NUM_BINS];
|
|
// Block histogram
|
|
groupshared int lds_histogram[NUM_BINS];
|
|
|
|
int BlockScan(int key, uint lidx)
|
|
{
|
|
#ifndef USE_WAVE_INTRINSICS
|
|
// Load the key into LDS
|
|
lds_keys[lidx] = key;
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Calculate reduction
|
|
uint stride = 0;
|
|
for (stride = 1; stride < GROUP_SIZE; stride <<= 1)
|
|
{
|
|
if (lidx < GROUP_SIZE / (2 * stride))
|
|
{
|
|
lds_keys[2 * (lidx + 1) * stride - 1] =
|
|
lds_keys[2 * (lidx + 1) * stride - 1] + lds_keys[(2 * lidx + 1) * stride - 1];
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
}
|
|
|
|
// Then put 0 into the root for downsweep
|
|
if (lidx == 0)
|
|
lds_keys[GROUP_SIZE - 1] = 0;
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Perform downsweep
|
|
for (stride = GROUP_SIZE >> 1; stride > 0; stride >>= 1)
|
|
{
|
|
if (lidx < GROUP_SIZE / (2 * stride))
|
|
{
|
|
int temp = lds_keys[(2 * lidx + 1) * stride - 1];
|
|
lds_keys[(2 * lidx + 1) * stride - 1] = lds_keys[2 * (lidx + 1) * stride - 1];
|
|
lds_keys[2 * (lidx + 1) * stride - 1] = lds_keys[2 * (lidx + 1) * stride - 1] + temp;
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
}
|
|
|
|
return lds_keys[lidx];
|
|
#else
|
|
lds_keys[lidx] = 0;
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Perform scan within a subgroup
|
|
int wave_scanned = WavePrefixSum(key);
|
|
|
|
uint widx = lidx / WaveGetLaneCount();
|
|
uint wlidx = WaveGetLaneIndex();
|
|
|
|
// Last element in each subgroup writes partial sum into LDS
|
|
if (wlidx == WaveGetLaneCount() - 1)
|
|
{
|
|
lds_keys[widx] = wave_scanned + key;
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Then first subgroup scannes partial sums
|
|
if (widx == 0)
|
|
{
|
|
lds_keys[lidx] = WavePrefixSum(lds_keys[lidx]);
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// And we add partial sums back to each subgroup-scanned element
|
|
wave_scanned += lds_keys[widx];
|
|
|
|
return wave_scanned;
|
|
#endif
|
|
}
|
|
|
|
//[RootSignature(ROOT_SIGNATURE)]
|
|
#pragma kernel Scatter
|
|
[numthreads(GROUP_SIZE, 1, 1)]
|
|
void Scatter(
|
|
in uint gidx: SV_DispatchThreadID,
|
|
in uint lidx: SV_GroupThreadID,
|
|
in uint bidx: SV_GroupID)
|
|
{
|
|
// Cache scanned histogram in LDS and clear block histogram
|
|
if (lidx < NUM_BINS)
|
|
{
|
|
lds_scanned_histogram[lidx] = g_buffer[g_group_histograms_offset + g_constants_num_blocks * lidx + bidx];
|
|
}
|
|
|
|
// Starting point of our block in global memory
|
|
uint block_start_index = bidx * GROUP_SIZE * KEYS_PER_THREAD;
|
|
// Each thread handles PP_KEYS_PER_THREAD elements
|
|
for (uint i = 0; i < KEYS_PER_THREAD; ++i)
|
|
{
|
|
// Clear block histogram
|
|
if (lidx < NUM_BINS)
|
|
{
|
|
lds_histogram[lidx] = 0;
|
|
}
|
|
|
|
// Calculate next input element index
|
|
uint key_index = block_start_index + i * GROUP_SIZE + lidx;
|
|
|
|
// Fetch next element and put it in LDS
|
|
int key = (key_index < g_constants_num_keys) ? g_buffer[g_input_keys_offset + key_index] : INT_MAX;
|
|
#ifdef SORT_VALUES
|
|
int value = (key_index < g_constants_num_keys) ? g_buffer[g_input_values_offset + key_index] : 0;
|
|
#endif
|
|
|
|
// Sort keys locally in LDS
|
|
for (uint shift = 0; shift < NUM_BITS_PER_PASS; shift += 2)
|
|
{
|
|
// Detemine bin index for the key
|
|
int bin_index = ((key >> g_constants_bit_shift) >> shift) & 0x3;
|
|
// Create local packed histogram (0th in 1st byte, 1th in 2nd, etc)
|
|
int local_histogram = 1 << (bin_index * 8);
|
|
// Scan local histograms
|
|
int local_histogram_scanned = BlockScan(local_histogram, lidx);
|
|
// Last thread in a block broadcasts total block histogram
|
|
if (lidx == (GROUP_SIZE - 1))
|
|
{
|
|
lds_scratch[0] = local_histogram_scanned + local_histogram;
|
|
}
|
|
// Make sure broadcast happened
|
|
GroupMemoryBarrierWithGroupSync();
|
|
// Load broadcast value
|
|
local_histogram = lds_scratch[0];
|
|
// Scan block histogram in order to add scanned values as offsets
|
|
local_histogram = (local_histogram << 8) +
|
|
(local_histogram << 16) +
|
|
(local_histogram << 24);
|
|
// Add offsets from previos bins
|
|
local_histogram_scanned += local_histogram;
|
|
// Calculate target offset
|
|
int offset = (local_histogram_scanned >> (bin_index * 8)) & 0xff;
|
|
// Distribute the key
|
|
lds_keys[offset] = key;
|
|
//
|
|
GroupMemoryBarrierWithGroupSync();
|
|
// Load key
|
|
key = lds_keys[lidx];
|
|
|
|
#ifdef SORT_VALUES
|
|
// Perform value exchange
|
|
GroupMemoryBarrierWithGroupSync();
|
|
lds_keys[offset] = value;
|
|
GroupMemoryBarrierWithGroupSync();
|
|
value = lds_keys[lidx];
|
|
GroupMemoryBarrierWithGroupSync();
|
|
#endif
|
|
}
|
|
|
|
// Reconstruct original 16-bins histogram
|
|
int bin_index = (key >> g_constants_bit_shift) & 0xf;
|
|
|
|
InterlockedAdd(lds_histogram[bin_index], 1);
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Scan original histogram
|
|
int histogram_value = BlockScan(lidx < NUM_BINS ? lds_histogram[lidx] : 0, lidx);
|
|
|
|
// Broadcast scanned histogram via LDS
|
|
if (lidx < NUM_BINS)
|
|
{
|
|
lds_scratch[lidx] = histogram_value;
|
|
}
|
|
|
|
// Fetch scanned block histogram index
|
|
int global_offset = lds_scanned_histogram[bin_index];
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Fetch scanned histogram within a block
|
|
int local_offset = int(lidx) - lds_scratch[bin_index];
|
|
|
|
// Write the element back to global memory
|
|
if (global_offset + local_offset < int(g_constants_num_keys))
|
|
{
|
|
g_buffer[g_output_keys_offset + global_offset + local_offset] = key;
|
|
#ifdef SORT_VALUES
|
|
g_buffer[g_output_values_offset + global_offset + local_offset] = value;
|
|
#endif
|
|
}
|
|
|
|
GroupMemoryBarrierWithGroupSync();
|
|
|
|
// Update scanned histogram
|
|
if (lidx < NUM_BINS)
|
|
{
|
|
lds_scanned_histogram[lidx] += lds_histogram[lidx];
|
|
}
|
|
}
|
|
}
|