nir_lower_mem_access_bit_sizes: add nir_mem_access_shift_method

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31904>
This commit is contained in:
Rhys Perry 2024-10-30 14:41:25 +00:00 committed by Marge Bot
parent e2dd36c66e
commit 61752152f7
13 changed files with 125 additions and 22 deletions

View file

@ -3252,6 +3252,7 @@ mem_access_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = MIN2(bytes / (bit_size / 8), 4),
.bit_size = bit_size,
.align = bit_size / 8,
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -138,6 +138,7 @@ v3d_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = 1,
.bit_size = 32,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
}
@ -177,6 +178,7 @@ v3d_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = num_components,
.bit_size = bit_size,
.align = (bit_size / 8) * (num_components == 3 ? 4 : num_components),
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -5882,10 +5882,27 @@ bool nir_lower_explicit_io(nir_shader *shader,
nir_variable_mode modes,
nir_address_format);
typedef enum {
/* Use open-coded funnel shifts for each component. */
nir_mem_access_shift_method_scalar,
/* Prefer to use 64-bit shifts to do the same with less instructions. Useful
* if 64-bit shifts are cheap.
*/
nir_mem_access_shift_method_shift64,
/* If nir_op_alignbyte_amd can be used, this is the best option with just a
* single nir_op_alignbyte_amd for each 32-bit components.
*/
nir_mem_access_shift_method_bytealign_amd,
} nir_mem_access_shift_method;
typedef struct {
uint8_t num_components;
uint8_t bit_size;
uint16_t align;
/* If a load's alignment is increased, this specifies how the data should be
* shifted before converting to the original bit size.
*/
nir_mem_access_shift_method shift;
} nir_mem_access_size_align;
/* clang-format off */

View file

