using System; using System.Runtime.InteropServices; using UnityEngine; using UnityEngine.Assertions; using UnityEngine.Profiling; using UnityEngine.Rendering; namespace FidelityFX.FrameGen { public class OpticalFlowContext: FfxContextBase { private const int MaxQueuedFrames = 16; private OpticalFlow.ContextDescription _contextDescription; private OpticalFlowPass _prepareLumaPass; private OpticalFlowPass _generateInputPyramidPass; private OpticalFlowPass _generateScdHistogramPass; private OpticalFlowPass _computeScdDivergencePass; private OpticalFlowPass _computeOpticalFlowPass; private OpticalFlowPass _filterOpticalFlowPass; private OpticalFlowPass _scaleOpticalFlowPass; private readonly OpticalFlowResources _resources = new OpticalFlowResources(); private readonly ConstantsBuffer _opticalFlowConstants = new ConstantsBuffer(); private readonly ConstantsBuffer _spdConstants = new ConstantsBuffer(); private bool _firstExecution; private int _resourceFrameIndex; private readonly Vector2Int[] _opticalFlowTextureSizes = new Vector2Int[OpticalFlow.OpticalFlowMaxPyramidLevels]; private readonly CustomSampler _sampler = CustomSampler.Create("Optical Flow"); public void Create(in OpticalFlow.ContextDescription contextDescription) { _contextDescription = contextDescription; _opticalFlowConstants.Create(); _spdConstants.Create(); _firstExecution = true; _resourceFrameIndex = 0; _opticalFlowConstants.Value.inputLumaResolution = _contextDescription.resolution; _resources.Create(_contextDescription); CreatePasses(); } private void CreatePasses() { _prepareLumaPass = new OpticalFlowPrepareLumaPass(_contextDescription, _resources, _opticalFlowConstants); _generateInputPyramidPass = new OpticalFlowGenerateInputPyramidPass(_contextDescription, _resources, _opticalFlowConstants, _spdConstants); _generateScdHistogramPass = new OpticalFlowGenerateSCDHistogramPass(_contextDescription, _resources, _opticalFlowConstants); _computeScdDivergencePass = new OpticalFlowComputeSCDDivergencePass(_contextDescription, _resources, _opticalFlowConstants); _computeOpticalFlowPass = new OpticalFlowComputePass(_contextDescription, _resources, _opticalFlowConstants); _filterOpticalFlowPass = new OpticalFlowFilterPass(_contextDescription, _resources, _opticalFlowConstants); _scaleOpticalFlowPass = new OpticalFlowScalePass(_contextDescription, _resources, _opticalFlowConstants); } public void Destroy() { DestroyPass(ref _scaleOpticalFlowPass); DestroyPass(ref _filterOpticalFlowPass); DestroyPass(ref _computeOpticalFlowPass); DestroyPass(ref _computeScdDivergencePass); DestroyPass(ref _generateScdHistogramPass); DestroyPass(ref _generateInputPyramidPass); DestroyPass(ref _prepareLumaPass); _resources.Destroy(); _spdConstants.Destroy(); _opticalFlowConstants.Destroy(); } public void Dispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchDescription) { const int advancedAlgorithmIterations = 7; const int opticalFlowBlockSize = 8; ref var constants = ref _opticalFlowConstants.Value; constants.backbufferTransferFunction = (uint)dispatchDescription.BackbufferTransferFunction; constants.minMaxLuminance = dispatchDescription.MinMaxLuminance; int frameIndex = _resourceFrameIndex % 2; bool resetAccumulation = dispatchDescription.Reset || _firstExecution; _firstExecution = false; if (resetAccumulation) constants.frameIndex = 0; else constants.frameIndex++; commandBuffer.BeginSample(_sampler); if (resetAccumulation) { commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDTemp); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(dispatchDescription.OpticalFlowSCD.RenderTarget); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDHistogram); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDPreviousHistogram); commandBuffer.ClearRenderTarget(false, true, Color.clear); for (int i = 0; i < 2; ++i) { for (int level = 0; level < OpticalFlow.OpticalFlowMaxPyramidLevels; ++level) { commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevels[level][i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); } } } SetupSpdConstants(out var threadGroupSizeOpticalFlowInputPyramid); _opticalFlowConstants.UpdateBufferData(commandBuffer); _spdConstants.UpdateBufferData(commandBuffer); { const int threadGroupSizeX = 16; const int threadGroupSizeY = 16; const int threadPixelsX = 2; const int threadPixelsY = 2; int dispatchX = ((_contextDescription.resolution.x + (threadPixelsX - 1)) / threadPixelsX + (threadGroupSizeX - 1)) / threadGroupSizeX; int dispatchY = ((_contextDescription.resolution.y + (threadPixelsY - 1)) / threadPixelsY + (threadGroupSizeY - 1)) / threadGroupSizeY; _prepareLumaPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY); } _generateInputPyramidPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, threadGroupSizeOpticalFlowInputPyramid.x, threadGroupSizeOpticalFlowInputPyramid.y); { const int threadGroupSizeX = 32; const int threadGroupSizeY = 8; int strataWidth = (_contextDescription.resolution.x / 4) / OpticalFlow.HistogramsPerDim; int strataHeight = _contextDescription.resolution.y / OpticalFlow.HistogramsPerDim; int dispatchX = (strataWidth + threadGroupSizeX - 1) / threadGroupSizeX; const int dispatchY = 16; const int dispatchZ = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim; _generateScdHistogramPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY, dispatchZ); } { const int dispatchX = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim; const int dispatchY = OpticalFlow.HistogramShifts; _computeScdDivergencePass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY); } const int pyramidMaxIterations = advancedAlgorithmIterations; Assert.IsTrue(pyramidMaxIterations <= OpticalFlow.OpticalFlowMaxPyramidLevels); _opticalFlowTextureSizes[0] = OpticalFlow.GetOpticalFlowTextureSize(_contextDescription.resolution, opticalFlowBlockSize); for (int i = 1; i < pyramidMaxIterations; ++i) { _opticalFlowTextureSizes[i] = new Vector2Int( (_opticalFlowTextureSizes[i - 1].x + 1) / 2, (_opticalFlowTextureSizes[i - 1].y + 1) / 2 ); } for (int level = pyramidMaxIterations - 1; level >= 0; --level) { constants.opticalFlowPyramidLevel = (uint)level; constants.opticalFlowPyramidLevelCount = pyramidMaxIterations; _opticalFlowConstants.UpdateBufferData(commandBuffer); { int inputLumaWidth = Math.Max(_contextDescription.resolution.x >> level, 1); int inputLumaHeight = Math.Max(_contextDescription.resolution.y >> level, 1); const int threadPixels = 4; Assert.IsTrue(opticalFlowBlockSize >= threadPixels); const int threadGroupSizeX = 4; const int threadGroupSizeY = 16; const int threadGroupSize = 64; int dispatchX = ((inputLumaWidth + threadPixels - 1) / threadPixels * threadGroupSizeY + (threadGroupSize - 1)) / threadGroupSize; int dispatchY = (inputLumaHeight + (threadGroupSizeY - 1)) / threadGroupSizeY; _computeOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY); } { int levelWidth = _opticalFlowTextureSizes[level].x; int levelHeight = _opticalFlowTextureSizes[level].y; const int threadGroupSizeX = 16; const int threadGroupSizeY = 4; int dispatchX = (levelWidth + threadGroupSizeX - 1) / threadGroupSizeX; int dispatchY = (levelHeight + threadGroupSizeY - 1) / threadGroupSizeY; _filterOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY); } if (level > 0) { int nextLevelWidth = _opticalFlowTextureSizes[level - 1].x; int nextLevelHeight = _opticalFlowTextureSizes[level - 1].y; const int threadGroupSizeX = opticalFlowBlockSize / 2; const int threadGroupSizeY = opticalFlowBlockSize / 2; const int threadGroupSizeZ = 4; int dispatchX = (nextLevelWidth + threadGroupSizeX - 1) / threadGroupSizeX; int dispatchY = (nextLevelHeight + threadGroupSizeY - 1) / threadGroupSizeY; const int dispatchZ = 1; _scaleOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY, dispatchZ); } } commandBuffer.EndSample(_sampler); _resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames; } private void SetupSpdConstants(out Vector2Int dispatchThreadGroupCount) { const int resolutionMultiplier = 1; ref OpticalFlow.SpdConstants spdConstants = ref _spdConstants.Value; FfxSpd.SetupSpdConstants(_contextDescription.resolution * resolutionMultiplier, ref spdConstants.spd, out dispatchThreadGroupCount); spdConstants.numWorkGroupsOpticalFlowInputPyramid = spdConstants.spd.numWorkGroups; } private static void DestroyPass(ref OpticalFlowPass pass) { if (pass == null) return; pass.Dispose(); pass = null; } } }