diff --git a/src/amd/compiler/aco_interface.cpp b/src/amd/compiler/aco_interface.cpp index 9a61c8d2433..062402d732b 100644 --- a/src/amd/compiler/aco_interface.cpp +++ b/src/amd/compiler/aco_interface.cpp @@ -478,6 +478,8 @@ aco_nir_op_supports_packed_math_16bit(const nir_alu_instr* alu) case nir_op_ishl: case nir_op_ishr: case nir_op_ushr: return true; + case nir_op_u2u16: + case nir_op_i2i16: return alu->src[0].src.ssa->bit_size == 8; default: return false; } } diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp index e5ca68fb2b5..eb5baa76db0 100644 --- a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp +++ b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp @@ -2964,33 +2964,38 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } case nir_op_i2i8: case nir_op_i2i16: - case nir_op_i2i32: { - if (dst.type() == RegType::sgpr && instr->src[0].src.ssa->bit_size < 32) { - /* no need to do the extract in get_alu_src() */ - sgpr_extract_mode mode = instr->def.bit_size > instr->src[0].src.ssa->bit_size - ? sgpr_extract_sext - : sgpr_extract_undef; - extract_8_16_bit_sgpr_element(ctx, dst, &instr->src[0], mode); - } else { - const unsigned input_bitsize = instr->src[0].src.ssa->bit_size; - const unsigned output_bitsize = instr->def.bit_size; - convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]), input_bitsize, output_bitsize, - output_bitsize > input_bitsize, dst); - } - break; - } + case nir_op_i2i32: case nir_op_u2u8: case nir_op_u2u16: case nir_op_u2u32: { + const unsigned input_bitsize = instr->src[0].src.ssa->bit_size; + const unsigned output_bitsize = instr->def.bit_size; + bool sext = + instr->op == nir_op_i2i8 || instr->op == nir_op_i2i16 || instr->op == nir_op_i2i32; + bool trunc = output_bitsize <= input_bitsize; + + if (instr->def.num_components == 2) { + assert(output_bitsize == 16 && input_bitsize == 8); + assert((instr->src[0].swizzle[0] & ~0x3) == (instr->src[0].swizzle[1] & ~0x3)); + + Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa); + if (src.bytes() >= 4) + src = emit_extract_vector(ctx, src, instr->src[0].swizzle[0] & ~0x3, v1); + + emit_pk_int16_from_8bit(ctx, dst, src, instr->src[0].swizzle[0] & 0x3, + instr->src[0].swizzle[1] & 0x3, sext); + break; + } + if (dst.type() == RegType::sgpr && instr->src[0].src.ssa->bit_size < 32) { /* no need to do the extract in get_alu_src() */ - sgpr_extract_mode mode = instr->def.bit_size > instr->src[0].src.ssa->bit_size - ? sgpr_extract_zext - : sgpr_extract_undef; + sgpr_extract_mode mode = trunc ? sgpr_extract_undef + : sext ? sgpr_extract_sext + : sgpr_extract_zext; extract_8_16_bit_sgpr_element(ctx, dst, &instr->src[0], mode); } else { - convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]), instr->src[0].src.ssa->bit_size, - instr->def.bit_size, false, dst); + convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]), input_bitsize, output_bitsize, + sext && !trunc, dst); } break; }