You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
275 lines
14 KiB
275 lines
14 KiB
using System;
|
|
using System.Runtime.InteropServices;
|
|
using UnityEngine;
|
|
using UnityEngine.Assertions;
|
|
using UnityEngine.Profiling;
|
|
using UnityEngine.Rendering;
|
|
|
|
namespace FidelityFX.FrameGen
|
|
{
|
|
public class OpticalFlowContext
|
|
{
|
|
private const int MaxQueuedFrames = 16;
|
|
|
|
private OpticalFlow.ContextDescription _contextDescription;
|
|
|
|
private OpticalFlowPass _prepareLumaPass;
|
|
private OpticalFlowPass _generateInputPyramidPass;
|
|
private OpticalFlowPass _generateScdHistogramPass;
|
|
private OpticalFlowPass _computeScdDivergencePass;
|
|
private OpticalFlowPass _computeOpticalFlowPass;
|
|
private OpticalFlowPass _filterOpticalFlowPass;
|
|
private OpticalFlowPass _scaleOpticalFlowPass;
|
|
|
|
private readonly OpticalFlowResources _resources = new OpticalFlowResources();
|
|
|
|
private ComputeBuffer _opticalFlowConstantsBuffer;
|
|
private readonly OpticalFlow.OpticalFlowConstants[] _opticalFlowConstantsArray = { new OpticalFlow.OpticalFlowConstants() };
|
|
private ref OpticalFlow.OpticalFlowConstants Constants => ref _opticalFlowConstantsArray[0];
|
|
|
|
private ComputeBuffer _spdConstantsBuffer;
|
|
private readonly OpticalFlow.SpdConstants[] _spdConstantsArray = { new OpticalFlow.SpdConstants() };
|
|
private ref OpticalFlow.SpdConstants SpdConsts => ref _spdConstantsArray[0];
|
|
|
|
private bool _firstExecution;
|
|
private int _resourceFrameIndex;
|
|
private readonly Vector2Int[] _opticalFlowTextureSizes = new Vector2Int[OpticalFlow.OpticalFlowMaxPyramidLevels];
|
|
private readonly CustomSampler _sampler = CustomSampler.Create("Optical Flow");
|
|
|
|
public void Create(OpticalFlow.ContextDescription contextDescription)
|
|
{
|
|
_contextDescription = contextDescription;
|
|
|
|
_opticalFlowConstantsBuffer = CreateConstantBuffer<OpticalFlow.OpticalFlowConstants>();
|
|
_spdConstantsBuffer = CreateConstantBuffer<OpticalFlow.SpdConstants>();
|
|
|
|
_firstExecution = true;
|
|
_resourceFrameIndex = 0;
|
|
|
|
Constants.inputLumaResolution = _contextDescription.resolution;
|
|
|
|
_resources.Create(_contextDescription);
|
|
CreatePasses();
|
|
}
|
|
|
|
private void CreatePasses()
|
|
{
|
|
_prepareLumaPass = new OpticalFlowPrepareLumaPass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
|
|
_generateInputPyramidPass = new OpticalFlowGenerateInputPyramidPass(_contextDescription, _resources, _opticalFlowConstantsBuffer, _spdConstantsBuffer);
|
|
_generateScdHistogramPass = new OpticalFlowGenerateSCDHistogramPass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
|
|
_computeScdDivergencePass = new OpticalFlowComputeSCDDivergencePass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
|
|
_computeOpticalFlowPass = new OpticalFlowComputePass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
|
|
_filterOpticalFlowPass = new OpticalFlowFilterPass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
|
|
_scaleOpticalFlowPass = new OpticalFlowScalePass(_contextDescription, _resources, _opticalFlowConstantsBuffer);
|
|
}
|
|
|
|
public void Destroy()
|
|
{
|
|
DestroyPass(ref _scaleOpticalFlowPass);
|
|
DestroyPass(ref _filterOpticalFlowPass);
|
|
DestroyPass(ref _computeOpticalFlowPass);
|
|
DestroyPass(ref _computeScdDivergencePass);
|
|
DestroyPass(ref _generateScdHistogramPass);
|
|
DestroyPass(ref _generateInputPyramidPass);
|
|
DestroyPass(ref _prepareLumaPass);
|
|
|
|
_resources.Destroy();
|
|
|
|
DestroyConstantBuffer(ref _spdConstantsBuffer);
|
|
DestroyConstantBuffer(ref _opticalFlowConstantsBuffer);
|
|
}
|
|
|
|
public void Dispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchDescription)
|
|
{
|
|
const int advancedAlgorithmIterations = 7;
|
|
const int opticalFlowBlockSize = 8;
|
|
|
|
Constants.backbufferTransferFunction = (uint)dispatchDescription.BackbufferTransferFunction;
|
|
Constants.minMaxLuminance = dispatchDescription.MinMaxLuminance;
|
|
|
|
int frameIndex = _resourceFrameIndex % 2;
|
|
bool resetAccumulation = dispatchDescription.Reset || _firstExecution;
|
|
_firstExecution = false;
|
|
|
|
if (resetAccumulation)
|
|
Constants.frameIndex = 0;
|
|
else
|
|
Constants.frameIndex++;
|
|
|
|
commandBuffer.BeginSample(_sampler);
|
|
|
|
if (resetAccumulation)
|
|
{
|
|
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDTemp);
|
|
commandBuffer.ClearRenderTarget(false, true, Color.clear);
|
|
commandBuffer.SetRenderTarget(dispatchDescription.OpticalFlowSCD.RenderTarget);
|
|
commandBuffer.ClearRenderTarget(false, true, Color.clear);
|
|
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDHistogram);
|
|
commandBuffer.ClearRenderTarget(false, true, Color.clear);
|
|
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDPreviousHistogram);
|
|
commandBuffer.ClearRenderTarget(false, true, Color.clear);
|
|
for (int i = 0; i < 2; ++i)
|
|
{
|
|
for (int level = 0; level < OpticalFlow.OpticalFlowMaxPyramidLevels; ++level)
|
|
{
|
|
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevels[level][i]);
|
|
commandBuffer.ClearRenderTarget(false, true, Color.clear);
|
|
}
|
|
}
|
|
}
|
|
|
|
SetupSpdConstants(out var threadGroupSizeOpticalFlowInputPyramid);
|
|
|
|
commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray);
|
|
commandBuffer.SetBufferData(_spdConstantsBuffer, _spdConstantsArray);
|
|
|
|
{
|
|
const int threadGroupSizeX = 16;
|
|
const int threadGroupSizeY = 16;
|
|
const int threadPixelsX = 2;
|
|
const int threadPixelsY = 2;
|
|
int dispatchX = ((_contextDescription.resolution.x + (threadPixelsX - 1)) / threadPixelsX + (threadGroupSizeX - 1)) / threadGroupSizeX;
|
|
int dispatchY = ((_contextDescription.resolution.y + (threadPixelsY - 1)) / threadPixelsY + (threadGroupSizeY - 1)) / threadGroupSizeY;
|
|
_prepareLumaPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY);
|
|
}
|
|
|
|
_generateInputPyramidPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, threadGroupSizeOpticalFlowInputPyramid.x, threadGroupSizeOpticalFlowInputPyramid.y);
|
|
|
|
{
|
|
const int threadGroupSizeX = 32;
|
|
const int threadGroupSizeY = 8;
|
|
int strataWidth = (_contextDescription.resolution.x / 4) / OpticalFlow.HistogramsPerDim;
|
|
int strataHeight = _contextDescription.resolution.y / OpticalFlow.HistogramsPerDim;
|
|
int dispatchX = (strataWidth + threadGroupSizeX - 1) / threadGroupSizeX;
|
|
const int dispatchY = 16;
|
|
const int dispatchZ = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim;
|
|
_generateScdHistogramPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
{
|
|
const int dispatchX = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim;
|
|
const int dispatchY = OpticalFlow.HistogramShifts;
|
|
_computeScdDivergencePass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY);
|
|
}
|
|
|
|
const int pyramidMaxIterations = advancedAlgorithmIterations;
|
|
Assert.IsTrue(pyramidMaxIterations <= OpticalFlow.OpticalFlowMaxPyramidLevels);
|
|
|
|
_opticalFlowTextureSizes[0] = OpticalFlow.GetOpticalFlowTextureSize(_contextDescription.resolution, opticalFlowBlockSize);
|
|
for (int i = 1; i < pyramidMaxIterations; ++i)
|
|
{
|
|
_opticalFlowTextureSizes[i] = new Vector2Int(
|
|
(_opticalFlowTextureSizes[i - 1].x + 1) / 2,
|
|
(_opticalFlowTextureSizes[i - 1].y + 1) / 2
|
|
);
|
|
}
|
|
|
|
for (int level = pyramidMaxIterations - 1; level >= 0; --level)
|
|
{
|
|
Constants.opticalFlowPyramidLevel = (uint)level;
|
|
Constants.opticalFlowPyramidLevelCount = pyramidMaxIterations;
|
|
commandBuffer.SetBufferData(_opticalFlowConstantsBuffer, _opticalFlowConstantsArray);
|
|
|
|
{
|
|
int inputLumaWidth = Math.Max(_contextDescription.resolution.x >> level, 1);
|
|
int inputLumaHeight = Math.Max(_contextDescription.resolution.y >> level, 1);
|
|
|
|
const int threadPixels = 4;
|
|
Assert.IsTrue(opticalFlowBlockSize >= threadPixels);
|
|
const int threadGroupSizeX = 4;
|
|
const int threadGroupSizeY = 16;
|
|
const int threadGroupSize = 64;
|
|
int dispatchX = ((inputLumaWidth + threadPixels - 1) / threadPixels * threadGroupSizeY + (threadGroupSize - 1)) / threadGroupSize;
|
|
int dispatchY = (inputLumaHeight + (threadGroupSizeY - 1)) / threadGroupSizeY;
|
|
|
|
_computeOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY);
|
|
}
|
|
|
|
{
|
|
int levelWidth = _opticalFlowTextureSizes[level].x;
|
|
int levelHeight = _opticalFlowTextureSizes[level].y;
|
|
|
|
const int threadGroupSizeX = 16;
|
|
const int threadGroupSizeY = 4;
|
|
int dispatchX = (levelWidth + threadGroupSizeX - 1) / threadGroupSizeX;
|
|
int dispatchY = (levelHeight + threadGroupSizeY - 1) / threadGroupSizeY;
|
|
|
|
_filterOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY);
|
|
}
|
|
|
|
if (level > 0)
|
|
{
|
|
int nextLevelWidth = _opticalFlowTextureSizes[level - 1].x;
|
|
int nextLevelHeight = _opticalFlowTextureSizes[level - 1].y;
|
|
|
|
const int threadGroupSizeX = opticalFlowBlockSize / 2;
|
|
const int threadGroupSizeY = opticalFlowBlockSize / 2;
|
|
const int threadGroupSizeZ = 4;
|
|
int dispatchX = (nextLevelWidth + threadGroupSizeX - 1) / threadGroupSizeX;
|
|
int dispatchY = (nextLevelHeight + threadGroupSizeY - 1) / threadGroupSizeY;
|
|
const int dispatchZ = 1;
|
|
|
|
_scaleOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
commandBuffer.EndSample(_sampler);
|
|
|
|
_resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames;
|
|
}
|
|
|
|
private void SetupSpdConstants(out Vector2Int dispatchThreadGroupCount)
|
|
{
|
|
const int resolutionMultiplier = 1;
|
|
|
|
RectInt rectInfo = new RectInt(0, 0, _contextDescription.resolution.x * resolutionMultiplier, _contextDescription.resolution.y * resolutionMultiplier);
|
|
SpdSetup(rectInfo, out dispatchThreadGroupCount, out var workGroupOffset, out var numWorkGroupsAndMips, 4);
|
|
|
|
ref OpticalFlow.SpdConstants spdConstants = ref SpdConsts;
|
|
spdConstants.numWorkGroups = (uint)numWorkGroupsAndMips.x;
|
|
spdConstants.mips = (uint)numWorkGroupsAndMips.y;
|
|
spdConstants.workGroupOffsetX = (uint)workGroupOffset.x;
|
|
spdConstants.workGroupOffsetY = (uint)workGroupOffset.y;
|
|
spdConstants.numWorkGroupsOpticalFlowInputPyramid = (uint)numWorkGroupsAndMips.x;
|
|
}
|
|
|
|
private static void SpdSetup(RectInt rectInfo, out Vector2Int dispatchThreadGroupCount, out Vector2Int workGroupOffset, out Vector2Int numWorkGroupsAndMips, int mips = -1)
|
|
{
|
|
workGroupOffset = new Vector2Int(rectInfo.x / 64, rectInfo.y / 64);
|
|
|
|
int endIndexX = (rectInfo.x + rectInfo.width - 1) / 64;
|
|
int endIndexY = (rectInfo.y + rectInfo.height - 1) / 64;
|
|
|
|
dispatchThreadGroupCount = new Vector2Int(endIndexX + 1 - workGroupOffset.x, endIndexY + 1 - workGroupOffset.y);
|
|
|
|
numWorkGroupsAndMips = new Vector2Int(dispatchThreadGroupCount.x * dispatchThreadGroupCount.y, mips);
|
|
if (mips < 0)
|
|
{
|
|
float resolution = Math.Max(rectInfo.width, rectInfo.height);
|
|
numWorkGroupsAndMips.y = Math.Min(Mathf.FloorToInt(Mathf.Log(resolution, 2.0f)), 12);
|
|
}
|
|
}
|
|
|
|
private static ComputeBuffer CreateConstantBuffer<TConstants>() where TConstants: struct
|
|
{
|
|
return new ComputeBuffer(1, Marshal.SizeOf<TConstants>(), ComputeBufferType.Constant);
|
|
}
|
|
|
|
private static void DestroyConstantBuffer(ref ComputeBuffer bufferRef)
|
|
{
|
|
if (bufferRef == null)
|
|
return;
|
|
|
|
bufferRef.Release();
|
|
bufferRef = null;
|
|
}
|
|
|
|
private static void DestroyPass(ref OpticalFlowPass pass)
|
|
{
|
|
if (pass == null)
|
|
return;
|
|
|
|
pass.Dispose();
|
|
pass = null;
|
|
}
|
|
}
|
|
}
|