@ -68,6 +68,86 @@ dup_mem_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin,
return dup;
}
static nir_def *
shift_load_data_alignbyte_amd(nir_builder *b, nir_def *load, nir_def *offset)
{
/* We don't need to mask the offset by 0x3 because only the low 2 bits matter. */
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
unsigned i = 0;
for (; i < load->num_components - 1; i++)
comps[i] = nir_alignbyte_amd(b, nir_channel(b, load, i + 1), nir_channel(b, load, i), offset);
/* Shift the last element. */
comps[i] = nir_alignbyte_amd(b, nir_channel(b, load, i), nir_channel(b, load, i), offset);
return nir_vec(b, comps, load->num_components);
}
static nir_def *
shift_load_data_shift64(nir_builder *b, nir_def *load, nir_def *offset, uint64_t align_mask)
{
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
nir_def *shift = nir_imul_imm(b, nir_iand_imm(b, offset, 0x3), 8);
for (unsigned i = 0; i < load->num_components - 1; i++) {
nir_def *qword = nir_pack_64_2x32_split(
b, nir_channel(b, load, i), nir_channel(b, load, i + 1));
qword = nir_ushr(b, qword, shift);
comps[i] = nir_unpack_64_2x32_split_x(b, qword);
if (i == load->num_components - 2)
comps[i + 1] = nir_unpack_64_2x32_split_y(b, qword);
}
return nir_vec(b, comps, load->num_components);
}
static nir_def *
shift_load_data_scalar(nir_builder *b, nir_def *load, nir_def *offset, uint64_t align_mask)
{
nir_def *pad = nir_iand_imm(b, offset, align_mask);
nir_def *shift = nir_imul_imm(b, pad, 8);
nir_def *shifted = nir_ushr(b, load, shift);
if (load->num_components > 1) {
nir_def *rev_shift =
nir_isub_imm(b, load->bit_size, shift);
nir_def *rev_shifted = nir_ishl(b, load, rev_shift);
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 1; i < load->num_components; i++)
comps[i - 1] = nir_channel(b, rev_shifted, i);
comps[load->num_components - 1] =
nir_imm_zero(b, 1, load->bit_size);
rev_shifted = nir_vec(b, comps, load->num_components);
shifted = nir_bcsel(b, nir_ieq_imm(b, shift, 0), load,
nir_ior(b, shifted, rev_shifted));
}
return shifted;
}
static nir_def *
shift_load_data(nir_builder *b, nir_def *load, nir_def *offset, uint64_t align_mask,
nir_mem_access_shift_method method)
{
bool use_alignbyte = method == nir_mem_access_shift_method_bytealign_amd &&
load->bit_size == 32 && align_mask == 0x3;
bool use_shift64 =
method == nir_mem_access_shift_method_shift64 && load->bit_size == 32 && align_mask == 0x3 &&
load->num_components >= 2;
offset = nir_u2u32(b, offset);
if (use_alignbyte)
return shift_load_data_alignbyte_amd(b, load, offset);
else if (use_shift64)
return shift_load_data_shift64(b, load, offset, align_mask);
else
return shift_load_data_scalar(b, load, offset, align_mask);
}
static bool
lower_mem_load(nir_builder *b, nir_intrinsic_instr *intrin,
nir_lower_mem_access_bit_sizes_cb mem_access_size_align_cb,
@ -126,11 +206,10 @@ lower_mem_load(nir_builder *b, nir_intrinsic_instr *intrin,
uint64_t align_mask = requested.align - 1;
nir_def *chunk_offset = nir_iadd_imm(b, offset, chunk_start);
nir_def *pad = nir_u2u32(b, nir_iand_imm(b, chunk_offset, align_mask));
chunk_offset = nir_iand_imm(b, chunk_offset, ~align_mask);
nir_def *aligned_offset = nir_iand_imm(b, chunk_offset, ~align_mask);
nir_intrinsic_instr *load =
dup_mem_intrinsic(b, intrin, chunk_offset,
dup_mem_intrinsic(b, intrin, aligned_offset,
requested.align, 0, NULL,
requested.num_components, requested.bit_size);
@ -139,25 +218,8 @@ lower_mem_load(nir_builder *b, nir_intrinsic_instr *intrin,
requested.num_components * requested.bit_size / 8;
chunk_bytes = MIN2(bytes_left, requested_bytes - max_pad);
nir_def *shift = nir_imul_imm(b, pad, 8);
nir_def *shifted = nir_ushr(b, &load->def, shift);
if (load->def.num_components > 1) {
nir_def *rev_shift =
nir_isub_imm(b, load->def.bit_size, shift);
nir_def *rev_shifted = nir_ishl(b, &load->def, rev_shift);
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 1; i < load->def.num_components; i++)
comps[i - 1] = nir_channel(b, rev_shifted, i);
comps[load->def.num_components - 1] =
nir_imm_zero(b, 1, load->def.bit_size);
rev_shifted = nir_vec(b, comps, load->def.num_components);
shifted = nir_bcsel(b, nir_ieq_imm(b, shift, 0), &load->def,
nir_ior(b, shifted, rev_shifted));
}
nir_def *shifted = shift_load_data(
b, &load->def, chunk_offset, align_mask, requested.shift);
unsigned chunk_bit_size = MIN2(8 << (ffs(chunk_bytes) - 1), bit_size);
unsigned chunk_num_components = chunk_bytes / (chunk_bit_size / 8);

View file

@ -991,6 +991,7 @@ ir3_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = MAX2(1, MIN2(bytes / (bit_size / 8), 4)),
.bit_size = bit_size,
.align = bit_size / 8,
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -5728,12 +5728,14 @@ mem_access_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = MIN2(bytes / align, 4),
.bit_size = align * 8,
.align = align,
.shift = nir_mem_access_shift_method_scalar,
};
} else {
return (nir_mem_access_size_align){
.num_components = MIN2(bytes / (bit_size / 8), 4),
.bit_size = bit_size,
.align = bit_size / 8,
.shift = nir_mem_access_shift_method_scalar,
};
}
}
@ -5753,6 +5755,7 @@ mem_access_scratch_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = MIN2(bytes / (bit_size / 8), 4),
.bit_size = bit_size,
.align = bit_size / 8,
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -1517,6 +1517,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.bit_size = 32,
.num_components = comps32,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
}
break;
@ -1527,6 +1528,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.bit_size = 32,
.num_components = 1,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
}
break;
@ -1562,6 +1564,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.bit_size = bytes * 8,
.num_components = 1,
.align = 1,
.shift = nir_mem_access_shift_method_scalar,
};
} else {
bytes = MIN2(bytes, 16);
@ -1570,6 +1573,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = is_scratch ? 1 :
is_load ? DIV_ROUND_UP(bytes, 4) : bytes / 4,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
}
}

