intel/cmat: Enable packed formats for unary, length, and construct

With this, a minimum test case passes:

    void main()
    {
        coopmat<float16_t, gl_ScopeSubgroup, M, N, gl_MatrixUseA> matA;
        coopmat<float, gl_ScopeSubgroup, M, N, gl_MatrixUseA> matR;

        matA = coopmat<float16_t, gl_ScopeSubgroup, M, N, gl_MatrixUseA>(2.0);
        matR = coopmat<float, gl_ScopeSubgroup, M, N, gl_MatrixUseA>(matA);

        coopMatStore(matR, result, 0, N, gl_CooperativeMatrixLayoutRowMajor);
    }

v2: Use nir_vec instead of explicit nir_vec{2,4}. Also fixes a typo in
one of the 4x8 cases.

v3: Use nir_pack_bits and nir_unpack_bits to dramatically simplify
coop_unary handling. This saved 67 lines of code.

v4: Allow packing factor 2 and packing factor 1 elements be stored in
16-bit integers.

v5: Massive update to the comment in lower_cooperative_matrix_unary_op
with some suggestions from Caio. Add a comment and assertion around
`nir_def *v[4]`. Suggested by Caio.

Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25994>
This commit is contained in:
Ian Romanick 2023-07-13 11:05:16 -07:00
parent 75388a71c9
commit 0d314eb3cc

View file

@ -252,13 +252,83 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
struct lower_cmat_state *state)
{
nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
nir_def *src = nir_load_deref(b, nir_src_as_deref(intrin->src[1]));
nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
nir_def *results[NIR_MAX_VEC_COMPONENTS];
const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
for (unsigned i = 0; i < num_components; i++) {
nir_def *val = nir_channel(b, src, i);
results[i] = nir_build_alu1(b, nir_intrinsic_alu_op(intrin), val);
const struct glsl_type *dst_mat_type =
get_coop_type_for_slice(state, dst_slice);
const struct glsl_type *src_mat_type =
get_coop_type_for_slice(state, src_slice);
const struct glsl_cmat_description dst_desc =
*glsl_get_cmat_description(dst_mat_type);
const struct glsl_cmat_description src_desc =
*glsl_get_cmat_description(src_mat_type);
const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type);
const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type);
/* The type of the returned slice may be different from the type of the
* input slice.
*/
const unsigned dst_packing_factor =
get_packing_factor(dst_desc, dst_slice->type);
const unsigned src_packing_factor =
get_packing_factor(src_desc, src_slice->type);
const nir_op op = nir_intrinsic_alu_op(intrin);
/* There are three possible cases:
*
* 1. dst_packing_factor == src_packing_factor. This is the common case,
* and handling it is straightforward.
*
* 2. dst_packing_factor > src_packing_factor. This occurs when converting a
* float32_t matrix slice to a packed float16_t slice. Loop over the size
* of the destination slice, but read multiple entries from the source
* slice on each iteration.
*
* 3. dst_packing_factor < src_packing_factor. This occurs when converting a
* packed int8_t matrix slice to an int32_t slice. Loop over the size of
* the source slice, but write multiple entries to the destination slice
* on each iteration.
*
* Handle all cases by iterating over the total (non-packed) number of
* elements in the slice. When dst_packing_factor values have been
* calculated, store them.
*/
assert((dst_packing_factor * glsl_get_vector_elements(dst_slice->type)) ==
(src_packing_factor * glsl_get_vector_elements(src_slice->type)));
/* Stores at most dst_packing_factor partial results. */
nir_def *v[4];
assert(dst_packing_factor <= 4);
for (unsigned i = 0; i < num_components * dst_packing_factor; i++) {
const unsigned dst_chan_index = i % dst_packing_factor;
const unsigned src_chan_index = i % src_packing_factor;
const unsigned dst_index = i / dst_packing_factor;
const unsigned src_index = i / src_packing_factor;
nir_def *src =
nir_channel(b,
nir_unpack_bits(b,
nir_channel(b,
nir_load_deref(b, src_slice),
src_index),
src_bits),
src_chan_index);
v[dst_chan_index] = nir_build_alu1(b, op, src);
if (dst_chan_index == (dst_packing_factor - 1)) {
results[dst_index] =
nir_pack_bits(b, nir_vec(b, v, dst_packing_factor),
dst_packing_factor * dst_bits);
}
}
nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
@ -362,6 +432,17 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
case nir_intrinsic_cmat_construct: {
nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
nir_def *src = intrin->src[1].ssa;
const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
const struct glsl_cmat_description desc =
*glsl_get_cmat_description(mat_type);
const unsigned packing_factor = get_packing_factor(desc, slice->type);
if (packing_factor > 1) {
src = nir_pack_bits(b, nir_replicate(b, src, packing_factor),
packing_factor * glsl_base_type_get_bit_size(desc.element_type));
}
const unsigned num_components = glsl_get_vector_elements(slice->type);
nir_store_deref(b, slice, nir_replicate(b, src, num_components),
@ -385,7 +466,8 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin);
const struct glsl_type *mat_type = glsl_cmat_type(&desc);
const struct glsl_type *slice_type = get_slice_type(state, mat_type);
return nir_imm_intN_t(b, glsl_get_vector_elements(slice_type), 32);
return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) *
glsl_get_vector_elements(slice_type)), 32);
}
case nir_intrinsic_cmat_muladd: