diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index dbc4c4dd9b5..700c065a227 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -79,6 +79,7 @@ static const struct spirv_capabilities implemented_capabilities = { .Float16Buffer = true, .Float64 = true, .FloatControls2 = true, + .FMAKHR = true, .FragmentBarycentricKHR = true, .FragmentDensityEXT = true, .FragmentFullyCoveredEXT = true, @@ -6846,6 +6847,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpFSub: case SpvOpIMul: case SpvOpFMul: + case SpvOpFmaKHR: case SpvOpUDiv: case SpvOpSDiv: case SpvOpFDiv: diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 6898b2bf151..dd3226ba641 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -807,6 +807,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, nir_fdot(&b->nb, src[0], src[1]); break; + case SpvOpFmaKHR: + const unsigned save_fp_math_ctrl = b->nb.fp_math_ctrl; + + b->nb.fp_math_ctrl |= nir_fp_exact; + dest->def = nir_ffma(&b->nb, src[0], src[1], src[2]); + b->nb.fp_math_ctrl = save_fp_math_ctrl; + break; + case SpvOpIAddCarry: vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); diff --git a/src/gallium/frontends/rusticl/core/device.rs b/src/gallium/frontends/rusticl/core/device.rs index 3e2c9d43370..4d0ea7c8ea0 100644 --- a/src/gallium/frontends/rusticl/core/device.rs +++ b/src/gallium/frontends/rusticl/core/device.rs @@ -637,6 +637,7 @@ impl DeviceBase { add_spirv(c"SPV_KHR_bit_instructions"); add_spirv(c"SPV_KHR_expect_assume"); add_spirv(c"SPV_KHR_float_controls"); + add_spirv(c"SPV_KHR_fma"); add_spirv(c"SPV_KHR_integer_dot_product"); add_spirv(c"SPV_KHR_no_integer_wrap_decoration"); @@ -647,6 +648,7 @@ impl DeviceBase { add_cap(SpvCapability::SpvCapabilityDotProductInput4x8BitPacked); add_cap(SpvCapability::SpvCapabilityExpectAssumeKHR); add_cap(SpvCapability::SpvCapabilityFloat16Buffer); + add_cap(SpvCapability::SpvCapabilityFMAKHR); add_cap(SpvCapability::SpvCapabilityInt8); add_cap(SpvCapability::SpvCapabilityInt16); add_cap(SpvCapability::SpvCapabilityLinkage);