diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 2980b757143..8e31caddf0e 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -3930,184 +3930,6 @@ combine_add_bcnt(opt_ctx& ctx, aco_ptr& instr) return false; } -bool -get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3, - aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only) -{ - switch (op) { -#define MINMAX(type, gfx9) \ - case aco_opcode::v_min_##type: \ - case aco_opcode::v_max_##type: \ - *min = aco_opcode::v_min_##type; \ - *max = aco_opcode::v_max_##type; \ - *med3 = aco_opcode::v_med3_##type; \ - *min3 = aco_opcode::v_min3_##type; \ - *max3 = aco_opcode::v_max3_##type; \ - *minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type; \ - *some_gfx9_only = gfx9; \ - return true; -#define MINMAX_INT16(type, gfx9) \ - case aco_opcode::v_min_##type: \ - case aco_opcode::v_max_##type: \ - *min = aco_opcode::v_min_##type; \ - *max = aco_opcode::v_max_##type; \ - *med3 = aco_opcode::v_med3_##type; \ - *min3 = aco_opcode::v_min3_##type; \ - *max3 = aco_opcode::v_max3_##type; \ - *minmax = aco_opcode::num_opcodes; \ - *some_gfx9_only = gfx9; \ - return true; -#define MINMAX_INT16_E64(type, gfx9) \ - case aco_opcode::v_min_##type##_e64: \ - case aco_opcode::v_max_##type##_e64: \ - *min = aco_opcode::v_min_##type##_e64; \ - *max = aco_opcode::v_max_##type##_e64; \ - *med3 = aco_opcode::v_med3_##type; \ - *min3 = aco_opcode::v_min3_##type; \ - *max3 = aco_opcode::v_max3_##type; \ - *minmax = aco_opcode::num_opcodes; \ - *some_gfx9_only = gfx9; \ - return true; - MINMAX(f32, false) - MINMAX(u32, false) - MINMAX(i32, false) - MINMAX(f16, true) - MINMAX_INT16(u16, true) - MINMAX_INT16(i16, true) - MINMAX_INT16_E64(u16, true) - MINMAX_INT16_E64(i16, true) -#undef MINMAX_INT16_E64 -#undef MINMAX_INT16 -#undef MINMAX - default: return false; - } -} - -/* when ub > lb: - * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub) - * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub) - */ -bool -combine_clamp(opt_ctx& ctx, aco_ptr& instr, aco_opcode min, aco_opcode max, - aco_opcode med) -{ - /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's - * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if - * minVal > maxVal, which means we can always select it to a v_med3_f32 */ - aco_opcode other_op; - if (instr->opcode == min) - other_op = max; - else if (instr->opcode == max) - other_op = min; - else - return false; - - for (unsigned swap = 0; swap < 2; swap++) { - Operand operands[3]; - bool clamp, precise; - bitarray8 opsel = 0, neg = 0, abs = 0; - uint8_t omod = 0; - if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg, - abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) { - /* max(min(src, upper), lower) returns upper if src is NaN, but - * med3(src, lower, upper) returns lower. - */ - if (precise && instr->opcode != min && - (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32)) - continue; - - int const0_idx = -1, const1_idx = -1; - uint32_t const0 = 0, const1 = 0; - for (int i = 0; i < 3; i++) { - uint32_t val; - bool hi16 = opsel & (1 << i); - if (operands[i].isConstant()) { - val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue(); - } else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant()) { - val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0); - } else { - continue; - } - if (const0_idx >= 0) { - const1_idx = i; - const1 = val; - } else { - const0_idx = i; - const0 = val; - } - } - if (const0_idx < 0 || const1_idx < 0) - continue; - - int lower_idx = const0_idx; - switch (min) { - case aco_opcode::v_min_f32: - case aco_opcode::v_min_f16: { - float const0_f, const1_f; - if (min == aco_opcode::v_min_f32) { - memcpy(&const0_f, &const0, 4); - memcpy(&const1_f, &const1, 4); - } else { - const0_f = _mesa_half_to_float(const0); - const1_f = _mesa_half_to_float(const1); - } - if (abs[const0_idx]) - const0_f = fabsf(const0_f); - if (abs[const1_idx]) - const1_f = fabsf(const1_f); - if (neg[const0_idx]) - const0_f = -const0_f; - if (neg[const1_idx]) - const1_f = -const1_f; - lower_idx = const0_f < const1_f ? const0_idx : const1_idx; - break; - } - case aco_opcode::v_min_u32: { - lower_idx = const0 < const1 ? const0_idx : const1_idx; - break; - } - case aco_opcode::v_min_u16: - case aco_opcode::v_min_u16_e64: { - lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx; - break; - } - case aco_opcode::v_min_i32: { - int32_t const0_i = - const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0; - int32_t const1_i = - const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1; - lower_idx = const0_i < const1_i ? const0_idx : const1_idx; - break; - } - case aco_opcode::v_min_i16: - case aco_opcode::v_min_i16_e64: { - int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0; - int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1; - lower_idx = const0_i < const1_i ? const0_idx : const1_idx; - break; - } - default: break; - } - int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx; - - if (instr->opcode == min) { - if (upper_idx != 0 || lower_idx == 0) - return false; - } else { - if (upper_idx == 0 || lower_idx != 0) - return false; - } - - ctx.uses[instr->operands[swap].tempId()]--; - create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod); - - return true; - } - } - - return false; -} - bool interp_can_become_fma(opt_ctx& ctx, aco_ptr& instr) { @@ -4780,6 +4602,41 @@ create_fma_cb(opt_ctx& ctx, alu_opt_info& info) return false; } +template +bool +create_med3_cb(opt_ctx& ctx, alu_opt_info& info) +{ + aco_type type = instr_info.alu_opcode_infos[(int)info.opcode].def_types[0]; + + /* NaN correctness needs max first, then min. */ + if (!max_first && type.base_type == aco_base_type_float && info.defs[0].isPrecise()) + return false; + + uint64_t upper = 0; + uint64_t lower = 0; + + if (!op_info_get_constant(ctx, info.operands[0], type, &upper)) + return false; + + if (!op_info_get_constant(ctx, info.operands[1], type, &lower) && + !op_info_get_constant(ctx, info.operands[2], type, &lower)) + return false; + + if (!max_first) + std::swap(upper, lower); + + switch (info.opcode) { + case aco_opcode::v_med3_f32: return uif(lower) <= uif(upper); + case aco_opcode::v_med3_f16: return _mesa_half_to_float(lower) <= _mesa_half_to_float(upper); + case aco_opcode::v_med3_u32: return uint32_t(lower) <= uint32_t(upper); + case aco_opcode::v_med3_u16: return uint16_t(lower) <= uint16_t(upper); + case aco_opcode::v_med3_i32: return int32_t(lower) <= int32_t(upper); + case aco_opcode::v_med3_i16: return int16_t(lower) <= int16_t(upper); + default: UNREACHABLE("invalid clamp"); + } + return false; +} + bool is_mul(Instruction* instr) { @@ -5003,14 +4860,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) combine_sabsdiff(ctx, instr); } else if (instr->opcode == aco_opcode::v_and_b32) { combine_v_andor_not(ctx, instr); - } else { - aco_opcode min, max, min3, max3, med3, minmax; - bool some_gfx9_only; - if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax, - &some_gfx9_only) && - (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) { - combine_clamp(ctx, instr, min, max, med3); - } } alu_opt_info info; @@ -5060,50 +4909,74 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) add_opt(v_max_f32, v_max3_f32, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_min_f32, v_minmax_f32, 0x3, "120", nullptr, true); + else + add_opt(v_min_f32, v_med3_f32, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_f32) { add_opt(v_min_f32, v_min3_f32, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_max_f32, v_maxmin_f32, 0x3, "120", nullptr, true); + else + add_opt(v_max_f32, v_med3_f32, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_u32) { add_opt(v_max_u32, v_max3_u32, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_min_u32, v_minmax_u32, 0x3, "120", nullptr, true); + else + add_opt(v_min_u32, v_med3_u32, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_u32) { add_opt(v_min_u32, v_min3_u32, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_max_u32, v_maxmin_u32, 0x3, "120", nullptr, true); + else + add_opt(v_max_u32, v_med3_u32, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_i32) { add_opt(v_max_i32, v_max3_i32, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_min_i32, v_minmax_i32, 0x3, "120", nullptr, true); + else + add_opt(v_min_i32, v_med3_i32, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_i32) { add_opt(v_min_i32, v_min3_i32, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_max_i32, v_maxmin_i32, 0x3, "120", nullptr, true); + else + add_opt(v_max_i32, v_med3_i32, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_f16 && ctx.program->gfx_level >= GFX9) { add_opt(v_max_f16, v_max3_f16, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_min_f16, v_minmax_f16, 0x3, "120", nullptr, true); + else + add_opt(v_min_f16, v_med3_f16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_f16 && ctx.program->gfx_level >= GFX9) { add_opt(v_min_f16, v_min3_f16, 0x3, "120", nullptr, true); if (ctx.program->gfx_level >= GFX11) add_opt(v_max_f16, v_maxmin_f16, 0x3, "120", nullptr, true); + else + add_opt(v_max_f16, v_med3_f16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_u16 && ctx.program->gfx_level >= GFX9) { add_opt(v_max_u16, v_max3_u16, 0x3, "120", nullptr, true); + add_opt(v_min_u16, v_med3_u16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_u16 && ctx.program->gfx_level >= GFX9) { add_opt(v_min_u16, v_min3_u16, 0x3, "120", nullptr, true); + add_opt(v_max_u16, v_med3_u16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_i16 && ctx.program->gfx_level >= GFX9) { add_opt(v_max_i16, v_max3_i16, 0x3, "120", nullptr, true); + add_opt(v_min_i16, v_med3_i16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_i16 && ctx.program->gfx_level >= GFX9) { add_opt(v_min_i16, v_min3_i16, 0x3, "120", nullptr, true); + add_opt(v_max_i16, v_med3_i16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_u16_e64) { add_opt(v_max_u16_e64, v_max3_u16, 0x3, "120", nullptr, true); + add_opt(v_min_u16_e64, v_med3_u16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_u16_e64) { add_opt(v_min_u16_e64, v_min3_u16, 0x3, "120", nullptr, true); + add_opt(v_max_u16_e64, v_med3_u16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_max_i16_e64) { add_opt(v_max_i16_e64, v_max3_i16, 0x3, "120", nullptr, true); + add_opt(v_min_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb, true); } else if (info.opcode == aco_opcode::v_min_i16_e64) { add_opt(v_min_i16_e64, v_min3_i16, 0x3, "120", nullptr, true); + add_opt(v_max_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb, true); } if (match_and_apply_patterns(ctx, info, patterns)) {