using System; using System.Runtime.InteropServices; using UnityEngine; using UnityEngine.Assertions; using UnityEngine.Rendering; namespace FidelityFX.OpticalFlow { public class OpticalFlowContext { private const int MaxQueuedFrames = 16; private OpticalFlow.ContextDescription _contextDescription; private OpticalFlowPass _prepareLumaPass; private OpticalFlowPass _generateInputPyramidPass; private OpticalFlowPass _generateScdHistogramPass; private OpticalFlowPass _computeScdDivergencePass; private OpticalFlowPass _computeOpticalFlowPass; private OpticalFlowPass _filterOpticalFlowPass; private OpticalFlowPass _scaleOpticalFlowPass; private readonly OpticalFlowResources _resources = new OpticalFlowResources(); private ComputeBuffer _opticalFlowConstantsBuffer; private readonly OpticalFlow.OpticalFlowConstants[] _opticalFlowConstantsArray = { new OpticalFlow.OpticalFlowConstants() }; private ref OpticalFlow.OpticalFlowConstants Constants => ref _opticalFlowConstantsArray[0]; private ComputeBuffer _spdConstantsBuffer; private readonly OpticalFlow.SpdConstants[] _spdConstantsArray = { new OpticalFlow.SpdConstants() }; private ref OpticalFlow.SpdConstants SpdConsts => ref _spdConstantsArray[0]; private bool _firstExecution; private int _resourceFrameIndex; private readonly Vector2Int[] _opticalFlowTextureSizes = new Vector2Int[OpticalFlowResources.OpticalFlowMaxPyramidLevels]; public void Create(OpticalFlow.ContextDescription contextDescription) { _contextDescription = contextDescription; _opticalFlowConstantsBuffer = CreateConstantBuffer(); _spdConstantsBuffer = CreateConstantBuffer(); _firstExecution = true; _resourceFrameIndex = 0; Constants.inputLumaResolution = _contextDescription.resolution; _resources.Create(_contextDescription); CreatePasses(); } private void CreatePasses() { _prepareLumaPass = new OpticalFlowPrepareLumaPass(_contextDescription, _resources, _opticalFlowConstantsBuffer); _generateInputPyramidPass = new OpticalFlowGenerateInputPyramidPass(_contextDescription, _resources, _opticalFlowConstantsBuffer, _spdConstantsBuffer); _generateScdHistogramPass = new OpticalFlowGenerateSCDHistogramPass(_contextDescription, _resources, _opticalFlowConstantsBuffer); _computeScdDivergencePass = new OpticalFlowComputeSCDDivergencePass(_contextDescription, _resources, _opticalFlowConstantsBuffer); _computeOpticalFlowPass = new OpticalFlowComputePass(_contextDescription, _resources, _opticalFlowConstantsBuffer); _filterOpticalFlowPass = new OpticalFlowFilterPass(_contextDescription, _resources, _opticalFlowConstantsBuffer); _scaleOpticalFlowPass = new OpticalFlowScalePass(_contextDescription, _resources, _opticalFlowConstantsBuffer); } public void Destroy() { DestroyPass(ref _scaleOpticalFlowPass); DestroyPass(ref _filterOpticalFlowPass); DestroyPass(ref _computeOpticalFlowPass); DestroyPass(ref _computeScdDivergencePass); DestroyPass(ref _generateScdHistogramPass); DestroyPass(ref _generateInputPyramidPass); DestroyPass(ref _prepareLumaPass); _resources.Destroy(); DestroyConstantBuffer(ref _spdConstantsBuffer); DestroyConstantBuffer(ref _opticalFlowConstantsBuffer); } public void Dispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchDescription) { const int advancedAlgorithmIterations = 7; const int opticalFlowBlockSize = 8; Constants.backbufferTransferFunction = (uint)dispatchDescription.backbufferTransferFunction; Constants.minMaxLuminance = dispatchDescription.minMaxLuminance; int frameIndex = _resourceFrameIndex % 2; bool resetAccumulation = dispatchDescription.reset || _firstExecution; _firstExecution = false; if (resetAccumulation) Constants.frameIndex = 0; else Constants.frameIndex++; if (resetAccumulation) { commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDTemp); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(dispatchDescription.opticalFlowSCD.RenderTarget); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDHistogram); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDPreviousHistogram); commandBuffer.ClearRenderTarget(false, true, Color.clear); for (int i = 0; i < 2; ++i) { commandBuffer.SetRenderTarget(_resources.OpticalFlowInput[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel1[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel2[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel3[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel4[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel5[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel6[i]); commandBuffer.ClearRenderTarget(false, true, Color.clear); } } SetupSpdConstants(out var threadGroupSizeOpticalFlowInputPyramid); commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray); commandBuffer.SetBufferData(_spdConstantsBuffer, _spdConstantsArray); { const int threadGroupSizeX = 16; const int threadGroupSizeY = 16; const int threadPixelsX = 2; const int threadPixelsY = 2; int dispatchX = ((_contextDescription.resolution.x + (threadPixelsX - 1)) / threadPixelsX + (threadGroupSizeX - 1)) / threadGroupSizeX; int dispatchY = ((_contextDescription.resolution.y + (threadPixelsY - 1)) / threadPixelsY + (threadGroupSizeY - 1)) / threadGroupSizeY; _prepareLumaPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY); } _generateInputPyramidPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, threadGroupSizeOpticalFlowInputPyramid.x, threadGroupSizeOpticalFlowInputPyramid.y); { const int threadGroupSizeX = 32; const int threadGroupSizeY = 8; int strataWidth = (_contextDescription.resolution.x / 4) / OpticalFlowResources.HistogramsPerDim; int strataHeight = _contextDescription.resolution.y / OpticalFlowResources.HistogramsPerDim; int dispatchX = (strataWidth + threadGroupSizeX - 1) / threadGroupSizeX; const int dispatchY = 16; const int dispatchZ = OpticalFlowResources.HistogramsPerDim * OpticalFlowResources.HistogramsPerDim; _generateScdHistogramPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY, dispatchZ); } { const int dispatchX = OpticalFlowResources.HistogramsPerDim * OpticalFlowResources.HistogramsPerDim; const int dispatchY = OpticalFlowResources.HistogramShifts; _computeScdDivergencePass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY); } const int pyramidMaxIterations = advancedAlgorithmIterations; Assert.IsTrue(pyramidMaxIterations <= OpticalFlowResources.OpticalFlowMaxPyramidLevels); _opticalFlowTextureSizes[0] = OpticalFlowResources.GetOpticalFlowTextureSize(_contextDescription.resolution, opticalFlowBlockSize); for (int i = 1; i < pyramidMaxIterations; ++i) { _opticalFlowTextureSizes[i] = new Vector2Int( (_opticalFlowTextureSizes[i - 1].x + 1) / 2, (_opticalFlowTextureSizes[i - 1].y + 1) / 2 ); } for (int level = pyramidMaxIterations - 1; level >= 0; --level) { // TODO: need to flip-flop Optical Flow output resources between levels as well (isOddFrame != isOddLevel) Constants.opticalFlowPyramidLevel = (uint)level; Constants.opticalFlowPyramidLevelCount = pyramidMaxIterations; commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray); // TODO: need to somehow address resources by level, not great when you've got a bazillion loose RT arrays... *sigh* } _resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames; } private void SetupSpdConstants(out Vector2Int dispatchThreadGroupCount) { const int resolutionMultiplier = 1; RectInt rectInfo = new RectInt(0, 0, _contextDescription.resolution.x * resolutionMultiplier, _contextDescription.resolution.y * resolutionMultiplier); SpdSetup(rectInfo, out dispatchThreadGroupCount, out var workGroupOffset, out var numWorkGroupsAndMips, 4); ref OpticalFlow.SpdConstants spdConstants = ref SpdConsts; spdConstants.numWorkGroups = (uint)numWorkGroupsAndMips.x; spdConstants.mips = (uint)numWorkGroupsAndMips.y; spdConstants.workGroupOffsetX = (uint)workGroupOffset.x; spdConstants.workGroupOffsetY = (uint)workGroupOffset.y; spdConstants.numWorkGroupsOpticalFlowInputPyramid = (uint)numWorkGroupsAndMips.x; } private static void SpdSetup(RectInt rectInfo, out Vector2Int dispatchThreadGroupCount, out Vector2Int workGroupOffset, out Vector2Int numWorkGroupsAndMips, int mips = -1) { workGroupOffset = new Vector2Int(rectInfo.x / 64, rectInfo.y / 64); int endIndexX = (rectInfo.x + rectInfo.width - 1) / 64; int endIndexY = (rectInfo.y + rectInfo.height - 1) / 64; dispatchThreadGroupCount = new Vector2Int(endIndexX + 1 - workGroupOffset.x, endIndexY + 1 - workGroupOffset.y); numWorkGroupsAndMips = new Vector2Int(dispatchThreadGroupCount.x * dispatchThreadGroupCount.y, mips); if (mips < 0) { float resolution = Math.Max(rectInfo.width, rectInfo.height); numWorkGroupsAndMips.y = Math.Min(Mathf.FloorToInt(Mathf.Log(resolution, 2.0f)), 12); } } private static ComputeBuffer CreateConstantBuffer() where TConstants: struct { return new ComputeBuffer(1, Marshal.SizeOf(), ComputeBufferType.Constant); } private static void DestroyConstantBuffer(ref ComputeBuffer bufferRef) { if (bufferRef == null) return; bufferRef.Release(); bufferRef = null; } private static void DestroyPass(ref OpticalFlowPass pass) { if (pass == null) return; pass.Dispose(); pass = null; } } }