lavapipe: add NV_cooperative_matrix2 conversions support

This adds the conversions/transpose support.

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38964>
This commit is contained in:
Dave Airlie 2025-12-15 16:34:36 +10:00 committed by Marge Bot
parent 485728e2cf
commit 58f7fa3f6c
2 changed files with 53 additions and 1 deletions

View file

@ -873,6 +873,7 @@ lvp_get_features(const struct lvp_physical_device *pdevice,
.cooperativeMatrixRobustBufferAccess = has_cooperative_matrix(),
.cooperativeMatrixFlexibleDimensions = true,
.cooperativeMatrixConversions = true,
};
}

View file

@ -106,6 +106,35 @@ lower_cmat_copy(nir_builder *b, nir_intrinsic_instr *intr)
return true;
}
static nir_def *
convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use,
enum glsl_cmat_use dst_use)
{
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
nir_def *out_comps[NIR_MAX_VEC_COMPONENTS] = {};
unsigned num_comps = src->num_components;
for (unsigned i = 0; i < num_comps; i++) {
comps[i] = nir_channel(b, src, i);
out_comps[i] = nir_imm_zero(b, 1, comps[i]->bit_size);
}
nir_def *lane_id = nir_load_subgroup_invocation(b);
/* construct the outer row */
for (unsigned i = 0; i < num_comps; i++) {
for (unsigned j = 0; j < CMAT_LEN; j++) {
nir_def *else_val = out_comps[i];
nir_def *active_lane = nir_ieq(b, lane_id, nir_imm_int(b, j));
out_comps[i] = nir_read_invocation(b, comps[j], nir_imm_int(b, i));
out_comps[i] = nir_bcsel(b, active_lane, out_comps[i], else_val);
}
}
return nir_vec(b, out_comps, num_comps);
}
static nir_def *
convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type)
{
@ -122,14 +151,35 @@ static bool
lower_cmat_convert(nir_builder *b,
nir_intrinsic_instr *intr)
{
const bool transpose = intr->intrinsic == nir_intrinsic_cmat_transpose;
struct glsl_cmat_description dst_desc = cmat_src_desc(intr->src[0]);
struct glsl_cmat_description src_desc = cmat_src_desc(intr->src[1]);
enum glsl_base_type dst_element_type = dst_desc.element_type;
enum glsl_base_type src_element_type = src_desc.element_type;
enum glsl_cmat_use dst_use = dst_desc.use;
enum glsl_cmat_use src_use = src_desc.use;
nir_def *cmat = load_cmat_src(b, intr->src[1]);
nir_def *ret = convert_base_type(b, cmat, src_element_type, dst_element_type);
if (dst_use == GLSL_CMAT_USE_ACCUMULATOR)
dst_use = GLSL_CMAT_USE_A;
if (src_use == GLSL_CMAT_USE_ACCUMULATOR)
src_use = GLSL_CMAT_USE_A;
if (transpose) {
if (src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_B)
src_use = dst_use;
if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_A)
src_use = dst_use;
}
nir_def *ret = cmat;
if (dst_use != src_use) {
ret = convert_use(b, cmat, src_use, dst_use);
}
ret = convert_base_type(b, ret, src_element_type, dst_element_type);
store_cmat_src(b, intr->src[0], ret);
nir_instr_remove(&intr->instr);
return true;
@ -448,6 +498,7 @@ lower_impl(nir_function_impl *impl,
progress |= lower_cmat_copy(&b, intr);
break;
case nir_intrinsic_cmat_convert:
case nir_intrinsic_cmat_transpose:
progress |= lower_cmat_convert(&b, intr);
break;
case nir_intrinsic_cmat_bitcast: