using System; using System.Runtime.InteropServices; using UnityEngine; using UnityEngine.Profiling; using UnityEngine.Rendering; namespace FidelityFX.FrameGen { internal abstract class OpticalFlowPass: FfxPassBase { protected readonly OpticalFlowResources Resources; protected readonly ComputeBuffer Constants; protected OpticalFlowPass(OpticalFlowResources resources, ComputeBuffer constants) : base("Optical Flow") { Resources = resources; Constants = constants; } public void ScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ = 1) { commandBuffer.BeginSample(Sampler); DoScheduleDispatch(commandBuffer, dispatchParams, bufferIndex, level, dispatchX, dispatchY, dispatchZ); commandBuffer.EndSample(Sampler); } protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ); protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int dispatchX, int dispatchY, int dispatchZ) { // Optical Flow dispatch requires an extra `level` parameter, so we don't use this overload } } internal sealed class OpticalFlowPrepareLumaPass : OpticalFlowPass { public OpticalFlowPrepareLumaPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) { InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, dispatchParams.color); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal sealed class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass { private ComputeBuffer _spdConstants; public OpticalFlowGenerateInputPyramidPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants) : base(resources, constants) { _spdConstants = spdConstants; InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][bufferIndex]); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal sealed class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass { public OpticalFlowGenerateSCDHistogramPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) { InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal sealed class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass { public OpticalFlowComputeSCDDivergencePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) { InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp); commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.opticalFlowSCD); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal sealed class OpticalFlowComputePass : OpticalFlowPass { public OpticalFlowComputePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) { InitComputeShader("Optical Flow Search", contextDescription.shaders.computeOpticalFlow); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { int levelIndex = bufferIndex ^ (level & 1); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][bufferIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex]); commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.opticalFlowSCD); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal sealed class OpticalFlowFilterPass : OpticalFlowPass { public OpticalFlowFilterPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) { InitComputeShader("Optical Flow Filter", contextDescription.shaders.filterOpticalFlow); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { int levelIndex = bufferIndex ^ (level & 1); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPrevious, Resources.OpticalFlowLevels[level][levelIndex]); if (level == 0) { // Final output (levels are counted in reverse) commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, dispatchParams.opticalFlowVector); } else { commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]); } commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal sealed class OpticalFlowScalePass : OpticalFlowPass { public OpticalFlowScalePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(resources, constants) { InitComputeShader("Optical Flow Scale", contextDescription.shaders.scaleOpticalFlow); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ) { if (level <= 0) return; int levelIndex = bufferIndex ^ (level & 1); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][bufferIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][bufferIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowNextLevel, Resources.OpticalFlowLevels[level - 1][levelIndex ^ 1]); commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.opticalFlowSCD); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } }