brw: Expand EU validation for DPAS

Allow BFloat16 types when supported and allow destination/accumulator to
match the other source types in Gfx20+.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34035>
This commit is contained in:
Caio Oliveira 2025-03-12 13:08:17 -07:00 committed by Marge Bot
parent 2b2132d2ac
commit e384ccde28
2 changed files with 73 additions and 17 deletions

View file

@ -1085,6 +1085,10 @@ special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa,
struct string error_msg = { .str = NULL, .len = 0 };
/* See instruction_restrictions() for DPAS operand type validation. */
if (inst->opcode == BRW_OPCODE_DPAS)
return error_msg;
ERROR_IF(is_pure_bfloat(inst),
"Instructions with pure bfloat16 operands are not supported.");
@ -2191,14 +2195,31 @@ instruction_restrictions(const struct brw_isa_info *isa,
if (brw_eu_inst_dpas_3src_exec_type(devinfo, inst->raw) ==
BRW_ALIGN1_3SRC_EXEC_TYPE_FLOAT) {
ERROR_IF(dst_type != BRW_TYPE_F,
"DPAS destination type must be F.");
ERROR_IF(src0_type != BRW_TYPE_F,
"DPAS src0 type must be F.");
ERROR_IF(src1_type != BRW_TYPE_HF,
"DPAS src1 type must be HF.");
ERROR_IF(src2_type != BRW_TYPE_HF,
"DPAS src2 type must be HF.");
if (devinfo->ver < 20) {
ERROR_IF(src1_type != BRW_TYPE_HF,
"DPAS src1 type must be HF in Gfx12.");
ERROR_IF(src2_type != BRW_TYPE_HF,
"DPAS src2 type must be HF in Gfx12.");
ERROR_IF(dst_type != BRW_TYPE_F,
"DPAS destination type must be F in Gfx12.");
ERROR_IF(src0_type != BRW_TYPE_F,
"DPAS src0 type must be F in Gfx12.");
} else {
ERROR_IF(src1_type != BRW_TYPE_HF &&
src1_type != BRW_TYPE_BF,
"DPAS src1 type must be HF or BF in Gfx20+.");
ERROR_IF(src2_type != BRW_TYPE_HF &&
src2_type != BRW_TYPE_BF,
"DPAS src2 type must be HF or BF in Gfx20+.");
ERROR_IF(src1_type != src2_type,
"DPAS src1 and src2 with types must match when using float types.");
ERROR_IF(dst_type != BRW_TYPE_F &&
dst_type != src1_type,
"DPAS destination type must be F or match Src1/Src2 in Gfx20+.");
ERROR_IF(src0_type != BRW_TYPE_F &&
src0_type != src1_type,
"DPAS src0 type must be F or match Src1/Src2 in Gfx20+.");
}
} else {
ERROR_IF(dst_type != BRW_TYPE_D &&
dst_type != BRW_TYPE_UD,

View file

@ -3183,22 +3183,32 @@ TEST_P(validation_test, dpas_types)
if (devinfo.verx10 < 125)
return;
if (devinfo.ver >= 20)
assert(devinfo.has_bfloat16);
#define TV(a, b, c, d, r) \
{ BRW_TYPE_ ## a, BRW_TYPE_ ## b, BRW_TYPE_ ## c, BRW_TYPE_ ## d, r }
static const struct {
const struct {
brw_reg_type dst_type;
brw_reg_type src0_type;
brw_reg_type src1_type;
brw_reg_type src2_type;
bool expected_result;
} test_vectors[] = {
TV( F, F, HF, HF, true),
TV( F, HF, HF, HF, false),
TV(HF, F, HF, HF, false),
TV(HF, HF, HF, HF, devinfo.ver >= 20),
TV( F, HF, HF, HF, devinfo.ver >= 20),
TV(HF, F, HF, HF, devinfo.ver >= 20),
TV( F, F, F, HF, false),
TV( F, F, HF, F, false),
TV( F, F, BF, BF, devinfo.ver >= 20),
TV(BF, BF, BF, BF, devinfo.ver >= 20),
TV(BF, F, BF, BF, devinfo.ver >= 20),
TV( F, BF, BF, BF, devinfo.ver >= 20),
TV(DF, DF, DF, DF, false),
TV(DF, DF, DF, F, false),
TV(DF, DF, F, DF, false),
@ -3252,16 +3262,41 @@ TEST_P(validation_test, dpas_types)
: BRW_EXECUTE_8);
for (unsigned i = 0; i < ARRAY_SIZE(test_vectors); i++) {
const auto &t = test_vectors[i];
/* We encode the instruction and then decode to validate. But our
* encode reasonably asserts BF types when unsupported. So skip those.
*
* TODO: Promote up the brw_eu_decoded_inst so that validation test can
* use those types too instead of encoding/decoding.
*/
if (!devinfo.has_bfloat16 &&
(t.dst_type == BRW_TYPE_BF ||
t.src0_type == BRW_TYPE_BF ||
t.src1_type == BRW_TYPE_BF ||
t.src2_type == BRW_TYPE_BF))
continue;
brw_DPAS(p,
BRW_SYSTOLIC_DEPTH_8,
8,
retype(brw_vec8_grf(0, 0), test_vectors[i].dst_type),
retype(brw_vec8_grf(16, 0), test_vectors[i].src0_type),
retype(brw_vec8_grf(32, 0), test_vectors[i].src1_type),
retype(brw_vec8_grf(48, 0), test_vectors[i].src2_type));
retype(brw_vec8_grf(0, 0), t.dst_type),
retype(brw_vec8_grf(16, 0), t.src0_type),
retype(brw_vec8_grf(32, 0), t.src1_type),
retype(brw_vec8_grf(48, 0), t.src2_type));
EXPECT_EQ(test_vectors[i].expected_result, validate(p)) <<
"test vector index = " << i;
char *error = nullptr;
bool valid = validate(p, &error);
if (t.expected_result) {
EXPECT_TRUE(valid)
<< "Test vector index = " << i << " expected to succeed "
<< "but failed validation with error: '" << error << "'.";
} else {
EXPECT_FALSE(valid)
<< "Test vector index = " << i << " expected to "
<< "fail validation but succeeded.";
}
clear_instructions(p);
}