mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-18 16:40:34 +01:00
lavapipe: add NV_cooperative_matrix2 reductions support
This adds support for the coopmat2 reductions Reviewed-by: Georg Lehmann <dadschoorse@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38964>
This commit is contained in:
parent
6d53931cf4
commit
2db1a624e3
4 changed files with 200 additions and 2 deletions
|
|
@ -874,6 +874,7 @@ lvp_get_features(const struct lvp_physical_device *pdevice,
|
|||
|
||||
.cooperativeMatrixFlexibleDimensions = true,
|
||||
.cooperativeMatrixConversions = true,
|
||||
.cooperativeMatrixReductions = true,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -356,7 +356,18 @@ lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_pipelin
|
|||
NIR_PASS(_, nir, nir_opt_dce);
|
||||
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
|
||||
}
|
||||
NIR_PASS(_, nir, lvp_nir_lower_cooperative_matrix);
|
||||
NIR_PASS(progress, nir, lvp_nir_lower_cooperative_matrix);
|
||||
if (progress) {
|
||||
NIR_PASS(_, nir, nir_opt_dce);
|
||||
NIR_PASS(progress, nir, nir_inline_functions);
|
||||
nir_remove_non_entrypoints(nir); /* remove the late inlined functions */
|
||||
if (progress) {
|
||||
NIR_PASS(_, nir, nir_opt_copy_prop_vars);
|
||||
NIR_PASS(_, nir, nir_opt_copy_prop);
|
||||
}
|
||||
NIR_PASS(_, nir, nir_opt_deref);
|
||||
NIR_PASS(_, nir, nir_opt_dce);
|
||||
}
|
||||
|
||||
const struct nir_lower_compute_system_values_options compute_system_values = {0};
|
||||
NIR_PASS(_, nir, nir_lower_compute_system_values, &compute_system_values);
|
||||
|
|
|
|||
|
|
@ -441,6 +441,172 @@ lower_cmat_bitcast(nir_builder *b, nir_intrinsic_instr *intr)
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_reduce_finish_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src0_deref = nir_src_as_deref(call->params[1]);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src0_deref->type);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
|
||||
nir_def *src0 = load_cmat_src(b, call->params[1]);
|
||||
nir_def *src1 = load_cmat_src(b, call->params[2]);
|
||||
|
||||
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
|
||||
|
||||
if (reduce & NIR_CMAT_REDUCE_COLUMN) {
|
||||
nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp");
|
||||
/* All of the rows contains the same data, so just reduce both first rows. */
|
||||
nir_def *row_accum0 = nir_channel(b, src0, 0);
|
||||
nir_def *row_accum1 = nir_channel(b, src1, 0);
|
||||
|
||||
nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp);
|
||||
|
||||
nir_call(b, fnptr, &col_tmp_deref->def, row_accum0, row_accum1);
|
||||
|
||||
nir_def *first_col = nir_load_deref(b, col_tmp_deref);
|
||||
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++)
|
||||
comps[i] = first_col;
|
||||
} else if (reduce & NIR_CMAT_REDUCE_ROW) {
|
||||
for (unsigned i = 0; i < CMAT_LEN; ++i) {
|
||||
nir_def *row0_accum = nir_channel(b, src0, i);
|
||||
nir_def *row1_accum = nir_channel(b, src1, i);
|
||||
|
||||
nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp");
|
||||
nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp);
|
||||
|
||||
nir_call(b, fnptr, &row_tmp_deref->def, row0_accum, row1_accum);
|
||||
|
||||
nir_def *row = nir_load_deref(b, row_tmp_deref);
|
||||
comps[i] = row;
|
||||
}
|
||||
}
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_reduce_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
|
||||
nir_def *src = load_cmat_src(b, call->params[1]);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
|
||||
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
|
||||
for (unsigned i = 0; i < CMAT_LEN; ++i) {
|
||||
comps[i] = nir_channel(b, src, i);
|
||||
}
|
||||
|
||||
if (reduce & NIR_CMAT_REDUCE_COLUMN) {
|
||||
nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp");
|
||||
|
||||
nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp);
|
||||
nir_store_deref(b, col_tmp_deref, comps[0], 1);
|
||||
|
||||
for (unsigned i = 1; i < CMAT_LEN; i++) {
|
||||
nir_def *col_accum_val = nir_load_deref(b, col_tmp_deref);
|
||||
nir_call(b, fnptr, &col_tmp_deref->def, col_accum_val, comps[i]);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++)
|
||||
comps[i] = nir_load_deref(b, col_tmp_deref);
|
||||
}
|
||||
|
||||
if (reduce & NIR_CMAT_REDUCE_ROW) {
|
||||
for (unsigned i = 0; i < CMAT_LEN; ++i) {
|
||||
nir_def *row_accum = comps[i];
|
||||
nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp");
|
||||
nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp);
|
||||
nir_store_deref(b, row_tmp_deref, row_accum, 1);
|
||||
|
||||
for (unsigned j = 1; j < CMAT_LEN; j *= 2) {
|
||||
nir_def *prev_row_accum_val = nir_load_deref(b, row_tmp_deref);
|
||||
nir_def *this_row = nir_shuffle(b, prev_row_accum_val, nir_iadd(b, lane_id, nir_imm_int(b, j)));
|
||||
|
||||
nir_call(b, fnptr, &row_tmp_deref->def, prev_row_accum_val, this_row);
|
||||
}
|
||||
row_tmp_deref = nir_build_deref_var(b, row_tmp);
|
||||
comps[i] = nir_load_deref(b, row_tmp_deref);
|
||||
}
|
||||
}
|
||||
|
||||
/* this should be lowered earlier */
|
||||
assert(!(reduce & NIR_CMAT_REDUCE_2X2));
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
nir_def *src_components[4][NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned m = 0; m < 4; m++) {
|
||||
nir_def *src = load_cmat_src(b, call->params[m + 1]);
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
src_components[m][i] = nir_channel(b, src, i);
|
||||
}
|
||||
}
|
||||
nir_variable *qd_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "qd_tmp");
|
||||
nir_deref_instr *qd_tmp_deref = nir_build_deref_var(b, qd_tmp);
|
||||
|
||||
for (unsigned m = 0; m < 4; m++) {
|
||||
for (unsigned i = 0; i < CMAT_LEN / 2; i++) {
|
||||
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]);
|
||||
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
|
||||
|
||||
nir_def *other_col = nir_shuffle_down(b, src_components[m][i], nir_imm_int(b, 1));
|
||||
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i], other_col);
|
||||
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
|
||||
}
|
||||
}
|
||||
|
||||
nir_def *even = nir_inverse_ballot_imm(b, 0x5555555555555555, 32);
|
||||
for (unsigned m = 0; m < 2; m++) {
|
||||
for (unsigned i = 0; i < CMAT_LEN / 2; i++) {
|
||||
nir_def *m0_comp = src_components[m * 2][i];
|
||||
nir_def *m1_comp = nir_shuffle_up(b, src_components[m * 2 + 1][i], nir_imm_int(b, 1));
|
||||
|
||||
nir_def *combined = nir_bcsel(b, even, m0_comp, m1_comp);
|
||||
comps[m * (CMAT_LEN / 2) + i] = combined;
|
||||
}
|
||||
}
|
||||
|
||||
nir_def *low_lane_id = nir_ilt_imm(b, lane_id, 4);
|
||||
nir_def *new_lane_id_lo = nir_imul_imm(b, lane_id, 2);
|
||||
nir_def *new_lane_id_hi = nir_iadd_imm(b, nir_imul_imm(b, nir_iadd_imm(b, lane_id, -4), 2), 1);
|
||||
nir_def *new_lane_id = nir_bcsel(b, low_lane_id, new_lane_id_lo, new_lane_id_hi);
|
||||
|
||||
for (unsigned m = 0; m < CMAT_LEN; m++) {
|
||||
comps[m] = nir_shuffle(b, comps[m], new_lane_id);
|
||||
}
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_impl(nir_function_impl *impl,
|
||||
struct hash_table *type_mapping)
|
||||
|
|
@ -520,6 +686,23 @@ lower_impl(nir_function_impl *impl,
|
|||
}
|
||||
break;
|
||||
}
|
||||
case nir_instr_type_cmat_call: {
|
||||
nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr);
|
||||
switch (call->op) {
|
||||
case nir_cmat_call_op_reduce:
|
||||
progress |= lower_cmat_reduce_call(&b, call);
|
||||
break;
|
||||
case nir_cmat_call_op_reduce_finish:
|
||||
progress |= lower_cmat_reduce_finish_call(&b, call);
|
||||
break;
|
||||
case nir_cmat_call_op_reduce_2x2:
|
||||
progress |= lower_cmat_reduce_2x2_call(&b, call);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -551,5 +734,8 @@ lvp_nir_lower_cooperative_matrix(nir_shader *shader)
|
|||
progress |= lower_impl(nir_shader_get_entrypoint(shader), type_mapping);
|
||||
|
||||
_mesa_hash_table_destroy(type_mapping, NULL);
|
||||
|
||||
nir_foreach_function_impl(fnim, shader)
|
||||
nir_progress(progress, fnim, 0);
|
||||
return progress;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ vk_spirv_to_nir(struct vk_device *device,
|
|||
NIR_PASS(_, nir, nir_opt_deref);
|
||||
|
||||
/* Pick off the single entrypoint that we want */
|
||||
nir_remove_non_entrypoints(nir);
|
||||
nir_remove_non_cmat_call_entrypoints(nir);
|
||||
|
||||
/* Now that we've deleted all but the main function, we can go ahead and
|
||||
* lower the rest of the constant initializers. We do this here so that
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue