diff --git a/src/gallium/frontends/lavapipe/lvp_device.c b/src/gallium/frontends/lavapipe/lvp_device.c index 8e134d7ff32..3759c224466 100644 --- a/src/gallium/frontends/lavapipe/lvp_device.c +++ b/src/gallium/frontends/lavapipe/lvp_device.c @@ -874,6 +874,7 @@ lvp_get_features(const struct lvp_physical_device *pdevice, .cooperativeMatrixFlexibleDimensions = true, .cooperativeMatrixConversions = true, + .cooperativeMatrixReductions = true, }; } diff --git a/src/gallium/frontends/lavapipe/lvp_pipeline.c b/src/gallium/frontends/lavapipe/lvp_pipeline.c index 8e1385063f5..d5bb0040d25 100644 --- a/src/gallium/frontends/lavapipe/lvp_pipeline.c +++ b/src/gallium/frontends/lavapipe/lvp_pipeline.c @@ -356,7 +356,18 @@ lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_pipelin NIR_PASS(_, nir, nir_opt_dce); NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL); } - NIR_PASS(_, nir, lvp_nir_lower_cooperative_matrix); + NIR_PASS(progress, nir, lvp_nir_lower_cooperative_matrix); + if (progress) { + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(progress, nir, nir_inline_functions); + nir_remove_non_entrypoints(nir); /* remove the late inlined functions */ + if (progress) { + NIR_PASS(_, nir, nir_opt_copy_prop_vars); + NIR_PASS(_, nir, nir_opt_copy_prop); + } + NIR_PASS(_, nir, nir_opt_deref); + NIR_PASS(_, nir, nir_opt_dce); + } const struct nir_lower_compute_system_values_options compute_system_values = {0}; NIR_PASS(_, nir, nir_lower_compute_system_values, &compute_system_values); diff --git a/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c b/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c index a3f2ceb8765..daa0b631c5b 100644 --- a/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c +++ b/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c @@ -441,6 +441,172 @@ lower_cmat_bitcast(nir_builder *b, nir_intrinsic_instr *intr) return true; } +static bool +lower_cmat_reduce_finish_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_deref_instr *src0_deref = nir_src_as_deref(call->params[1]); + struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src0_deref->type); + nir_function *fnptr = call->callee; + nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call); + nir_def *src0 = load_cmat_src(b, call->params[1]); + nir_def *src1 = load_cmat_src(b, call->params[2]); + + assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {}; + + if (reduce & NIR_CMAT_REDUCE_COLUMN) { + nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp"); + /* All of the rows contains the same data, so just reduce both first rows. */ + nir_def *row_accum0 = nir_channel(b, src0, 0); + nir_def *row_accum1 = nir_channel(b, src1, 0); + + nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp); + + nir_call(b, fnptr, &col_tmp_deref->def, row_accum0, row_accum1); + + nir_def *first_col = nir_load_deref(b, col_tmp_deref); + + for (unsigned i = 0; i < CMAT_LEN; i++) + comps[i] = first_col; + } else if (reduce & NIR_CMAT_REDUCE_ROW) { + for (unsigned i = 0; i < CMAT_LEN; ++i) { + nir_def *row0_accum = nir_channel(b, src0, i); + nir_def *row1_accum = nir_channel(b, src1, i); + + nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp"); + nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp); + + nir_call(b, fnptr, &row_tmp_deref->def, row0_accum, row1_accum); + + nir_def *row = nir_load_deref(b, row_tmp_deref); + comps[i] = row; + } + } + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components)); + nir_instr_remove(&call->instr); + return true; +} + +static bool +lower_cmat_reduce_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]); + struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); + nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call); + nir_def *src = load_cmat_src(b, call->params[1]); + nir_function *fnptr = call->callee; + nir_def *lane_id = nir_load_subgroup_invocation(b); + + assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {}; + for (unsigned i = 0; i < CMAT_LEN; ++i) { + comps[i] = nir_channel(b, src, i); + } + + if (reduce & NIR_CMAT_REDUCE_COLUMN) { + nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp"); + + nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp); + nir_store_deref(b, col_tmp_deref, comps[0], 1); + + for (unsigned i = 1; i < CMAT_LEN; i++) { + nir_def *col_accum_val = nir_load_deref(b, col_tmp_deref); + nir_call(b, fnptr, &col_tmp_deref->def, col_accum_val, comps[i]); + } + + for (unsigned i = 0; i < CMAT_LEN; i++) + comps[i] = nir_load_deref(b, col_tmp_deref); + } + + if (reduce & NIR_CMAT_REDUCE_ROW) { + for (unsigned i = 0; i < CMAT_LEN; ++i) { + nir_def *row_accum = comps[i]; + nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp"); + nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp); + nir_store_deref(b, row_tmp_deref, row_accum, 1); + + for (unsigned j = 1; j < CMAT_LEN; j *= 2) { + nir_def *prev_row_accum_val = nir_load_deref(b, row_tmp_deref); + nir_def *this_row = nir_shuffle(b, prev_row_accum_val, nir_iadd(b, lane_id, nir_imm_int(b, j))); + + nir_call(b, fnptr, &row_tmp_deref->def, prev_row_accum_val, this_row); + } + row_tmp_deref = nir_build_deref_var(b, row_tmp); + comps[i] = nir_load_deref(b, row_tmp_deref); + } + } + + /* this should be lowered earlier */ + assert(!(reduce & NIR_CMAT_REDUCE_2X2)); + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components)); + nir_instr_remove(&call->instr); + return true; +} + +static bool +lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]); + struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); + nir_function *fnptr = call->callee; + nir_def *lane_id = nir_load_subgroup_invocation(b); + assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS]; + + nir_def *src_components[4][NIR_MAX_VEC_COMPONENTS]; + for (unsigned m = 0; m < 4; m++) { + nir_def *src = load_cmat_src(b, call->params[m + 1]); + for (unsigned i = 0; i < CMAT_LEN; i++) { + src_components[m][i] = nir_channel(b, src, i); + } + } + nir_variable *qd_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "qd_tmp"); + nir_deref_instr *qd_tmp_deref = nir_build_deref_var(b, qd_tmp); + + for (unsigned m = 0; m < 4; m++) { + for (unsigned i = 0; i < CMAT_LEN / 2; i++) { + nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]); + src_components[m][i] = nir_load_deref(b, qd_tmp_deref); + + nir_def *other_col = nir_shuffle_down(b, src_components[m][i], nir_imm_int(b, 1)); + nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i], other_col); + src_components[m][i] = nir_load_deref(b, qd_tmp_deref); + } + } + + nir_def *even = nir_inverse_ballot_imm(b, 0x5555555555555555, 32); + for (unsigned m = 0; m < 2; m++) { + for (unsigned i = 0; i < CMAT_LEN / 2; i++) { + nir_def *m0_comp = src_components[m * 2][i]; + nir_def *m1_comp = nir_shuffle_up(b, src_components[m * 2 + 1][i], nir_imm_int(b, 1)); + + nir_def *combined = nir_bcsel(b, even, m0_comp, m1_comp); + comps[m * (CMAT_LEN / 2) + i] = combined; + } + } + + nir_def *low_lane_id = nir_ilt_imm(b, lane_id, 4); + nir_def *new_lane_id_lo = nir_imul_imm(b, lane_id, 2); + nir_def *new_lane_id_hi = nir_iadd_imm(b, nir_imul_imm(b, nir_iadd_imm(b, lane_id, -4), 2), 1); + nir_def *new_lane_id = nir_bcsel(b, low_lane_id, new_lane_id_lo, new_lane_id_hi); + + for (unsigned m = 0; m < CMAT_LEN; m++) { + comps[m] = nir_shuffle(b, comps[m], new_lane_id); + } + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components)); + nir_instr_remove(&call->instr); + return true; +} + static bool lower_impl(nir_function_impl *impl, struct hash_table *type_mapping) @@ -520,6 +686,23 @@ lower_impl(nir_function_impl *impl, } break; } + case nir_instr_type_cmat_call: { + nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr); + switch (call->op) { + case nir_cmat_call_op_reduce: + progress |= lower_cmat_reduce_call(&b, call); + break; + case nir_cmat_call_op_reduce_finish: + progress |= lower_cmat_reduce_finish_call(&b, call); + break; + case nir_cmat_call_op_reduce_2x2: + progress |= lower_cmat_reduce_2x2_call(&b, call); + break; + default: + break; + } + break; + } default: break; } @@ -551,5 +734,8 @@ lvp_nir_lower_cooperative_matrix(nir_shader *shader) progress |= lower_impl(nir_shader_get_entrypoint(shader), type_mapping); _mesa_hash_table_destroy(type_mapping, NULL); + + nir_foreach_function_impl(fnim, shader) + nir_progress(progress, fnim, 0); return progress; } diff --git a/src/vulkan/runtime/vk_nir.c b/src/vulkan/runtime/vk_nir.c index b9cec0878e9..391b07616e7 100644 --- a/src/vulkan/runtime/vk_nir.c +++ b/src/vulkan/runtime/vk_nir.c @@ -168,7 +168,7 @@ vk_spirv_to_nir(struct vk_device *device, NIR_PASS(_, nir, nir_opt_deref); /* Pick off the single entrypoint that we want */ - nir_remove_non_entrypoints(nir); + nir_remove_non_cmat_call_entrypoints(nir); /* Now that we've deleted all but the main function, we can go ahead and * lower the rest of the constant initializers. We do this here so that