You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

228 lines
12 KiB

using System;
using UnityEngine;
using UnityEngine.Assertions;
using UnityEngine.Profiling;
using UnityEngine.Rendering;
namespace FidelityFX.FrameGen
{
public class OpticalFlowContext: FfxContextBase
{
private const int MaxQueuedFrames = 16;
public int BlockSize => OpticalFlow.MinBlockSize;
public Vector2Int TextureSize => OpticalFlow.GetOpticalFlowTextureSize(_contextDescription.resolution, OpticalFlow.MinBlockSize);
private OpticalFlow.ContextDescription _contextDescription;
private OpticalFlowPrepareLumaPass _prepareLumaPass;
private OpticalFlowGenerateInputPyramidPass _generateInputPyramidPass;
private OpticalFlowGenerateSCDHistogramPass _generateScdHistogramPass;
private OpticalFlowComputeSCDDivergencePass _computeScdDivergencePass;
private OpticalFlowComputePass _computeOpticalFlowPass;
private OpticalFlowFilterPass _filterOpticalFlowPass;
private OpticalFlowScalePass _scaleOpticalFlowPass;
private readonly OpticalFlowResources _resources = new OpticalFlowResources();
private readonly ConstantsBuffer<OpticalFlow.OpticalFlowConstants> _opticalFlowConstants = new ConstantsBuffer<OpticalFlow.OpticalFlowConstants>();
private readonly ConstantsBuffer<OpticalFlow.SpdConstants> _spdConstants = new ConstantsBuffer<OpticalFlow.SpdConstants>();
private bool _firstExecution;
private int _resourceFrameIndex;
private readonly Vector2Int[] _opticalFlowTextureSizes = new Vector2Int[OpticalFlow.OpticalFlowMaxPyramidLevels];
private readonly CustomSampler _sampler = CustomSampler.Create("Optical Flow");
public void Create(in OpticalFlow.ContextDescription contextDescription)
{
_contextDescription = contextDescription;
_opticalFlowConstants.Create();
_spdConstants.Create();
_firstExecution = true;
_resourceFrameIndex = 0;
_opticalFlowConstants.Value.inputLumaResolution = _contextDescription.resolution;
_resources.Create(_contextDescription);
CreatePasses();
}
private void CreatePasses()
{
_prepareLumaPass = new OpticalFlowPrepareLumaPass(_contextDescription, _resources, _opticalFlowConstants);
_generateInputPyramidPass = new OpticalFlowGenerateInputPyramidPass(_contextDescription, _resources, _opticalFlowConstants, _spdConstants);
_generateScdHistogramPass = new OpticalFlowGenerateSCDHistogramPass(_contextDescription, _resources, _opticalFlowConstants);
_computeScdDivergencePass = new OpticalFlowComputeSCDDivergencePass(_contextDescription, _resources, _opticalFlowConstants);
_computeOpticalFlowPass = new OpticalFlowComputePass(_contextDescription, _resources, _opticalFlowConstants);
_filterOpticalFlowPass = new OpticalFlowFilterPass(_contextDescription, _resources, _opticalFlowConstants);
_scaleOpticalFlowPass = new OpticalFlowScalePass(_contextDescription, _resources, _opticalFlowConstants);
}
public void Destroy()
{
DestroyPass(ref _scaleOpticalFlowPass);
DestroyPass(ref _filterOpticalFlowPass);
DestroyPass(ref _computeOpticalFlowPass);
DestroyPass(ref _computeScdDivergencePass);
DestroyPass(ref _generateScdHistogramPass);
DestroyPass(ref _generateInputPyramidPass);
DestroyPass(ref _prepareLumaPass);
_resources.Destroy();
_spdConstants.Destroy();
_opticalFlowConstants.Destroy();
}
public void Dispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchDescription)
{
const int advancedAlgorithmIterations = 7;
const int opticalFlowBlockSize = 8;
ref var constants = ref _opticalFlowConstants.Value;
constants.backbufferTransferFunction = (uint)dispatchDescription.backbufferTransferFunction;
constants.minMaxLuminance = dispatchDescription.minMaxLuminance;
int frameIndex = _resourceFrameIndex % 2;
bool resetAccumulation = dispatchDescription.reset || _firstExecution;
_firstExecution = false;
if (resetAccumulation)
constants.frameIndex = 0;
else
constants.frameIndex++;
commandBuffer.BeginSample(_sampler);
if (resetAccumulation)
{
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDTemp);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(dispatchDescription.opticalFlowSCD.RenderTarget);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDHistogram);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
commandBuffer.SetRenderTarget(_resources.OpticalFlowSCDPreviousHistogram);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
for (int i = 0; i < 2; ++i)
{
for (int level = 0; level < OpticalFlow.OpticalFlowMaxPyramidLevels; ++level)
{
commandBuffer.SetRenderTarget(_resources.OpticalFlowInputLevels[level][i]);
commandBuffer.ClearRenderTarget(false, true, Color.clear);
}
}
}
SetupSpdConstants(out var threadGroupSizeOpticalFlowInputPyramid);
_opticalFlowConstants.UpdateBufferData(commandBuffer);
_spdConstants.UpdateBufferData(commandBuffer);
{
const int threadGroupSizeX = 16;
const int threadGroupSizeY = 16;
const int threadPixelsX = 2;
const int threadPixelsY = 2;
int dispatchX = ((_contextDescription.resolution.x + (threadPixelsX - 1)) / threadPixelsX + (threadGroupSizeX - 1)) / threadGroupSizeX;
int dispatchY = ((_contextDescription.resolution.y + (threadPixelsY - 1)) / threadPixelsY + (threadGroupSizeY - 1)) / threadGroupSizeY;
_prepareLumaPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY);
}
_generateInputPyramidPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, threadGroupSizeOpticalFlowInputPyramid.x, threadGroupSizeOpticalFlowInputPyramid.y);
{
const int threadGroupSizeX = 32;
//const int threadGroupSizeY = 8;
int strataWidth = (_contextDescription.resolution.x / 4) / OpticalFlow.HistogramsPerDim;
int strataHeight = _contextDescription.resolution.y / OpticalFlow.HistogramsPerDim;
int dispatchX = (strataWidth + threadGroupSizeX - 1) / threadGroupSizeX;
const int dispatchY = 16;
const int dispatchZ = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim;
_generateScdHistogramPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY, dispatchZ);
}
{
const int dispatchX = OpticalFlow.HistogramsPerDim * OpticalFlow.HistogramsPerDim;
const int dispatchY = OpticalFlow.HistogramShifts;
_computeScdDivergencePass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, 0, dispatchX, dispatchY);
}
const int pyramidMaxIterations = advancedAlgorithmIterations;
Assert.IsTrue(pyramidMaxIterations <= OpticalFlow.OpticalFlowMaxPyramidLevels);
_opticalFlowTextureSizes[0] = OpticalFlow.GetOpticalFlowTextureSize(_contextDescription.resolution, opticalFlowBlockSize);
for (int i = 1; i < pyramidMaxIterations; ++i)
{
_opticalFlowTextureSizes[i] = new Vector2Int(
(_opticalFlowTextureSizes[i - 1].x + 1) / 2,
(_opticalFlowTextureSizes[i - 1].y + 1) / 2
);
}
for (int level = pyramidMaxIterations - 1; level >= 0; --level)
{
constants.opticalFlowPyramidLevel = (uint)level;
constants.opticalFlowPyramidLevelCount = pyramidMaxIterations;
_opticalFlowConstants.UpdateBufferData(commandBuffer);
{
int inputLumaWidth = Math.Max(_contextDescription.resolution.x >> level, 1);
int inputLumaHeight = Math.Max(_contextDescription.resolution.y >> level, 1);
const int threadPixels = 4;
Assert.IsTrue(opticalFlowBlockSize >= threadPixels);
//const int threadGroupSizeX = 4;
const int threadGroupSizeY = 16;
const int threadGroupSize = 64;
int dispatchX = ((inputLumaWidth + threadPixels - 1) / threadPixels * threadGroupSizeY + (threadGroupSize - 1)) / threadGroupSize;
int dispatchY = (inputLumaHeight + (threadGroupSizeY - 1)) / threadGroupSizeY;
_computeOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY);
}
{
int levelWidth = _opticalFlowTextureSizes[level].x;
int levelHeight = _opticalFlowTextureSizes[level].y;
const int threadGroupSizeX = 16;
const int threadGroupSizeY = 4;
int dispatchX = (levelWidth + threadGroupSizeX - 1) / threadGroupSizeX;
int dispatchY = (levelHeight + threadGroupSizeY - 1) / threadGroupSizeY;
_filterOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY);
}
if (level > 0)
{
int nextLevelWidth = _opticalFlowTextureSizes[level - 1].x;
int nextLevelHeight = _opticalFlowTextureSizes[level - 1].y;
const int threadGroupSizeX = opticalFlowBlockSize / 2;
const int threadGroupSizeY = opticalFlowBlockSize / 2;
//const int threadGroupSizeZ = 4;
int dispatchX = (nextLevelWidth + threadGroupSizeX - 1) / threadGroupSizeX;
int dispatchY = (nextLevelHeight + threadGroupSizeY - 1) / threadGroupSizeY;
const int dispatchZ = 1;
_scaleOpticalFlowPass.ScheduleDispatch(commandBuffer, dispatchDescription, frameIndex, level, dispatchX, dispatchY, dispatchZ);
}
}
commandBuffer.EndSample(_sampler);
_resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames;
}
private void SetupSpdConstants(out Vector2Int dispatchThreadGroupCount)
{
const int resolutionMultiplier = 1;
ref OpticalFlow.SpdConstants spdConstants = ref _spdConstants.Value;
FfxSpd.SetupSpdConstants(_contextDescription.resolution * resolutionMultiplier, ref spdConstants.spd, out dispatchThreadGroupCount);
spdConstants.numWorkGroupsOpticalFlowInputPyramid = spdConstants.spd.numWorkGroups;
}
}
}