You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

240 lines
12 KiB

using System;
using System.Runtime.InteropServices;
using UnityEngine;
using UnityEngine.Assertions;
using UnityEngine.Rendering;
namespace FidelityFX.OpticalFlow
{
public class OpticalFlowContext
{
private const int MaxQueuedFrames = 16;
private OpticalFlow.ContextDescription _contextDescription;
private OpticalFlowPass _prepareLumaPass;
private OpticalFlowPass _generateInputPyramidPass;
private OpticalFlowPass _generateScdHistogramPass;
private OpticalFlowPass _computeScdDivergencePass;
private OpticalFlowPass _computeOpticalFlowPass;
private OpticalFlowPass _filterOpticalFlowPass;
private OpticalFlowPass _scaleOpticalFlowPass;
private readonly OpticalFlowResources _resources = new OpticalFlowResources();
private ComputeBuffer _opticalFlowConstantsBuffer;
private readonly OpticalFlow.OpticalFlowConstants[] _opticalFlowConstantsArray = { new OpticalFlow.OpticalFlowConstants() };
private ref OpticalFlow.OpticalFlowConstants Constants => ref _opticalFlowConstantsArray[0];
private ComputeBuffer _spdConstantsBuffer;
private readonly OpticalFlow.SpdConstants[] _spdConstantsArray = { new OpticalFlow.SpdConstants() };
private ref OpticalFlow.SpdConstants SpdConsts => ref _spdConstantsArray[0];
private bool _firstExecution;
private int _resourceFrameIndex;
private readonly Vector2Int[] _opticalFlowTextureSizes = new Vector2Int[OpticalFlow.OpticalFlowMaxPyramidLevels];
public void Create(OpticalFlow.ContextDescription contextDescription)
{
_contextDescription = contextDescription;
_opticalFlowConstantsBuffer = CreateConstantBuffer<OpticalFlow.OpticalFlowConstants>();
_spdConstantsBuffer = CreateConstantBuffer<OpticalFlow.SpdConstants>();
_firstExecution = true;
_resourceFrameIndex = 0;
Constants.inputLumaResolution = _contextDescription.resolution;
_resources.Create(_contextDescription);
CreatePasses();
}
private void CreatePasses()
{
_prepareLumaPass = new OpticalFlowPrepareLumaPass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
_generateInputPyramidPass = new OpticalFlowGenerateInputPyramidPass(_contextDescription, _resources, _opticalFlowConstantsBuffer, _spdConstantsBuffer);
_generateScdHistogramPass = new OpticalFlowGenerateSCDHistogramPass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
_computeScdDivergencePass = new OpticalFlowComputeSCDDivergencePass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
_computeOpticalFlowPass = new OpticalFlowComputePass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
_filterOpticalFlowPass = new OpticalFlowFilterPass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
_scaleOpticalFlowPass = new OpticalFlowScalePass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
}
public void Destroy()
{
DestroyPass(ref _scaleOpticalFlowPass);
DestroyPass(ref _filterOpticalFlowPass);
DestroyPass(ref _computeOpticalFlowPass);
DestroyPass(ref _computeScdDivergencePass);
DestroyPass(ref _generateScdHistogramPass);
DestroyPass(ref _generateInputPyramidPass);
DestroyPass(ref _prepareLumaPass);
_resources.Destroy();
DestroyConstantBuffer(ref _spdConstantsBuffer);
DestroyConstantBuffer(ref _opticalFlowConstantsBuffer);
}
public void Dispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchDescription)
{
const int advancedAlgorithmIterations = 7;
const int opticalFlowBlockSize = 8;
Constants.backbufferTransferFunction = (uint)dispatchDescription.backbufferTransferFunction;
Constants.minMaxLuminance = dispatchDescription.minMaxLuminance;
int frameIndex = _resourceFrameIndex % 2;
bool resetAccumulation = dispatchDescription.reset || _firstExecution;
_firstExecution = false;
if (resetAccumulation)
Constants.frameIndex = 0;
else
Constants.frameIndex++;
if (resetAccumulation)
{
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDTemp);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(dispatchDescription.opticalFlowSCD.RenderTarget);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDHistogram);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDPreviousHistogram);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
for (int i = 0; i < 2; ++i)
{
commandBuffer.SetRenderTarget(_resources.OpticalFlowInput[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel1[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel2[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel3[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel4[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel5[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevel6[i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
}
}
SetupSpdConstants(out var threadGroupSizeOpticalFlowInputPyramid);
commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray);
commandBuffer.SetBufferData(_spdConstantsBuffer, _spdConstantsArray);
{
const int threadGroupSizeX = 16;
const int threadGroupSizeY = 16;
const int threadPixelsX = 2;
const int threadPixelsY = 2;
int dispatchX = ((_contextDescription.resolution.x + (threadPixelsX - 1)) / threadPixelsX + (threadGroupSizeX - 1)) / threadGroupSizeX;
int dispatchY = ((_contextDescription.resolution.y + (threadPixelsY - 1)) / threadPixelsY + (threadGroupSizeY - 1)) / threadGroupSizeY;
_prepareLumaPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY);
}
_generateInputPyramidPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, threadGroupSizeOpticalFlowInputPyramid.x, threadGroupSizeOpticalFlowInputPyramid.y);
{
const int threadGroupSizeX = 32;
const int threadGroupSizeY = 8;
int strataWidth = (_contextDescription.resolution.x / 4) / OpticalFlow.HistogramsPerDim;
int strataHeight = _contextDescription.resolution.y / OpticalFlow.HistogramsPerDim;
int dispatchX = (strataWidth + threadGroupSizeX - 1) / threadGroupSizeX;
const int dispatchY = 16;
const int dispatchZ = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim;
_generateScdHistogramPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY, dispatchZ);
}
{
const int dispatchX = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim;
const int dispatchY = OpticalFlow.HistogramShifts;
_computeScdDivergencePass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, dispatchX, dispatchY);
}
const int pyramidMaxIterations = advancedAlgorithmIterations;
Assert.IsTrue(pyramidMaxIterations <= OpticalFlow.OpticalFlowMaxPyramidLevels);
_opticalFlowTextureSizes[0] = OpticalFlow.GetOpticalFlowTextureSize(_contextDescription.resolution, opticalFlowBlockSize);
for (int i = 1; i < pyramidMaxIterations; ++i)
{
_opticalFlowTextureSizes[i] = new Vector2Int(
(_opticalFlowTextureSizes[i - 1].x + 1) / 2,
(_opticalFlowTextureSizes[i - 1].y + 1) / 2
);
}
for (int level = pyramidMaxIterations - 1; level >= 0; --level)
{
// TODO: need to flip-flop Optical Flow output resources between levels as well (isOddFrame != isOddLevel)
Constants.opticalFlowPyramidLevel = (uint)level;
Constants.opticalFlowPyramidLevelCount = pyramidMaxIterations;
commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray);
// TODO: need to somehow address resources by level, not great when you've got a bazillion loose RT arrays... *sigh*
}
_resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames;
}
private void SetupSpdConstants(out Vector2Int dispatchThreadGroupCount)
{
const int resolutionMultiplier = 1;
RectInt rectInfo = new RectInt(0, 0, _contextDescription.resolution.x * resolutionMultiplier, _contextDescription.resolution.y * resolutionMultiplier);
SpdSetup(rectInfo, out dispatchThreadGroupCount, out var workGroupOffset, out var numWorkGroupsAndMips, 4);
ref OpticalFlow.SpdConstants spdConstants = ref SpdConsts;
spdConstants.numWorkGroups = (uint)numWorkGroupsAndMips.x;
spdConstants.mips = (uint)numWorkGroupsAndMips.y;
spdConstants.workGroupOffsetX = (uint)workGroupOffset.x;
spdConstants.workGroupOffsetY = (uint)workGroupOffset.y;
spdConstants.numWorkGroupsOpticalFlowInputPyramid = (uint)numWorkGroupsAndMips.x;
}
private static void SpdSetup(RectInt rectInfo, out Vector2Int dispatchThreadGroupCount, out Vector2Int workGroupOffset, out Vector2Int numWorkGroupsAndMips, int mips = -1)
{
workGroupOffset = new Vector2Int(rectInfo.x / 64, rectInfo.y / 64);
int endIndexX = (rectInfo.x + rectInfo.width - 1) / 64;
int endIndexY = (rectInfo.y + rectInfo.height - 1) / 64;
dispatchThreadGroupCount = new Vector2Int(endIndexX + 1 - workGroupOffset.x, endIndexY + 1 - workGroupOffset.y);
numWorkGroupsAndMips = new Vector2Int(dispatchThreadGroupCount.x * dispatchThreadGroupCount.y, mips);
if (mips < 0)
{
float resolution = Math.Max(rectInfo.width, rectInfo.height);
numWorkGroupsAndMips.y = Math.Min(Mathf.FloorToInt(Mathf.Log(resolution, 2.0f)), 12);
}
}
private static ComputeBuffer CreateConstantBuffer<TConstants>() where TConstants: struct
{
return new ComputeBuffer(1, Marshal.SizeOf<TConstants>(), ComputeBufferType.Constant);
}
private static void DestroyConstantBuffer(ref ComputeBuffer bufferRef)
{
if (bufferRef == null)
return;
bufferRef.Release();
bufferRef = null;
}
private static void DestroyPass(ref OpticalFlowPass pass)
{
if (pass == null)
return;
pass.Dispose();
pass = null;
}
}
}