brw: Expand EU validation for DPAS

Allow BFloat16 types when supported and allow destination/accumulator to
match the other source types in Gfx20+.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34035>
This commit is contained in:
Caio Oliveira 2025-03-12 13:08:17 -07:00 committed by Marge Bot
parent 2b2132d2ac
commit e384ccde28
2 changed files with 73 additions and 17 deletions

View file

@ -1085,6 +1085,10 @@ special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa,
struct string error_msg = { .str = NULL, .len = 0 }; struct string error_msg = { .str = NULL, .len = 0 };
/* See instruction_restrictions() for DPAS operand type validation. */
if (inst->opcode == BRW_OPCODE_DPAS)
return error_msg;
ERROR_IF(is_pure_bfloat(inst), ERROR_IF(is_pure_bfloat(inst),
"Instructions with pure bfloat16 operands are not supported."); "Instructions with pure bfloat16 operands are not supported.");
@ -2191,14 +2195,31 @@ instruction_restrictions(const struct brw_isa_info *isa,
if (brw_eu_inst_dpas_3src_exec_type(devinfo, inst->raw) == if (brw_eu_inst_dpas_3src_exec_type(devinfo, inst->raw) ==
BRW_ALIGN1_3SRC_EXEC_TYPE_FLOAT) { BRW_ALIGN1_3SRC_EXEC_TYPE_FLOAT) {
ERROR_IF(dst_type != BRW_TYPE_F, if (devinfo->ver < 20) {
"DPAS destination type must be F."); ERROR_IF(src1_type != BRW_TYPE_HF,
ERROR_IF(src0_type != BRW_TYPE_F, "DPAS src1 type must be HF in Gfx12.");
"DPAS src0 type must be F."); ERROR_IF(src2_type != BRW_TYPE_HF,
ERROR_IF(src1_type != BRW_TYPE_HF, "DPAS src2 type must be HF in Gfx12.");
"DPAS src1 type must be HF."); ERROR_IF(dst_type != BRW_TYPE_F,
ERROR_IF(src2_type != BRW_TYPE_HF, "DPAS destination type must be F in Gfx12.");
"DPAS src2 type must be HF."); ERROR_IF(src0_type != BRW_TYPE_F,
"DPAS src0 type must be F in Gfx12.");
} else {
ERROR_IF(src1_type != BRW_TYPE_HF &&
src1_type != BRW_TYPE_BF,
"DPAS src1 type must be HF or BF in Gfx20+.");
ERROR_IF(src2_type != BRW_TYPE_HF &&
src2_type != BRW_TYPE_BF,
"DPAS src2 type must be HF or BF in Gfx20+.");
ERROR_IF(src1_type != src2_type,
"DPAS src1 and src2 with types must match when using float types.");
ERROR_IF(dst_type != BRW_TYPE_F &&
dst_type != src1_type,
"DPAS destination type must be F or match Src1/Src2 in Gfx20+.");
ERROR_IF(src0_type != BRW_TYPE_F &&
src0_type != src1_type,
"DPAS src0 type must be F or match Src1/Src2 in Gfx20+.");
}
} else { } else {
ERROR_IF(dst_type != BRW_TYPE_D && ERROR_IF(dst_type != BRW_TYPE_D &&
dst_type != BRW_TYPE_UD, dst_type != BRW_TYPE_UD,

View file

@ -3183,22 +3183,32 @@ TEST_P(validation_test, dpas_types)
if (devinfo.verx10 < 125) if (devinfo.verx10 < 125)
return; return;
if (devinfo.ver >= 20)
assert(devinfo.has_bfloat16);
#define TV(a, b, c, d, r) \ #define TV(a, b, c, d, r) \
{ BRW_TYPE_ ## a, BRW_TYPE_ ## b, BRW_TYPE_ ## c, BRW_TYPE_ ## d, r } { BRW_TYPE_ ## a, BRW_TYPE_ ## b, BRW_TYPE_ ## c, BRW_TYPE_ ## d, r }
static const struct { const struct {
brw_reg_type dst_type; brw_reg_type dst_type;
brw_reg_type src0_type; brw_reg_type src0_type;
brw_reg_type src1_type; brw_reg_type src1_type;
brw_reg_type src2_type; brw_reg_type src2_type;
bool expected_result; bool expected_result;
} test_vectors[] = { } test_vectors[] = {
TV( F, F, HF, HF, true), TV( F, F, HF, HF, true),
TV( F, HF, HF, HF, false), TV(HF, HF, HF, HF, devinfo.ver >= 20),
TV(HF, F, HF, HF, false), TV( F, HF, HF, HF, devinfo.ver >= 20),
TV(HF, F, HF, HF, devinfo.ver >= 20),
TV( F, F, F, HF, false), TV( F, F, F, HF, false),
TV( F, F, HF, F, false), TV( F, F, HF, F, false),
TV( F, F, BF, BF, devinfo.ver >= 20),
TV(BF, BF, BF, BF, devinfo.ver >= 20),
TV(BF, F, BF, BF, devinfo.ver >= 20),
TV( F, BF, BF, BF, devinfo.ver >= 20),
TV(DF, DF, DF, DF, false), TV(DF, DF, DF, DF, false),
TV(DF, DF, DF, F, false), TV(DF, DF, DF, F, false),
TV(DF, DF, F, DF, false), TV(DF, DF, F, DF, false),
@ -3252,16 +3262,41 @@ TEST_P(validation_test, dpas_types)
: BRW_EXECUTE_8); : BRW_EXECUTE_8);
for (unsigned i = 0; i < ARRAY_SIZE(test_vectors); i++) { for (unsigned i = 0; i < ARRAY_SIZE(test_vectors); i++) {
const auto &t = test_vectors[i];
/* We encode the instruction and then decode to validate. But our
* encode reasonably asserts BF types when unsupported. So skip those.
*
* TODO: Promote up the brw_eu_decoded_inst so that validation test can
* use those types too instead of encoding/decoding.
*/
if (!devinfo.has_bfloat16 &&
(t.dst_type == BRW_TYPE_BF ||
t.src0_type == BRW_TYPE_BF ||
t.src1_type == BRW_TYPE_BF ||
t.src2_type == BRW_TYPE_BF))
continue;
brw_DPAS(p, brw_DPAS(p,
BRW_SYSTOLIC_DEPTH_8, BRW_SYSTOLIC_DEPTH_8,
8, 8,
retype(brw_vec8_grf(0, 0), test_vectors[i].dst_type), retype(brw_vec8_grf(0, 0), t.dst_type),
retype(brw_vec8_grf(16, 0), test_vectors[i].src0_type), retype(brw_vec8_grf(16, 0), t.src0_type),
retype(brw_vec8_grf(32, 0), test_vectors[i].src1_type), retype(brw_vec8_grf(32, 0), t.src1_type),
retype(brw_vec8_grf(48, 0), test_vectors[i].src2_type)); retype(brw_vec8_grf(48, 0), t.src2_type));
EXPECT_EQ(test_vectors[i].expected_result, validate(p)) << char *error = nullptr;
"test vector index = " << i; bool valid = validate(p, &error);
if (t.expected_result) {
EXPECT_TRUE(valid)
<< "Test vector index = " << i << " expected to succeed "
<< "but failed validation with error: '" << error << "'.";
} else {
EXPECT_FALSE(valid)
<< "Test vector index = " << i << " expected to "
<< "fail validation but succeeded.";
}
clear_instructions(p); clear_instructions(p);
} }