mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-28 08:10:09 +01:00
intel/cmat: Enable packed formats for unary, length, and construct
With this, a minimum test case passes:
void main()
{
coopmat<float16_t, gl_ScopeSubgroup, M, N, gl_MatrixUseA> matA;
coopmat<float, gl_ScopeSubgroup, M, N, gl_MatrixUseA> matR;
matA = coopmat<float16_t, gl_ScopeSubgroup, M, N, gl_MatrixUseA>(2.0);
matR = coopmat<float, gl_ScopeSubgroup, M, N, gl_MatrixUseA>(matA);
coopMatStore(matR, result, 0, N, gl_CooperativeMatrixLayoutRowMajor);
}
v2: Use nir_vec instead of explicit nir_vec{2,4}. Also fixes a typo in
one of the 4x8 cases.
v3: Use nir_pack_bits and nir_unpack_bits to dramatically simplify
coop_unary handling. This saved 67 lines of code.
v4: Allow packing factor 2 and packing factor 1 elements be stored in
16-bit integers.
v5: Massive update to the comment in lower_cooperative_matrix_unary_op
with some suggestions from Caio. Add a comment and assertion around
`nir_def *v[4]`. Suggested by Caio.
Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25994>
This commit is contained in:
parent
75388a71c9
commit
0d314eb3cc
1 changed files with 87 additions and 5 deletions
|
|
@ -252,13 +252,83 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
struct lower_cmat_state *state)
|
||||
{
|
||||
nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
|
||||
nir_def *src = nir_load_deref(b, nir_src_as_deref(intrin->src[1]));
|
||||
nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
|
||||
nir_def *results[NIR_MAX_VEC_COMPONENTS];
|
||||
const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
|
||||
|
||||
for (unsigned i = 0; i < num_components; i++) {
|
||||
nir_def *val = nir_channel(b, src, i);
|
||||
results[i] = nir_build_alu1(b, nir_intrinsic_alu_op(intrin), val);
|
||||
const struct glsl_type *dst_mat_type =
|
||||
get_coop_type_for_slice(state, dst_slice);
|
||||
const struct glsl_type *src_mat_type =
|
||||
get_coop_type_for_slice(state, src_slice);
|
||||
|
||||
const struct glsl_cmat_description dst_desc =
|
||||
*glsl_get_cmat_description(dst_mat_type);
|
||||
|
||||
const struct glsl_cmat_description src_desc =
|
||||
*glsl_get_cmat_description(src_mat_type);
|
||||
|
||||
const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type);
|
||||
const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type);
|
||||
|
||||
/* The type of the returned slice may be different from the type of the
|
||||
* input slice.
|
||||
*/
|
||||
const unsigned dst_packing_factor =
|
||||
get_packing_factor(dst_desc, dst_slice->type);
|
||||
|
||||
const unsigned src_packing_factor =
|
||||
get_packing_factor(src_desc, src_slice->type);
|
||||
|
||||
const nir_op op = nir_intrinsic_alu_op(intrin);
|
||||
|
||||
/* There are three possible cases:
|
||||
*
|
||||
* 1. dst_packing_factor == src_packing_factor. This is the common case,
|
||||
* and handling it is straightforward.
|
||||
*
|
||||
* 2. dst_packing_factor > src_packing_factor. This occurs when converting a
|
||||
* float32_t matrix slice to a packed float16_t slice. Loop over the size
|
||||
* of the destination slice, but read multiple entries from the source
|
||||
* slice on each iteration.
|
||||
*
|
||||
* 3. dst_packing_factor < src_packing_factor. This occurs when converting a
|
||||
* packed int8_t matrix slice to an int32_t slice. Loop over the size of
|
||||
* the source slice, but write multiple entries to the destination slice
|
||||
* on each iteration.
|
||||
*
|
||||
* Handle all cases by iterating over the total (non-packed) number of
|
||||
* elements in the slice. When dst_packing_factor values have been
|
||||
* calculated, store them.
|
||||
*/
|
||||
assert((dst_packing_factor * glsl_get_vector_elements(dst_slice->type)) ==
|
||||
(src_packing_factor * glsl_get_vector_elements(src_slice->type)));
|
||||
|
||||
/* Stores at most dst_packing_factor partial results. */
|
||||
nir_def *v[4];
|
||||
assert(dst_packing_factor <= 4);
|
||||
|
||||
for (unsigned i = 0; i < num_components * dst_packing_factor; i++) {
|
||||
const unsigned dst_chan_index = i % dst_packing_factor;
|
||||
const unsigned src_chan_index = i % src_packing_factor;
|
||||
const unsigned dst_index = i / dst_packing_factor;
|
||||
const unsigned src_index = i / src_packing_factor;
|
||||
|
||||
nir_def *src =
|
||||
nir_channel(b,
|
||||
nir_unpack_bits(b,
|
||||
nir_channel(b,
|
||||
nir_load_deref(b, src_slice),
|
||||
src_index),
|
||||
src_bits),
|
||||
src_chan_index);
|
||||
|
||||
v[dst_chan_index] = nir_build_alu1(b, op, src);
|
||||
|
||||
if (dst_chan_index == (dst_packing_factor - 1)) {
|
||||
results[dst_index] =
|
||||
nir_pack_bits(b, nir_vec(b, v, dst_packing_factor),
|
||||
dst_packing_factor * dst_bits);
|
||||
}
|
||||
}
|
||||
|
||||
nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
|
||||
|
|
@ -362,6 +432,17 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
|
|||
case nir_intrinsic_cmat_construct: {
|
||||
nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
|
||||
nir_def *src = intrin->src[1].ssa;
|
||||
|
||||
const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
|
||||
const struct glsl_cmat_description desc =
|
||||
*glsl_get_cmat_description(mat_type);
|
||||
const unsigned packing_factor = get_packing_factor(desc, slice->type);
|
||||
|
||||
if (packing_factor > 1) {
|
||||
src = nir_pack_bits(b, nir_replicate(b, src, packing_factor),
|
||||
packing_factor * glsl_base_type_get_bit_size(desc.element_type));
|
||||
}
|
||||
|
||||
const unsigned num_components = glsl_get_vector_elements(slice->type);
|
||||
|
||||
nir_store_deref(b, slice, nir_replicate(b, src, num_components),
|
||||
|
|
@ -385,7 +466,8 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
|
|||
const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin);
|
||||
const struct glsl_type *mat_type = glsl_cmat_type(&desc);
|
||||
const struct glsl_type *slice_type = get_slice_type(state, mat_type);
|
||||
return nir_imm_intN_t(b, glsl_get_vector_elements(slice_type), 32);
|
||||
return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) *
|
||||
glsl_get_vector_elements(slice_type)), 32);
|
||||
}
|
||||
|
||||
case nir_intrinsic_cmat_muladd:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue