From 404a090adfcbabb42a26c9cc7f5230165ed1fab2 Mon Sep 17 00:00:00 2001 From: Nico de Poel Date: Sun, 4 Aug 2024 01:04:56 +0200 Subject: [PATCH] Reworked optical flow passes to inherit from FfxPassBase, with a bit of a hack to work around the need for the extra level parameter. Made the pass classes sealed, which is a more sensible way to please the compiler about using virtual methods inside constructors. --- .../FrameInterpolationPass.cs | 22 ++--- Runtime/OpticalFlow/OpticalFlowPass.cs | 95 ++++++++----------- 2 files changed, 51 insertions(+), 66 deletions(-) diff --git a/Runtime/FrameInterpolation/FrameInterpolationPass.cs b/Runtime/FrameInterpolation/FrameInterpolationPass.cs index a2bf6ac..2a71d58 100644 --- a/Runtime/FrameInterpolation/FrameInterpolationPass.cs +++ b/Runtime/FrameInterpolation/FrameInterpolationPass.cs @@ -23,7 +23,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationReconstructAndDilatePass : FrameInterpolationPass + internal sealed class FrameInterpolationReconstructAndDilatePass : FrameInterpolationPass { public FrameInterpolationReconstructAndDilatePass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -54,7 +54,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationSetupPass : FrameInterpolationPass + internal sealed class FrameInterpolationSetupPass : FrameInterpolationPass { public FrameInterpolationSetupPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -79,7 +79,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationReconstructPreviousDepthPass : FrameInterpolationPass + internal sealed class FrameInterpolationReconstructPreviousDepthPass : FrameInterpolationPass { public FrameInterpolationReconstructPreviousDepthPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -101,7 +101,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationGameMotionVectorFieldPass : FrameInterpolationPass + internal sealed class FrameInterpolationGameMotionVectorFieldPass : FrameInterpolationPass { public FrameInterpolationGameMotionVectorFieldPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -125,7 +125,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationOpticalFlowVectorFieldPass : FrameInterpolationPass + internal sealed class FrameInterpolationOpticalFlowVectorFieldPass : FrameInterpolationPass { public FrameInterpolationOpticalFlowVectorFieldPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -154,7 +154,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationDisocclusionMaskPass : FrameInterpolationPass + internal sealed class FrameInterpolationDisocclusionMaskPass : FrameInterpolationPass { public FrameInterpolationDisocclusionMaskPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -179,7 +179,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationInterpolationPass : FrameInterpolationPass + internal sealed class FrameInterpolationInterpolationPass : FrameInterpolationPass { public FrameInterpolationInterpolationPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -207,7 +207,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationInpaintingPyramidPass : FrameInterpolationPass + internal sealed class FrameInterpolationInpaintingPyramidPass : FrameInterpolationPass { private readonly ComputeBuffer _spdConstants; @@ -239,7 +239,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationInpaintingPass : FrameInterpolationPass + internal sealed class FrameInterpolationInpaintingPass : FrameInterpolationPass { public FrameInterpolationInpaintingPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) @@ -262,7 +262,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationGameVectorFieldInpaintingPyramidPass : FrameInterpolationPass + internal sealed class FrameInterpolationGameVectorFieldInpaintingPyramidPass : FrameInterpolationPass { private readonly ComputeBuffer _spdConstants; @@ -295,7 +295,7 @@ namespace FidelityFX.FrameGen } } - internal class FrameInterpolationDebugViewPass : FrameInterpolationPass + internal sealed class FrameInterpolationDebugViewPass : FrameInterpolationPass { public FrameInterpolationDebugViewPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) diff --git a/Runtime/OpticalFlow/OpticalFlowPass.cs b/Runtime/OpticalFlow/OpticalFlowPass.cs index ef9ed85..fd7efbc 100644 --- a/Runtime/OpticalFlow/OpticalFlowPass.cs +++ b/Runtime/OpticalFlow/OpticalFlowPass.cs @@ -6,49 +6,34 @@ using UnityEngine.Rendering; namespace FidelityFX.FrameGen { - internal abstract class OpticalFlowPass: IDisposable + internal abstract class OpticalFlowPass: FfxPassBase { protected readonly OpticalFlowResources Resources; protected readonly ComputeBuffer Constants; - protected ComputeShader ComputeShader; - protected int KernelIndex; - - private CustomSampler _sampler; - protected OpticalFlowPass(OpticalFlowResources resources, ComputeBuffer constants) + : base("Optical Flow") { Resources = resources; Constants = constants; } - - public virtual void Dispose() - { - } - public void ScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ = 1) + public void ScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ = 1) { - commandBuffer.BeginSample(_sampler); - DoScheduleDispatch(commandBuffer, dispatchParams, frameIndex, level, dispatchX, dispatchY, dispatchZ); - commandBuffer.EndSample(_sampler); + commandBuffer.BeginSample(Sampler); + DoScheduleDispatch(commandBuffer, dispatchParams, bufferIndex, level, dispatchX, dispatchY, dispatchZ); + commandBuffer.EndSample(Sampler); } + + protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ); - protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ); - - protected void InitComputeShader(string passName, ComputeShader shader) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int dispatchX, int dispatchY, int dispatchZ) { - if (shader == null) - { - throw new MissingReferenceException($"Shader for Optical Flow pass '{passName}' could not be loaded! Please ensure it is included in the project correctly."); - } - - ComputeShader = shader; - KernelIndex = ComputeShader.FindKernel("CS"); - _sampler = CustomSampler.Create(passName); + // Optical Flow dispatch requires an extra `level` parameter, so we don't use this overload } } - internal class OpticalFlowPrepareLumaPass : OpticalFlowPass + internal sealed class OpticalFlowPrepareLumaPass : OpticalFlowPass { public OpticalFlowPrepareLumaPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) @@ -56,11 +41,11 @@ namespace FidelityFX.FrameGen InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, dispatchParams.Color); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); @@ -68,7 +53,7 @@ namespace FidelityFX.FrameGen } } - internal class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass + internal sealed class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass { private ComputeBuffer _spdConstants; @@ -79,15 +64,15 @@ namespace FidelityFX.FrameGen InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][bufferIndex]); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants); @@ -96,7 +81,7 @@ namespace FidelityFX.FrameGen } } - internal class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass + internal sealed class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass { public OpticalFlowGenerateSCDHistogramPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) @@ -104,9 +89,9 @@ namespace FidelityFX.FrameGen InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); @@ -115,7 +100,7 @@ namespace FidelityFX.FrameGen } } - internal class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass + internal sealed class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass { public OpticalFlowComputeSCDDivergencePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) @@ -123,7 +108,7 @@ namespace FidelityFX.FrameGen InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram); @@ -136,7 +121,7 @@ namespace FidelityFX.FrameGen } } - internal class OpticalFlowComputePass : OpticalFlowPass + internal sealed class OpticalFlowComputePass : OpticalFlowPass { public OpticalFlowComputePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) @@ -144,12 +129,12 @@ namespace FidelityFX.FrameGen InitComputeShader("Optical Flow Search", contextDescription.shaders.computeOpticalFlow); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { - int levelIndex = frameIndex ^ (level & 1); + int levelIndex = bufferIndex ^ (level & 1); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][bufferIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex]); commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD); @@ -159,7 +144,7 @@ namespace FidelityFX.FrameGen } } - internal class OpticalFlowFilterPass : OpticalFlowPass + internal sealed class OpticalFlowFilterPass : OpticalFlowPass { public OpticalFlowFilterPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) @@ -167,9 +152,9 @@ namespace FidelityFX.FrameGen InitComputeShader("Optical Flow Filter", contextDescription.shaders.filterOpticalFlow); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { - int levelIndex = frameIndex ^ (level & 1); + int levelIndex = bufferIndex ^ (level & 1); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPrevious, Resources.OpticalFlowLevels[level][levelIndex]); @@ -189,7 +174,7 @@ namespace FidelityFX.FrameGen } } - internal class OpticalFlowScalePass : OpticalFlowPass + internal sealed class OpticalFlowScalePass : OpticalFlowPass { public OpticalFlowScalePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) @@ -197,15 +182,15 @@ namespace FidelityFX.FrameGen InitComputeShader("Optical Flow Scale", contextDescription.shaders.scaleOpticalFlow); } - protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) + protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { if (level <= 0) return; - int levelIndex = frameIndex ^ (level & 1); + int levelIndex = bufferIndex ^ (level & 1); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]); - commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][bufferIndex]); + commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][bufferIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowNextLevel, Resources.OpticalFlowLevels[level - 1][levelIndex ^ 1]);