diff --git a/src/compiler/nir/nir_lower_cooperative_matrix.c b/src/compiler/nir/nir_lower_cooperative_matrix.c index dc748deea37..62dae837d2e 100644 --- a/src/compiler/nir/nir_lower_cooperative_matrix.c +++ b/src/compiler/nir/nir_lower_cooperative_matrix.c @@ -11,8 +11,8 @@ * Lower flexible size cooperative matrix operations down to operations at the supported granularity. */ struct split_box { - unsigned num_row_splits; - unsigned num_col_splits; + unsigned outer_rows; + unsigned outer_cols; }; struct split_mat { @@ -29,7 +29,7 @@ struct split_info { static unsigned get_num_splits_box(const struct split_box *box) { - return box->num_row_splits * box->num_col_splits; + return box->outer_rows * box->outer_cols; } static unsigned get_num_splits(struct split_mat *split) @@ -134,21 +134,21 @@ split_desc(struct glsl_cmat_description *desc, struct split_info *info, struct split_box *box) { unsigned split_rows = 0, split_cols = 0; - box->num_col_splits = 1; - box->num_row_splits = 1; + box->outer_cols = 1; + box->outer_rows = 1; get_lower_sizes(*desc, info->m_gran, info->n_gran, info->k_gran, &split_rows, &split_cols); if (split_rows) { - box->num_row_splits = desc->rows / split_rows; + box->outer_rows = desc->rows / split_rows; desc->rows = split_rows; } if (split_cols) { - box->num_col_splits = desc->cols / split_cols; + box->outer_cols = desc->cols / split_cols; desc->cols = split_cols; } - if (box->num_row_splits == 1 && box->num_col_splits == 1) + if (box->outer_rows == 1 && box->outer_cols == 1) return false; return true; } @@ -358,9 +358,9 @@ split_cmat_transpose(nir_builder *b, unsigned dst_splits = get_num_splits(dst_split); for (unsigned i = 0; i < dst_splits; i++) { - int r = i / dst_split->box.num_row_splits; - int c = i % dst_split->box.num_row_splits; - int out_idx = c * dst_split->box.num_col_splits + r; + int r = i / dst_split->box.outer_rows; + int c = i % dst_split->box.outer_rows; + int out_idx = c * dst_split->box.outer_cols + r; nir_deref_instr *dst_deref = recreate_derefs(b, &intr->src[0], dst_split->split_vars[out_idx]); nir_deref_instr *src_deref = recreate_derefs(b, &intr->src[1], src_split->split_vars[i]); b->cursor = nir_before_instr(instr); @@ -511,22 +511,22 @@ split_cmat_muladd(nir_builder *b, if (result_split) { assert(c_split); - m_splits = result_split->box.num_row_splits; - n_splits = result_split->box.num_col_splits; + m_splits = result_split->box.outer_rows; + n_splits = result_split->box.outer_cols; - assert(c_split->box.num_row_splits == m_splits); - assert(c_split->box.num_col_splits == n_splits); + assert(c_split->box.outer_rows == m_splits); + assert(c_split->box.outer_cols == n_splits); if (a_split) - assert(a_split->box.num_row_splits == m_splits); + assert(a_split->box.outer_rows == m_splits); if (b_split) - assert(b_split->box.num_col_splits == n_splits); + assert(b_split->box.outer_cols == n_splits); } - if (a_split && a_split->box.num_col_splits > 1) { + if (a_split && a_split->box.outer_cols > 1) { assert(b_split); - assert(b_split->box.num_row_splits == a_split->box.num_col_splits); - k_splits = a_split->box.num_col_splits; + assert(b_split->box.outer_rows == a_split->box.outer_cols); + k_splits = a_split->box.outer_cols; } for (unsigned m = 0; m < m_splits; m++) { @@ -641,14 +641,14 @@ split_cmat_call_reduce(nir_builder *b, call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def); } } else if (reduce & NIR_CMAT_REDUCE_ROW) { - for (unsigned i = 1; i < src_split->box.num_col_splits; i++) { + for (unsigned i = 1; i < src_split->box.outer_cols; i++) { nir_deref_instr *second_deref = temp_derefs[i]; b->cursor = nir_before_instr(instr); call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def); } } else if (reduce & NIR_CMAT_REDUCE_COLUMN) { - for (unsigned i = 1; i < src_split->box.num_row_splits; i++) { - nir_deref_instr *second_deref = temp_derefs[i * src_split->box.num_col_splits]; + for (unsigned i = 1; i < src_split->box.outer_rows; i++) { + nir_deref_instr *second_deref = temp_derefs[i * src_split->box.outer_cols]; b->cursor = nir_before_instr(instr); call_reduce_finish(b, call, reduce, &temp_derefs[0]->def, &temp_derefs[0]->def, &second_deref->def); } @@ -660,16 +660,16 @@ split_cmat_call_reduce(nir_builder *b, /* at this point temp_derefs should contain all the split reduced src matrices now to store them */ if (dst_split) { - for (unsigned r = 0; r < dst_split->box.num_row_splits; r++) { - for (unsigned c = 0; c < dst_split->box.num_col_splits; c++) { - int didx = r * dst_split->box.num_col_splits + c; + for (unsigned r = 0; r < dst_split->box.outer_rows; r++) { + for (unsigned c = 0; c < dst_split->box.outer_cols; c++) { + int didx = r * dst_split->box.outer_cols + c; int idx; if ((reduce & (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) == (NIR_CMAT_REDUCE_ROW | NIR_CMAT_REDUCE_COLUMN)) idx = 0; else if (reduce & NIR_CMAT_REDUCE_ROW) - idx = r % (src_split ? src_split->box.num_row_splits : 1); + idx = r % (src_split ? src_split->box.outer_rows : 1); else if (reduce & NIR_CMAT_REDUCE_COLUMN) - idx = c % (src_split ? src_split->box.num_col_splits : 1); + idx = c % (src_split ? src_split->box.outer_cols : 1); else UNREACHABLE("Unknown NIR_CMAT_REDUCE_*"); @@ -691,8 +691,8 @@ split_cmat_call_reduce(nir_builder *b, int rows = 1, cols = 1; if (dst_split) { - rows = dst_split->box.num_row_splits; - cols = dst_split->box.num_col_splits; + rows = dst_split->box.outer_rows; + cols = dst_split->box.outer_cols; } for (unsigned r = 0; r < rows; r++) { @@ -700,8 +700,8 @@ split_cmat_call_reduce(nir_builder *b, int d_idx = c + r * cols; int src_top_left_col = c * 2; int src_top_left_row = r * 2; - int src_top_idx = src_top_left_col + src_top_left_row * src_split->box.num_col_splits; - int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->box.num_col_splits; + int src_top_idx = src_top_left_col + src_top_left_row * src_split->box.outer_cols; + int src_bottom_idx = src_top_left_col + (src_top_left_row + 1) * src_split->box.outer_cols; nir_deref_instr *src0_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx]); nir_deref_instr *src1_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_top_idx + 1]); nir_deref_instr *src2_deref = recreate_derefs(b, &call->params[1], src_split->split_vars[src_bottom_idx]); @@ -749,8 +749,8 @@ split_cmat_load_store(nir_builder *b, struct glsl_cmat_description desc = *glsl_get_cmat_description(split->split_vars[i]->type); unsigned row_offset, col_offset; - row_offset = (i % split->box.num_col_splits) * desc.cols; - col_offset = (i / split->box.num_col_splits) * desc.rows; + row_offset = (i % split->box.outer_cols) * desc.cols; + col_offset = (i / split->box.outer_cols) * desc.rows; if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) SWAP(row_offset, col_offset); @@ -793,8 +793,8 @@ split_cmat_call_per_element_op(nir_builder *b, return false; for (unsigned i = 0; i < splits; i++) { - int r = i / dst_split->box.num_col_splits; - int c = i % dst_split->box.num_col_splits; + int r = i / dst_split->box.outer_cols; + int c = i % dst_split->box.outer_cols; nir_deref_instr *dst_deref = recreate_derefs(b, &call->params[0], dst_split->split_vars[i]); nir_deref_instr *src_deref = recreate_derefs(b, &call->params[3], src_split->split_vars[i]); struct glsl_cmat_description cmat_desc = *glsl_get_cmat_description(src_split->split_vars[0]->type);