diff --git a/src/amd/compiler/aco_insert_fp_mode.cpp b/src/amd/compiler/aco_insert_fp_mode.cpp index c4b6c0e66f4..fef489817b6 100644 --- a/src/amd/compiler/aco_insert_fp_mode.cpp +++ b/src/amd/compiler/aco_insert_fp_mode.cpp @@ -28,7 +28,7 @@ static_assert(mode_field_count <= sizeof(mode_mask) * 8, "larger mode_mask neede struct fp_mode_state { uint8_t fields[mode_field_count] = {}; - mode_mask dirty = 0; /* BITFIELD_BIT(enum mode_field) */ + mode_mask required = 0; /* BITFIELD_BIT(enum mode_field) */ fp_mode_state() = default; @@ -41,51 +41,84 @@ struct fp_mode_state { fields[mode_fp16_ovfl] = 0; } - void join(const fp_mode_state& other) + /* Returns a mask of fields that cannot be joined. */ + mode_mask join(const fp_mode_state& other) { - dirty |= other.dirty; - for (unsigned i = 0; i < mode_field_count; i++) { - if (fields[i] != other.fields[i]) - dirty |= BITFIELD_BIT(i); + const std::array part_masks = { + BITFIELD_BIT(mode_round32) | BITFIELD_BIT(mode_round16_64), + BITFIELD_BIT(mode_denorm32) | BITFIELD_BIT(mode_denorm16_64), + BITFIELD_BIT(mode_fp16_ovfl), + }; + + mode_mask result = 0; + for (mode_mask part : part_masks) { + bool can_join = true; + u_foreach_bit (i, required & other.required & part) { + if (fields[i] != other.fields[i]) + can_join = false; + } + + if (!can_join) { + result |= part; + continue; + } + + u_foreach_bit (i, ~required & other.required & part) + fields[i] = other.fields[i]; + + required |= other.required & part; } + + return result; } - bool require(mode_field field, uint8_t val) + void require(mode_field field, uint8_t val) { - if (fields[field] == val && !(dirty & BITFIELD_BIT(field))) - return false; - fields[field] = val; - dirty |= BITFIELD_BIT(field); - return true; + required |= BITFIELD_BIT(field); } uint8_t round() const { return fields[mode_round32] | (fields[mode_round16_64] << 2); } uint8_t denorm() const { return fields[mode_denorm32] | (fields[mode_denorm16_64] << 2); } + + uint8_t round_denorm() const { return round() | (denorm() << 4); } }; struct fp_mode_ctx { std::vector block_states; + + uint32_t last_set[mode_field_count]; + Program* program; }; void -emit_set_mode(Builder& bld, const fp_mode_state& state) +set_mode(fp_mode_ctx* ctx, Block* block, fp_mode_state& state, unsigned idx, mode_mask mask) { - bool set_round = state.dirty & (BITFIELD_BIT(mode_round32) | BITFIELD_BIT(mode_round16_64)); - bool set_denorm = state.dirty & (BITFIELD_BIT(mode_denorm32) | BITFIELD_BIT(mode_denorm16_64)); - bool set_fp16_ovfl = state.dirty & BITFIELD_BIT(mode_fp16_ovfl); + Builder bld(ctx->program, block); + bld.reset(&block->instructions, block->instructions.begin() + idx); + + bool set_round = mask & (BITFIELD_BIT(mode_round32) | BITFIELD_BIT(mode_round16_64)); + bool set_denorm = mask & (BITFIELD_BIT(mode_denorm32) | BITFIELD_BIT(mode_denorm16_64)); + bool set_fp16_ovfl = mask & BITFIELD_BIT(mode_fp16_ovfl); if (bld.program->gfx_level >= GFX10) { - if (set_round) + if (set_round) { bld.sopp(aco_opcode::s_round_mode, state.round()); - if (set_denorm) + mask |= BITFIELD_BIT(mode_round32) | BITFIELD_BIT(mode_round16_64); + } + if (set_denorm) { bld.sopp(aco_opcode::s_denorm_mode, state.denorm()); + mask |= BITFIELD_BIT(mode_denorm32) | BITFIELD_BIT(mode_denorm16_64); + } } else if (set_round || set_denorm) { /* "((size - 1) << 11) | register" (MODE is encoded as register 1) */ - uint8_t val = state.round() | (state.denorm() << 4); + uint8_t val = state.round_denorm(); bld.sopk(aco_opcode::s_setreg_imm32_b32, Operand::literal32(val), (7 << 11) | 1); + + mask |= BITFIELD_BIT(mode_round32) | BITFIELD_BIT(mode_round16_64); + mask |= BITFIELD_BIT(mode_denorm32) | BITFIELD_BIT(mode_denorm16_64); } if (set_fp16_ovfl) { @@ -95,10 +128,15 @@ emit_set_mode(Builder& bld, const fp_mode_state& state) bld.sopk(aco_opcode::s_setreg_imm32_b32, Operand::literal32(state.fields[mode_fp16_ovfl]), (0 << 11) | (23 << 6) | 1); } + + state.required &= ~mask; + + u_foreach_bit (i, mask) + ctx->last_set[i] = MIN2(ctx->last_set[i], block->index); } mode_mask -vmem_default_needs(Instruction* instr) +vmem_default_needs(const Instruction* instr) { switch (instr->opcode) { case aco_opcode::buffer_atomic_fcmpswap: @@ -139,7 +177,7 @@ vmem_default_needs(Instruction* instr) } bool -instr_ignores_round_mode(Instruction* instr) +instr_ignores_round_mode(const Instruction* instr) { switch (instr->opcode) { case aco_opcode::v_min_f64_e64: @@ -221,26 +259,16 @@ instr_ignores_round_mode(Instruction* instr) } mode_mask -instr_default_needs(fp_mode_ctx* ctx, Block* block, Instruction* instr) +instr_default_needs(const fp_mode_ctx* ctx, const Instruction* instr) { if ((instr->isVMEM() || instr->isFlatLike()) && ctx->program->gfx_level < GFX12) return vmem_default_needs(instr); switch (instr->opcode) { - case aco_opcode::s_branch: - case aco_opcode::s_cbranch_scc0: - case aco_opcode::s_cbranch_scc1: - case aco_opcode::s_cbranch_vccz: - case aco_opcode::s_cbranch_vccnz: - case aco_opcode::s_cbranch_execz: - case aco_opcode::s_cbranch_execnz: - if (instr->salu().imm > block->index) - return 0; - FALLTHROUGH; case aco_opcode::s_swappc_b64: case aco_opcode::s_setpc_b64: case aco_opcode::s_call_b64: - /* Restore defaults on loop back edges and calls. */ + /* Restore defaults on calls. */ return BITFIELD_MASK(mode_field_count); case aco_opcode::ds_cmpst_f32: case aco_opcode::ds_min_f32: @@ -315,100 +343,127 @@ instr_default_needs(fp_mode_ctx* ctx, Block* block, Instruction* instr) void emit_set_mode_block(fp_mode_ctx* ctx, Block* block) { - Builder bld(ctx->program, block); - fp_mode_state fp_state; const fp_mode_state default_state(block->fp_mode); + fp_mode_state fp_state = default_state; - if (block->index == 0) { - bool inital_unknown = (ctx->program->info.merged_shader_compiled_separately && - ctx->program->stage.sw == SWStage::GS) || - (ctx->program->info.merged_shader_compiled_separately && - ctx->program->stage.sw == SWStage::TCS); - - if (inital_unknown) { - fp_state.dirty = BITFIELD_MASK(mode_field_count) & ~BITFIELD_BIT(mode_fp16_ovfl); - } else { - float_mode program_mode; - program_mode.val = ctx->program->config->float_mode; - fp_state = fp_mode_state(program_mode); - } - } else if (block->linear_preds.empty()) { - fp_state = default_state; + if (block->kind & block_kind_end_with_regs) { + /* Restore default. */ + fp_state.required = BITFIELD_MASK(mode_field_count); + assert(block->linear_succs.empty()); } else { - assert(block->linear_preds[0] < block->index); - fp_state = ctx->block_states[block->linear_preds[0]]; - for (unsigned i = 1; i < block->linear_preds.size(); i++) { - unsigned pred = block->linear_preds[i]; - fp_mode_state other = pred < block->index - ? ctx->block_states[pred] - : fp_mode_state(ctx->program->blocks[pred].fp_mode); - fp_state.join(other); + for (unsigned succ : block->linear_succs) { + /* Skip loop headers, they are handled at the end. */ + if (succ <= block->index) + continue; + + fp_mode_state& other = ctx->block_states[succ]; + mode_mask to_set = fp_state.join(other); + + if (to_set) { + Block* succ_block = &ctx->program->blocks[succ]; + set_mode(ctx, succ_block, other, 0, to_set); + } } } - /* If we don't know the value, set it to the default one next time. */ - u_foreach_bit (field, fp_state.dirty) - fp_state.fields[field] = default_state.fields[field]; - - for (std::vector>::iterator it = block->instructions.begin(); - it < block->instructions.end(); ++it) { - bool set_mode = false; - - Instruction* instr = it->get(); + for (int idx = block->instructions.size() - 1; idx >= 0; idx--) { + Instruction* instr = block->instructions[idx].get(); + fp_mode_state instr_state; if (instr->opcode == aco_opcode::p_v_cvt_f16_f32_rtne || instr->opcode == aco_opcode::p_s_cvt_f16_f32_rtne) { - set_mode |= fp_state.require(mode_round16_64, fp_round_ne); - set_mode |= fp_state.require(mode_fp16_ovfl, default_state.fields[mode_fp16_ovfl]); - set_mode |= fp_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); + instr_state.require(mode_round16_64, fp_round_ne); + instr_state.require(mode_fp16_ovfl, default_state.fields[mode_fp16_ovfl]); + instr_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); + if (instr->opcode == aco_opcode::p_v_cvt_f16_f32_rtne) instr->opcode = aco_opcode::v_cvt_f16_f32; else instr->opcode = aco_opcode::s_cvt_f16_f32; } else if (instr->opcode == aco_opcode::p_v_cvt_f16_f32_rtpi || instr->opcode == aco_opcode::p_v_cvt_f16_f32_rtni) { - set_mode |= fp_state.require(mode_round16_64, instr->opcode == aco_opcode::p_v_cvt_f16_f32_rtpi ? fp_round_pi : fp_round_ni); - set_mode |= fp_state.require(mode_fp16_ovfl, default_state.fields[mode_fp16_ovfl]); - set_mode |= fp_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); - set_mode |= fp_state.require(mode_denorm32, default_state.fields[mode_denorm32]); + instr_state.require(mode_round16_64, instr->opcode == aco_opcode::p_v_cvt_f16_f32_rtpi + ? fp_round_pi + : fp_round_ni); + instr_state.require(mode_fp16_ovfl, default_state.fields[mode_fp16_ovfl]); + instr_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); + instr_state.require(mode_denorm32, default_state.fields[mode_denorm32]); + instr->opcode = aco_opcode::v_cvt_f16_f32; } else if (instr->opcode == aco_opcode::p_v_cvt_pk_fp8_f32_ovfl) { - set_mode |= fp_state.require(mode_fp16_ovfl, 1); + instr_state.require(mode_fp16_ovfl, 1); instr->opcode = aco_opcode::v_cvt_pk_fp8_f32; } else if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || instr->opcode == aco_opcode::p_v_fma_mixhi_f16_rtz) { - set_mode |= fp_state.require(mode_round16_64, fp_round_tz); - set_mode |= fp_state.require(mode_round32, default_state.fields[mode_round32]); - set_mode |= fp_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); - set_mode |= fp_state.require(mode_denorm32, default_state.fields[mode_denorm32]); + instr_state.require(mode_round16_64, fp_round_tz); + instr_state.require(mode_round32, default_state.fields[mode_round32]); + instr_state.require(mode_denorm16_64, default_state.fields[mode_denorm16_64]); + instr_state.require(mode_denorm32, default_state.fields[mode_denorm32]); + if (instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz) instr->opcode = aco_opcode::v_fma_mixlo_f16; else instr->opcode = aco_opcode::v_fma_mixhi_f16; } else { - mode_mask default_needs = instr_default_needs(ctx, block, instr); + mode_mask default_needs = instr_default_needs(ctx, instr); u_foreach_bit (i, default_needs) - set_mode |= fp_state.require((mode_field)i, default_state.fields[i]); + instr_state.require((mode_field)i, default_state.fields[i]); } - if (set_mode) { - bld.reset(&block->instructions, it); - emit_set_mode(bld, fp_state); - fp_state.dirty = 0; - /* Update the iterator if it was invalidated */ - it = bld.it; + mode_mask to_set = fp_state.join(instr_state); + + if (to_set) { + /* If the mode required by the current instruction is incompatible with + * the mode(s) required by future instructions, set the next mode after + * the current instruction and update the required mode. + */ + set_mode(ctx, block, fp_state, idx + 1, to_set); + to_set = fp_state.join(instr_state); + assert(!to_set); } } - if (block->kind & block_kind_end_with_regs) { - /* Restore default. */ - for (unsigned i = 0; i < mode_field_count; i++) - fp_state.require((mode_field)i, default_state.fields[i]); - if (fp_state.dirty) { - bld.reset(block); - emit_set_mode(bld, fp_state); - fp_state.dirty = 0; + if (block->linear_preds.empty()) { + + if (fp_state.fields[mode_fp16_ovfl] == 0) { + /* We always set fp16_ovfl=0 from the commmand stream */ + fp_state.required &= ~BITFIELD_BIT(mode_fp16_ovfl); } + + bool initial_unknown = (ctx->program->info.merged_shader_compiled_separately && + ctx->program->stage.sw == SWStage::GS) || + (ctx->program->info.merged_shader_compiled_separately && + ctx->program->stage.sw == SWStage::TCS); + + if (ctx->program->stage == raytracing_cs || block->index) { + /* Assume the default state is already set. */ + for (unsigned i = 0; i < mode_field_count; i++) { + if (fp_state.fields[i] == default_state.fields[i]) + fp_state.required &= ~BITFIELD_BIT(i); + } + } else if (!initial_unknown) { + /* Set what's required from the command stream. */ + ctx->program->config->float_mode = fp_state.round_denorm(); + fp_state.required &= BITFIELD_BIT(mode_fp16_ovfl); + } + + if (fp_state.required) + set_mode(ctx, block, fp_state, 0, fp_state.required); + } else if (block->kind & block_kind_loop_header) { + uint32_t max_pred = 0; + for (uint32_t pred : block->linear_preds) + max_pred = MAX2(max_pred, pred); + + assert(max_pred != 0); + + mode_mask to_set = 0; + /* Check if the any mode was changed during the loop. */ + u_foreach_bit (i, fp_state.required) { + if (ctx->last_set[i] <= max_pred) + to_set |= BITFIELD_BIT(i); + } + if (to_set) + set_mode(ctx, block, fp_state, 0, to_set); } ctx->block_states[block->index] = fp_state; @@ -428,9 +483,11 @@ insert_fp_mode(Program* program) fp_mode_ctx ctx; ctx.program = program; ctx.block_states.resize(program->blocks.size()); + for (unsigned i = 0; i < mode_field_count; i++) + ctx.last_set[i] = UINT32_MAX; - for (Block& block : program->blocks) - emit_set_mode_block(&ctx, &block); + for (int i = program->blocks.size() - 1; i >= 0; i--) + emit_set_mode_block(&ctx, &program->blocks[i]); } } // namespace aco