diff --git a/src/gallium/frontends/lavapipe/lvp_device.c b/src/gallium/frontends/lavapipe/lvp_device.c index 27a79b6c309..8e134d7ff32 100644 --- a/src/gallium/frontends/lavapipe/lvp_device.c +++ b/src/gallium/frontends/lavapipe/lvp_device.c @@ -873,6 +873,7 @@ lvp_get_features(const struct lvp_physical_device *pdevice, .cooperativeMatrixRobustBufferAccess = has_cooperative_matrix(), .cooperativeMatrixFlexibleDimensions = true, + .cooperativeMatrixConversions = true, }; } diff --git a/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c b/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c index 6f08a3813c6..a3f2ceb8765 100644 --- a/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c +++ b/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c @@ -106,6 +106,35 @@ lower_cmat_copy(nir_builder *b, nir_intrinsic_instr *intr) return true; } +static nir_def * +convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, + enum glsl_cmat_use dst_use) +{ + nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {}; + nir_def *out_comps[NIR_MAX_VEC_COMPONENTS] = {}; + unsigned num_comps = src->num_components; + for (unsigned i = 0; i < num_comps; i++) { + comps[i] = nir_channel(b, src, i); + out_comps[i] = nir_imm_zero(b, 1, comps[i]->bit_size); + } + + nir_def *lane_id = nir_load_subgroup_invocation(b); + + /* construct the outer row */ + for (unsigned i = 0; i < num_comps; i++) { + + for (unsigned j = 0; j < CMAT_LEN; j++) { + nir_def *else_val = out_comps[i]; + nir_def *active_lane = nir_ieq(b, lane_id, nir_imm_int(b, j)); + + out_comps[i] = nir_read_invocation(b, comps[j], nir_imm_int(b, i)); + + out_comps[i] = nir_bcsel(b, active_lane, out_comps[i], else_val); + } + } + return nir_vec(b, out_comps, num_comps); +} + static nir_def * convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type) { @@ -122,14 +151,35 @@ static bool lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intr) { + const bool transpose = intr->intrinsic == nir_intrinsic_cmat_transpose; struct glsl_cmat_description dst_desc = cmat_src_desc(intr->src[0]); struct glsl_cmat_description src_desc = cmat_src_desc(intr->src[1]); enum glsl_base_type dst_element_type = dst_desc.element_type; enum glsl_base_type src_element_type = src_desc.element_type; + + enum glsl_cmat_use dst_use = dst_desc.use; + enum glsl_cmat_use src_use = src_desc.use; + nir_def *cmat = load_cmat_src(b, intr->src[1]); - nir_def *ret = convert_base_type(b, cmat, src_element_type, dst_element_type); + if (dst_use == GLSL_CMAT_USE_ACCUMULATOR) + dst_use = GLSL_CMAT_USE_A; + if (src_use == GLSL_CMAT_USE_ACCUMULATOR) + src_use = GLSL_CMAT_USE_A; + + if (transpose) { + if (src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_B) + src_use = dst_use; + if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_A) + src_use = dst_use; + } + + nir_def *ret = cmat; + if (dst_use != src_use) { + ret = convert_use(b, cmat, src_use, dst_use); + } + ret = convert_base_type(b, ret, src_element_type, dst_element_type); store_cmat_src(b, intr->src[0], ret); nir_instr_remove(&intr->instr); return true; @@ -448,6 +498,7 @@ lower_impl(nir_function_impl *impl, progress |= lower_cmat_copy(&b, intr); break; case nir_intrinsic_cmat_convert: + case nir_intrinsic_cmat_transpose: progress |= lower_cmat_convert(&b, intr); break; case nir_intrinsic_cmat_bitcast: