brw: Add BRW_TYPE_BF validation

Reviewed-by: Rohan Garg <rohan.garg@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33664>
This commit is contained in:
Caio Oliveira 2024-09-17 09:41:58 -07:00 committed by Marge Bot
parent 9916cc1050
commit 62323a934b
3 changed files with 317 additions and 16 deletions

View file

@ -425,8 +425,9 @@ execution_type_for_type(enum brw_reg_type type)
case BRW_TYPE_V:
case BRW_TYPE_UV:
return BRW_TYPE_W;
default:
unreachable("invalid type");
return BRW_TYPE_INVALID;
}
}
@ -564,6 +565,43 @@ is_mixed_float(const brw_hw_decoded_inst *inst)
types_are_mixed_float(src1_type, dst_type);
}
static bool
is_pure_bfloat(const brw_hw_decoded_inst *inst)
{
if (inst_is_send(inst))
return false;
if (inst->num_sources == 0 && !inst->has_dst)
return false;
for (int i = 0; i < inst->num_sources; i++) {
if (!brw_type_is_bfloat(inst->src[i].type))
return false;
}
if (inst->has_dst && !brw_type_is_bfloat(inst->dst.type))
return false;
return true;
}
static bool
is_mixed_bfloat(const brw_hw_decoded_inst *inst)
{
if (inst_is_send(inst))
return false;
const int operands = inst->num_sources + inst->has_dst;
int bfloat = 0;
for (int i = 0; i < inst->num_sources; i++)
bfloat += brw_type_is_bfloat(inst->src[i].type);
if (inst->has_dst)
bfloat += brw_type_is_bfloat(inst->dst.type);
return bfloat > 0 && bfloat != operands;
}
/**
* Returns whether an instruction is an explicit or implicit conversion
* to/from byte.
@ -624,6 +662,10 @@ general_restrictions_based_on_operand_types(const struct brw_isa_info *isa,
enum brw_reg_type dst_type = inst->dst.type;
ERROR_IF(brw_type_is_bfloat(dst_type) &&
!devinfo->has_bfloat16,
"Bfloat destination, but platform does not support it");
ERROR_IF(dst_type == BRW_TYPE_DF &&
!devinfo->has_64bit_float,
"64-bit float destination, but platform does not support it");
@ -636,6 +678,10 @@ general_restrictions_based_on_operand_types(const struct brw_isa_info *isa,
for (unsigned s = 0; s < inst->num_sources; s++) {
enum brw_reg_type src_type = inst->src[s].type;
ERROR_IF(brw_type_is_bfloat(src_type) &&
!devinfo->has_bfloat16,
"Bfloat source, but platform does not support it");
ERROR_IF(src_type == BRW_TYPE_DF &&
!devinfo->has_64bit_float,
"64-bit float source, but platform does not support it");
@ -842,7 +888,7 @@ general_restrictions_based_on_operand_types(const struct brw_isa_info *isa,
* override the general rule for the ratio of sizes of the destination type
* and the execution type. We will add validation for those in a later patch.
*/
bool validate_dst_size_and_exec_size_ratio = !is_mixed_float(inst);
bool validate_dst_size_and_exec_size_ratio = !is_mixed_float(inst) && !is_mixed_bfloat(inst);
if (validate_dst_size_and_exec_size_ratio &&
exec_type_size > dst_type_size) {
@ -1016,6 +1062,21 @@ general_restrictions_on_region_parameters(const struct brw_isa_info *isa,
return error_msg;
}
static bool
is_multiplier_instruction(const brw_hw_decoded_inst *inst)
{
/* TODO: Complete this list. */
switch (inst->opcode) {
case BRW_OPCODE_MUL:
case BRW_OPCODE_MAC:
case BRW_OPCODE_MACH:
case BRW_OPCODE_MAD:
return true;
default:
return false;
}
}
static struct string
special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa,
const brw_hw_decoded_inst *inst)
@ -1024,6 +1085,60 @@ special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa,
struct string error_msg = { .str = NULL, .len = 0 };
ERROR_IF(is_pure_bfloat(inst),
"Instructions with pure bfloat16 operands are not supported.");
if (is_mixed_bfloat(inst)) {
ERROR_IF(devinfo->ver < 20 && inst->exec_size > 8,
"Execution size must not be greater than 8 in Gfx12.");
ERROR_IF(devinfo->ver >= 20 && inst->exec_size > 16,
"Execution size must not be greater than 8 in Gfx20+.");
for (int i = 0; i < inst->num_sources; i++) {
ERROR_IF(brw_type_is_bfloat(inst->src[i].type) &&
src_has_scalar_region(inst, i),
"Broadcast of bfloat16 scalar is not supported.");
}
if (is_multiplier_instruction(inst)) {
if (inst->num_sources == 2) {
ERROR_IF(brw_type_is_bfloat(inst->src[1].type),
"Bfloat16 not allowed in Src1 of 2-source instructions involving multiplier.");
} else if (inst->num_sources == 3) {
ERROR_IF(brw_type_is_bfloat(inst->src[2].type),
"Bfloat16 not allowed in Src2 of 3-source instructions involving multiplier.");
}
}
const unsigned half_offset = REG_SIZE * reg_unit(devinfo) / 2;
if (inst->has_dst && brw_type_is_bfloat(inst->dst.type)) {
unsigned dst_stride = inst->dst.hstride;
bool dst_is_packed = is_packed(inst->exec_size * dst_stride, inst->exec_size, dst_stride);
if (dst_is_packed) {
ERROR_IF(inst->dst.subnr != 0 && inst->dst.subnr != half_offset,
"Packed bfloat16 destination must have register offset 0 or half of GRF register.");
} else {
/* Offset in the restriction text is in terms of elements. */
const unsigned elem_size = brw_type_size_bytes(inst->dst.type);
ERROR_IF(dst_stride != 2 || (inst->dst.subnr != 0 &&
inst->dst.subnr != 1 * elem_size),
"Unpacked bfloat16 destination must have stride 2 and register offset 0 or 1.");
}
}
for (int i = 0; i < inst->num_sources; i++) {
if (brw_type_is_bfloat(inst->src[i].type)) {
bool src_is_packed = is_packed(inst->src[i].vstride, inst->src[i].width, inst->src[i].hstride);
ERROR_IF(!src_is_packed,
"Bfloat16 source must be packed");
ERROR_IF(inst->src[i].subnr != 0 && inst->src[i].subnr != half_offset,
"Bfloat16 source must have register offset 0 or half of GRF register.");
}
}
}
const unsigned opcode = inst->opcode;
if (inst->num_sources >= 3)
return error_msg;
@ -1640,13 +1755,18 @@ special_requirements_for_handling_double_precision_data_types(
*
* "Vx1 and VxH indirect addressing for Float, Half-Float, Double-Float and
* Quad-Word data must not be used."
*
* and
*
* "Vx1 and VxH indirect addressing for BFloat16 (...) data
* must not be used."
*/
if (devinfo->verx10 >= 125 &&
(brw_type_is_float(type) || brw_type_size_bytes(type) == 8)) {
(brw_type_is_float_or_bfloat(type) || brw_type_size_bytes(type) == 8)) {
ERROR_IF(address_mode == BRW_ADDRESS_REGISTER_INDIRECT_REGISTER &&
vstride == BRW_VERTICAL_STRIDE_ONE_DIMENSIONAL,
"Vx1 and VxH indirect addressing for Float, Half-Float, "
"Double-Float and Quad-Word data must not be used");
"Double-Float, Quad-Word, and Bfloat16 data must not be used");
}
}
@ -2442,6 +2562,18 @@ VSTRIDE_3SRC(unsigned vstride)
unreachable("invalid vstride");
}
static inline unsigned
brw_implied_width_for_3src_a1(unsigned v, unsigned h)
{
/* "Regioning Rules for Align1 Ternary Operations" */
/* TODO: Add remaining rules and de-duplicate with brw_disasm.c */
if (v == 0) return 1;
if (h == 0) return v;
return v/h;
}
static struct string
brw_hw_decode_inst(const struct brw_isa_info *isa,
brw_hw_decoded_inst *inst,
@ -2629,6 +2761,7 @@ brw_hw_decode_inst(const struct brw_isa_info *isa,
inst->src[0].subnr = brw_eu_inst_3src_a1_src0_subreg_nr(devinfo, raw);
inst->src[0].vstride = VSTRIDE_3SRC(brw_eu_inst_3src_a1_src0_vstride(devinfo, raw));
inst->src[0].hstride = STRIDE(brw_eu_inst_3src_a1_src0_hstride(devinfo, raw));
inst->src[0].width = brw_implied_width_for_3src_a1(inst->src[0].vstride, inst->src[0].hstride);
}
inst->src[1].file = brw_eu_inst_3src_a1_src1_reg_file(devinfo, raw);
@ -2639,6 +2772,7 @@ brw_hw_decode_inst(const struct brw_isa_info *isa,
inst->src[1].subnr = brw_eu_inst_3src_a1_src1_subreg_nr(devinfo, raw);
inst->src[1].vstride = VSTRIDE_3SRC(brw_eu_inst_3src_a1_src1_vstride(devinfo, raw));
inst->src[1].hstride = STRIDE(brw_eu_inst_3src_a1_src1_hstride(devinfo, raw));
inst->src[1].width = brw_implied_width_for_3src_a1(inst->src[1].vstride, inst->src[1].hstride);
inst->src[2].file = brw_eu_inst_3src_a1_src2_reg_file(devinfo, raw);
inst->src[2].type = brw_eu_inst_3src_a1_src2_type(devinfo, raw);
@ -2648,6 +2782,7 @@ brw_hw_decode_inst(const struct brw_isa_info *isa,
inst->src[2].nr = brw_eu_inst_3src_src2_reg_nr(devinfo, raw);
inst->src[2].subnr = brw_eu_inst_3src_a1_src2_subreg_nr(devinfo, raw);
inst->src[2].hstride = STRIDE(brw_eu_inst_3src_a1_src2_hstride(devinfo, raw));
inst->src[2].width = brw_implied_width_for_3src_a1(inst->src[2].vstride, inst->src[2].hstride);
}
} else {

View file

@ -345,9 +345,9 @@ brw_validate(const brw_shader &s)
brw_type_is_int(inst->src[1].type) +
brw_type_is_int(inst->src[2].type);
const unsigned float_sources =
brw_type_is_float(inst->src[0].type) +
brw_type_is_float(inst->src[1].type) +
brw_type_is_float(inst->src[2].type);
brw_type_is_float_or_bfloat(inst->src[0].type) +
brw_type_is_float_or_bfloat(inst->src[1].type) +
brw_type_is_float_or_bfloat(inst->src[2].type);
fsv_assert((integer_sources == 3 && float_sources == 0) ||
(integer_sources == 0 && float_sources == 3));

View file

@ -105,15 +105,13 @@ INSTANTIATE_TEST_SUITE_P(
);
static bool
validate(struct brw_codegen *p)
validate(struct brw_codegen *p, char **error = nullptr)
{
const bool print = getenv("TEST_DEBUG");
struct disasm_info *disasm = disasm_initialize(p->isa, NULL);
if (print) {
disasm_new_inst_group(disasm, 0);
disasm_new_inst_group(disasm, p->next_insn_offset);
}
struct inst_group *group = disasm_new_inst_group(disasm, 0);
disasm_new_inst_group(disasm, p->next_insn_offset);
bool ret = brw_validate_instructions(p->isa, p->store, 0,
p->next_insn_offset, disasm);
@ -121,7 +119,9 @@ validate(struct brw_codegen *p)
if (print) {
dump_assembly(p->store, 0, p->next_insn_offset, disasm, NULL);
}
ralloc_free(disasm);
if (error)
*error = group->error;
return ret;
}
@ -470,13 +470,17 @@ TEST_P(validation_test, invalid_type_encoding_3src_a1)
}
struct brw_reg g = retype(g0, test_case[i].type);
if (!brw_type_is_int(test_case[i].type)) {
if (brw_type_is_bfloat(test_case[i].type)) {
/* BF is more restrictive, so ensure the instruction is valid. */
brw_MAD(p, retype(g, BRW_TYPE_F), g, g, retype(g, BRW_TYPE_F));
} else if (!brw_type_is_int(test_case[i].type)) {
brw_MAD(p, g, g, g, g);
} else {
brw_BFE(p, g, g, g, g);
}
EXPECT_TRUE(validate(p));
char *error = NULL;
EXPECT_TRUE(validate(p, &error)) << "Unexpected validation failure: " << error;
clear_instructions(p);
}
@ -2385,7 +2389,9 @@ TEST_P(validation_test, qword_low_power_no_indirect_addressing)
if (intel_device_info_is_9lp(&devinfo)) {
EXPECT_EQ(inst[i].expected_result, validate(p));
} else {
EXPECT_TRUE(validate(p));
char *error = nullptr;
EXPECT_TRUE(validate(p, &error))
<< "Test index = " << i << " failed to validate: " << error;
}
clear_instructions(p);
@ -3718,3 +3724,163 @@ TEST_P(validation_test, scalar_register_restrictions)
clear_instructions(p);
}
}
TEST_P(validation_test, bfloat_restrictions)
{
/* Restrictions from ACM PRM, vol. 9, section "Register Region
* Restrictions", sub-section 7.
*/
if (!devinfo.has_bfloat16)
return;
struct test {
const char *error_pattern;
enum opcode opcode;
unsigned exec_size;
brw_reg dst, src0, src1, src2;
};
const char *PASS = nullptr;
const struct test tests[] = {
{ PASS,
BRW_OPCODE_MOV, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0) },
{ "pure bfloat16 operands are not supported",
BRW_OPCODE_MOV, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0) },
{ "Execution size must not be greater than",
BRW_OPCODE_MOV, 16 * reg_unit(&devinfo),
brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0) },
{ PASS,
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ "pure bfloat16 operands are not supported",
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 30, 0, 1,1,0) },
{ "Broadcast of bfloat16 scalar is not supported",
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 30, 0, 0,1,0) },
{ PASS,
BRW_OPCODE_MUL, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ "Bfloat16 not allowed in Src1 of 2-source instructions involving multiplier",
BRW_OPCODE_MUL, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 30, 0, 1,1,0) },
{ PASS,
BRW_OPCODE_MAD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 30, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 40, 0, 1,1,0) },
{ "Bfloat16 not allowed in Src2 of 3-source instructions involving multiplier",
BRW_OPCODE_MAD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 40, 0, 1,1,0) },
{ PASS,
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 1, 2,1,2),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ "Unpacked bfloat16 destination must have stride 2 and register offset 0 or 1",
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 3, 2,1,2),
brw_grf(BRW_TYPE_F, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ PASS,
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 8 * reg_unit(&devinfo), 1,1,0),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ "Packed bfloat16 destination must have register offset 0 or half of GRF register",
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 1, 1,1,0),
brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ "Bfloat16 source must be packed",
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 20, 0, 2,1,2),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ PASS,
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 20, 8 * reg_unit(&devinfo), 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
{ "Bfloat16 source must have register offset 0 or half of GRF register",
BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0),
brw_grf(BRW_TYPE_BF, 20, 5, 1,1,0),
brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) },
};
for (unsigned i = 0; i < ARRAY_SIZE(tests); i++) {
const struct test &t = tests[i];
switch (tests[i].opcode) {
case BRW_OPCODE_MOV:
brw_MOV(p, t.dst, t.src0);
break;
case BRW_OPCODE_ADD:
brw_ADD(p, t.dst, t.src0, t.src1);
break;
case BRW_OPCODE_MUL:
brw_MUL(p, t.dst, t.src0, t.src1);
break;
case BRW_OPCODE_MAD:
brw_MAD(p, t.dst, t.src0, t.src1, t.src2);
break;
default:
unreachable("unexpected opcode in tests");
}
if (tests[i].opcode == BRW_OPCODE_MAD) {
brw_eu_inst_set_3src_exec_size(&devinfo, last_inst, cvt(t.exec_size) - 1);
} else {
brw_eu_inst_set_exec_size(&devinfo, last_inst, cvt(t.exec_size) - 1);
}
/* TODO: Expand this test logic to check validation error to other
* tests.
*/
char *error = nullptr;
bool valid = validate(p, &error);
if (t.error_pattern) {
EXPECT_FALSE(valid)
<< "Test vector index = " << i << " expected to "
<< "fail validation with error containing: '" << t.error_pattern << "' "
<< "but succeeded instead.";
if (error) {
EXPECT_TRUE(strstr(error, t.error_pattern))
<< "Test vector index = " << i << " expected to "
<< "fail validation with error containing: '" << t.error_pattern << "' "
<< "but error was: '" << error << "'.";
}
} else {
EXPECT_TRUE(valid)
<< "Test vector index = " << i << " expected to succeed "
<< "but failed validation with error: '" << error << "'.";
}
clear_instructions(p);
}
}