Browse Source

Reworked passes to make the context description copy unnecessary, and made use of extension methods to simplify resource binding in Optical Flow passes.

fsr3framegen
Nico de Poel 2 years ago
parent
commit
b21b1b43c4
  1. 41
      Runtime/FrameInterpolation/FrameInterpolationPass.cs
  2. 65
      Runtime/OpticalFlow/OpticalFlowPass.cs

41
Runtime/FrameInterpolation/FrameInterpolationPass.cs

@ -5,34 +5,27 @@ namespace FidelityFX.FrameGen
{ {
internal abstract class FrameInterpolationPass: FfxPassWithFlags<FrameInterpolation.DispatchDescription, FrameInterpolation.InitializationFlags> internal abstract class FrameInterpolationPass: FfxPassWithFlags<FrameInterpolation.DispatchDescription, FrameInterpolation.InitializationFlags>
{ {
protected readonly FrameInterpolation.ContextDescription ContextDescription;
protected readonly FrameInterpolationResources Resources; protected readonly FrameInterpolationResources Resources;
protected readonly ComputeBuffer Constants; protected readonly ComputeBuffer Constants;
protected FrameInterpolationPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base("Frame Interpolation")
protected FrameInterpolationPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base("Frame Interpolation", contextDescription.flags)
{ {
ContextDescription = contextDescription;
Resources = resources; Resources = resources;
Constants = constants; Constants = constants;
} }
protected void InitComputeShader(string passName, ComputeShader shader)
protected override void SetupShaderKeywords()
{ {
InitComputeShader(passName, shader, ContextDescription.flags);
}
protected override void SetupShaderKeywords(FrameInterpolation.InitializationFlags flags)
{
if ((flags & FrameInterpolation.InitializationFlags.EnableDisplayResolutionMotionVectors) == 0) ComputeShader.EnableKeyword("FFX_FRAMEINTERPOLATION_OPTION_LOW_RES_MOTION_VECTORS");
if ((flags & FrameInterpolation.InitializationFlags.EnableJitterMotionVectors) != 0) ComputeShader.EnableKeyword("FFX_FRAMEINTERPOLATION_OPTION_JITTERED_MOTION_VECTORS");
if ((flags & FrameInterpolation.InitializationFlags.EnableDepthInverted) != 0) ComputeShader.EnableKeyword("FFX_FRAMEINTERPOLATION_OPTION_INVERTED_DEPTH");
if ((Flags & FrameInterpolation.InitializationFlags.EnableDisplayResolutionMotionVectors) == 0) ComputeShader.EnableKeyword("FFX_FRAMEINTERPOLATION_OPTION_LOW_RES_MOTION_VECTORS");
if ((Flags & FrameInterpolation.InitializationFlags.EnableJitterMotionVectors) != 0) ComputeShader.EnableKeyword("FFX_FRAMEINTERPOLATION_OPTION_JITTERED_MOTION_VECTORS");
if ((Flags & FrameInterpolation.InitializationFlags.EnableDepthInverted) != 0) ComputeShader.EnableKeyword("FFX_FRAMEINTERPOLATION_OPTION_INVERTED_DEPTH");
} }
} }
internal class FrameInterpolationReconstructAndDilatePass : FrameInterpolationPass internal class FrameInterpolationReconstructAndDilatePass : FrameInterpolationPass
{ {
public FrameInterpolationReconstructAndDilatePass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationReconstructAndDilatePass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Reconstruct and Dilate", contextDescription.shaders.reconstructAndDilate); InitComputeShader("Reconstruct and Dilate", contextDescription.shaders.reconstructAndDilate);
@ -63,7 +56,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationSetupPass : FrameInterpolationPass internal class FrameInterpolationSetupPass : FrameInterpolationPass
{ {
public FrameInterpolationSetupPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationSetupPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Setup", contextDescription.shaders.setup); InitComputeShader("Setup", contextDescription.shaders.setup);
@ -88,7 +81,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationReconstructPreviousDepthPass : FrameInterpolationPass internal class FrameInterpolationReconstructPreviousDepthPass : FrameInterpolationPass
{ {
public FrameInterpolationReconstructPreviousDepthPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationReconstructPreviousDepthPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Reconstruct Previous Depth", contextDescription.shaders.reconstructPreviousDepth); InitComputeShader("Reconstruct Previous Depth", contextDescription.shaders.reconstructPreviousDepth);
@ -110,7 +103,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationGameMotionVectorFieldPass : FrameInterpolationPass internal class FrameInterpolationGameMotionVectorFieldPass : FrameInterpolationPass
{ {
public FrameInterpolationGameMotionVectorFieldPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationGameMotionVectorFieldPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Game Motion Vector Field", contextDescription.shaders.gameMotionVectorField); InitComputeShader("Game Motion Vector Field", contextDescription.shaders.gameMotionVectorField);
@ -134,7 +127,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationOpticalFlowVectorFieldPass : FrameInterpolationPass internal class FrameInterpolationOpticalFlowVectorFieldPass : FrameInterpolationPass
{ {
public FrameInterpolationOpticalFlowVectorFieldPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationOpticalFlowVectorFieldPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Optical Flow Vector Field", contextDescription.shaders.opticalFlowVectorField); InitComputeShader("Optical Flow Vector Field", contextDescription.shaders.opticalFlowVectorField);
@ -163,7 +156,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationDisocclusionMaskPass : FrameInterpolationPass internal class FrameInterpolationDisocclusionMaskPass : FrameInterpolationPass
{ {
public FrameInterpolationDisocclusionMaskPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationDisocclusionMaskPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Disocclusion Mask", contextDescription.shaders.disocclusionMask); InitComputeShader("Disocclusion Mask", contextDescription.shaders.disocclusionMask);
@ -188,7 +181,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationInterpolationPass : FrameInterpolationPass internal class FrameInterpolationInterpolationPass : FrameInterpolationPass
{ {
public FrameInterpolationInterpolationPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationInterpolationPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Interpolation", contextDescription.shaders.interpolation); InitComputeShader("Interpolation", contextDescription.shaders.interpolation);
@ -218,7 +211,7 @@ namespace FidelityFX.FrameGen
{ {
private readonly ComputeBuffer _spdConstants; private readonly ComputeBuffer _spdConstants;
public FrameInterpolationInpaintingPyramidPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
public FrameInterpolationInpaintingPyramidPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
_spdConstants = spdConstants; _spdConstants = spdConstants;
@ -248,7 +241,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationInpaintingPass : FrameInterpolationPass internal class FrameInterpolationInpaintingPass : FrameInterpolationPass
{ {
public FrameInterpolationInpaintingPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationInpaintingPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Inpainting", contextDescription.shaders.inpainting); InitComputeShader("Inpainting", contextDescription.shaders.inpainting);
@ -273,7 +266,7 @@ namespace FidelityFX.FrameGen
{ {
private readonly ComputeBuffer _spdConstants; private readonly ComputeBuffer _spdConstants;
public FrameInterpolationGameVectorFieldInpaintingPyramidPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
public FrameInterpolationGameVectorFieldInpaintingPyramidPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
_spdConstants = spdConstants; _spdConstants = spdConstants;
@ -304,7 +297,7 @@ namespace FidelityFX.FrameGen
internal class FrameInterpolationDebugViewPass : FrameInterpolationPass internal class FrameInterpolationDebugViewPass : FrameInterpolationPass
{ {
public FrameInterpolationDebugViewPass(FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
public FrameInterpolationDebugViewPass(in FrameInterpolation.ContextDescription contextDescription, FrameInterpolationResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants) : base(contextDescription, resources, constants)
{ {
InitComputeShader("Debug View", contextDescription.shaders.debugView); InitComputeShader("Debug View", contextDescription.shaders.debugView);

65
Runtime/OpticalFlow/OpticalFlowPass.cs

@ -8,7 +8,6 @@ namespace FidelityFX.FrameGen
{ {
internal abstract class OpticalFlowPass: IDisposable internal abstract class OpticalFlowPass: IDisposable
{ {
protected readonly OpticalFlow.ContextDescription ContextDescription;
protected readonly OpticalFlowResources Resources; protected readonly OpticalFlowResources Resources;
protected readonly ComputeBuffer Constants; protected readonly ComputeBuffer Constants;
@ -17,9 +16,8 @@ namespace FidelityFX.FrameGen
private CustomSampler _sampler; private CustomSampler _sampler;
protected OpticalFlowPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
protected OpticalFlowPass(OpticalFlowResources resources, ComputeBuffer constants)
{ {
ContextDescription = contextDescription;
Resources = resources; Resources = resources;
Constants = constants; Constants = constants;
} }
@ -52,20 +50,19 @@ namespace FidelityFX.FrameGen
internal class OpticalFlowPrepareLumaPass : OpticalFlowPass internal class OpticalFlowPrepareLumaPass : OpticalFlowPass
{ {
public OpticalFlowPrepareLumaPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
public OpticalFlowPrepareLumaPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{ {
InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma); InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma);
} }
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{ {
ref var color = ref dispatchParams.Color;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, color.RenderTarget, color.MipLevel, color.SubElement);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, dispatchParams.Color);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }
@ -75,8 +72,8 @@ namespace FidelityFX.FrameGen
{ {
private ComputeBuffer _spdConstants; private ComputeBuffer _spdConstants;
public OpticalFlowGenerateInputPyramidPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
: base(contextDescription, resources, constants)
public OpticalFlowGenerateInputPyramidPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
: base(resources, constants)
{ {
_spdConstants = spdConstants; _spdConstants = spdConstants;
InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid); InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid);
@ -92,8 +89,8 @@ namespace FidelityFX.FrameGen
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][frameIndex]);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants, 0, Marshal.SizeOf<OpticalFlow.SpdConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.SpdConstants>(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }
@ -101,19 +98,18 @@ namespace FidelityFX.FrameGen
internal class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass internal class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass
{ {
public OpticalFlowGenerateSCDHistogramPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
public OpticalFlowGenerateSCDHistogramPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{ {
InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram); InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram);
} }
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{ {
// TODO: probably needs to be input from this frame (result from previous passes), but double check to be sure
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }
@ -121,22 +117,20 @@ namespace FidelityFX.FrameGen
internal class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass internal class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass
{ {
public OpticalFlowComputeSCDDivergencePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
public OpticalFlowComputeSCDDivergencePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{ {
InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence); InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence);
} }
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{ {
ref var scdOutput = ref dispatchParams.OpticalFlowSCD;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, scdOutput.RenderTarget, scdOutput.MipLevel, scdOutput.SubElement);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }
@ -144,8 +138,8 @@ namespace FidelityFX.FrameGen
internal class OpticalFlowComputePass : OpticalFlowPass internal class OpticalFlowComputePass : OpticalFlowPass
{ {
public OpticalFlowComputePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
public OpticalFlowComputePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{ {
InitComputeShader("Optical Flow Search", contextDescription.shaders.computeOpticalFlow); InitComputeShader("Optical Flow Search", contextDescription.shaders.computeOpticalFlow);
} }
@ -153,14 +147,13 @@ namespace FidelityFX.FrameGen
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ) protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{ {
int levelIndex = frameIndex ^ (level & 1); int levelIndex = frameIndex ^ (level & 1);
ref var scdOutput = ref dispatchParams.OpticalFlowSCD;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, scdOutput.RenderTarget, scdOutput.MipLevel, scdOutput.SubElement);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }
@ -168,8 +161,8 @@ namespace FidelityFX.FrameGen
internal class OpticalFlowFilterPass : OpticalFlowPass internal class OpticalFlowFilterPass : OpticalFlowPass
{ {
public OpticalFlowFilterPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
public OpticalFlowFilterPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{ {
InitComputeShader("Optical Flow Filter", contextDescription.shaders.filterOpticalFlow); InitComputeShader("Optical Flow Filter", contextDescription.shaders.filterOpticalFlow);
} }
@ -183,15 +176,14 @@ namespace FidelityFX.FrameGen
if (level == 0) if (level == 0)
{ {
// Final output (levels are counted in reverse) // Final output (levels are counted in reverse)
ref var ofVector = ref dispatchParams.OpticalFlowVector;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, ofVector.RenderTarget, ofVector.MipLevel, ofVector.SubElement);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, dispatchParams.OpticalFlowVector);
} }
else else
{ {
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]);
} }
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }
@ -199,8 +191,8 @@ namespace FidelityFX.FrameGen
internal class OpticalFlowScalePass : OpticalFlowPass internal class OpticalFlowScalePass : OpticalFlowPass
{ {
public OpticalFlowScalePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
public OpticalFlowScalePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{ {
InitComputeShader("Optical Flow Scale", contextDescription.shaders.scaleOpticalFlow); InitComputeShader("Optical Flow Scale", contextDescription.shaders.scaleOpticalFlow);
} }
@ -211,16 +203,15 @@ namespace FidelityFX.FrameGen
return; return;
int levelIndex = frameIndex ^ (level & 1); int levelIndex = frameIndex ^ (level & 1);
ref var scdOutput = ref dispatchParams.OpticalFlowSCD;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowNextLevel, Resources.OpticalFlowLevels[level - 1][levelIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowNextLevel, Resources.OpticalFlowLevels[level - 1][levelIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, scdOutput.RenderTarget, scdOutput.MipLevel, scdOutput.SubElement);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
} }

Loading…
Cancel
Save