nir/coopmat: move the row/col into a box and add some helpers.

This makes adding workgroup scope easier, this just creates the
split_box and moves things into it and adds some helpers.

This also rewrites some loops from r/c into i which calc r/c

Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41500>
This commit is contained in:
David Airlie 2026-05-12 17:29:14 +10:00 committed by Marge Bot
parent eaf6207e06
commit 6062bcde56

View file

@ -10,9 +10,13 @@
/*
* Lower flexible size cooperative matrix operations down to operations at the supported granularity.
*/
struct split_mat {
struct split_box {
unsigned num_row_splits;
unsigned num_col_splits;
};
struct split_mat {
struct split_box box;
nir_variable **split_vars;
};
@ -23,6 +27,23 @@ struct split_info {
unsigned k_gran;
};
static unsigned get_num_splits_box(const struct split_box *box)
{
return box->num_row_splits * box->num_col_splits;
}
static unsigned get_num_splits(struct split_mat *split)
{
return get_num_splits_box(&split->box);
}
static unsigned split_alloc_vars(struct split_mat *split)
{
unsigned num_split = get_num_splits_box(&split->box);
split->split_vars = ralloc_array(split, struct nir_variable *, num_split);
return num_split;
}
static struct split_mat *find_split(struct hash_table *split_mats,
nir_intrinsic_instr *intr, int idx)
{
@ -108,6 +129,30 @@ get_lower_sizes(struct glsl_cmat_description desc, unsigned m_gran,
*split_cols_out = split_cols;
}
static bool
split_desc(struct glsl_cmat_description *desc, struct split_info *info,
struct split_box *box)
{
unsigned split_rows = 0, split_cols = 0;
box->num_col_splits = 1;
box->num_row_splits = 1;
get_lower_sizes(*desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols);
if (split_rows) {
box->num_row_splits = desc->rows / split_rows;
desc->rows = split_rows;
}
if (split_cols) {
box->num_col_splits = desc->cols / split_cols;
desc->cols = split_cols;
}
if (box->num_row_splits == 1 && box->num_col_splits == 1)
return false;
return true;
}
static bool
split_cmat_construct(nir_builder *b,
nir_intrinsic_instr *intr,
@ -118,7 +163,7 @@ split_cmat_construct(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -143,7 +188,7 @@ split_cmat_copy(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -210,7 +255,7 @@ split_cmat_insert(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -246,7 +291,7 @@ split_cmat_extract(nir_builder *b,
if (!src_split)
return false;
unsigned splits = src_split->num_col_splits * src_split->num_row_splits;
unsigned splits = get_num_splits(src_split);
if (splits <= 1)
return false;
@ -282,7 +327,7 @@ split_cmat_convert(nir_builder *b,
assert(dst_split && src_split);
unsigned splits = src_split->num_col_splits * src_split->num_row_splits;
unsigned splits = get_num_splits(src_split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
@ -310,15 +355,16 @@ split_cmat_transpose(nir_builder *b,
assert(dst_split && src_split);
for (unsigned r = 0; r < src_split->num_row_splits; r++) {
for (unsigned c = 0; c < src_split->num_col_splits; c++) {
int in_idx = r * src_split->num_col_splits + c;
int out_idx = c * dst_split->num_col_splits + r;
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[in_idx]);
b->cursor = nir_before_instr(instr);
nir_cmat_transpose(b, &dst_deref->def, &src_deref->def, .fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intr));
}
unsigned dst_splits = get_num_splits(dst_split);
for (unsigned i = 0; i < dst_splits; i++) {
int r = i / dst_split->box.num_row_splits;
int c = i % dst_split->box.num_row_splits;
int out_idx = c * dst_split->box.num_col_splits + r;
nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]);
nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]);
b->cursor = nir_before_instr(instr);
nir_cmat_transpose(b, &dst_deref->def, &src_deref->def, .fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intr));
}
nir_instr_remove(instr);
return true;
@ -336,7 +382,7 @@ split_cmat_bitcast(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -365,7 +411,7 @@ split_cmat_binary_op(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -397,7 +443,7 @@ split_cmat_unary_op(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -426,7 +472,7 @@ split_cmat_scalar_op(nir_builder *b,
if (!dst_split)
return false;
unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits;
unsigned splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
@ -465,22 +511,22 @@ split_cmat_muladd(nir_builder *b,
if (result_split) {
assert(c_split);
m_splits = result_split->num_row_splits;
n_splits = result_split->num_col_splits;
m_splits = result_split->box.num_row_splits;
n_splits = result_split->box.num_col_splits;
assert(c_split->num_row_splits == m_splits);
assert(c_split->num_col_splits == n_splits);
assert(c_split->box.num_row_splits == m_splits);
assert(c_split->box.num_col_splits == n_splits);
if (a_split)
assert(a_split->num_row_splits == m_splits);
assert(a_split->box.num_row_splits == m_splits);
if (b_split)
assert(b_split->num_col_splits == n_splits);
assert(b_split->box.num_col_splits == n_splits);
}
if (a_split && a_split->num_col_splits > 1) {
if (a_split && a_split->box.num_col_splits > 1) {
assert(b_split);
assert(b_split->num_row_splits == a_split->num_col_splits);
k_splits = a_split->num_col_splits;
assert(b_split->box.num_row_splits == a_split->box.num_col_splits);
k_splits = a_split->box.num_col_splits;
}
for (unsigned m = 0; m < m_splits; m++) {
@ -568,7 +614,7 @@ split_cmat_call_reduce(nir_builder *b,
/* for each source split - reduce it by itself. */
int src_splits = 1;
if (src_split)
src_splits = src_split->num_col_splits * src_split->num_row_splits;
src_splits = get_num_splits(src_split);
nir_deref_instr **temp_derefs = ralloc_array(NULL, nir_deref_instr *, src_splits);
const struct glsl_type *temp_type = nir_deref_instr_get_variable(nir_src_as_deref(call->params[1]))->type;
@ -595,14 +641,14 @@ split_cmat_call_reduce(nir_builder *b,
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
}
} else if (reduce & NIR_CMAT_REDUCE_ROW) {
for (unsigned i = 1; i < src_split->num_col_splits; i++) {
for (unsigned i = 1; i < src_split->box.num_col_splits; i++) {
nir_deref_instr *second_deref = temp_derefs[i];
b->cursor = nir_before_instr(instr);
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
}
} else if (reduce & NIR_CMAT_REDUCE_COLUMN) {
for (unsigned i = 1; i < src_split->num_row_splits; i++) {
nir_deref_instr *second_deref = temp_derefs[i * src_split->num_col_splits];
for (unsigned i = 1; i < src_split->box.num_row_splits; i++) {
nir_deref_instr *second_deref = temp_derefs[i * src_split->box.num_col_splits];
b->cursor = nir_before_instr(instr);
call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def);
}
@ -614,16 +660,16 @@ split_cmat_call_reduce(nir_builder *b,
/* at this point temp_derefs should contain all the split reduced src matrices
now to store them */
if (dst_split) {
for (unsigned r = 0; r < dst_split->num_row_splits; r++) {
for (unsigned c = 0; c < dst_split->num_col_splits; c++) {
int didx = r * dst_split->num_col_splits + c;
for (unsigned r = 0; r < dst_split->box.num_row_splits; r++) {
for (unsigned c = 0; c < dst_split->box.num_col_splits; c++) {
int didx = r * dst_split->box.num_col_splits + c;
int idx;
if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN))
idx = 0;
else if (reduce & NIR_CMAT_REDUCE_ROW)
idx = r % (src_split ? src_split->num_row_splits : 1);
idx = r % (src_split ? src_split->box.num_row_splits : 1);
else if (reduce & NIR_CMAT_REDUCE_COLUMN)
idx = c % (src_split ? src_split->num_col_splits : 1);
idx = c % (src_split ? src_split->box.num_col_splits : 1);
else
UNREACHABLE("Unknown NIR_CMAT_REDUCE_*");
@ -645,8 +691,8 @@ split_cmat_call_reduce(nir_builder *b,
int rows = 1, cols = 1;
if (dst_split) {
rows = dst_split->num_row_splits;
cols = dst_split->num_col_splits;
rows = dst_split->box.num_row_splits;
cols = dst_split->box.num_col_splits;
}
for (unsigned r = 0; r < rows; r++) {
@ -654,8 +700,8 @@ split_cmat_call_reduce(nir_builder *b,
int d_idx = c + r * cols;
int src_top_left_col = c * 2;
int src_top_left_row = r * 2;
int src_top_idx = src_top_left_col + src_top_left_row * src_split->num_col_splits;
int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->num_col_splits;
int src_top_idx = src_top_left_col + src_top_left_row * src_split->box.num_col_splits;
int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->box.num_col_splits;
nir_deref_instr *src0_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx]);
nir_deref_instr *src1_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx + 1]);
nir_deref_instr *src2_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx]);
@ -685,7 +731,7 @@ split_cmat_load_store(nir_builder *b,
return false;
struct split_mat *split = entry->data;
unsigned splits = split->num_row_splits * split->num_col_splits;
unsigned splits = get_num_splits(split);
for (unsigned i = 0; i < splits; i++) {
nir_deref_instr *new_deref = recreate_derefs(b, &intr->src[!is_load], split->split_vars[i]);
nir_deref_instr *ptr_deref;
@ -703,8 +749,8 @@ split_cmat_load_store(nir_builder *b,
struct glsl_cmat_description desc = *glsl_get_cmat_description(split->split_vars[i]->type);
unsigned row_offset, col_offset;
row_offset = (i % split->num_col_splits) * desc.cols;
col_offset = (i / split->num_col_splits) * desc.rows;
row_offset = (i % split->box.num_col_splits) * desc.cols;
col_offset = (i / split->box.num_col_splits) * desc.rows;
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR)
SWAP(row_offset, col_offset);
@ -742,33 +788,32 @@ split_cmat_call_per_element_op(nir_builder *b,
return false;
assert(src_split);
int splits = dst_split->num_col_splits * dst_split->num_row_splits;
int splits = get_num_splits(dst_split);
if (splits <= 1)
return false;
for (unsigned r = 0; r < dst_split->num_row_splits; r++) {
for (unsigned c = 0; c < dst_split->num_col_splits; c++) {
int idx = r * dst_split->num_col_splits + c;
nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[idx]);
nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[idx]);
struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type);
nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee);
new_call->params[0] = nir_src_for_ssa(&dst_deref->def);
new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r));
new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c));
new_call->params[3] = nir_src_for_ssa(&src_deref->def);
for (unsigned i = 0; i < splits; i++) {
int r = i / dst_split->box.num_col_splits;
int c = i % dst_split->box.num_col_splits;
nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[i]);
nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[i]);
struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type);
nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee);
new_call->params[0] = nir_src_for_ssa(&dst_deref->def);
new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r));
new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c));
new_call->params[3] = nir_src_for_ssa(&src_deref->def);
for (unsigned i = 4; i < call->num_params; i++) {
if (nir_src_as_deref(call->params[i])) {
struct split_mat *src1_split = find_call_split(info->split_mats, call, i);
nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[i], src1_split->split_vars[idx]) : nir_src_as_deref(call->params[i]);
new_call->params[i] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[i];
} else
new_call->params[i] = call->params[i];
}
b->cursor = nir_before_instr(instr);
nir_builder_instr_insert(b, &new_call->instr);
for (unsigned j = 4; j < call->num_params; j++) {
if (nir_src_as_deref(call->params[j])) {
struct split_mat *src1_split = find_call_split(info->split_mats, call, j);
nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[j], src1_split->split_vars[i]) : nir_src_as_deref(call->params[j]);
new_call->params[j] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[j];
} else
new_call->params[j] = call->params[j];
}
b->cursor = nir_before_instr(instr);
nir_builder_instr_insert(b, &new_call->instr);
}
nir_instr_remove(instr);
return true;
@ -868,34 +913,19 @@ split_var(nir_shader *shader,
}
struct glsl_cmat_description desc = *glsl_get_cmat_description(type);
unsigned split_rows = 0, split_cols = 0;
get_lower_sizes(desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols);
unsigned num_row_split = 1, num_col_split = 1;
if (split_rows) {
num_row_split = desc.rows / split_rows;
desc.rows = split_rows;
}
if (split_cols) {
num_col_split = desc.cols / split_cols;
desc.cols = split_cols;
}
if (num_row_split == 1 && num_col_split == 1)
return;
struct split_box box;
if (!split_desc(&desc, info, &box))
return;
const struct glsl_type *new_type = glsl_type_wrap_in_arrays(glsl_cmat_type(&desc), var->type);
struct split_mat *split_mat = ralloc(info->split_mats, struct split_mat);
struct split_mat *split_mat = rzalloc(info->split_mats, struct split_mat);
if (!split_mat)
return;
unsigned num_split = num_row_split * num_col_split;
split_mat->num_row_splits = num_row_split;
split_mat->num_col_splits = num_col_split;
split_mat->split_vars = ralloc_array(split_mat, struct nir_variable *, num_split);
split_mat->box = box;
unsigned num_split = split_alloc_vars(split_mat);
for (unsigned i = 0; i < num_split; i++) {
if (!nir_variable_is_global(var)) {
split_mat->split_vars[i] = nir_local_variable_create(impl,