using System; using System.Collections.Generic; using System.Runtime.InteropServices; using UnityEngine; namespace FidelityFX { public class Fsr2Context { private const int MaxQueuedFrames = 16; private Fsr2.ContextDescription _contextDescription; private ComputeShader _prepareInputColorShader; private ComputeShader _depthClipShader; private ComputeShader _reconstructPreviousDepthShader; private ComputeShader _lockShader; private ComputeShader _accumulateShader; private ComputeShader _generateReactiveShader; private ComputeShader _rcasShader; private ComputeShader _computeLuminancePyramidShader; private ComputeShader _tcrAutogenShader; private ComputeBuffer _fsr2ConstantsBuffer; private readonly Fsr2Constants[] _fsr2ConstantsArray = { new Fsr2Constants() }; private ref Fsr2Constants Constants => ref _fsr2ConstantsArray[0]; private ComputeBuffer _spdConstantsBuffer; private readonly SpdConstants[] _spdConstantsArray = { new SpdConstants() }; private ref SpdConstants SpdConsts => ref _spdConstantsArray[0]; private ComputeBuffer _rcasConstantsBuffer; private readonly RcasConstants[] _rcasConstantsArray = new RcasConstants[1]; private ref RcasConstants RcasConsts => ref _rcasConstantsArray[0]; private ComputeBuffer _generateReactiveConstantsBuffer; private readonly GenerateReactiveConstants[] _generateReactiveConstantsArray = { new GenerateReactiveConstants() }; private ref GenerateReactiveConstants GenReactiveConsts => ref _generateReactiveConstantsArray[0]; private bool _firstExecution; private Vector2 _previousJitterOffset; private uint _resourceFrameIndex; public void Create(Fsr2.ContextDescription contextDescription) { _contextDescription = contextDescription; _fsr2ConstantsBuffer = CreateConstantBuffer(); _spdConstantsBuffer = CreateConstantBuffer(); _rcasConstantsBuffer = CreateConstantBuffer(); // Set defaults _firstExecution = true; _resourceFrameIndex = 0; Constants.displaySize = _contextDescription.DisplaySize; // Generate the data for the LUT const uint lanczos2LutWidth = 128; short[] lanczos2Weights = new short[lanczos2LutWidth]; for (uint currentLanczosWidthIndex = 0; currentLanczosWidthIndex < lanczos2LutWidth; ++currentLanczosWidthIndex) { float x = 2.0f * currentLanczosWidthIndex / (lanczos2LutWidth - 1); float y = Fsr2.Lanczos2(x); lanczos2Weights[currentLanczosWidthIndex] = (short)Mathf.Round(y * 32767.0f); } InitShaders(); // TODO: create resources, i.e. render textures used for intermediate results. // Note that "aliasable" resources should be equivalent to GetTemporary render textures // UAVs *may* be an issue with the PS4 not handling simultaneous reading and writing to an RT properly // Unity does have Graphics.SetRandomWriteTarget for enabling UAV on ComputeBuffers or RTs // Unity doesn't do 1D textures so just default to Texture2D } private void InitShaders() { LoadComputeShader("FSR2/ffx_fsr2_compute_luminance_pyramid_pass", ref _computeLuminancePyramidShader); LoadComputeShader("FSR2/ffx_fsr2_rcas_pass", ref _rcasShader); LoadComputeShader("FSR2/ffx_fsr2_prepare_input_color_pass", ref _prepareInputColorShader); LoadComputeShader("FSR2/ffx_fsr2_depth_clip_pass", ref _depthClipShader); LoadComputeShader("FSR2/ffx_fsr2_reconstruct_previous_depth_pass", ref _reconstructPreviousDepthShader); LoadComputeShader("FSR2/ffx_fsr2_lock_pass", ref _lockShader); LoadComputeShader("FSR2/ffx_fsr2_accumulate_pass", ref _accumulateShader); LoadComputeShader("FSR2/ffx_fsr2_autogen_reactive_pass", ref _generateReactiveShader); LoadComputeShader("FSR2/ffx_fsr2_tcr_autogen_pass", ref _tcrAutogenShader); } public void Dispatch(Fsr2.DispatchDescription dispatchParams) { if (_firstExecution) { // TODO: clear values } // TODO: setup resource indices for buffers that get swapped per frame bool resetAccumulation = dispatchParams.Reset || _firstExecution; _firstExecution = false; // TODO Register resources: ... SetupConstants(dispatchParams, resetAccumulation); // Reactive mask bias const int threadGroupWorkRegionDim = 8; int dispatchSrcX = (Constants.renderSize.x + (threadGroupWorkRegionDim - 1)) / threadGroupWorkRegionDim; int dispatchSrcY = (Constants.renderSize.y + (threadGroupWorkRegionDim - 1)) / threadGroupWorkRegionDim; int dispatchDstX = (_contextDescription.DisplaySize.x + (threadGroupWorkRegionDim - 1)) / threadGroupWorkRegionDim; int dispatchDstY = (_contextDescription.DisplaySize.y + (threadGroupWorkRegionDim - 1)) / threadGroupWorkRegionDim; if (resetAccumulation) { // TODO: clear reconstructed depth for max depth store } // Auto exposure SetupSpdConstants(dispatchParams); if (dispatchParams.EnableSharpening) { SetupRcasConstants(dispatchParams); // Run the RCAS sharpening filter on the upscaled image int rcasKernel = _rcasShader.FindKernel("CS"); _rcasShader.SetTexture(rcasKernel, "r_input_exposure", dispatchParams.Exposure); _rcasShader.SetTexture(rcasKernel, "r_rcas_input", dispatchParams.Input); _rcasShader.SetTexture(rcasKernel, "rw_upscaled_output", dispatchParams.Output); _rcasShader.SetConstantBuffer("cbFSR2", _fsr2ConstantsBuffer, 0, Marshal.SizeOf()); _rcasShader.SetConstantBuffer("cbRCAS", _rcasConstantsBuffer, 0, Marshal.SizeOf()); const int threadGroupWorkRegionDimRcas = 16; int threadGroupsX = (Screen.width + threadGroupWorkRegionDimRcas - 1) / threadGroupWorkRegionDimRcas; int threadGroupsY = (Screen.height + threadGroupWorkRegionDimRcas - 1) / threadGroupWorkRegionDimRcas; _rcasShader.Dispatch(rcasKernel, threadGroupsX, threadGroupsY, 1); } else { Graphics.Blit(dispatchParams.Input, dispatchParams.Output); } _resourceFrameIndex = (_resourceFrameIndex + 1) % MaxQueuedFrames; // TODO Unregister resources: release temp RT's } private void SetupConstants(Fsr2.DispatchDescription dispatchParams, bool resetAccumulation) { ref Fsr2Constants constants = ref Constants; constants.jitterOffset = dispatchParams.JitterOffset; constants.renderSize = new Vector2Int( dispatchParams.RenderSize.x > 0 ? dispatchParams.RenderSize.x : dispatchParams.Input.width, dispatchParams.RenderSize.y > 0 ? dispatchParams.RenderSize.y : dispatchParams.Input.height); constants.maxRenderSize = _contextDescription.MaxRenderSize; constants.inputColorResourceDimensions = new Vector2Int(dispatchParams.Input.width, dispatchParams.Input.height); // Compute the horizontal FOV for the shader from the vertical one float aspectRatio = (float)dispatchParams.RenderSize.x / dispatchParams.RenderSize.y; float cameraAngleHorizontal = Mathf.Atan(Mathf.Tan(dispatchParams.CameraFovAngleVertical / 2.0f) * aspectRatio) * 2.0f; constants.tanHalfFOV = Mathf.Tan(cameraAngleHorizontal * 0.5f); constants.viewSpaceToMetersFactor = (dispatchParams.ViewSpaceToMetersFactor > 0.0f) ? dispatchParams.ViewSpaceToMetersFactor : 1.0f; // Compute params to enable device depth to view space depth computation in shader constants.deviceToViewDepth = SetupDeviceDepthToViewSpaceDepthParams(dispatchParams); // To be updated if resource is larger than the actual image size constants.downscaleFactor = new Vector2( (float)constants.renderSize.x / _contextDescription.DisplaySize.x, (float)constants.renderSize.y / _contextDescription.DisplaySize.y); constants.previousFramePreExposure = constants.preExposure; constants.preExposure = (dispatchParams.PreExposure != 0) ? dispatchParams.PreExposure : 1.0f; // Motion vector data Vector2Int motionVectorsTargetSize = (_contextDescription.Flags & Fsr2.InitializationFlags.EnableDisplayResolutionMotionVectors) != 0 ? constants.displaySize : constants.renderSize; constants.motionVectorScale = dispatchParams.MotionVectorScale / motionVectorsTargetSize; // Compute jitter cancellation if ((_contextDescription.Flags & Fsr2.InitializationFlags.EnableMotionVectorsJitterCancellation) != 0) { constants.motionVectorJitterCancellation = (_previousJitterOffset - constants.jitterOffset) / motionVectorsTargetSize; _previousJitterOffset = constants.jitterOffset; } int jitterPhaseCount = Fsr2.GetJitterPhaseCount(dispatchParams.RenderSize.x, _contextDescription.DisplaySize.x); if (resetAccumulation || constants.jitterPhaseCount == 0) { constants.jitterPhaseCount = jitterPhaseCount; } else { int jitterPhaseCountDelta = (int)(jitterPhaseCount - constants.jitterPhaseCount); if (jitterPhaseCountDelta > 0) constants.jitterPhaseCount++; else if (jitterPhaseCountDelta < 0) constants.jitterPhaseCount--; } // Convert delta time to seconds and clamp to [0, 1] constants.deltaTime = Mathf.Clamp01(dispatchParams.FrameTimeDelta / 1000.0f); if (resetAccumulation) constants.frameIndex = 0; else constants.frameIndex++; // Shading change usage of the SPD mip levels constants.lumaMipLevelToUse = 4; // NOTE: this is derived from a bunch of auto-generated constant values in the FSR2 code float mipDiv = 2 << constants.lumaMipLevelToUse; constants.lumaMipDimensions.x = (int)(constants.maxRenderSize.x / mipDiv); constants.lumaMipDimensions.y = (int)(constants.maxRenderSize.y / mipDiv); _fsr2ConstantsBuffer.SetData(_fsr2ConstantsArray); } private Vector4 SetupDeviceDepthToViewSpaceDepthParams(Fsr2.DispatchDescription dispatchParams) { bool inverted = (_contextDescription.Flags & Fsr2.InitializationFlags.EnableDepthInverted) != 0; bool infinite = (_contextDescription.Flags & Fsr2.InitializationFlags.EnableDepthInfinite) != 0; // make sure it has no impact if near and far plane values are swapped in dispatch params // the flags "inverted" and "infinite" will decide what transform to use float min = Mathf.Min(dispatchParams.CameraNear, dispatchParams.CameraFar); float max = Mathf.Max(dispatchParams.CameraNear, dispatchParams.CameraFar); if (inverted) { (min, max) = (max, min); } float q = max / (min - max); float d = -1.0f; Vector4 matrixElemC = new Vector4(q, -1.0f - Mathf.Epsilon, q, 0.0f + Mathf.Epsilon); Vector4 matrixElemE = new Vector4(q * min, -min - Mathf.Epsilon, q * min, max); // Revert x and y coords float aspect = (float)dispatchParams.RenderSize.x / dispatchParams.RenderSize.y; float cotHalfFovY = Mathf.Cos(0.5f * dispatchParams.CameraFovAngleVertical) / Mathf.Sin(0.5f * dispatchParams.CameraFovAngleVertical); int matrixIndex = (inverted ? 2 : 0) + (infinite ? 1 : 0); return new Vector4( d * matrixElemC[matrixIndex], matrixElemE[matrixIndex], aspect / cotHalfFovY, 1.0f / cotHalfFovY); } private void SetupRcasConstants(Fsr2.DispatchDescription dispatchParams) { int sharpnessIndex = Mathf.RoundToInt(Mathf.Clamp01(dispatchParams.Sharpness) * (RcasConfigs.Count - 1)); RcasConsts = RcasConfigs[sharpnessIndex]; _rcasConstantsBuffer.SetData(_rcasConstantsArray); } private void SetupSpdConstants(Fsr2.DispatchDescription dispatchParams) { RectInt rectInfo = new RectInt(0, 0, dispatchParams.RenderSize.x, dispatchParams.RenderSize.y); SpdSetup(rectInfo, out var dispatchThreadGroupCount, out var workGroupOffset, out var numWorkGroupsAndMips); // Downsample ref SpdConstants spdConstants = ref SpdConsts; spdConstants.numWorkGroups = (uint)numWorkGroupsAndMips.x; spdConstants.mips = (uint)numWorkGroupsAndMips.y; spdConstants.workGroupOffsetX = (uint)workGroupOffset.x; spdConstants.workGroupOffsetY = (uint)workGroupOffset.y; spdConstants.renderSizeX = (uint)dispatchParams.RenderSize.x; spdConstants.renderSizeY = (uint)dispatchParams.RenderSize.y; _spdConstantsBuffer.SetData(_spdConstantsArray); } private static void SpdSetup(RectInt rectInfo, out Vector2Int dispatchThreadGroupCount, out Vector2Int workGroupOffset, out Vector2Int numWorkGroupsAndMips, int mips = -1) { workGroupOffset = new Vector2Int(rectInfo.x / 64, rectInfo.y / 64); int endIndexX = (rectInfo.x + rectInfo.width - 1) / 64; int endIndexY = (rectInfo.y + rectInfo.height - 1) / 64; dispatchThreadGroupCount = new Vector2Int(endIndexX + 1 - workGroupOffset.x, endIndexY + 1 - workGroupOffset.y); numWorkGroupsAndMips = new Vector2Int(dispatchThreadGroupCount.x * dispatchThreadGroupCount.y, mips); if (mips < 0) { float resolution = Math.Max(rectInfo.width, rectInfo.height); numWorkGroupsAndMips.y = Math.Min(Mathf.FloorToInt(Mathf.Log(resolution, 2.0f)), 12); } } public void Destroy() { DestroyConstantBuffer(ref _rcasConstantsBuffer); DestroyConstantBuffer(ref _spdConstantsBuffer); DestroyConstantBuffer(ref _fsr2ConstantsBuffer); DestroyComputeShader(ref _tcrAutogenShader); DestroyComputeShader(ref _generateReactiveShader); DestroyComputeShader(ref _accumulateShader); DestroyComputeShader(ref _lockShader); DestroyComputeShader(ref _reconstructPreviousDepthShader); DestroyComputeShader(ref _depthClipShader); DestroyComputeShader(ref _prepareInputColorShader); DestroyComputeShader(ref _rcasShader); DestroyComputeShader(ref _computeLuminancePyramidShader); } [Serializable, StructLayout(LayoutKind.Sequential)] private struct Fsr2Constants { public Vector2Int renderSize; public Vector2Int maxRenderSize; public Vector2Int displaySize; public Vector2Int inputColorResourceDimensions; public Vector2Int lumaMipDimensions; public int lumaMipLevelToUse; public int frameIndex; public Vector4 deviceToViewDepth; public Vector2 jitterOffset; public Vector2 motionVectorScale; public Vector2 downscaleFactor; public Vector2 motionVectorJitterCancellation; public float preExposure; public float previousFramePreExposure; public float tanHalfFOV; public float jitterPhaseCount; public float deltaTime; public float dynamicResChangeFactor; public float viewSpaceToMetersFactor; } [Serializable, StructLayout(LayoutKind.Sequential)] private struct SpdConstants { public uint mips; public uint numWorkGroups; public uint workGroupOffsetX, workGroupOffsetY; public uint renderSizeX, renderSizeY; } [Serializable, StructLayout(LayoutKind.Sequential)] private struct GenerateReactiveConstants { public float scale; public float threshold; public float binaryValue; public uint flags; } [Serializable, StructLayout(LayoutKind.Sequential)] private struct RcasConstants { public RcasConstants(uint sharpness, uint halfSharp) { this.sharpness = sharpness; this.halfSharp = halfSharp; dummy0 = dummy1 = 0; } public readonly uint sharpness; public readonly uint halfSharp; public readonly uint dummy0; public readonly uint dummy1; } /// /// The FSR2 C++ codebase uses floats bitwise converted to ints to pass sharpness parameters to the RCAS shader. /// This is not possible in C# without enabling unsafe code compilation, so to avoid that we instead use a table of precomputed values. /// private static readonly List RcasConfigs = new() { new(1048576000u, 872428544u), new(1049178080u, 877212745u), new(1049823372u, 882390168u), new(1050514979u, 887895276u), new(1051256227u, 893859143u), new(1052050675u, 900216232u), new(1052902144u, 907032080u), new(1053814727u, 914306687u), new(1054792807u, 922105590u), new(1055841087u, 930494326u), new(1056964608u, 939538432u), new(1057566688u, 944322633u), new(1058211980u, 949500056u), new(1058903587u, 955005164u), new(1059644835u, 960969031u), new(1060439283u, 967326120u), new(1061290752u, 974141968u), new(1062203335u, 981416575u), new(1063181415u, 989215478u), new(1064229695u, 997604214u), new(1065353216u, 1006648320), }; private static ComputeBuffer CreateConstantBuffer() where TConstants: struct { return new ComputeBuffer(1, Marshal.SizeOf(), ComputeBufferType.Constant); } private void LoadComputeShader(string name, ref ComputeShader shaderRef) { if (shaderRef == null) shaderRef = _contextDescription.Callbacks.LoadComputeShader(name); } private static void DestroyConstantBuffer(ref ComputeBuffer bufferRef) { if (bufferRef == null) return; bufferRef.Release(); bufferRef = null; } private void DestroyComputeShader(ref ComputeShader shaderRef) { if (shaderRef == null) return; _contextDescription.Callbacks.UnloadComputeShader(shaderRef); shaderRef = null; } } }