diff --git a/src/intel/compiler/brw_eu_validate.c b/src/intel/compiler/brw_eu_validate.c index 8f69d93b1f0..7d07875e92b 100644 --- a/src/intel/compiler/brw_eu_validate.c +++ b/src/intel/compiler/brw_eu_validate.c @@ -1085,6 +1085,10 @@ special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa, struct string error_msg = { .str = NULL, .len = 0 }; + /* See instruction_restrictions() for DPAS operand type validation. */ + if (inst->opcode == BRW_OPCODE_DPAS) + return error_msg; + ERROR_IF(is_pure_bfloat(inst), "Instructions with pure bfloat16 operands are not supported."); @@ -2191,14 +2195,31 @@ instruction_restrictions(const struct brw_isa_info *isa, if (brw_eu_inst_dpas_3src_exec_type(devinfo, inst->raw) == BRW_ALIGN1_3SRC_EXEC_TYPE_FLOAT) { - ERROR_IF(dst_type != BRW_TYPE_F, - "DPAS destination type must be F."); - ERROR_IF(src0_type != BRW_TYPE_F, - "DPAS src0 type must be F."); - ERROR_IF(src1_type != BRW_TYPE_HF, - "DPAS src1 type must be HF."); - ERROR_IF(src2_type != BRW_TYPE_HF, - "DPAS src2 type must be HF."); + if (devinfo->ver < 20) { + ERROR_IF(src1_type != BRW_TYPE_HF, + "DPAS src1 type must be HF in Gfx12."); + ERROR_IF(src2_type != BRW_TYPE_HF, + "DPAS src2 type must be HF in Gfx12."); + ERROR_IF(dst_type != BRW_TYPE_F, + "DPAS destination type must be F in Gfx12."); + ERROR_IF(src0_type != BRW_TYPE_F, + "DPAS src0 type must be F in Gfx12."); + } else { + ERROR_IF(src1_type != BRW_TYPE_HF && + src1_type != BRW_TYPE_BF, + "DPAS src1 type must be HF or BF in Gfx20+."); + ERROR_IF(src2_type != BRW_TYPE_HF && + src2_type != BRW_TYPE_BF, + "DPAS src2 type must be HF or BF in Gfx20+."); + ERROR_IF(src1_type != src2_type, + "DPAS src1 and src2 with types must match when using float types."); + ERROR_IF(dst_type != BRW_TYPE_F && + dst_type != src1_type, + "DPAS destination type must be F or match Src1/Src2 in Gfx20+."); + ERROR_IF(src0_type != BRW_TYPE_F && + src0_type != src1_type, + "DPAS src0 type must be F or match Src1/Src2 in Gfx20+."); + } } else { ERROR_IF(dst_type != BRW_TYPE_D && dst_type != BRW_TYPE_UD, diff --git a/src/intel/compiler/test_eu_validate.cpp b/src/intel/compiler/test_eu_validate.cpp index d81d8c617b5..361d0434100 100644 --- a/src/intel/compiler/test_eu_validate.cpp +++ b/src/intel/compiler/test_eu_validate.cpp @@ -3183,22 +3183,32 @@ TEST_P(validation_test, dpas_types) if (devinfo.verx10 < 125) return; + if (devinfo.ver >= 20) + assert(devinfo.has_bfloat16); + #define TV(a, b, c, d, r) \ { BRW_TYPE_ ## a, BRW_TYPE_ ## b, BRW_TYPE_ ## c, BRW_TYPE_ ## d, r } - static const struct { + const struct { brw_reg_type dst_type; brw_reg_type src0_type; brw_reg_type src1_type; brw_reg_type src2_type; + bool expected_result; } test_vectors[] = { TV( F, F, HF, HF, true), - TV( F, HF, HF, HF, false), - TV(HF, F, HF, HF, false), + TV(HF, HF, HF, HF, devinfo.ver >= 20), + TV( F, HF, HF, HF, devinfo.ver >= 20), + TV(HF, F, HF, HF, devinfo.ver >= 20), TV( F, F, F, HF, false), TV( F, F, HF, F, false), + TV( F, F, BF, BF, devinfo.ver >= 20), + TV(BF, BF, BF, BF, devinfo.ver >= 20), + TV(BF, F, BF, BF, devinfo.ver >= 20), + TV( F, BF, BF, BF, devinfo.ver >= 20), + TV(DF, DF, DF, DF, false), TV(DF, DF, DF, F, false), TV(DF, DF, F, DF, false), @@ -3252,16 +3262,41 @@ TEST_P(validation_test, dpas_types) : BRW_EXECUTE_8); for (unsigned i = 0; i < ARRAY_SIZE(test_vectors); i++) { + const auto &t = test_vectors[i]; + + /* We encode the instruction and then decode to validate. But our + * encode reasonably asserts BF types when unsupported. So skip those. + * + * TODO: Promote up the brw_eu_decoded_inst so that validation test can + * use those types too instead of encoding/decoding. + */ + if (!devinfo.has_bfloat16 && + (t.dst_type == BRW_TYPE_BF || + t.src0_type == BRW_TYPE_BF || + t.src1_type == BRW_TYPE_BF || + t.src2_type == BRW_TYPE_BF)) + continue; + brw_DPAS(p, BRW_SYSTOLIC_DEPTH_8, 8, - retype(brw_vec8_grf(0, 0), test_vectors[i].dst_type), - retype(brw_vec8_grf(16, 0), test_vectors[i].src0_type), - retype(brw_vec8_grf(32, 0), test_vectors[i].src1_type), - retype(brw_vec8_grf(48, 0), test_vectors[i].src2_type)); + retype(brw_vec8_grf(0, 0), t.dst_type), + retype(brw_vec8_grf(16, 0), t.src0_type), + retype(brw_vec8_grf(32, 0), t.src1_type), + retype(brw_vec8_grf(48, 0), t.src2_type)); - EXPECT_EQ(test_vectors[i].expected_result, validate(p)) << - "test vector index = " << i; + char *error = nullptr; + bool valid = validate(p, &error); + + if (t.expected_result) { + EXPECT_TRUE(valid) + << "Test vector index = " << i << " expected to succeed " + << "but failed validation with error: '" << error << "'."; + } else { + EXPECT_FALSE(valid) + << "Test vector index = " << i << " expected to " + << "fail validation but succeeded."; + } clear_instructions(p); }