You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
13 KiB
219 lines
13 KiB
using System;
|
|
using System.Runtime.InteropServices;
|
|
using UnityEngine;
|
|
using UnityEngine.Profiling;
|
|
using UnityEngine.Rendering;
|
|
|
|
namespace FidelityFX.FrameGen
|
|
{
|
|
internal abstract class OpticalFlowPass: IDisposable
|
|
{
|
|
protected readonly OpticalFlowResources Resources;
|
|
protected readonly ComputeBuffer Constants;
|
|
|
|
protected ComputeShader ComputeShader;
|
|
protected int KernelIndex;
|
|
|
|
private CustomSampler _sampler;
|
|
|
|
protected OpticalFlowPass(OpticalFlowResources resources, ComputeBuffer constants)
|
|
{
|
|
Resources = resources;
|
|
Constants = constants;
|
|
}
|
|
|
|
public virtual void Dispose()
|
|
{
|
|
}
|
|
|
|
public void ScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ = 1)
|
|
{
|
|
commandBuffer.BeginSample(_sampler);
|
|
DoScheduleDispatch(commandBuffer, dispatchParams, frameIndex, level, dispatchX, dispatchY, dispatchZ);
|
|
commandBuffer.EndSample(_sampler);
|
|
}
|
|
|
|
protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ);
|
|
|
|
protected void InitComputeShader(string passName, ComputeShader shader)
|
|
{
|
|
if (shader == null)
|
|
{
|
|
throw new MissingReferenceException($"Shader for Optical Flow pass '{passName}' could not be loaded! Please ensure it is included in the project correctly.");
|
|
}
|
|
|
|
ComputeShader = shader;
|
|
KernelIndex = ComputeShader.FindKernel("CS");
|
|
_sampler = CustomSampler.Create(passName);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowPrepareLumaPass : OpticalFlowPass
|
|
{
|
|
public OpticalFlowPrepareLumaPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
|
|
: base(resources, constants)
|
|
{
|
|
InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, dispatchParams.Color);
|
|
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass
|
|
{
|
|
private ComputeBuffer _spdConstants;
|
|
|
|
public OpticalFlowGenerateInputPyramidPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
|
|
: base(resources, constants)
|
|
{
|
|
_spdConstants = spdConstants;
|
|
InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][frameIndex]);
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.SpdConstants>(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass
|
|
{
|
|
public OpticalFlowGenerateSCDHistogramPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
|
|
: base(resources, constants)
|
|
{
|
|
InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass
|
|
{
|
|
public OpticalFlowComputeSCDDivergencePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
|
|
: base(resources, constants)
|
|
{
|
|
InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp);
|
|
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowComputePass : OpticalFlowPass
|
|
{
|
|
public OpticalFlowComputePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
|
|
: base(resources, constants)
|
|
{
|
|
InitComputeShader("Optical Flow Search", contextDescription.shaders.computeOpticalFlow);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
int levelIndex = frameIndex ^ (level & 1);
|
|
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex]);
|
|
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowFilterPass : OpticalFlowPass
|
|
{
|
|
public OpticalFlowFilterPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
|
|
: base(resources, constants)
|
|
{
|
|
InitComputeShader("Optical Flow Filter", contextDescription.shaders.filterOpticalFlow);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
int levelIndex = frameIndex ^ (level & 1);
|
|
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPrevious, Resources.OpticalFlowLevels[level][levelIndex]);
|
|
|
|
if (level == 0)
|
|
{
|
|
// Final output (levels are counted in reverse)
|
|
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, dispatchParams.OpticalFlowVector);
|
|
}
|
|
else
|
|
{
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]);
|
|
}
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
|
|
internal class OpticalFlowScalePass : OpticalFlowPass
|
|
{
|
|
public OpticalFlowScalePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
|
|
: base(resources, constants)
|
|
{
|
|
InitComputeShader("Optical Flow Scale", contextDescription.shaders.scaleOpticalFlow);
|
|
}
|
|
|
|
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
|
|
{
|
|
if (level <= 0)
|
|
return;
|
|
|
|
int levelIndex = frameIndex ^ (level & 1);
|
|
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][frameIndex]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][frameIndex ^ 1]);
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]);
|
|
|
|
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowNextLevel, Resources.OpticalFlowLevels[level - 1][levelIndex ^ 1]);
|
|
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
|
|
|
|
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
|
|
|
|
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
|
|
}
|
|
}
|
|
}
|