diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 0fae0814204..00c846f82b1 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -65,8 +65,10 @@ enum Label { label_abs_fp16 = 1ull << 19, label_neg_fp16 = 1ull << 20, label_fcanonicalize_fp16 = 1ull << 21, - - label_canonicalized = 1ull << 22, + /* One label for each bit size because there are packed fp32 definitions. */ + label_canonicalized_fp16 = 1ull << 22, + label_canonicalized_fp32 = 1ull << 23, + label_canonicalized_fp64 = 1ull << 24, /* label_{omod2,omod4,omod5,clamp} are used for both 16 and * 32-bit operations but this doesn't cause any issues because @@ -94,6 +96,22 @@ static constexpr uint64_t temp_labels = label_temp | label_uniform_bool | label_ static constexpr uint64_t val_labels = label_constant | label_mad; +static constexpr uint64_t canonicalized_labels = + label_canonicalized_fp16 | label_canonicalized_fp32 | label_canonicalized_fp64; + +static Label +canonicalized_label(unsigned bit_size) +{ + if (bit_size == 16) + return label_canonicalized_fp16; + else if (bit_size == 32) + return label_canonicalized_fp32; + else if (bit_size == 64) + return label_canonicalized_fp64; + else + UNREACHABLE("unknown canonicalized size"); +} + static_assert((instr_mod_labels & temp_labels) == 0, "labels cannot intersect"); static_assert((instr_mod_labels & val_labels) == 0, "labels cannot intersect"); static_assert((temp_labels & val_labels) == 0, "labels cannot intersect"); @@ -293,9 +311,9 @@ struct ssa_info { : label & label_fcanonicalize_fp32_64; } - void set_canonicalized() { add_label(label_canonicalized); } + void set_canonicalized(unsigned bit_size) { add_label(canonicalized_label(bit_size)); } - bool is_canonicalized() { return label & label_canonicalized; } + bool is_canonicalized(unsigned bit_size) { return label & canonicalized_label(bit_size); } void set_extract() { add_label(label_extract); } @@ -1484,6 +1502,93 @@ alu_opt_info_to_instr(opt_ctx& ctx, alu_opt_info& info, Instruction* old_instr) return instr; } +uint64_t +operand_canonicalized_labels(opt_ctx& ctx, Operand op) +{ + if (op.isConstant()) { + uint64_t val = op.constantValue64(); + uint64_t res = 0; + if (op.size() == 2) { + if (((val << 1) >> 1) == 0 || ((val << 1) >> 1) > 0x000f'ffff'ffff'ffffull) + res |= label_canonicalized_fp64; + } else if (op.size() == 1) { + /* Check both fp16 halves for denorms because of packed math and opsel.*/ + if (((val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff) && + ((val & 0x7fff0000) == 0 || (val & 0x7fff0000) > 0x3ff0000)) + res |= label_canonicalized_fp16; + if ((val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff) + res |= label_canonicalized_fp32; + } + return res; + } else if (op.isTemp()) { + return ctx.info[op.tempId()].label & canonicalized_labels; + } + + return 0; +} + +void +gather_canonicalized(opt_ctx& ctx, aco_ptr& instr) +{ + if (instr->isSDWA() || instr->definitions.size() == 0) + return; + + if (is_phi(instr)) { + /* This is correct even for loop header phis because label is 0 initially. */ + uint64_t label = canonicalized_labels; + for (Operand& op : instr->operands) + label &= operand_canonicalized_labels(ctx, op); + + ctx.info[instr->definitions[0].tempId()].label |= label; + } else if (instr->opcode == aco_opcode::p_parallelcopy || + instr->opcode == aco_opcode::p_as_uniform || instr->opcode == aco_opcode::v_mov_b32 || + instr->opcode == aco_opcode::v_mov_b16 || + instr->opcode == aco_opcode::v_readfirstlane_b32 || + instr->opcode == aco_opcode::v_readlane_b32 || + instr->opcode == aco_opcode::v_readlane_b32_e64) { + ctx.info[instr->definitions[0].tempId()].label |= + operand_canonicalized_labels(ctx, instr->operands[0]); + } else if (instr->opcode == aco_opcode::v_cndmask_b32 || + instr->opcode == aco_opcode::v_cndmask_b16 || + instr->opcode == aco_opcode::s_cselect_b32 || + instr->opcode == aco_opcode::s_cselect_b64) { + uint64_t label = canonicalized_labels; + for (unsigned i = 0; i < 2; i++) + label &= operand_canonicalized_labels(ctx, instr->operands[i]); + + ctx.info[instr->definitions[0].tempId()].label |= label; + } else if (instr->opcode == aco_opcode::s_mul_i32) { + for (unsigned i = 0; i < 2; i++) { + if (!instr->operands[i].isTemp()) + continue; + Temp tmp = instr->operands[i].getTemp(); + Definition parent_def = ctx.info[tmp.id()].parent_instr->definitions.back(); + if (parent_def.getTemp() == tmp && parent_def.isFixed() && parent_def.physReg() == scc) { + /* The operand is either 0 or 1, so this is a select between 0 and the other operand. */ + ctx.info[instr->definitions[0].tempId()].label |= + operand_canonicalized_labels(ctx, instr->operands[!i]); + break; + } + } + } else if (ctx.program->gfx_level < GFX9 && + (instr->opcode == aco_opcode::v_max_f32 || instr->opcode == aco_opcode::v_min_f32 || + instr->opcode == aco_opcode::v_max_f64_e64 || + instr->opcode == aco_opcode::v_min_f64_e64 || + instr->opcode == aco_opcode::v_max3_f32 || instr->opcode == aco_opcode::v_min3_f32 || + instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_max_f16 || + instr->opcode == aco_opcode::v_min_f16)) { + uint64_t label = canonicalized_labels; + for (Operand& op : instr->operands) + label &= operand_canonicalized_labels(ctx, op); + + ctx.info[instr->definitions[0].tempId()].label |= label; + } else if (instr->isVALU() || instr->isSALU() || instr->isVINTRP()) { + aco_type type = instr_info.alu_opcode_infos[(int)instr->opcode].def_types[0]; + if (type.base_type == aco_base_type_float && type.bit_size >= 16) + ctx.info[instr->definitions[0].tempId()].set_canonicalized(type.bit_size); + } +} + bool can_use_VOP3(opt_ctx& ctx, const aco_ptr& instr) { @@ -1878,25 +1983,6 @@ remove_operand_extract(opt_ctx& ctx, aco_ptr& instr) } } -bool -does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op) -{ - switch (op) { - case aco_opcode::v_min_f32: - case aco_opcode::v_max_f32: - case aco_opcode::v_med3_f32: - case aco_opcode::v_min3_f32: - case aco_opcode::v_max3_f32: - case aco_opcode::v_min_f16: - case aco_opcode::v_max_f16: return ctx.program->gfx_level > GFX8; - case aco_opcode::v_cndmask_b32: - case aco_opcode::v_cndmask_b16: - case aco_opcode::v_mov_b32: - case aco_opcode::v_mov_b16: return false; - default: return true; - } -} - bool can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags, bool allow_cselect = false) { @@ -1932,24 +2018,6 @@ can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags, bool allow_c } } -bool -is_op_canonicalized(opt_ctx& ctx, Operand op) -{ - float_mode* fp = &ctx.fp_mode; - if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) || - (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep) - return true; - - if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant())) { - uint64_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue(); - if (op.bytes() == 2) - return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff; - else if (op.bytes() == 4) - return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff; - } - return false; -} - bool is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int64_t offset0, int64_t offset1) { @@ -2037,7 +2105,7 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type) if (info.is_fcanonicalize(bit_size) || info.is_abs(bit_size) || info.is_neg(bit_size)) { type.num_components = 1; type.bit_size = bit_size; - if (ctx.info[info.temp.id()].is_canonicalized() || + if (ctx.info[info.temp.id()].is_canonicalized(bit_size) || (bit_size == 32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == fp_denorm_keep) type.base_type = aco_base_type_uint; else @@ -2306,7 +2374,7 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va instr.reset(alu_opt_info_to_instr(ctx, result_info, instr.release())); for (const Definition& def : instr->definitions) - ctx.info[def.tempId()].label &= instr_mod_labels | label_canonicalized; + ctx.info[def.tempId()].label &= instr_mod_labels | canonicalized_labels; } void @@ -2529,19 +2597,7 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) if (instr->opcode == aco_opcode::p_extract || instr->opcode == aco_opcode::p_extract_vector) extract_apply_extract(ctx, instr); - if (instr->isVALU() || (instr->isVINTRP() && instr->opcode != aco_opcode::v_interp_mov_f32)) { - if (instr_info.alu_opcode_infos[(int)instr->opcode].output_modifiers || instr->isVINTRP() || - instr->opcode == aco_opcode::v_cndmask_b32) { - bool canonicalized = true; - if (!does_fp_op_flush_denorms(ctx, instr->opcode)) { - unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size(); - for (unsigned i = 0; canonicalized && (i < ops); i++) - canonicalized = is_op_canonicalized(ctx, instr->operands[i]); - } - if (canonicalized) - ctx.info[instr->definitions[0].tempId()].set_canonicalized(); - } - } + gather_canonicalized(ctx, instr); switch (instr->opcode) { case aco_opcode::p_create_vector: { @@ -2745,8 +2801,6 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) instr->operands[0].constantValue64()); } else if (instr->operands[0].isTemp()) { ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp()); - if (ctx.info[instr->operands[0].tempId()].is_canonicalized()) - ctx.info[instr->definitions[0].tempId()].set_canonicalized(); } else { assert(instr->operands[0].isFixed()); } @@ -2792,7 +2846,8 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) else if (!abs && neg && other.type() == RegType::vgpr) ctx.info[instr->definitions[0].tempId()].set_neg(other, bit_size); else if (!abs && !neg) { - if (denorm_mode == fp_denorm_keep || ctx.info[other.id()].is_canonicalized()) + if (denorm_mode == fp_denorm_keep || + ctx.info[other.id()].is_canonicalized(bit_size)) ctx.info[instr->definitions[0].tempId()].set_temp(other); else ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other, bit_size); @@ -2901,12 +2956,6 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp); } break; - case aco_opcode::s_mul_i32: - /* Testing every uint32_t shows that 0x3f800000*n is never a denormal. - * This pattern is created from a uniform nir_op_b2f. */ - if (instr->operands[0].constantEquals(0x3f800000u)) - ctx.info[instr->definitions[0].tempId()].set_canonicalized(); - break; case aco_opcode::p_extract: { if (instr->operands[0].isTemp()) { ctx.info[instr->definitions[0].tempId()].set_extract();