View file

@ -1234,6 +1234,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.bit_size = 32,
.num_components = comps32,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
}
break;
@ -1269,6 +1270,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.bit_size = bytes * 8,
.num_components = 1,
.align = 1,
.shift = nir_mem_access_shift_method_scalar,
};
} else {
bytes = MIN2(bytes, 16);
@ -1277,6 +1279,7 @@ get_mem_access_size_align(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = is_scratch ? 1 :
is_load ? DIV_ROUND_UP(bytes, 4) : bytes / 4,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
}
}

View file

@ -6263,6 +6263,7 @@ lower_mem_access_bit_sizes_cb(nir_intrinsic_op intrin,
.align = closest_bit_size / 8,
.bit_size = closest_bit_size,
.num_components = DIV_ROUND_UP(MIN2(bytes, 16) * 8, closest_bit_size),
.shift = nir_mem_access_shift_method_scalar,
};
}
@ -6277,6 +6278,7 @@ lower_mem_access_bit_sizes_cb(nir_intrinsic_op intrin,
.align = min_bit_size / 8,
.bit_size = min_bit_size,
.num_components = MIN2(4, ideal_num_components),
.shift = nir_mem_access_shift_method_scalar,
};
}
@ -6296,6 +6298,7 @@ lower_mem_access_bit_sizes_cb(nir_intrinsic_op intrin,
.align = bit_size / 8,
.bit_size = bit_size,
.num_components = MIN2(4, num_components),
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -1437,6 +1437,7 @@ Converter::getMemAccessSizeAlign(nir_intrinsic_op intrin,
.num_components = (uint8_t) (bytes / (bit_size / 8)),
.bit_size = (uint8_t) bit_size,
.align = (uint16_t) bytes,
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -857,6 +857,7 @@ nak_mem_access_size_align(nir_intrinsic_op intrin,
.bit_size = 32,
.num_components = 1,
.align = 4,
.shift = nir_mem_access_shift_method_scalar,
};
} else {
assert(align == 1);
@ -864,6 +865,7 @@ nak_mem_access_size_align(nir_intrinsic_op intrin,
.bit_size = 8,
.num_components = 1,
.align = 1,
.shift = nir_mem_access_shift_method_scalar,
};
}
} else if (chunk_bytes < 4) {
@ -871,12 +873,14 @@ nak_mem_access_size_align(nir_intrinsic_op intrin,
.bit_size = chunk_bytes * 8,
.num_components = 1,
.align = chunk_bytes,
.shift = nir_mem_access_shift_method_scalar,
};
} else {
return (nir_mem_access_size_align) {
.bit_size = 32,
.num_components = chunk_bytes / 4,
.align = chunk_bytes,
.shift = nir_mem_access_shift_method_scalar,
};
}
}

View file

@ -4768,6 +4768,7 @@ mem_access_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = num_comps,
.bit_size = bit_size,
.align = bit_size / 8,
.shift = nir_mem_access_shift_method_scalar,
};
}

View file

@ -374,6 +374,7 @@ mem_access_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
.num_components = num_comps,
.bit_size = bit_size,
.align = bit_size / 8,
.shift = nir_mem_access_shift_method_scalar,
};
}