You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

204 lines
13 KiB

using System;
using System.Runtime.InteropServices;
using UnityEngine;
using UnityEngine.Profiling;
using UnityEngine.Rendering;
namespace FidelityFX.FrameGen
{
internal abstract class OpticalFlowPass: FfxPassBase<OpticalFlow.DispatchDescription>
{
protected readonly OpticalFlowResources Resources;
protected readonly ComputeBuffer Constants;
protected OpticalFlowPass(OpticalFlowResources resources, ComputeBuffer constants)
: base("Optical Flow")
{
Resources = resources;
Constants = constants;
}
public void ScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ = 1)
{
commandBuffer.BeginSample(Sampler);
DoScheduleDispatch(commandBuffer, dispatchParams, bufferIndex, level, dispatchX, dispatchY, dispatchZ);
commandBuffer.EndSample(Sampler);
}
protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ);
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int dispatchX, int dispatchY, int dispatchZ)
{
// Optical Flow dispatch requires an extra `level` parameter, so we don't use this overload
}
}
internal sealed class OpticalFlowPrepareLumaPass : OpticalFlowPass
{
public OpticalFlowPrepareLumaPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{
InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, dispatchParams.Color);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal sealed class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass
{
private ComputeBuffer _spdConstants;
public OpticalFlowGenerateInputPyramidPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
: base(resources, constants)
{
_spdConstants = spdConstants;
InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][bufferIndex]);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.SpdConstants>(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal sealed class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass
{
public OpticalFlowGenerateSCDHistogramPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{
InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal sealed class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass
{
public OpticalFlowComputeSCDDivergencePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{
InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal sealed class OpticalFlowComputePass : OpticalFlowPass
{
public OpticalFlowComputePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{
InitComputeShader("Optical Flow Search", contextDescription.shaders.computeOpticalFlow);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
int levelIndex = bufferIndex ^ (level & 1);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][bufferIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex]);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal sealed class OpticalFlowFilterPass : OpticalFlowPass
{
public OpticalFlowFilterPass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{
InitComputeShader("Optical Flow Filter", contextDescription.shaders.filterOpticalFlow);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
int levelIndex = bufferIndex ^ (level & 1);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPrevious, Resources.OpticalFlowLevels[level][levelIndex]);
if (level == 0)
{
// Final output (levels are counted in reverse)
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, dispatchParams.OpticalFlowVector);
}
else
{
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]);
}
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal sealed class OpticalFlowScalePass : OpticalFlowPass
{
public OpticalFlowScalePass(in OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(resources, constants)
{
InitComputeShader("Optical Flow Scale", contextDescription.shaders.scaleOpticalFlow);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, in OpticalFlow.DispatchDescription dispatchParams, int bufferIndex, int level, int dispatchX, int dispatchY, int dispatchZ)
{
if (level <= 0)
return;
int levelIndex = bufferIndex ^ (level & 1);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[level][bufferIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowPreviousInput, Resources.OpticalFlowInputLevels[level][bufferIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlow, Resources.OpticalFlowLevels[level][levelIndex ^ 1]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowNextLevel, Resources.OpticalFlowLevels[level - 1][levelIndex ^ 1]);
commandBuffer.SetComputeResourceParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, dispatchParams.OpticalFlowSCD);
commandBuffer.SetComputeConstantBufferParam<OpticalFlow.OpticalFlowConstants>(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants);
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
}