You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

183 lines
10 KiB

using System;
using System.Runtime.InteropServices;
using UnityEngine;
using UnityEngine.Profiling;
using UnityEngine.Rendering;
namespace FidelityFX.OpticalFlow
{
internal abstract class OpticalFlowPass: IDisposable
{
protected readonly OpticalFlow.ContextDescription ContextDescription;
protected readonly OpticalFlowResources Resources;
protected readonly ComputeBuffer Constants;
protected ComputeShader ComputeShader;
protected int KernelIndex;
private CustomSampler _sampler;
protected OpticalFlowPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
{
ContextDescription = contextDescription;
Resources = resources;
Constants = constants;
}
public virtual void Dispose()
{
}
public void ScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ = 1)
{
commandBuffer.BeginSample(_sampler);
DoScheduleDispatch(commandBuffer, dispatchParams, frameIndex, dispatchX, dispatchY, dispatchZ);
commandBuffer.EndSample(_sampler);
}
protected abstract void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ);
protected void InitComputeShader(string passName, ComputeShader shader)
{
if (shader == null)
{
throw new MissingReferenceException($"Shader for Optical Flow pass '{passName}' could not be loaded! Please ensure it is included in the project correctly.");
}
ComputeShader = shader;
KernelIndex = ComputeShader.FindKernel("CS");
_sampler = CustomSampler.Create(passName);
}
}
internal class OpticalFlowPrepareLumaPass : OpticalFlowPass
{
public OpticalFlowPrepareLumaPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
{
InitComputeShader("Prepare Luma", contextDescription.shaders.prepareLuma);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
ref var color = ref dispatchParams.color;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvInputColor, color.RenderTarget, color.MipLevel, color.SubElement);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal class OpticalFlowGenerateInputPyramidPass : OpticalFlowPass
{
private ComputeBuffer _spdConstants;
public OpticalFlowGenerateInputPyramidPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants, ComputeBuffer spdConstants)
: base(contextDescription, resources, constants)
{
_spdConstants = spdConstants;
InitComputeShader("Generate Optical Flow Input Pyramid", contextDescription.shaders.generateOpticalFlowInputPyramid);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel1, Resources.OpticalFlowInputLevels[1][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel2, Resources.OpticalFlowInputLevels[2][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel3, Resources.OpticalFlowInputLevels[3][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel4, Resources.OpticalFlowInputLevels[4][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel5, Resources.OpticalFlowInputLevels[5][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowInputLevel6, Resources.OpticalFlowInputLevels[6][frameIndex]);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbSpd, _spdConstants, 0, Marshal.SizeOf<OpticalFlow.SpdConstants>());
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal class OpticalFlowGenerateSCDHistogramPass : OpticalFlowPass
{
public OpticalFlowGenerateSCDHistogramPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
{
InitComputeShader("Generate SCD Histogram", contextDescription.shaders.generateScdHistogram);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
// TODO: probably needs to be input from this frame (result from previous passes), but double check to be sure
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.SrvOpticalFlowInput, Resources.OpticalFlowInputLevels[0][frameIndex]);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal class OpticalFlowComputeSCDDivergencePass : OpticalFlowPass
{
public OpticalFlowComputeSCDDivergencePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
{
InitComputeShader("Compute SCD Divergence", contextDescription.shaders.computeScdDivergence);
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
ref var scdOutput = ref dispatchParams.opticalFlowSCD;
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdHistogram, Resources.OpticalFlowSCDHistogram);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdPreviousHistogram, Resources.OpticalFlowSCDPreviousHistogram);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdTemp, Resources.OpticalFlowSCDTemp);
commandBuffer.SetComputeTextureParam(ComputeShader, KernelIndex, OpticalFlowShaderIDs.UavOpticalFlowScdOutput, scdOutput.RenderTarget, scdOutput.MipLevel, scdOutput.SubElement);
commandBuffer.SetComputeConstantBufferParam(ComputeShader, OpticalFlowShaderIDs.CbOpticalFlow, Constants, 0, Marshal.SizeOf<OpticalFlow.OpticalFlowConstants>());
commandBuffer.DispatchCompute(ComputeShader, KernelIndex, dispatchX, dispatchY, dispatchZ);
}
}
internal class OpticalFlowComputePass : OpticalFlowPass
{
public OpticalFlowComputePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
{
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
throw new NotImplementedException();
}
}
internal class OpticalFlowFilterPass : OpticalFlowPass
{
public OpticalFlowFilterPass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
{
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
throw new NotImplementedException();
}
}
internal class OpticalFlowScalePass : OpticalFlowPass
{
public OpticalFlowScalePass(OpticalFlow.ContextDescription contextDescription, OpticalFlowResources resources, ComputeBuffer constants)
: base(contextDescription, resources, constants)
{
}
protected override void DoScheduleDispatch(CommandBuffer commandBuffer, OpticalFlow.DispatchDescription dispatchParams, int frameIndex, int dispatchX, int dispatchY, int dispatchZ)
{
throw new NotImplementedException();
}
}
}