From c212a00ebfb3901df95beabd2bcce76715581060 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Tue, 6 Sep 2022 15:50:23 -0400 Subject: [PATCH] zink: handle 64bit float atomics Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 64 +++++++++++++++++-- src/gallium/drivers/zink/zink_compiler.c | 16 +++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 0ed25f39f4f..f7ceff7d660 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -214,7 +214,7 @@ emit_access_decorations(struct ntv_context *ctx, nir_variable *var, SpvId var_id } static SpvOp -get_atomic_op(nir_intrinsic_op op) +get_atomic_op(struct ntv_context *ctx, unsigned bit_size, nir_intrinsic_op op) { switch (op) { #define CASE_ATOMIC_OP(type) \ @@ -222,6 +222,32 @@ get_atomic_op(nir_intrinsic_op op) case nir_intrinsic_image_deref_atomic_##type: \ case nir_intrinsic_shared_atomic_##type +#define ATOMIC_FCAP(NAME) \ + do {\ + if (bit_size == 16) \ + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityAtomicFloat16##NAME##EXT); \ + if (bit_size == 32) \ + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityAtomicFloat32##NAME##EXT); \ + if (bit_size == 64) \ + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityAtomicFloat64##NAME##EXT); \ + } while (0) + + CASE_ATOMIC_OP(fadd): + ATOMIC_FCAP(Add); + if (bit_size == 16) + spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float16_add"); + else + spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float_add"); + return SpvOpAtomicFAddEXT; + CASE_ATOMIC_OP(fmax): + ATOMIC_FCAP(MinMax); + spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float_min_max"); + return SpvOpAtomicFMaxEXT; + CASE_ATOMIC_OP(fmin): + ATOMIC_FCAP(MinMax); + spirv_builder_emit_extension(&ctx->builder, "SPV_EXT_shader_atomic_float_min_max"); + return SpvOpAtomicFMinEXT; + CASE_ATOMIC_OP(add): return SpvOpAtomicIAdd; CASE_ATOMIC_OP(umin): @@ -248,7 +274,22 @@ get_atomic_op(nir_intrinsic_op op) } return 0; } + +static bool +atomic_op_is_float(nir_intrinsic_op op) +{ + switch (op) { + CASE_ATOMIC_OP(fadd): + CASE_ATOMIC_OP(fmax): + CASE_ATOMIC_OP(fmin): + return true; + default: + break; + } + return false; +} #undef CASE_ATOMIC_OP + static SpvId emit_float_const(struct ntv_context *ctx, int bit_size, double value) { @@ -2672,7 +2713,7 @@ static void handle_atomic_op(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId ptr, SpvId param, SpvId param2, nir_alu_type type) { SpvId dest_type = get_dest_type(ctx, &intr->dest, type); - SpvId result = emit_atomic(ctx, get_atomic_op(intr->intrinsic), dest_type, ptr, param, param2); + SpvId result = emit_atomic(ctx, get_atomic_op(ctx, nir_dest_bit_size(intr->dest), intr->intrinsic), dest_type, ptr, param, param2); assert(result); store_dest(ctx, &intr->dest, result, type); } @@ -2685,10 +2726,13 @@ emit_deref_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId param2 = 0; + if (nir_src_bit_size(intr->src[1]) == 64) + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics); + if (intr->intrinsic == nir_intrinsic_deref_atomic_comp_swap) param2 = get_src(ctx, &intr->src[2]); - handle_atomic_op(ctx, intr, ptr, param, param2, nir_type_uint32); + handle_atomic_op(ctx, intr, ptr, param, param2, atomic_op_is_float(intr->intrinsic) ? nir_type_float : nir_type_uint32); } static void @@ -2701,17 +2745,18 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, SpvStorageClassWorkgroup, dest_type); - SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, 4)); + SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, bit_size / 8)); SpvId shared_block = get_shared_block(ctx, bit_size); SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type, shared_block, &offset, 1); - + if (nir_src_bit_size(intr->src[1]) == 64) + spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics); SpvId param2 = 0; if (intr->intrinsic == nir_intrinsic_shared_atomic_comp_swap) param2 = get_src(ctx, &intr->src[2]); - handle_atomic_op(ctx, intr, ptr, param, param2, nir_type_uint32); + handle_atomic_op(ctx, intr, ptr, param, param2, atomic_op_is_float(intr->intrinsic) ? nir_type_float : nir_type_uint32); } static void @@ -3183,6 +3228,10 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvMemorySemanticsAcquireReleaseMask); break; + case nir_intrinsic_deref_atomic_fadd: + case nir_intrinsic_deref_atomic_fmin: + case nir_intrinsic_deref_atomic_fmax: + case nir_intrinsic_deref_atomic_fcomp_swap: case nir_intrinsic_deref_atomic_add: case nir_intrinsic_deref_atomic_umin: case nir_intrinsic_deref_atomic_imin: @@ -3196,6 +3245,9 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) emit_deref_atomic_intrinsic(ctx, intr); break; + case nir_intrinsic_shared_atomic_fadd: + case nir_intrinsic_shared_atomic_fmin: + case nir_intrinsic_shared_atomic_fmax: case nir_intrinsic_shared_atomic_add: case nir_intrinsic_shared_atomic_umin: case nir_intrinsic_shared_atomic_imin: diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index dce7d53d339..3ac3b0bb269 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -1218,6 +1218,18 @@ rewrite_atomic_ssbo_instr(nir_builder *b, nir_instr *instr, struct bo_vars *bo) nir_intrinsic_op op; nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); switch (intr->intrinsic) { + case nir_intrinsic_ssbo_atomic_fadd: + op = nir_intrinsic_deref_atomic_fadd; + break; + case nir_intrinsic_ssbo_atomic_fmin: + op = nir_intrinsic_deref_atomic_fmin; + break; + case nir_intrinsic_ssbo_atomic_fmax: + op = nir_intrinsic_deref_atomic_fmax; + break; + case nir_intrinsic_ssbo_atomic_fcomp_swap: + op = nir_intrinsic_deref_atomic_fcomp_swap; + break; case nir_intrinsic_ssbo_atomic_add: op = nir_intrinsic_deref_atomic_add; break; @@ -1297,6 +1309,10 @@ remove_bo_access_instr(nir_builder *b, nir_instr *instr, void *data) nir_src *src; bool ssbo = true; switch (intr->intrinsic) { + case nir_intrinsic_ssbo_atomic_fadd: + case nir_intrinsic_ssbo_atomic_fmin: + case nir_intrinsic_ssbo_atomic_fmax: + case nir_intrinsic_ssbo_atomic_fcomp_swap: case nir_intrinsic_ssbo_atomic_add: case nir_intrinsic_ssbo_atomic_umin: case nir_intrinsic_ssbo_atomic_imin: