using System; using System.Runtime.InteropServices; using UnityEngine; using UnityEngine.Profiling; using UnityEngine.Rendering; namespace FidelityFX.OpticalFlow { internal abstract class OpticalFlowPass: IDisposable { protected readonly OpticalFlow.ContextDescription ContextDescription; protected readonly OpticalFlowResources Resources; protected readonly ComputeBuffer Constants; protected ComputeShader ComputeShader; protected int KernelIndex; private CustomSampler _sampler; protected OpticalFlowPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) { ContextDescription = contextDescription; Resources = resources; Constants = constants; } public virtual void Dispose() { } public void ScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ = 1) { commandBuffer.BeginSample(_sampler); DoScheduleDispatch(commandBuffer, dispatchParams, frameIndex, dispatchX, dispatchY, dispatchZ); commandBuffer.EndSample(_sampler); } protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ); protected void InitComputeShader(string passName, ComputeShader shader) { if (shader == null) { throw new MissingReferenceException($"Shader for Optical Flow pass '{passName}' could not be loaded! Please ensure it is included in the project correctly."); } ComputeShader = shader; KernelIndex = ComputeShader.FindKernel("CS"); _sampler = CustomSampler.Create(passName); } } internal class OpticalFlowPrepareLumaPass : OpticalFlowPass { public OpticalFlowPrepareLumaPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { ref var color = ref dispatchParams.color; commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, color.RenderTarget, color.MipLevel, color.SubElement); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass { private ComputeBuffer _spdConstants; public OpticalFlowGenerateInputPyramidPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants) : base(contextDescription, resources, constants) { _spdConstants = spdConstants; InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][frameIndex]); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants, 0, Marshal.SizeOf()); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass { public OpticalFlowGenerateSCDHistogramPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { // TODO: probably needs to be input from this frame (result from previous passes), but double check to be sure commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass { public OpticalFlowComputeSCDDivergencePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence); } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { ref var scdOutput = ref dispatchParams.opticalFlowSCD; commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp); commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, scdOutput.RenderTarget, scdOutput.MipLevel, scdOutput.SubElement); commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf()); commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ); } } internal class OpticalFlowComputePass : OpticalFlowPass { public OpticalFlowComputePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { throw new NotImplementedException(); } } internal class OpticalFlowFilterPass : OpticalFlowPass { public OpticalFlowFilterPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { throw new NotImplementedException(); } } internal class OpticalFlowScalePass : OpticalFlowPass { public OpticalFlowScalePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants) : base(contextDescription, resources, constants) { } protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ) { throw new NotImplementedException(); } } }