From 4be784e42703135d628f32ac23e16113bc4d3f89 Mon Sep 17 00:00:00 2001 From: Nico de Poel Date: Tue, 23 Jul 2024 22:36:37 +0200 Subject: [PATCH] Implemented passes up to and including SCD divergence --- Runtime/OpticalFlow/OpticalFlowContext.cs | 45 +++++++++++++++- Runtime/OpticalFlow/OpticalFlowPass.cs | 57 ++++++++++++++++----- Runtime/OpticalFlow/OpticalFlowResources.cs | 11 ++-- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/Runtime/OpticalFlow/OpticalFlowContext.cs b/Runtime/OpticalFlow/OpticalFlowContext.cs index a3a950e..b556ca2 100644 --- a/Runtime/OpticalFlow/OpticalFlowContext.cs +++ b/Runtime/OpticalFlow/OpticalFlowContext.cs @@ -1,6 +1,7 @@ using System; using System.Runtime.InteropServices; using UnityEngine; +using UnityEngine.Assertions; using UnityEngine.Rendering; namespace FidelityFX.OpticalFlow @@ -31,6 +32,7 @@ namespace FidelityFX.OpticalFlow private bool _firstExecution; private int _resourceFrameIndex; + private readonly Vector2Int[] _opticalFlowTextureSizes = new Vector2Int[OpticalFlowResources.OpticalFlowMaxPyramidLevels]; public void Create(OpticalFlow.ContextDescription contextDescription) { @@ -121,7 +123,7 @@ namespace FidelityFX.OpticalFlow } } - SetupSpdConstants(out var dispatchThreadGroupCount); + SetupSpdConstants(out var threadGroupSizeOpticalFlowInputPyramid); commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray); commandBuffer.SetBufferData(_spdConstantsBuffer, _spdConstantsArray); @@ -135,6 +137,47 @@ namespace FidelityFX.OpticalFlow int dispatchY = ((_contextDescription.resolution.y + (threadPixelsY - 1)) / threadPixelsY + (threadGroupSizeY - 1)) / threadGroupSizeY; _prepareLumaPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY); } + + _generateInputPyramidPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, threadGroupSizeOpticalFlowInputPyramid.x, threadGroupSizeOpticalFlowInputPyramid.y); + + { + const int threadGroupSizeX = 32; + const int threadGroupSizeY = 8; + int strataWidth = (_contextDescription.resolution.x / 4) / OpticalFlowResources.HistogramsPerDim; + int strataHeight = _contextDescription.resolution.y / OpticalFlowResources.HistogramsPerDim; + int dispatchX = (strataWidth + threadGroupSizeX - 1) / threadGroupSizeX; + const int dispatchY = 16; + const int dispatchZ = OpticalFlowResources.HistogramsPerDim * OpticalFlowResources.HistogramsPerDim; + _generateScdHistogramPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY, dispatchZ); + } + { + const int dispatchX = OpticalFlowResources.HistogramsPerDim * OpticalFlowResources.HistogramsPerDim; + const int dispatchY = OpticalFlowResources.HistogramShifts; + _computeScdDivergencePass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY); + } + + const int pyramidMaxIterations = advancedAlgorithmIterations; + Assert.IsTrue(pyramidMaxIterations <= OpticalFlowResources.OpticalFlowMaxPyramidLevels); + + _opticalFlowTextureSizes[0] = OpticalFlowResources.GetOpticalFlowTextureSize(_contextDescription.resolution, opticalFlowBlockSize); + for (int i = 1; i < pyramidMaxIterations; ++i) + { + _opticalFlowTextureSizes[i] = new Vector2Int( + (_opticalFlowTextureSizes[i - 1].x + 1) / 2, + (_opticalFlowTextureSizes[i - 1].y + 1) / 2 + ); + } + + for (int level = pyramidMaxIterations - 1; level >= 0; --level) + { + // TODO: need to flip-flop Optical Flow output resources between levels as well (isOddFrame != isOddLevel) + + Constants.opticalFlowPyramidLevel = (uint)level; + Constants.opticalFlowPyramidLevelCount = pyramidMaxIterations; + commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray); + + // TODO: need to somehow address resources by level, not great when you've got a bazillion loose RT arrays... *sigh* + } _resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames; } diff --git a/Runtime/OpticalFlow/OpticalFlowPass.cs b/Runtime/OpticalFlow/OpticalFlowPass.cs index e4202fb..952ea72 100644 --- a/Runtime/OpticalFlow/OpticalFlowPass.cs +++ b/Runtime/OpticalFlow/OpticalFlowPass.cs @@ -28,14 +28,14 @@ namespace FidelityFX.OpticalFlow { } - public void ScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + public void ScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ = 1) { commandBuffer.BeginSample(_sampler); - DoScheduleDispatch(commandBuffer, dispatchParams, frameIndex, dispatchX, dispatchY); + DoScheduleDispatch(commandBuffer, dispatchParams, frameIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.EndSample(_sampler); } - protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY); + protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ); protected void InitComputeShader(string passName, ComputeShader shader) { @@ -58,7 +58,7 @@ namespace FidelityFX.OpticalFlow InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { ref var color = ref dispatchParams.color; commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, color.RenderTarget, color.MipLevel, color.SubElement); @@ -67,7 +67,7 @@ namespace FidelityFX.OpticalFlow commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); - commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, 1); + commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } @@ -79,11 +79,23 @@ namespace FidelityFX.OpticalFlow : base(contextDescription, resources, constants) { _spdConstants = spdConstants; + InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { - throw new NotImplementedException(); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInput[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevel1[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevel2[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevel3[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevel4[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevel5[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevel6[frameIndex]); + + commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); + commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants, 0, Marshal.SizeOf()); + + commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } @@ -92,11 +104,18 @@ namespace FidelityFX.OpticalFlow public OpticalFlowGenerateSCDHistogramPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { + InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { - throw new NotImplementedException(); + // TODO: probably needs to be input from this frame (result from pyramid pass), but double check to be sure + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInput[frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); + + commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); + + commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } @@ -105,11 +124,21 @@ namespace FidelityFX.OpticalFlow public OpticalFlowComputeSCDDivergencePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { + InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { - throw new NotImplementedException(); + ref var scdOutput = ref dispatchParams.opticalFlowSCD; + + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, scdOutput.RenderTarget, scdOutput.MipLevel, scdOutput.SubElement); + + commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); + + commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } @@ -120,7 +149,7 @@ namespace FidelityFX.OpticalFlow { } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { throw new NotImplementedException(); } @@ -133,7 +162,7 @@ namespace FidelityFX.OpticalFlow { } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { throw new NotImplementedException(); } @@ -146,7 +175,7 @@ namespace FidelityFX.OpticalFlow { } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { throw new NotImplementedException(); } diff --git a/Runtime/OpticalFlow/OpticalFlowResources.cs b/Runtime/OpticalFlow/OpticalFlowResources.cs index a36a198..7856543 100644 --- a/Runtime/OpticalFlow/OpticalFlowResources.cs +++ b/Runtime/OpticalFlow/OpticalFlowResources.cs @@ -5,10 +5,10 @@ namespace FidelityFX.OpticalFlow { internal class OpticalFlowResources { - private const int OpticalFlowMaxPyramidLevels = 7; - private const int HistogramBins = 256; - private const int HistogramsPerDim = 3; - private const int HistogramShifts = 3; + internal const int OpticalFlowMaxPyramidLevels = 7; + internal const int HistogramBins = 256; + internal const int HistogramsPerDim = 3; + internal const int HistogramShifts = 3; public readonly RenderTexture[] OpticalFlowInput = new RenderTexture[2]; public readonly RenderTexture[] OpticalFlowInputLevel1 = new RenderTexture[2]; @@ -107,7 +107,8 @@ namespace FidelityFX.OpticalFlow return (x + (y - 1)) & ~(y - 1); } - private static Vector2Int GetOpticalFlowTextureSize(Vector2Int displaySize, int opticalFlowBlockSize) + // TODO: move these to OpticalFlow class + internal static Vector2Int GetOpticalFlowTextureSize(Vector2Int displaySize, int opticalFlowBlockSize) { int width = (displaySize.x + opticalFlowBlockSize - 1) / opticalFlowBlockSize; int height = (displaySize.y + opticalFlowBlockSize - 1) / opticalFlowBlockSize;