brw: Consider bfloat16 in lower regioning pass

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-03-19 10:06:26 -07:00 committed by Marge Bot
parent 5936768ce0
commit 2c31516b3e
3 changed files with 72 additions and 6 deletions

View file

@ -1082,7 +1082,7 @@ has_dst_aligned_region_restriction(const intel_device_info *devinfo,
(brw_type_size_bytes(exec_type) == 4 && is_dword_multiply))
return intel_device_info_is_9lp(devinfo) || devinfo->verx10 >= 125;
else if (brw_type_is_float(dst_type))
else if (brw_type_is_float_or_bfloat(dst_type))
return devinfo->verx10 >= 125;
else

View file

@ -345,6 +345,20 @@ is_unordered(const intel_device_info *devinfo, const brw_inst *inst)
inst->dst.type == BRW_TYPE_DF));
}
static inline bool
has_bfloat_operands(const brw_inst *inst)
{
if (brw_type_is_bfloat(inst->dst.type))
return true;
for (int i = 0; i < inst->sources; i++) {
if (brw_type_is_bfloat(inst->src[i].type))
return true;
}
return false;
}
bool has_dst_aligned_region_restriction(const intel_device_info *devinfo,
const brw_inst *inst,
brw_reg_type dst_type);

View file

@ -52,7 +52,10 @@ namespace {
required_src_byte_stride(const intel_device_info *devinfo, const brw_inst *inst,
unsigned i)
{
if (has_dst_aligned_region_restriction(devinfo, inst)) {
if (devinfo->has_bfloat16 && has_bfloat_operands(inst)) {
return brw_type_size_bytes(inst->src[i].type);
} else if (has_dst_aligned_region_restriction(devinfo, inst)) {
return MAX2(brw_type_size_bytes(inst->dst.type),
byte_stride(inst->dst));
@ -143,7 +146,7 @@ namespace {
* that requires it to have some particular alignment.
*/
unsigned
required_dst_byte_stride(const brw_inst *inst)
required_dst_byte_stride(const intel_device_info *devinfo, const brw_inst *inst)
{
if (inst->dst.is_accumulator()) {
/* If the destination is an accumulator, insist that we leave the
@ -159,6 +162,11 @@ namespace {
* and fix the sources of the multiply instead of the destination.
*/
return inst->dst.hstride * brw_type_size_bytes(inst->dst.type);
} else if (devinfo->has_bfloat16 && has_bfloat_operands(inst)) {
/* Prefer packed since it can be used as a source. */
return brw_type_size_bytes(inst->dst.type);
} else if (brw_type_size_bytes(inst->dst.type) < get_exec_type_size(inst) &&
!is_byte_raw_mov(inst)) {
return get_exec_type_size(inst);
@ -201,6 +209,8 @@ namespace {
unsigned
required_dst_byte_offset(const intel_device_info *devinfo, const brw_inst *inst)
{
assert(!brw_type_is_bfloat(inst->dst.type));
for (unsigned i = 0; i < inst->sources; i++) {
if (!is_uniform(inst->src[i]) && !inst->is_control_source(i))
if (reg_offset(inst->src[i]) % (reg_unit(devinfo) * REG_SIZE) !=
@ -314,6 +324,25 @@ namespace {
const unsigned dst_byte_offset = reg_offset(inst->dst) % (reg_unit(devinfo) * REG_SIZE);
const unsigned src_byte_offset = reg_offset(inst->src[i]) % (reg_unit(devinfo) * REG_SIZE);
if (devinfo->has_bfloat16 && has_bfloat_operands(inst)) {
if (brw_type_is_bfloat(inst->src[i].type)) {
const unsigned half_register = REG_SIZE * reg_unit(devinfo) / 2;
const unsigned offset = reg_offset(inst->src[i]);
/* Region restrictions described by PRM
*
* Bfloat16 source must be packed.
*
* Bfloat16 source must have register offset 0 or half of GRF register.
*/
return !(byte_stride(inst->src[i]) == 2 && (offset == 0 || offset == half_register));
} else {
assert(inst->src[i].type == BRW_TYPE_F);
/* Restrict Floats sources mixed with BFloats to also be aligned and packed. */
return !is_uniform(inst->src[i]) && src_byte_offset != 0 && byte_stride(inst->src[i]) != 4;
}
}
return (has_dst_aligned_region_restriction(devinfo, inst) &&
!is_uniform(inst->src[i]) &&
(byte_stride(inst->src[i]) != required_src_byte_stride(devinfo, inst, i) ||
@ -333,6 +362,29 @@ namespace {
{
if (is_send(inst)) {
return false;
} else if (devinfo->has_bfloat16 && has_bfloat_operands(inst)) {
const unsigned stride = byte_stride(inst->dst);
const unsigned offset = reg_offset(inst->dst);
const unsigned half_register = REG_SIZE * reg_unit(devinfo) / 2;
/* Region restrictions described by PRM
*
* Packed bfloat16 destination must have register offset of 0 or half of GRF register.
*
* Unpacked bfloat16 destination must have stride 2 and register offset 0 or 1.
*
* Note numbers above are in terms of elements (2 bytes).
*/
if (inst->dst.type == BRW_TYPE_BF) {
return !(stride == 2 && (offset == 0 || offset == half_register)) &&
!(stride == 4 && (offset == 0 || offset == 2));
} else {
assert(inst->dst.type == BRW_TYPE_F);
/* Restrict Floats sources mixed with BFloats to also be aligned and packed. */
return !(stride == 4 && offset == 0);
}
} else {
const brw_reg_type exec_type = get_exec_type(inst);
const unsigned dst_byte_offset = reg_offset(inst->dst) % (reg_unit(devinfo) * REG_SIZE);
@ -340,10 +392,10 @@ namespace {
brw_type_size_bytes(inst->dst.type) < brw_type_size_bytes(exec_type);
return (has_dst_aligned_region_restriction(devinfo, inst) &&
(required_dst_byte_stride(inst) != byte_stride(inst->dst) ||
(required_dst_byte_stride(devinfo, inst) != byte_stride(inst->dst) ||
required_dst_byte_offset(devinfo, inst) != dst_byte_offset)) ||
(is_narrowing_conversion &&
required_dst_byte_stride(inst) != byte_stride(inst->dst));
required_dst_byte_stride(devinfo, inst) != byte_stride(inst->dst));
}
}
@ -615,7 +667,7 @@ namespace {
brw_type_is_float(inst->dst.type));
const brw_builder ibld(inst);
const unsigned stride = required_dst_byte_stride(inst) /
const unsigned stride = required_dst_byte_stride(v->devinfo, inst) /
brw_type_size_bytes(inst->dst.type);
assert(stride > 0);
brw_reg tmp = ibld.vgrf(inst->dst.type, stride);