diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 5cc74c48b71..0fae0814204 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -55,9 +55,17 @@ enum Label { label_scc_needed = 1ull << 6, label_extract = 1ull << 7, - label_abs = 1ull << 16, - label_neg = 1ull << 17, - label_fcanonicalize = 1ull << 18, + /* These have one label for fp16 and one for fp32/64. + * 32bit vs 64bit type mismatches are impossible because + * of the different register class sizes. + */ + label_abs_fp32_64 = 1ull << 16, + label_neg_fp32_64 = 1ull << 17, + label_fcanonicalize_fp32_64 = 1ull << 18, + label_abs_fp16 = 1ull << 19, + label_neg_fp16 = 1ull << 20, + label_fcanonicalize_fp16 = 1ull << 21, + label_canonicalized = 1ull << 22, /* label_{omod2,omod4,omod5,clamp} are used for both 16 and @@ -77,9 +85,12 @@ enum Label { static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16; -static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_b2f | - label_uniform_bool | label_scc_invert | label_b2i | - label_fcanonicalize; +static constexpr uint64_t input_mod_labels = + label_abs_fp16 | label_abs_fp32_64 | label_neg_fp16 | label_neg_fp32_64; + +static constexpr uint64_t temp_labels = label_temp | label_uniform_bool | label_scc_invert | + label_b2f | label_b2i | input_mod_labels | + label_fcanonicalize_fp32_64 | label_fcanonicalize_fp16; static constexpr uint64_t val_labels = label_constant | label_mad; @@ -126,25 +137,39 @@ struct ssa_info { bool is_constant() { return label & label_constant; } - void set_abs(Temp abs_temp) + void set_abs(Temp abs_temp, unsigned bit_size) { - add_label(label_abs); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + add_label(bit_size == 16 ? label_abs_fp16 : label_abs_fp32_64); temp = abs_temp; } - bool is_abs() { return label & label_abs; } - - void set_neg(Temp neg_temp) + bool is_abs(unsigned bit_size) { - add_label(label_neg); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + return bit_size == 16 ? label & label_abs_fp16 : label & label_abs_fp32_64; + } + + void set_neg(Temp neg_temp, unsigned bit_size) + { + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + add_label(bit_size == 16 ? label_neg_fp16 : label_neg_fp32_64); temp = neg_temp; } - bool is_neg() { return label & label_neg; } - - void set_neg_abs(Temp neg_abs_temp) + bool is_neg(unsigned bit_size) { - add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg)); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + return bit_size == 16 ? label & label_neg_fp16 : label & label_neg_fp32_64; + } + + void set_neg_abs(Temp neg_abs_temp, unsigned bit_size) + { + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + if (bit_size == 16) + add_label((Label)((uint32_t)label_abs_fp16 | (uint32_t)label_neg_fp16)); + else + add_label((Label)((uint32_t)label_abs_fp32_64 | (uint32_t)label_neg_fp32_64)); temp = neg_abs_temp; } @@ -254,13 +279,19 @@ struct ssa_info { bool is_b2i() { return label & label_b2i; } - void set_fcanonicalize(Temp tmp) + void set_fcanonicalize(Temp tmp, unsigned bit_size) { - add_label(label_fcanonicalize); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + add_label(bit_size == 16 ? label_fcanonicalize_fp16 : label_fcanonicalize_fp32_64); temp = tmp; } - bool is_fcanonicalize() { return label & label_fcanonicalize; } + bool is_fcanonicalize(unsigned bit_size) + { + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + return bit_size == 16 ? label & label_fcanonicalize_fp16 + : label & label_fcanonicalize_fp32_64; + } void set_canonicalized() { add_label(label_canonicalized); } @@ -2002,19 +2033,34 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type) return true; } - // TODO use parent dst type - if (info.is_fcanonicalize() || info.is_abs() || info.is_neg()) { - if (ctx.info[info.temp.id()].is_canonicalized() || - (tmp.bytes() == 4 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == fp_denorm_keep) - type.base_type = aco_base_type_uint; - else - type.base_type = aco_base_type_float; - } else { - type.base_type = aco_base_type_uint; + for (unsigned bit_size = tmp.size() == 2 ? 64 : 16; bit_size <= tmp.bytes() * 8; bit_size *= 2) { + if (info.is_fcanonicalize(bit_size) || info.is_abs(bit_size) || info.is_neg(bit_size)) { + type.num_components = 1; + type.bit_size = bit_size; + if (ctx.info[info.temp.id()].is_canonicalized() || + (bit_size == 32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == fp_denorm_keep) + type.base_type = aco_base_type_uint; + else + type.base_type = aco_base_type_float; + + op_info.op = Operand(info.temp); + if (info.is_abs(bit_size)) + op_info.abs[0] = true; + if (info.is_neg(bit_size)) + op_info.neg[0] = true; + return true; + } } + + type.base_type = aco_base_type_uint; type.num_components = 1; type.bit_size = tmp.bytes() * 8; + if (info.is_temp()) { + op_info.op = Operand(info.temp); + return true; + } + if (info.is_extract()) { op_info.extract[0] = parse_extract(info.parent_instr); op_info.op = info.parent_instr->operands[0]; @@ -2052,14 +2098,6 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type) return true; } - if (info.is_temp() || info.is_fcanonicalize() || info.is_abs() || info.is_neg()) { - op_info.op = Operand(info.temp); - if (info.is_abs()) - op_info.abs[0] = true; - if (info.is_neg()) - op_info.neg[0] = true; - return true; - } return false; } @@ -2723,6 +2761,7 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) /* TODO: try to move the negate/abs modifier to the consumer instead */ bool uses_mods = instr->usesModifiers(); bool fp16 = instr->opcode == aco_opcode::v_mul_f16; + unsigned bit_size = fp16 ? 16 : 32; unsigned denorm_mode = fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32; for (unsigned i = 0; i < 2; i++) { @@ -2747,16 +2786,16 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) } if (abs && neg && other.type() == RegType::vgpr) - ctx.info[instr->definitions[0].tempId()].set_neg_abs(other); + ctx.info[instr->definitions[0].tempId()].set_neg_abs(other, bit_size); else if (abs && !neg && other.type() == RegType::vgpr) - ctx.info[instr->definitions[0].tempId()].set_abs(other); + ctx.info[instr->definitions[0].tempId()].set_abs(other, bit_size); else if (!abs && neg && other.type() == RegType::vgpr) - ctx.info[instr->definitions[0].tempId()].set_neg(other); + ctx.info[instr->definitions[0].tempId()].set_neg(other, bit_size); else if (!abs && !neg) { if (denorm_mode == fp_denorm_keep || ctx.info[other.id()].is_canonicalized()) ctx.info[instr->definitions[0].tempId()].set_temp(other); else - ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other); + ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other, bit_size); } } else if (uses_mods || (instr->definitions[0].isSZPreserve() && instr->opcode != aco_opcode::v_mul_legacy_f32)) { @@ -4444,7 +4483,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) * floats. */ /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */ - if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) && + if ((ctx.info[instr->definitions[0].tempId()].label & input_mod_labels) && ctx.uses[ctx.info[instr->definitions[0].tempId()].temp.id()] == 1) { Temp val = ctx.info[instr->definitions[0].tempId()].temp; Instruction* mul_instr = ctx.info[val.id()].parent_instr; @@ -4467,8 +4506,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */ ctx.uses[mul_instr->definitions[0].tempId()]--; Definition def = instr->definitions[0]; - bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg(); - bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs(); + bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg(def.bytes() * 8); + bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs(def.bytes() * 8); uint32_t pass_flags = instr->pass_flags; Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format; instr.reset(create_instruction(mul_instr->opcode, format, mul_instr->operands.size(), 1));