diff --git a/src/intel/compiler/brw_eu_validate.c b/src/intel/compiler/brw_eu_validate.c index 1aeac37dc4b..8f69d93b1f0 100644 --- a/src/intel/compiler/brw_eu_validate.c +++ b/src/intel/compiler/brw_eu_validate.c @@ -425,8 +425,9 @@ execution_type_for_type(enum brw_reg_type type) case BRW_TYPE_V: case BRW_TYPE_UV: return BRW_TYPE_W; + default: - unreachable("invalid type"); + return BRW_TYPE_INVALID; } } @@ -564,6 +565,43 @@ is_mixed_float(const brw_hw_decoded_inst *inst) types_are_mixed_float(src1_type, dst_type); } +static bool +is_pure_bfloat(const brw_hw_decoded_inst *inst) +{ + if (inst_is_send(inst)) + return false; + + if (inst->num_sources == 0 && !inst->has_dst) + return false; + + for (int i = 0; i < inst->num_sources; i++) { + if (!brw_type_is_bfloat(inst->src[i].type)) + return false; + } + + if (inst->has_dst && !brw_type_is_bfloat(inst->dst.type)) + return false; + + return true; +} + +static bool +is_mixed_bfloat(const brw_hw_decoded_inst *inst) +{ + if (inst_is_send(inst)) + return false; + + const int operands = inst->num_sources + inst->has_dst; + + int bfloat = 0; + for (int i = 0; i < inst->num_sources; i++) + bfloat += brw_type_is_bfloat(inst->src[i].type); + if (inst->has_dst) + bfloat += brw_type_is_bfloat(inst->dst.type); + + return bfloat > 0 && bfloat != operands; +} + /** * Returns whether an instruction is an explicit or implicit conversion * to/from byte. @@ -624,6 +662,10 @@ general_restrictions_based_on_operand_types(const struct brw_isa_info *isa, enum brw_reg_type dst_type = inst->dst.type; + ERROR_IF(brw_type_is_bfloat(dst_type) && + !devinfo->has_bfloat16, + "Bfloat destination, but platform does not support it"); + ERROR_IF(dst_type == BRW_TYPE_DF && !devinfo->has_64bit_float, "64-bit float destination, but platform does not support it"); @@ -636,6 +678,10 @@ general_restrictions_based_on_operand_types(const struct brw_isa_info *isa, for (unsigned s = 0; s < inst->num_sources; s++) { enum brw_reg_type src_type = inst->src[s].type; + ERROR_IF(brw_type_is_bfloat(src_type) && + !devinfo->has_bfloat16, + "Bfloat source, but platform does not support it"); + ERROR_IF(src_type == BRW_TYPE_DF && !devinfo->has_64bit_float, "64-bit float source, but platform does not support it"); @@ -842,7 +888,7 @@ general_restrictions_based_on_operand_types(const struct brw_isa_info *isa, * override the general rule for the ratio of sizes of the destination type * and the execution type. We will add validation for those in a later patch. */ - bool validate_dst_size_and_exec_size_ratio = !is_mixed_float(inst); + bool validate_dst_size_and_exec_size_ratio = !is_mixed_float(inst) && !is_mixed_bfloat(inst); if (validate_dst_size_and_exec_size_ratio && exec_type_size > dst_type_size) { @@ -1016,6 +1062,21 @@ general_restrictions_on_region_parameters(const struct brw_isa_info *isa, return error_msg; } +static bool +is_multiplier_instruction(const brw_hw_decoded_inst *inst) +{ + /* TODO: Complete this list. */ + switch (inst->opcode) { + case BRW_OPCODE_MUL: + case BRW_OPCODE_MAC: + case BRW_OPCODE_MACH: + case BRW_OPCODE_MAD: + return true; + default: + return false; + } +} + static struct string special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa, const brw_hw_decoded_inst *inst) @@ -1024,6 +1085,60 @@ special_restrictions_for_mixed_float_mode(const struct brw_isa_info *isa, struct string error_msg = { .str = NULL, .len = 0 }; + ERROR_IF(is_pure_bfloat(inst), + "Instructions with pure bfloat16 operands are not supported."); + + if (is_mixed_bfloat(inst)) { + ERROR_IF(devinfo->ver < 20 && inst->exec_size > 8, + "Execution size must not be greater than 8 in Gfx12."); + ERROR_IF(devinfo->ver >= 20 && inst->exec_size > 16, + "Execution size must not be greater than 8 in Gfx20+."); + + for (int i = 0; i < inst->num_sources; i++) { + ERROR_IF(brw_type_is_bfloat(inst->src[i].type) && + src_has_scalar_region(inst, i), + "Broadcast of bfloat16 scalar is not supported."); + } + + if (is_multiplier_instruction(inst)) { + if (inst->num_sources == 2) { + ERROR_IF(brw_type_is_bfloat(inst->src[1].type), + "Bfloat16 not allowed in Src1 of 2-source instructions involving multiplier."); + } else if (inst->num_sources == 3) { + ERROR_IF(brw_type_is_bfloat(inst->src[2].type), + "Bfloat16 not allowed in Src2 of 3-source instructions involving multiplier."); + } + } + + const unsigned half_offset = REG_SIZE * reg_unit(devinfo) / 2; + + if (inst->has_dst && brw_type_is_bfloat(inst->dst.type)) { + unsigned dst_stride = inst->dst.hstride; + bool dst_is_packed = is_packed(inst->exec_size * dst_stride, inst->exec_size, dst_stride); + + if (dst_is_packed) { + ERROR_IF(inst->dst.subnr != 0 && inst->dst.subnr != half_offset, + "Packed bfloat16 destination must have register offset 0 or half of GRF register."); + } else { + /* Offset in the restriction text is in terms of elements. */ + const unsigned elem_size = brw_type_size_bytes(inst->dst.type); + ERROR_IF(dst_stride != 2 || (inst->dst.subnr != 0 && + inst->dst.subnr != 1 * elem_size), + "Unpacked bfloat16 destination must have stride 2 and register offset 0 or 1."); + } + } + + for (int i = 0; i < inst->num_sources; i++) { + if (brw_type_is_bfloat(inst->src[i].type)) { + bool src_is_packed = is_packed(inst->src[i].vstride, inst->src[i].width, inst->src[i].hstride); + ERROR_IF(!src_is_packed, + "Bfloat16 source must be packed"); + ERROR_IF(inst->src[i].subnr != 0 && inst->src[i].subnr != half_offset, + "Bfloat16 source must have register offset 0 or half of GRF register."); + } + } + } + const unsigned opcode = inst->opcode; if (inst->num_sources >= 3) return error_msg; @@ -1640,13 +1755,18 @@ special_requirements_for_handling_double_precision_data_types( * * "Vx1 and VxH indirect addressing for Float, Half-Float, Double-Float and * Quad-Word data must not be used." + * + * and + * + * "Vx1 and VxH indirect addressing for BFloat16 (...) data + * must not be used." */ if (devinfo->verx10 >= 125 && - (brw_type_is_float(type) || brw_type_size_bytes(type) == 8)) { + (brw_type_is_float_or_bfloat(type) || brw_type_size_bytes(type) == 8)) { ERROR_IF(address_mode == BRW_ADDRESS_REGISTER_INDIRECT_REGISTER && vstride == BRW_VERTICAL_STRIDE_ONE_DIMENSIONAL, "Vx1 and VxH indirect addressing for Float, Half-Float, " - "Double-Float and Quad-Word data must not be used"); + "Double-Float, Quad-Word, and Bfloat16 data must not be used"); } } @@ -2442,6 +2562,18 @@ VSTRIDE_3SRC(unsigned vstride) unreachable("invalid vstride"); } +static inline unsigned +brw_implied_width_for_3src_a1(unsigned v, unsigned h) +{ + /* "Regioning Rules for Align1 Ternary Operations" */ + + /* TODO: Add remaining rules and de-duplicate with brw_disasm.c */ + + if (v == 0) return 1; + if (h == 0) return v; + return v/h; +} + static struct string brw_hw_decode_inst(const struct brw_isa_info *isa, brw_hw_decoded_inst *inst, @@ -2629,6 +2761,7 @@ brw_hw_decode_inst(const struct brw_isa_info *isa, inst->src[0].subnr = brw_eu_inst_3src_a1_src0_subreg_nr(devinfo, raw); inst->src[0].vstride = VSTRIDE_3SRC(brw_eu_inst_3src_a1_src0_vstride(devinfo, raw)); inst->src[0].hstride = STRIDE(brw_eu_inst_3src_a1_src0_hstride(devinfo, raw)); + inst->src[0].width = brw_implied_width_for_3src_a1(inst->src[0].vstride, inst->src[0].hstride); } inst->src[1].file = brw_eu_inst_3src_a1_src1_reg_file(devinfo, raw); @@ -2639,6 +2772,7 @@ brw_hw_decode_inst(const struct brw_isa_info *isa, inst->src[1].subnr = brw_eu_inst_3src_a1_src1_subreg_nr(devinfo, raw); inst->src[1].vstride = VSTRIDE_3SRC(brw_eu_inst_3src_a1_src1_vstride(devinfo, raw)); inst->src[1].hstride = STRIDE(brw_eu_inst_3src_a1_src1_hstride(devinfo, raw)); + inst->src[1].width = brw_implied_width_for_3src_a1(inst->src[1].vstride, inst->src[1].hstride); inst->src[2].file = brw_eu_inst_3src_a1_src2_reg_file(devinfo, raw); inst->src[2].type = brw_eu_inst_3src_a1_src2_type(devinfo, raw); @@ -2648,6 +2782,7 @@ brw_hw_decode_inst(const struct brw_isa_info *isa, inst->src[2].nr = brw_eu_inst_3src_src2_reg_nr(devinfo, raw); inst->src[2].subnr = brw_eu_inst_3src_a1_src2_subreg_nr(devinfo, raw); inst->src[2].hstride = STRIDE(brw_eu_inst_3src_a1_src2_hstride(devinfo, raw)); + inst->src[2].width = brw_implied_width_for_3src_a1(inst->src[2].vstride, inst->src[2].hstride); } } else { diff --git a/src/intel/compiler/brw_validate.cpp b/src/intel/compiler/brw_validate.cpp index fac775af657..214722f62de 100644 --- a/src/intel/compiler/brw_validate.cpp +++ b/src/intel/compiler/brw_validate.cpp @@ -345,9 +345,9 @@ brw_validate(const brw_shader &s) brw_type_is_int(inst->src[1].type) + brw_type_is_int(inst->src[2].type); const unsigned float_sources = - brw_type_is_float(inst->src[0].type) + - brw_type_is_float(inst->src[1].type) + - brw_type_is_float(inst->src[2].type); + brw_type_is_float_or_bfloat(inst->src[0].type) + + brw_type_is_float_or_bfloat(inst->src[1].type) + + brw_type_is_float_or_bfloat(inst->src[2].type); fsv_assert((integer_sources == 3 && float_sources == 0) || (integer_sources == 0 && float_sources == 3)); diff --git a/src/intel/compiler/test_eu_validate.cpp b/src/intel/compiler/test_eu_validate.cpp index f2ff9283e67..d81d8c617b5 100644 --- a/src/intel/compiler/test_eu_validate.cpp +++ b/src/intel/compiler/test_eu_validate.cpp @@ -105,15 +105,13 @@ INSTANTIATE_TEST_SUITE_P( ); static bool -validate(struct brw_codegen *p) +validate(struct brw_codegen *p, char **error = nullptr) { const bool print = getenv("TEST_DEBUG"); struct disasm_info *disasm = disasm_initialize(p->isa, NULL); - if (print) { - disasm_new_inst_group(disasm, 0); - disasm_new_inst_group(disasm, p->next_insn_offset); - } + struct inst_group *group = disasm_new_inst_group(disasm, 0); + disasm_new_inst_group(disasm, p->next_insn_offset); bool ret = brw_validate_instructions(p->isa, p->store, 0, p->next_insn_offset, disasm); @@ -121,7 +119,9 @@ validate(struct brw_codegen *p) if (print) { dump_assembly(p->store, 0, p->next_insn_offset, disasm, NULL); } - ralloc_free(disasm); + + if (error) + *error = group->error; return ret; } @@ -470,13 +470,17 @@ TEST_P(validation_test, invalid_type_encoding_3src_a1) } struct brw_reg g = retype(g0, test_case[i].type); - if (!brw_type_is_int(test_case[i].type)) { + if (brw_type_is_bfloat(test_case[i].type)) { + /* BF is more restrictive, so ensure the instruction is valid. */ + brw_MAD(p, retype(g, BRW_TYPE_F), g, g, retype(g, BRW_TYPE_F)); + } else if (!brw_type_is_int(test_case[i].type)) { brw_MAD(p, g, g, g, g); } else { brw_BFE(p, g, g, g, g); } - EXPECT_TRUE(validate(p)); + char *error = NULL; + EXPECT_TRUE(validate(p, &error)) << "Unexpected validation failure: " << error; clear_instructions(p); } @@ -2385,7 +2389,9 @@ TEST_P(validation_test, qword_low_power_no_indirect_addressing) if (intel_device_info_is_9lp(&devinfo)) { EXPECT_EQ(inst[i].expected_result, validate(p)); } else { - EXPECT_TRUE(validate(p)); + char *error = nullptr; + EXPECT_TRUE(validate(p, &error)) + << "Test index = " << i << " failed to validate: " << error; } clear_instructions(p); @@ -3718,3 +3724,163 @@ TEST_P(validation_test, scalar_register_restrictions) clear_instructions(p); } } + +TEST_P(validation_test, bfloat_restrictions) +{ + /* Restrictions from ACM PRM, vol. 9, section "Register Region + * Restrictions", sub-section 7. + */ + + if (!devinfo.has_bfloat16) + return; + + struct test { + const char *error_pattern; + enum opcode opcode; + unsigned exec_size; + brw_reg dst, src0, src1, src2; + }; + + const char *PASS = nullptr; + + const struct test tests[] = { + { PASS, + BRW_OPCODE_MOV, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0) }, + + { "pure bfloat16 operands are not supported", + BRW_OPCODE_MOV, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0) }, + + { "Execution size must not be greater than", + BRW_OPCODE_MOV, 16 * reg_unit(&devinfo), + brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0) }, + + { PASS, + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { "pure bfloat16 operands are not supported", + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 30, 0, 1,1,0) }, + + { "Broadcast of bfloat16 scalar is not supported", + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 30, 0, 0,1,0) }, + + { PASS, + BRW_OPCODE_MUL, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { "Bfloat16 not allowed in Src1 of 2-source instructions involving multiplier", + BRW_OPCODE_MUL, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 30, 0, 1,1,0) }, + + { PASS, + BRW_OPCODE_MAD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 30, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 40, 0, 1,1,0) }, + + { "Bfloat16 not allowed in Src2 of 3-source instructions involving multiplier", + BRW_OPCODE_MAD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 2,1,2), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 40, 0, 1,1,0) }, + + { PASS, + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 1, 2,1,2), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { "Unpacked bfloat16 destination must have stride 2 and register offset 0 or 1", + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 3, 2,1,2), + brw_grf(BRW_TYPE_F, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { PASS, + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 8 * reg_unit(&devinfo), 1,1,0), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { "Packed bfloat16 destination must have register offset 0 or half of GRF register", + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 1, 1,1,0), + brw_grf(BRW_TYPE_BF, 20, 0, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { "Bfloat16 source must be packed", + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 20, 0, 2,1,2), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { PASS, + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 20, 8 * reg_unit(&devinfo), 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + + { "Bfloat16 source must have register offset 0 or half of GRF register", + BRW_OPCODE_ADD, 8, brw_grf(BRW_TYPE_BF, 10, 0, 1,1,0), + brw_grf(BRW_TYPE_BF, 20, 5, 1,1,0), + brw_grf(BRW_TYPE_F, 30, 0, 1,1,0) }, + }; + + for (unsigned i = 0; i < ARRAY_SIZE(tests); i++) { + const struct test &t = tests[i]; + + switch (tests[i].opcode) { + case BRW_OPCODE_MOV: + brw_MOV(p, t.dst, t.src0); + break; + case BRW_OPCODE_ADD: + brw_ADD(p, t.dst, t.src0, t.src1); + break; + case BRW_OPCODE_MUL: + brw_MUL(p, t.dst, t.src0, t.src1); + break; + case BRW_OPCODE_MAD: + brw_MAD(p, t.dst, t.src0, t.src1, t.src2); + break; + default: + unreachable("unexpected opcode in tests"); + } + + if (tests[i].opcode == BRW_OPCODE_MAD) { + brw_eu_inst_set_3src_exec_size(&devinfo, last_inst, cvt(t.exec_size) - 1); + } else { + brw_eu_inst_set_exec_size(&devinfo, last_inst, cvt(t.exec_size) - 1); + } + + /* TODO: Expand this test logic to check validation error to other + * tests. + */ + + char *error = nullptr; + bool valid = validate(p, &error); + + if (t.error_pattern) { + EXPECT_FALSE(valid) + << "Test vector index = " << i << " expected to " + << "fail validation with error containing: '" << t.error_pattern << "' " + << "but succeeded instead."; + + if (error) { + EXPECT_TRUE(strstr(error, t.error_pattern)) + << "Test vector index = " << i << " expected to " + << "fail validation with error containing: '" << t.error_pattern << "' " + << "but error was: '" << error << "'."; + } + } else { + EXPECT_TRUE(valid) + << "Test vector index = " << i << " expected to succeed " + << "but failed validation with error: '" << error << "'."; + } + + clear_instructions(p); + } +}