lavapipe: add NV_cooperative_matrix2 reductions support

This adds support for the coopmat2 reductions

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38964>
This commit is contained in:
Dave Airlie 2025-12-16 08:28:06 +10:00 committed by Marge Bot
parent 6d53931cf4
commit 2db1a624e3
4 changed files with 200 additions and 2 deletions

View file

@ -874,6 +874,7 @@ lvp_get_features(const struct lvp_physical_device *pdevice,
.cooperativeMatrixFlexibleDimensions = true,
.cooperativeMatrixConversions = true,
.cooperativeMatrixReductions = true,
};
}

View file

@ -356,7 +356,18 @@ lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_pipelin
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
}
NIR_PASS(_, nir, lvp_nir_lower_cooperative_matrix);
NIR_PASS(progress, nir, lvp_nir_lower_cooperative_matrix);
if (progress) {
NIR_PASS(_, nir, nir_opt_dce);
NIR_PASS(progress, nir, nir_inline_functions);
nir_remove_non_entrypoints(nir); /* remove the late inlined functions */
if (progress) {
NIR_PASS(_, nir, nir_opt_copy_prop_vars);
NIR_PASS(_, nir, nir_opt_copy_prop);
}
NIR_PASS(_, nir, nir_opt_deref);
NIR_PASS(_, nir, nir_opt_dce);
}
const struct nir_lower_compute_system_values_options compute_system_values = {0};
NIR_PASS(_, nir, nir_lower_compute_system_values, &compute_system_values);

View file

