diff --git a/src/compiler/nir/nir_lower_cooperative_matrix.c b/src/compiler/nir/nir_lower_cooperative_matrix.c index 20a56f92389..dc748deea37 100644 --- a/src/compiler/nir/nir_lower_cooperative_matrix.c +++ b/src/compiler/nir/nir_lower_cooperative_matrix.c @@ -10,9 +10,13 @@ /* * Lower flexible size cooperative matrix operations down to operations at the supported granularity. */ -struct split_mat { +struct split_box { unsigned num_row_splits; unsigned num_col_splits; +}; + +struct split_mat { + struct split_box box; nir_variable **split_vars; }; @@ -23,6 +27,23 @@ struct split_info { unsigned k_gran; }; +static unsigned get_num_splits_box(const struct split_box *box) +{ + return box->num_row_splits * box->num_col_splits; +} + +static unsigned get_num_splits(struct split_mat *split) +{ + return get_num_splits_box(&split->box); +} + +static unsigned split_alloc_vars(struct split_mat *split) +{ + unsigned num_split = get_num_splits_box(&split->box); + split->split_vars = ralloc_array(split, struct nir_variable *, num_split); + return num_split; +} + static struct split_mat *find_split(struct hash_table *split_mats, nir_intrinsic_instr *intr, int idx) { @@ -108,6 +129,30 @@ get_lower_sizes(struct glsl_cmat_description desc, unsigned m_gran, *split_cols_out = split_cols; } +static bool +split_desc(struct glsl_cmat_description *desc, struct split_info *info, + struct split_box *box) +{ + unsigned split_rows = 0, split_cols = 0; + box->num_col_splits = 1; + box->num_row_splits = 1; + + get_lower_sizes(*desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols); + + if (split_rows) { + box->num_row_splits = desc->rows / split_rows; + desc->rows = split_rows; + } + if (split_cols) { + box->num_col_splits = desc->cols / split_cols; + desc->cols = split_cols; + } + + if (box->num_row_splits == 1 && box->num_col_splits == 1) + return false; + return true; +} + static bool split_cmat_construct(nir_builder *b, nir_intrinsic_instr *intr, @@ -118,7 +163,7 @@ split_cmat_construct(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -143,7 +188,7 @@ split_cmat_copy(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -210,7 +255,7 @@ split_cmat_insert(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -246,7 +291,7 @@ split_cmat_extract(nir_builder *b, if (!src_split) return false; - unsigned splits = src_split->num_col_splits * src_split->num_row_splits; + unsigned splits = get_num_splits(src_split); if (splits <= 1) return false; @@ -282,7 +327,7 @@ split_cmat_convert(nir_builder *b, assert(dst_split && src_split); - unsigned splits = src_split->num_col_splits * src_split->num_row_splits; + unsigned splits = get_num_splits(src_split); for (unsigned i = 0; i < splits; i++) { nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[i]); nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]); @@ -310,15 +355,16 @@ split_cmat_transpose(nir_builder *b, assert(dst_split && src_split); - for (unsigned r = 0; r < src_split->num_row_splits; r++) { - for (unsigned c = 0; c < src_split->num_col_splits; c++) { - int in_idx = r * src_split->num_col_splits + c; - int out_idx = c * dst_split->num_col_splits + r; - nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]); - nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[in_idx]); - b->cursor = nir_before_instr(instr); - nir_cmat_transpose(b, &dst_deref->def, &src_deref->def, .fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intr)); - } + unsigned dst_splits = get_num_splits(dst_split); + + for (unsigned i = 0; i < dst_splits; i++) { + int r = i / dst_split->box.num_row_splits; + int c = i % dst_split->box.num_row_splits; + int out_idx = c * dst_split->box.num_col_splits + r; + nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]); + nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]); + b->cursor = nir_before_instr(instr); + nir_cmat_transpose(b, &dst_deref->def, &src_deref->def, .fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intr)); } nir_instr_remove(instr); return true; @@ -336,7 +382,7 @@ split_cmat_bitcast(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -365,7 +411,7 @@ split_cmat_binary_op(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -397,7 +443,7 @@ split_cmat_unary_op(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -426,7 +472,7 @@ split_cmat_scalar_op(nir_builder *b, if (!dst_split) return false; - unsigned splits = dst_split->num_col_splits * dst_split->num_row_splits; + unsigned splits = get_num_splits(dst_split); if (splits <= 1) return false; @@ -465,22 +511,22 @@ split_cmat_muladd(nir_builder *b, if (result_split) { assert(c_split); - m_splits = result_split->num_row_splits; - n_splits = result_split->num_col_splits; + m_splits = result_split->box.num_row_splits; + n_splits = result_split->box.num_col_splits; - assert(c_split->num_row_splits == m_splits); - assert(c_split->num_col_splits == n_splits); + assert(c_split->box.num_row_splits == m_splits); + assert(c_split->box.num_col_splits == n_splits); if (a_split) - assert(a_split->num_row_splits == m_splits); + assert(a_split->box.num_row_splits == m_splits); if (b_split) - assert(b_split->num_col_splits == n_splits); + assert(b_split->box.num_col_splits == n_splits); } - if (a_split && a_split->num_col_splits > 1) { + if (a_split && a_split->box.num_col_splits > 1) { assert(b_split); - assert(b_split->num_row_splits == a_split->num_col_splits); - k_splits = a_split->num_col_splits; + assert(b_split->box.num_row_splits == a_split->box.num_col_splits); + k_splits = a_split->box.num_col_splits; } for (unsigned m = 0; m < m_splits; m++) { @@ -568,7 +614,7 @@ split_cmat_call_reduce(nir_builder *b, /* for each source split - reduce it by itself. */ int src_splits = 1; if (src_split) - src_splits = src_split->num_col_splits * src_split->num_row_splits; + src_splits = get_num_splits(src_split); nir_deref_instr **temp_derefs = ralloc_array(NULL, nir_deref_instr *, src_splits); const struct glsl_type *temp_type = nir_deref_instr_get_variable(nir_src_as_deref(call->params[1]))->type; @@ -595,14 +641,14 @@ split_cmat_call_reduce(nir_builder *b, call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def); } } else if (reduce & NIR_CMAT_REDUCE_ROW) { - for (unsigned i = 1; i < src_split->num_col_splits; i++) { + for (unsigned i = 1; i < src_split->box.num_col_splits; i++) { nir_deref_instr *second_deref = temp_derefs[i]; b->cursor = nir_before_instr(instr); call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def); } } else if (reduce & NIR_CMAT_REDUCE_COLUMN) { - for (unsigned i = 1; i < src_split->num_row_splits; i++) { - nir_deref_instr *second_deref = temp_derefs[i * src_split->num_col_splits]; + for (unsigned i = 1; i < src_split->box.num_row_splits; i++) { + nir_deref_instr *second_deref = temp_derefs[i * src_split->box.num_col_splits]; b->cursor = nir_before_instr(instr); call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def); } @@ -614,16 +660,16 @@ split_cmat_call_reduce(nir_builder *b, /* at this point temp_derefs should contain all the split reduced src matrices now to store them */ if (dst_split) { - for (unsigned r = 0; r < dst_split->num_row_splits; r++) { - for (unsigned c = 0; c < dst_split->num_col_splits; c++) { - int didx = r * dst_split->num_col_splits + c; + for (unsigned r = 0; r < dst_split->box.num_row_splits; r++) { + for (unsigned c = 0; c < dst_split->box.num_col_splits; c++) { + int didx = r * dst_split->box.num_col_splits + c; int idx; if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) idx = 0; else if (reduce & NIR_CMAT_REDUCE_ROW) - idx = r % (src_split ? src_split->num_row_splits : 1); + idx = r % (src_split ? src_split->box.num_row_splits : 1); else if (reduce & NIR_CMAT_REDUCE_COLUMN) - idx = c % (src_split ? src_split->num_col_splits : 1); + idx = c % (src_split ? src_split->box.num_col_splits : 1); else UNREACHABLE("Unknown NIR_CMAT_REDUCE_*"); @@ -645,8 +691,8 @@ split_cmat_call_reduce(nir_builder *b, int rows = 1, cols = 1; if (dst_split) { - rows = dst_split->num_row_splits; - cols = dst_split->num_col_splits; + rows = dst_split->box.num_row_splits; + cols = dst_split->box.num_col_splits; } for (unsigned r = 0; r < rows; r++) { @@ -654,8 +700,8 @@ split_cmat_call_reduce(nir_builder *b, int d_idx = c + r * cols; int src_top_left_col = c * 2; int src_top_left_row = r * 2; - int src_top_idx = src_top_left_col + src_top_left_row * src_split->num_col_splits; - int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->num_col_splits; + int src_top_idx = src_top_left_col + src_top_left_row * src_split->box.num_col_splits; + int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->box.num_col_splits; nir_deref_instr *src0_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx]); nir_deref_instr *src1_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx + 1]); nir_deref_instr *src2_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx]); @@ -685,7 +731,7 @@ split_cmat_load_store(nir_builder *b, return false; struct split_mat *split = entry->data; - unsigned splits = split->num_row_splits * split->num_col_splits; + unsigned splits = get_num_splits(split); for (unsigned i = 0; i < splits; i++) { nir_deref_instr *new_deref = recreate_derefs(b, &intr->src[!is_load], split->split_vars[i]); nir_deref_instr *ptr_deref; @@ -703,8 +749,8 @@ split_cmat_load_store(nir_builder *b, struct glsl_cmat_description desc = *glsl_get_cmat_description(split->split_vars[i]->type); unsigned row_offset, col_offset; - row_offset = (i % split->num_col_splits) * desc.cols; - col_offset = (i / split->num_col_splits) * desc.rows; + row_offset = (i % split->box.num_col_splits) * desc.cols; + col_offset = (i / split->box.num_col_splits) * desc.rows; if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) SWAP(row_offset, col_offset); @@ -742,33 +788,32 @@ split_cmat_call_per_element_op(nir_builder *b, return false; assert(src_split); - int splits = dst_split->num_col_splits * dst_split->num_row_splits; + int splits = get_num_splits(dst_split); if (splits <= 1) return false; - for (unsigned r = 0; r < dst_split->num_row_splits; r++) { - for (unsigned c = 0; c < dst_split->num_col_splits; c++) { - int idx = r * dst_split->num_col_splits + c; - nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[idx]); - nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[idx]); - struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type); - nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee); - new_call->params[0] = nir_src_for_ssa(&dst_deref->def); - new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r)); - new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c)); - new_call->params[3] = nir_src_for_ssa(&src_deref->def); + for (unsigned i = 0; i < splits; i++) { + int r = i / dst_split->box.num_col_splits; + int c = i % dst_split->box.num_col_splits; + nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[i]); + nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[i]); + struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type); + nir_cmat_call_instr *new_call = nir_cmat_call_instr_create(b->shader, nir_cmat_call_op_per_element_op, call->callee); + new_call->params[0] = nir_src_for_ssa(&dst_deref->def); + new_call->params[1] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.rows * r)); + new_call->params[2] = nir_src_for_ssa(nir_imm_int(b, cmat_desc.cols * c)); + new_call->params[3] = nir_src_for_ssa(&src_deref->def); - for (unsigned i = 4; i < call->num_params; i++) { - if (nir_src_as_deref(call->params[i])) { - struct split_mat *src1_split = find_call_split(info->split_mats, call, i); - nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[i], src1_split->split_vars[idx]) : nir_src_as_deref(call->params[i]); - new_call->params[i] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[i]; - } else - new_call->params[i] = call->params[i]; - } - b->cursor = nir_before_instr(instr); - nir_builder_instr_insert(b, &new_call->instr); + for (unsigned j = 4; j < call->num_params; j++) { + if (nir_src_as_deref(call->params[j])) { + struct split_mat *src1_split = find_call_split(info->split_mats, call, j); + nir_deref_instr *src1_deref = src1_split ? recreate_derefs(b, &call->params[j], src1_split->split_vars[i]) : nir_src_as_deref(call->params[j]); + new_call->params[j] = src1_deref ? nir_src_for_ssa(&src1_deref->def) : call->params[j]; + } else + new_call->params[j] = call->params[j]; } + b->cursor = nir_before_instr(instr); + nir_builder_instr_insert(b, &new_call->instr); } nir_instr_remove(instr); return true; @@ -868,34 +913,19 @@ split_var(nir_shader *shader, } struct glsl_cmat_description desc = *glsl_get_cmat_description(type); - unsigned split_rows = 0, split_cols = 0; - get_lower_sizes(desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols); - - unsigned num_row_split = 1, num_col_split = 1; - - if (split_rows) { - num_row_split = desc.rows / split_rows; - desc.rows = split_rows; - } - if (split_cols) { - num_col_split = desc.cols / split_cols; - desc.cols = split_cols; - } - - if (num_row_split == 1 && num_col_split == 1) - return; + struct split_box box; + if (!split_desc(&desc, info, &box)) + return; const struct glsl_type *new_type = glsl_type_wrap_in_arrays(glsl_cmat_type(&desc), var->type); - struct split_mat *split_mat = ralloc(info->split_mats, struct split_mat); + struct split_mat *split_mat = rzalloc(info->split_mats, struct split_mat); if (!split_mat) return; - unsigned num_split = num_row_split * num_col_split; - split_mat->num_row_splits = num_row_split; - split_mat->num_col_splits = num_col_split; - split_mat->split_vars = ralloc_array(split_mat, struct nir_variable *, num_split); + split_mat->box = box; + unsigned num_split = split_alloc_vars(split_mat); for (unsigned i = 0; i < num_split; i++) { if (!nir_variable_is_global(var)) { split_mat->split_vars[i] = nir_local_variable_create(impl,