@ -441,6 +441,172 @@ lower_cmat_bitcast(nir_builder *b, nir_intrinsic_instr *intr)
return true;
}
static bool
lower_cmat_reduce_finish_call(nir_builder *b, nir_cmat_call_instr *call)
{
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
nir_deref_instr *src0_deref = nir_src_as_deref(call->params[1]);
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src0_deref->type);
nir_function *fnptr = call->callee;
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
nir_def *src0 = load_cmat_src(b, call->params[1]);
nir_def *src1 = load_cmat_src(b, call->params[2]);
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
if (reduce & NIR_CMAT_REDUCE_COLUMN) {
nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp");
/* All of the rows contains the same data, so just reduce both first rows. */
nir_def *row_accum0 = nir_channel(b, src0, 0);
nir_def *row_accum1 = nir_channel(b, src1, 0);
nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp);
nir_call(b, fnptr, &col_tmp_deref->def, row_accum0, row_accum1);
nir_def *first_col = nir_load_deref(b, col_tmp_deref);
for (unsigned i = 0; i < CMAT_LEN; i++)
comps[i] = first_col;
} else if (reduce & NIR_CMAT_REDUCE_ROW) {
for (unsigned i = 0; i < CMAT_LEN; ++i) {
nir_def *row0_accum = nir_channel(b, src0, i);
nir_def *row1_accum = nir_channel(b, src1, i);
nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp");
nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp);
nir_call(b, fnptr, &row_tmp_deref->def, row0_accum, row1_accum);
nir_def *row = nir_load_deref(b, row_tmp_deref);
comps[i] = row;
}
}
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(&call->instr);
return true;
}
static bool
lower_cmat_reduce_call(nir_builder *b, nir_cmat_call_instr *call)
{
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]);
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
nir_def *src = load_cmat_src(b, call->params[1]);
nir_function *fnptr = call->callee;
nir_def *lane_id = nir_load_subgroup_invocation(b);
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
for (unsigned i = 0; i < CMAT_LEN; ++i) {
comps[i] = nir_channel(b, src, i);
}
if (reduce & NIR_CMAT_REDUCE_COLUMN) {
nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp");
nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp);
nir_store_deref(b, col_tmp_deref, comps[0], 1);
for (unsigned i = 1; i < CMAT_LEN; i++) {
nir_def *col_accum_val = nir_load_deref(b, col_tmp_deref);
nir_call(b, fnptr, &col_tmp_deref->def, col_accum_val, comps[i]);
}
for (unsigned i = 0; i < CMAT_LEN; i++)
comps[i] = nir_load_deref(b, col_tmp_deref);
}
if (reduce & NIR_CMAT_REDUCE_ROW) {
for (unsigned i = 0; i < CMAT_LEN; ++i) {
nir_def *row_accum = comps[i];
nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp");
nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp);
nir_store_deref(b, row_tmp_deref, row_accum, 1);
for (unsigned j = 1; j < CMAT_LEN; j *= 2) {
nir_def *prev_row_accum_val = nir_load_deref(b, row_tmp_deref);
nir_def *this_row = nir_shuffle(b, prev_row_accum_val, nir_iadd(b, lane_id, nir_imm_int(b, j)));
nir_call(b, fnptr, &row_tmp_deref->def, prev_row_accum_val, this_row);
}
row_tmp_deref = nir_build_deref_var(b, row_tmp);
comps[i] = nir_load_deref(b, row_tmp_deref);
}
}
/* this should be lowered earlier */
assert(!(reduce & NIR_CMAT_REDUCE_2X2));
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(&call->instr);
return true;
}
static bool
lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call)
{
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]);
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
nir_function *fnptr = call->callee;
nir_def *lane_id = nir_load_subgroup_invocation(b);
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
nir_def *src_components[4][NIR_MAX_VEC_COMPONENTS];
for (unsigned m = 0; m < 4; m++) {
nir_def *src = load_cmat_src(b, call->params[m + 1]);
for (unsigned i = 0; i < CMAT_LEN; i++) {
src_components[m][i] = nir_channel(b, src, i);
}
}
nir_variable *qd_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "qd_tmp");
nir_deref_instr *qd_tmp_deref = nir_build_deref_var(b, qd_tmp);
for (unsigned m = 0; m < 4; m++) {
for (unsigned i = 0; i < CMAT_LEN / 2; i++) {
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]);
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
nir_def *other_col = nir_shuffle_down(b, src_components[m][i], nir_imm_int(b, 1));
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i], other_col);
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
}
}
nir_def *even = nir_inverse_ballot_imm(b, 0x5555555555555555, 32);
for (unsigned m = 0; m < 2; m++) {
for (unsigned i = 0; i < CMAT_LEN / 2; i++) {
nir_def *m0_comp = src_components[m * 2][i];
nir_def *m1_comp = nir_shuffle_up(b, src_components[m * 2 + 1][i], nir_imm_int(b, 1));
nir_def *combined = nir_bcsel(b, even, m0_comp, m1_comp);
comps[m * (CMAT_LEN / 2) + i] = combined;
}
}
nir_def *low_lane_id = nir_ilt_imm(b, lane_id, 4);
nir_def *new_lane_id_lo = nir_imul_imm(b, lane_id, 2);
nir_def *new_lane_id_hi = nir_iadd_imm(b, nir_imul_imm(b, nir_iadd_imm(b, lane_id, -4), 2), 1);
nir_def *new_lane_id = nir_bcsel(b, low_lane_id, new_lane_id_lo, new_lane_id_hi);
for (unsigned m = 0; m < CMAT_LEN; m++) {
comps[m] = nir_shuffle(b, comps[m], new_lane_id);
}
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(&call->instr);
return true;
}
static bool
lower_impl(nir_function_impl *impl,
struct hash_table *type_mapping)
@ -520,6 +686,23 @@ lower_impl(nir_function_impl *impl,
}
break;
}
case nir_instr_type_cmat_call: {
nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr);
switch (call->op) {
case nir_cmat_call_op_reduce:
progress |= lower_cmat_reduce_call(&b, call);
break;
case nir_cmat_call_op_reduce_finish:
progress |= lower_cmat_reduce_finish_call(&b, call);
break;
case nir_cmat_call_op_reduce_2x2:
progress |= lower_cmat_reduce_2x2_call(&b, call);
break;
default:
break;
}
break;
}
default:
break;
}
@ -551,5 +734,8 @@ lvp_nir_lower_cooperative_matrix(nir_shader *shader)
progress |= lower_impl(nir_shader_get_entrypoint(shader), type_mapping);
_mesa_hash_table_destroy(type_mapping, NULL);
nir_foreach_function_impl(fnim, shader)
nir_progress(progress, fnim, 0);
return progress;
}

View file

@ -168,7 +168,7 @@ vk_spirv_to_nir(struct vk_device *device,
NIR_PASS(_, nir, nir_opt_deref);
/* Pick off the single entrypoint that we want */
nir_remove_non_entrypoints(nir);
nir_remove_non_cmat_call_entrypoints(nir);
/* Now that we've deleted all but the main function, we can go ahead and
* lower the rest of the constant initializers. We do this here so that