diff --git a/src/compiler/nir/nir_propagate_invariant.c b/src/compiler/nir/nir_propagate_invariant.c index 096532842b6..0ba0af2ca50 100644 --- a/src/compiler/nir/nir_propagate_invariant.c +++ b/src/compiler/nir/nir_propagate_invariant.c @@ -126,6 +126,7 @@ propagate_invariant_instr(nir_instr *instr, struct set *invariants) case nir_instr_type_jump: case nir_instr_type_undef: case nir_instr_type_load_const: + case nir_instr_type_cmat_call: break; /* Nothing to do */ case nir_instr_type_phi: { diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index b57325c5c55..ef8df2acd83 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -2324,7 +2324,13 @@ static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src, /* can't use llvm reduction intrinsics because of exec_mask */ LLVMValueRef exec_mask = group_op_mask_vec(bld); nir_op reduction_op = nir_intrinsic_reduction_op(instr); - + bool is_flt = reduction_op == nir_op_fadd || + reduction_op == nir_op_fmul || + reduction_op == nir_op_fmin || + reduction_op == nir_op_fmax; + bool is_unsigned = reduction_op == nir_op_umin || + reduction_op == nir_op_umax; + struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true); uint32_t cluster_size = 0; if (instr->intrinsic == nir_intrinsic_reduce) @@ -2338,20 +2344,57 @@ static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src, src = LLVMBuildZExt(builder, src, bld->uint8_bld.vec_type, ""); } + /* reduction addition optimisation passes - for coopmat */ + /* for i/fadd instead of doing it manually just zero out inactive lanes */ + if ((reduction_op == nir_op_iadd || + reduction_op == nir_op_fadd) && + cluster_size == bld->int_bld.type.length) { + struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) : + get_int_bld(bld, is_unsigned, bit_size, true); + char intrinsic[64]; + uint32_t length = vec_bld->type.length; + uint32_t src_width = bit_size; + + src = LLVMBuildBitCast(builder, src, int_bld->vec_type, ""); + if (bit_size < 32) + exec_mask = LLVMBuildTrunc(builder, exec_mask, int_bld->vec_type, ""); + LLVMValueRef masked_val = lp_build_and(int_bld, src, exec_mask); + snprintf(intrinsic, sizeof intrinsic, "llvm.vector.reduce.%sadd.v%u%s%u", + is_flt ? "f" : "", + length, is_flt ? "f" : is_unsigned ? "u" : "i" , src_width); + + if (is_flt) + masked_val = LLVMBuildBitCast(builder, masked_val, vec_bld->vec_type, ""); + + LLVMValueRef args[2]; + int num_args = is_flt ? 2 : 1; + if (is_flt) { + args[0] = lp_build_const_elem(gallivm, vec_bld->type, 0); + args[1] = masked_val; + } else { + args[0] = masked_val; + } + LLVMValueRef res = lp_build_intrinsic(builder, intrinsic, vec_bld->elem_type, args, num_args, 0); + + res = lp_build_broadcast(gallivm, vec_bld->vec_type, res); + LLVMValueRef swizzle[LP_MAX_VECTOR_LENGTH]; + for (uint32_t i = 0; i < vec_bld->type.length; i++) + swizzle[i] = lp_build_const_int32(gallivm, i); + + LLVMValueRef undef = LLVMGetUndef(vec_bld->vec_type); + result[0] = LLVMBuildShuffleVector( + builder, res, undef, LLVMConstVector(swizzle, vec_bld->type.length), ""); + return; + } + LLVMValueRef res_store = NULL; LLVMValueRef scan_store; - struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true); + res_store = lp_build_alloca(gallivm, int_bld->vec_type, ""); scan_store = lp_build_alloca(gallivm, int_bld->elem_type, ""); struct lp_build_context elem_bld; - bool is_flt = reduction_op == nir_op_fadd || - reduction_op == nir_op_fmul || - reduction_op == nir_op_fmin || - reduction_op == nir_op_fmax; - bool is_unsigned = reduction_op == nir_op_umin || - reduction_op == nir_op_umax; struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) : get_int_bld(bld, is_unsigned, bit_size, true); diff --git a/src/gallium/frontends/lavapipe/lvp_device.c b/src/gallium/frontends/lavapipe/lvp_device.c index 676435e17b0..ec45f0a2c79 100644 --- a/src/gallium/frontends/lavapipe/lvp_device.c +++ b/src/gallium/frontends/lavapipe/lvp_device.c @@ -74,6 +74,13 @@ #define LVP_SAMPLE_COUNTS (VK_SAMPLE_COUNT_1_BIT | VK_SAMPLE_COUNT_4_BIT | \ VK_SAMPLE_COUNT_8_BIT) +extern unsigned lp_native_vector_width; + +static bool has_cooperative_matrix(void) { + /* only support coopmat if we have 8 wide */ + return (lp_native_vector_width / 32) >= 8; +} + VKAPI_ATTR VkResult VKAPI_CALL lvp_EnumerateInstanceVersion(uint32_t* pApiVersion) { *pApiVersion = LVP_API_VERSION; @@ -124,6 +131,7 @@ static const struct vk_device_extension_table lvp_device_extensions_supported = .KHR_buffer_device_address = true, .KHR_create_renderpass2 = true, .KHR_compute_shader_derivatives = true, + .KHR_cooperative_matrix = true, .KHR_copy_commands2 = true, .KHR_copy_memory_indirect = true, .KHR_dedicated_allocation = true, @@ -300,6 +308,7 @@ static const struct vk_device_extension_table lvp_device_extensions_supported = .GOOGLE_decorate_string = true, .GOOGLE_hlsl_functionality1 = true, .GOOGLE_user_type = true, + .NV_cooperative_matrix2 = true, }; static bool @@ -857,11 +866,18 @@ lvp_get_features(const struct lvp_physical_device *pdevice, /* VK_KHR_unified_image_layouts */ .unifiedImageLayouts = true, .unifiedImageLayoutsVideo = true, + + /* VK_KHR_cooperative_matrix */ + .cooperativeMatrix = has_cooperative_matrix(), + .cooperativeMatrixRobustBufferAccess = has_cooperative_matrix(), + + .cooperativeMatrixFlexibleDimensions = true, + .cooperativeMatrixConversions = true, + .cooperativeMatrixReductions = true, + .cooperativeMatrixPerElementOperations = true, }; } -extern unsigned lp_native_vector_width; - static VkImageLayout lvp_host_copy_image_layouts[] = { VK_IMAGE_LAYOUT_GENERAL, VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL, @@ -1304,6 +1320,9 @@ lvp_get_properties(const struct lvp_physical_device *device, struct vk_propertie /* VK_KHR_compute_shader_derivatives */ .meshAndTaskShaderDerivatives = true, + + /* VK_NV_cooperative_matrix2 */ + .cooperativeMatrixFlexibleDimensionsMaxDimension = 1024, }; /* Vulkan 1.0 */ @@ -1369,6 +1388,10 @@ lvp_get_properties(const struct lvp_physical_device *device, struct vk_propertie /* VK_EXT_mesh_shader */ p->maxMeshPayloadAndSharedMemorySize = p->maxTaskPayloadSize + p->maxMeshSharedMemorySize; /* 28K min required */ p->maxMeshPayloadAndOutputMemorySize = p->maxTaskPayloadSize + p->maxMeshOutputMemorySize; /* 47K min required */ + + /* VK_KHR_cooperative_matrix */ + p->cooperativeMatrixSupportedStages = VK_SHADER_STAGE_COMPUTE_BIT; + } static VkResult VKAPI_CALL @@ -1495,8 +1518,6 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_CreateInstance( return vk_error(NULL, result); } - instance->apiVersion = LVP_API_VERSION; - instance->vk.physical_devices.enumerate = lvp_enumerate_physical_devices; instance->vk.physical_devices.destroy = lvp_destroy_physical_device; @@ -2394,9 +2415,7 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_BindBufferMemory2(VkDevice _device, VK_FROM_HANDLE(lvp_buffer, buffer, pBindInfos[i].buffer); VkBindMemoryStatusKHR *status = (void*)vk_find_struct_const(&pBindInfos[i], BIND_MEMORY_STATUS_KHR); - buffer->mem = mem; buffer->map = (char*)mem->map + pBindInfos[i].memoryOffset; - buffer->offset = pBindInfos[i].memoryOffset; device->pscreen->resource_bind_backing(device->pscreen, buffer->bo, mem->pmem, @@ -2867,3 +2886,85 @@ VKAPI_ATTR void VKAPI_CALL lvp_GetRenderingAreaGranularityKHR( VkExtent2D tile_size = {64, 64}; *pGranularity = tile_size; } + +VKAPI_ATTR VkResult VKAPI_CALL lvp_GetPhysicalDeviceCooperativeMatrixPropertiesKHR( + VkPhysicalDevice physicalDevice, + uint32_t *pPropertyCount, + VkCooperativeMatrixPropertiesKHR *pProperties) +{ + VK_OUTARRAY_MAKE_TYPED(VkCooperativeMatrixPropertiesKHR, out, pProperties, pPropertyCount); + + vk_outarray_append_typed(VkCooperativeMatrixPropertiesKHR, &out, p) + { + *p = (struct VkCooperativeMatrixPropertiesKHR){ + .sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR, + .MSize = 8, + .NSize = 8, + .KSize = 8, + .AType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .BType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .CType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .ResultType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .saturatingAccumulation = false, + .scope = VK_SCOPE_SUBGROUP_KHR + }; + } + + vk_outarray_append_typed(VkCooperativeMatrixPropertiesKHR, &out, p) + { + *p = (struct VkCooperativeMatrixPropertiesKHR){ + .sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR, + .MSize = 8, + .NSize = 8, + .KSize = 8, + .AType = VK_COMPONENT_TYPE_UINT8_KHR, + .BType = VK_COMPONENT_TYPE_UINT8_KHR, + .CType = VK_COMPONENT_TYPE_UINT32_KHR, + .ResultType = VK_COMPONENT_TYPE_UINT32_KHR, + .saturatingAccumulation = false, + .scope = VK_SCOPE_SUBGROUP_KHR + }; + } + return vk_outarray_status(&out); +} + +VKAPI_ATTR VkResult VKAPI_CALL +lvp_GetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV( + VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount, + VkCooperativeMatrixFlexibleDimensionsPropertiesNV *pProperties) +{ + VK_OUTARRAY_MAKE_TYPED(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, out, pProperties, pPropertyCount); + + vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p) + { + *p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){ + .sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV, + .MGranularity = 8, + .NGranularity = 8, + .KGranularity = 8, + .AType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .BType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .CType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .ResultType = VK_COMPONENT_TYPE_FLOAT16_KHR, + .saturatingAccumulation = false, + .scope = VK_SCOPE_SUBGROUP_KHR + }; + } + + vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p) + { + *p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){ + .sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV, + .MGranularity = 8, + .NGranularity = 8, + .KGranularity = 8, + .AType = VK_COMPONENT_TYPE_UINT8_KHR, + .BType = VK_COMPONENT_TYPE_UINT8_KHR, + .CType = VK_COMPONENT_TYPE_UINT32_KHR, + .ResultType = VK_COMPONENT_TYPE_UINT32_KHR, + .saturatingAccumulation = false, + .scope = VK_SCOPE_SUBGROUP_KHR + }; + } + return vk_outarray_status(&out); +} diff --git a/src/gallium/frontends/lavapipe/lvp_execute.c b/src/gallium/frontends/lavapipe/lvp_execute.c index fd9ceca9f55..711a9e7b1f0 100644 --- a/src/gallium/frontends/lavapipe/lvp_execute.c +++ b/src/gallium/frontends/lavapipe/lvp_execute.c @@ -73,11 +73,11 @@ struct descriptor_buffer_offset { struct lvp_render_attachment { struct lvp_image_view *imgv; VkResolveModeFlags resolve_mode; + bool read_only; struct lvp_image_view *resolve_imgv; VkAttachmentLoadOp load_op; VkAttachmentStoreOp store_op; VkClearValue clear_value; - bool read_only; }; struct lvp_conditional_rendering_state { @@ -2961,21 +2961,22 @@ static void handle_copy_query_pool_results(struct vk_cmd_queue_entry *cmd, unsigned result_size = copycmd->flags & VK_QUERY_RESULT_64_BIT ? 8 : 4; for (unsigned i = copycmd->first_query; i < copycmd->first_query + copycmd->query_count; i++) { unsigned offset = copycmd->dst_offset + (copycmd->stride * (i - copycmd->first_query)); - if (pool->base_type >= PIPE_QUERY_TYPES) { struct pipe_transfer *transfer; uint8_t *map = pipe_buffer_map(state->pctx, lvp_buffer_from_handle(copycmd->dst_buffer)->bo, PIPE_MAP_WRITE, &transfer); map += offset; + void *data = &pool->queries; + if (copycmd->flags & VK_QUERY_RESULT_64_BIT) { uint64_t *dst = (uint64_t *)map; - uint64_t *src = (uint64_t *)pool->data; + uint64_t *src = (uint64_t *)data; *dst = src[i]; if (copycmd->flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT) *(dst + 1) = 1; } else { uint32_t *dst = (uint32_t *)map; - uint64_t *src = (uint64_t *)pool->data; + uint64_t *src = (uint64_t *)data; *dst = (uint32_t) (src[i] & UINT32_MAX); if (copycmd->flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT) *(dst + 1) = 1; @@ -4542,8 +4543,8 @@ handle_write_acceleration_structures_properties(struct vk_cmd_queue_entry *cmd, struct vk_cmd_write_acceleration_structures_properties_khr *write = &cmd->u.write_acceleration_structures_properties_khr; VK_FROM_HANDLE(lvp_query_pool, pool, write->query_pool); - - uint64_t *dst = pool->data; + void *data = &pool->queries; + uint64_t *dst = data; dst += write->first_query; for (uint32_t i = 0; i < write->acceleration_structure_count; i++) { diff --git a/src/gallium/frontends/lavapipe/lvp_pipeline.c b/src/gallium/frontends/lavapipe/lvp_pipeline.c index 7aea2456fac..0da2c8cbc07 100644 --- a/src/gallium/frontends/lavapipe/lvp_pipeline.c +++ b/src/gallium/frontends/lavapipe/lvp_pipeline.c @@ -38,10 +38,6 @@ #include "gallivm/lp_bld_debug.h" -#define SPIR_V_MAGIC_NUMBER 0x07230203 - -#define MAX_DYNAMIC_STATES 72 - typedef void (*cso_destroy_func)(struct pipe_context*, void*); static void @@ -351,6 +347,26 @@ lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_pipelin NIR_PASS(_, nir, nir_lower_system_values); NIR_PASS(_, nir, nir_lower_is_helper_invocation); + bool progress; + NIR_PASS(progress, nir, nir_lower_cooperative_matrix_flexible_dimensions, 8, 8, 8); + if (progress) { + NIR_PASS(_, nir, nir_opt_deref); + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL); + } + NIR_PASS(progress, nir, lvp_nir_lower_cooperative_matrix); + if (progress) { + NIR_PASS(_, nir, nir_opt_dce); + NIR_PASS(progress, nir, nir_inline_functions); + nir_remove_non_entrypoints(nir); /* remove the late inlined functions */ + if (progress) { + NIR_PASS(_, nir, nir_opt_copy_prop_vars); + NIR_PASS(_, nir, nir_opt_copy_prop); + } + NIR_PASS(_, nir, nir_opt_deref); + NIR_PASS(_, nir, nir_opt_dce); + } + const struct nir_lower_compute_system_values_options compute_system_values = {0}; NIR_PASS(_, nir, nir_lower_compute_system_values, &compute_system_values); diff --git a/src/gallium/frontends/lavapipe/lvp_private.h b/src/gallium/frontends/lavapipe/lvp_private.h index c31a03b7021..6dc5193d7d9 100644 --- a/src/gallium/frontends/lavapipe/lvp_private.h +++ b/src/gallium/frontends/lavapipe/lvp_private.h @@ -168,8 +168,6 @@ struct lvp_physical_device { struct lvp_instance { struct vk_instance vk; - uint32_t apiVersion; - uint64_t debug_flags; struct pipe_loader_device *devs; @@ -579,10 +577,8 @@ struct lvp_event { struct lvp_buffer { struct vk_buffer vk; - struct lvp_device_memory *mem; struct pipe_resource *bo; uint64_t total_size; - uint64_t offset; void *map; struct pipe_transfer *transfer; }; @@ -605,7 +601,6 @@ struct lvp_buffer_view { struct lvp_query_pool { struct vk_query_pool vk; enum pipe_query_type base_type; - void *data; /* Used by queries that are not implemented by pipe_query */ struct pipe_query *queries[0]; }; diff --git a/src/gallium/frontends/lavapipe/lvp_query.c b/src/gallium/frontends/lavapipe/lvp_query.c index 267f0c965d8..b791432aad2 100644 --- a/src/gallium/frontends/lavapipe/lvp_query.c +++ b/src/gallium/frontends/lavapipe/lvp_query.c @@ -81,7 +81,6 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_CreateQueryPool( return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY); pool->base_type = pipeq; - pool->data = &pool->queries; *pQueryPool = lvp_query_pool_to_handle(pool); return VK_SUCCESS; @@ -122,6 +121,7 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_GetQueryPoolResults( device->vk.dispatch_table.DeviceWaitIdle(_device); + void *data = &pool->queries; for (unsigned i = firstQuery; i < firstQuery + queryCount; i++) { uint8_t *dest = (uint8_t *)((char *)pData + (stride * (i - firstQuery))); union pipe_query_result result; @@ -130,13 +130,13 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_GetQueryPoolResults( if (pool->base_type >= PIPE_QUERY_TYPES) { if (flags & VK_QUERY_RESULT_64_BIT) { uint64_t *dst = (uint64_t *)dest; - uint64_t *src = (uint64_t *)pool->data; + uint64_t *src = (uint64_t *)data; *dst = src[i]; if (flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT) *(dst + 1) = 1; } else { uint32_t *dst = (uint32_t *)dest; - uint64_t *src = (uint64_t *)pool->data; + uint64_t *src = (uint64_t *)data; *dst = src[i]; if (flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT) *(dst + 1) = 1; diff --git a/src/gallium/frontends/lavapipe/meson.build b/src/gallium/frontends/lavapipe/meson.build index 2a0c1bcc563..b87926e8022 100644 --- a/src/gallium/frontends/lavapipe/meson.build +++ b/src/gallium/frontends/lavapipe/meson.build @@ -12,6 +12,7 @@ lvp_entrypoints = custom_target( ) liblvp_files = files( + 'nir/lvp_nir_lower_cooperative_matrix.c', 'nir/lvp_nir_lower_exec_graph.c', 'nir/lvp_nir_lower_input_attachments.c', 'nir/lvp_nir_lower_pipeline_layout.c', diff --git a/src/gallium/frontends/lavapipe/nir/lvp_nir.h b/src/gallium/frontends/lavapipe/nir/lvp_nir.h index a4b6535b696..fbb76d9b9a1 100644 --- a/src/gallium/frontends/lavapipe/nir/lvp_nir.h +++ b/src/gallium/frontends/lavapipe/nir/lvp_nir.h @@ -121,4 +121,6 @@ bool lvp_nir_lower_sparse_residency(struct nir_shader *shader); bool lvp_nir_opt_robustness(struct nir_shader *shader, struct lvp_device *device, struct vk_pipeline_robustness_state *robustness); +bool lvp_nir_lower_cooperative_matrix(nir_shader *shader); + #endif diff --git a/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c b/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c new file mode 100644 index 00000000000..bafc58c0041 --- /dev/null +++ b/src/gallium/frontends/lavapipe/nir/lvp_nir_lower_cooperative_matrix.c @@ -0,0 +1,778 @@ +/* + * Copyright © 2025 Red Hat + * + * SPDX-License-Identifier: MIT + */ +#include "lvp_nir.h" + +extern unsigned lp_native_vector_width; + +#define MAX_CMAT_LEN 16 +#define CMAT_LEN (lp_native_vector_width / 32) + +/* This pass lowers cooperative matrix. + * + * for lavapipe we advertise 8x8 matrix. + * This means we can store vec8[8] and get the backend to do the right thing. + */ +static unsigned +get_cmat_size(struct glsl_cmat_description matrix_desc) +{ + return matrix_desc.cols * matrix_desc.rows; +} + +static unsigned +get_cmat_length(struct glsl_cmat_description matrix_desc) +{ + return get_cmat_size(matrix_desc) / CMAT_LEN; +} + +static const struct glsl_type * +remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig) +{ + struct hash_entry *entry = _mesa_hash_table_search(mapping, orig); + + if (entry) + return entry->data; + + const struct glsl_type *new_type = orig; + + if (glsl_type_is_cmat(orig)) { + struct glsl_cmat_description matrix_desc = + *glsl_get_cmat_description(orig); + + new_type = glsl_vector_type(matrix_desc.element_type, get_cmat_length(matrix_desc)); + } else if (glsl_type_is_array(orig)) { + const struct glsl_type *elem_type = glsl_get_array_element(orig); + const struct glsl_type *new_elem_type = + remap_matrix_type(mapping, elem_type); + + if (elem_type != new_elem_type) { + new_type = glsl_array_type(new_elem_type, glsl_get_length(orig), + glsl_get_explicit_stride(orig)); + } + } + _mesa_hash_table_insert(mapping, orig, (void *)new_type); + return new_type; +} + +static nir_def * +load_cmat_deref(nir_builder *b, nir_deref_instr *src) +{ + struct glsl_cmat_description matrix_desc = + *glsl_get_cmat_description(src->type); + + return nir_build_load_deref( + b, get_cmat_length(matrix_desc), + glsl_base_type_bit_size(matrix_desc.element_type), &src->def, 0); +} + +static ALWAYS_INLINE nir_def * +load_cmat_src(nir_builder *b, nir_src src) +{ + return load_cmat_deref(b, nir_src_as_deref(src)); +} + +static ALWAYS_INLINE struct glsl_cmat_description +cmat_src_desc(nir_src src) +{ + nir_deref_instr *deref = nir_src_as_deref(src); + return *glsl_get_cmat_description(deref->type); +} + +static void +store_cmat_deref(nir_builder *b, nir_deref_instr *dst, nir_def *val) +{ + ASSERTED struct glsl_cmat_description matrix_desc = + *glsl_get_cmat_description(dst->type); + + assert(val->bit_size == glsl_base_type_bit_size(matrix_desc.element_type)); + assert(val->num_components == get_cmat_length(matrix_desc)); + + nir_store_deref(b, dst, val, ~0); +} + +static ALWAYS_INLINE void +store_cmat_src(nir_builder *b, nir_src dst_src, nir_def *val) +{ + store_cmat_deref(b, nir_src_as_deref(dst_src), val); +} + +static bool +lower_cmat_copy(nir_builder *b, nir_intrinsic_instr *intr) +{ + nir_build_copy_deref(b, intr->src[0].ssa, intr->src[1].ssa); + nir_instr_remove(&intr->instr); + return true; +} + +static nir_def * +convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, + enum glsl_cmat_use dst_use) +{ + nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {}; + nir_def *out_comps[NIR_MAX_VEC_COMPONENTS] = {}; + unsigned num_comps = src->num_components; + for (unsigned i = 0; i < num_comps; i++) { + comps[i] = nir_channel(b, src, i); + out_comps[i] = nir_imm_zero(b, 1, comps[i]->bit_size); + } + + nir_def *lane_id = nir_load_subgroup_invocation(b); + + /* construct the outer row */ + for (unsigned i = 0; i < num_comps; i++) { + + for (unsigned j = 0; j < CMAT_LEN; j++) { + nir_def *else_val = out_comps[i]; + nir_def *active_lane = nir_ieq(b, lane_id, nir_imm_int(b, j)); + + out_comps[i] = nir_read_invocation(b, comps[j], nir_imm_int(b, i)); + + out_comps[i] = nir_bcsel(b, active_lane, out_comps[i], else_val); + } + } + return nir_vec(b, out_comps, num_comps); +} + +static nir_def * +convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type) +{ + if (dst_type == src_type) + return src; + + nir_op op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_type), + nir_get_nir_type_for_glsl_base_type(dst_type), nir_rounding_mode_undef); + + return nir_build_alu1(b, op, src); +} + +static bool +lower_cmat_convert(nir_builder *b, + nir_intrinsic_instr *intr) +{ + const bool transpose = intr->intrinsic == nir_intrinsic_cmat_transpose; + struct glsl_cmat_description dst_desc = cmat_src_desc(intr->src[0]); + struct glsl_cmat_description src_desc = cmat_src_desc(intr->src[1]); + + enum glsl_base_type dst_element_type = dst_desc.element_type; + enum glsl_base_type src_element_type = src_desc.element_type; + + enum glsl_cmat_use dst_use = dst_desc.use; + enum glsl_cmat_use src_use = src_desc.use; + + nir_def *cmat = load_cmat_src(b, intr->src[1]); + + if (dst_use == GLSL_CMAT_USE_ACCUMULATOR) + dst_use = GLSL_CMAT_USE_A; + if (src_use == GLSL_CMAT_USE_ACCUMULATOR) + src_use = GLSL_CMAT_USE_A; + + if (transpose) { + if (src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_B) + src_use = dst_use; + if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_A) + src_use = dst_use; + } + + nir_def *ret = cmat; + if (dst_use != src_use) { + ret = convert_use(b, cmat, src_use, dst_use); + } + ret = convert_base_type(b, ret, src_element_type, dst_element_type); + store_cmat_src(b, intr->src[0], ret); + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_load_store(nir_builder *b, + struct hash_table *type_mapping, + nir_intrinsic_instr *intr) +{ + const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load; + const struct glsl_cmat_description desc = cmat_src_desc(intr->src[!is_load]); + enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); + nir_deref_instr *cmat_deref = nir_src_as_deref(intr->src[!is_load]); + nir_deref_instr *deref = nir_src_as_deref(intr->src[is_load]); + nir_def *stride = intr->src[2].ssa; + + nir_def *lane_id = nir_load_subgroup_invocation(b); + unsigned type_size_B = glsl_base_type_bit_size(desc.element_type) / 8; + const uint32_t ptr_stride = glsl_get_bit_size(deref->type) / 8 * glsl_get_vector_elements(deref->type); + deref = nir_build_deref_cast(b, &deref->def, deref->modes, deref->type, ptr_stride); + const struct glsl_type *cmat_type = remap_matrix_type(type_mapping, cmat_deref->type); + cmat_deref = nir_build_deref_cast(b, &cmat_deref->def, cmat_deref->modes, + cmat_type, 0); + + /* store B matrix transposed */ + if (desc.use == GLSL_CMAT_USE_B) + layout = + layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; + + unsigned idx_bits = deref->def.bit_size; + nir_def *vars[MAX_CMAT_LEN]; + + if (is_load) { + for (unsigned i = 0; i < CMAT_LEN; i++) { + vars[i] = nir_undef(b, 1, 16); + } + } else { + nir_def *src = load_cmat_src(b, intr->src[!is_load]); + for (unsigned i = 0; i < CMAT_LEN; i++) { + vars[i] = nir_channel(b, src, i); + } + } + for (unsigned i = 0; i < CMAT_LEN; i++) { + nir_def *col_offset = lane_id; + nir_def *row_offset = nir_imm_int(b, i); + + if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { + SWAP(col_offset, row_offset); + } + + col_offset = nir_imul(b, col_offset, stride); + col_offset = nir_u2uN(b, col_offset, idx_bits); + row_offset = nir_u2uN(b, row_offset, idx_bits); + + nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(b, deref, col_offset); + + iter_deref = nir_build_deref_cast(b, &iter_deref->def, + deref->modes, + glsl_scalar_type(desc.element_type), + type_size_B); + iter_deref = nir_build_deref_ptr_as_array(b, iter_deref, row_offset); + + if (is_load) { + vars[i] = nir_load_deref(b, iter_deref); + } else { + nir_store_deref(b, iter_deref, vars[i], ~0); + } + } + + if (is_load) { + nir_def *mat = nir_vec(b, vars, CMAT_LEN); + nir_store_deref(b, cmat_deref, mat, nir_component_mask(mat->num_components)); + } + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_construct(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]); + struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); + nir_def *elem = intr->src[1].ssa; + + nir_def *r = nir_replicate(b, elem, get_cmat_length(desc)); + + nir_store_deref(b, dst_deref, r, nir_component_mask(r->num_components)); + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_extract(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_def *mat = load_cmat_src(b, intr->src[0]); + nir_def *index = intr->src[1].ssa; + nir_def *elem = nir_vector_extract(b, mat, index); + nir_def_replace(&intr->def, elem); + return true; +} + +static bool +lower_cmat_insert(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_def *elem = intr->src[1].ssa; + nir_def *mat = load_cmat_src(b, intr->src[2]); + nir_def *index = intr->src[3].ssa; + + nir_def *r = nir_vector_insert(b, mat, elem, index); + store_cmat_src(b, intr->src[0], r); + + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_binary_op(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_def *src_a = load_cmat_src(b, intr->src[1]); + nir_def *src_b = load_cmat_src(b, intr->src[2]); + nir_op op = nir_intrinsic_alu_op(intr); + + nir_def *ret = nir_build_alu2(b, op, src_a, src_b); + store_cmat_src(b, intr->src[0], ret); + + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_unary_op(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_def *src = load_cmat_src(b, intr->src[1]); + nir_op op = nir_intrinsic_alu_op(intr); + + nir_def *ret = nir_build_alu1(b, op, src); + store_cmat_src(b, intr->src[0], ret); + + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_scalar_op(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_def *src_a = load_cmat_src(b, intr->src[1]); + nir_op op = nir_intrinsic_alu_op(intr); + + nir_def *ret = nir_build_alu2(b, op, src_a, intr->src[2].ssa); + store_cmat_src(b, intr->src[0], ret); + + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_length(nir_builder *b, + nir_intrinsic_instr *intr) +{ + nir_def_replace(&intr->def, nir_imm_int(b, CMAT_LEN)); + return true; +} + +static bool +lower_cmat_muladd(nir_builder *b, + nir_intrinsic_instr *intr) +{ + const struct glsl_cmat_description a_desc = cmat_src_desc(intr->src[1]); + const struct glsl_cmat_description b_desc = cmat_src_desc(intr->src[2]); + const struct glsl_cmat_description c_desc = cmat_src_desc(intr->src[3]); + nir_def *cmat_a = load_cmat_src(b, intr->src[1]); + nir_def *cmat_b = load_cmat_src(b, intr->src[2]); + nir_def *cmat_c = load_cmat_src(b, intr->src[3]); + + unsigned a_length = get_cmat_length(a_desc); + unsigned b_length = get_cmat_length(b_desc); + unsigned c_length = get_cmat_length(c_desc); + nir_def *a_comps[NIR_MAX_VEC_COMPONENTS]; + nir_def *b_comps[NIR_MAX_VEC_COMPONENTS]; + nir_def *c_comps[NIR_MAX_VEC_COMPONENTS]; + nir_def *d_comps[NIR_MAX_VEC_COMPONENTS]; + + for (unsigned i = 0; i < a_length; i++) + a_comps[i] = nir_channel(b, cmat_a, i); + + for (unsigned i = 0; i < b_length; i++) + b_comps[i] = nir_channel(b, cmat_b, i); + + for (unsigned i = 0; i < c_length; i++) + c_comps[i] = nir_channel(b, cmat_c, i); + + nir_def *lane_id = nir_load_subgroup_invocation(b); + int accum_bit_size = glsl_base_type_bit_size(c_desc.element_type); + for (unsigned i = 0; i < CMAT_LEN; i++) { + nir_def *ref = nir_imm_zero(b, 1, glsl_base_type_bit_size(c_desc.element_type)); + for (unsigned j = 0; j < CMAT_LEN; j++) { + nir_def *outer_else_val = ref; + ref = nir_imm_zero(b, 1, glsl_base_type_bit_size(c_desc.element_type)); + + nir_def *a_i = a_comps[i]; + nir_def *b_j = b_comps[j]; /* B is stored transposed */ + nir_def *val; + if (glsl_base_type_is_integer(c_desc.element_type)) { + a_i = nir_u2uN(b, a_i, accum_bit_size); + b_j = nir_u2uN(b, b_j, accum_bit_size); + val = nir_imul(b, a_i, b_j); + ref = nir_iadd(b, ref,val); + } else { + val = nir_fmul(b, a_i, b_j); + ref = nir_fadd(b, ref, val); + } + + if (glsl_base_type_is_integer(c_desc.element_type)) { + ref = nir_reduce(b, ref, .reduction_op = nir_op_iadd); + } else { + ref = nir_reduce(b, ref, .reduction_op = nir_op_fadd); + } + + nir_def *lane = nir_ieq_imm(b, lane_id, j); + ref = nir_bcsel(b, lane, ref, outer_else_val); + } + + if (glsl_base_type_is_integer(c_desc.element_type)) { + ref = nir_iadd(b, ref, c_comps[i]); + } else { + ref = nir_fadd(b, ref, c_comps[i]); + } + d_comps[i] = ref; + } + nir_def *ret = nir_vec(b, d_comps, CMAT_LEN); + store_cmat_src(b, intr->src[0], ret); + nir_instr_remove(&intr->instr); + return true; +} + +static bool +lower_cmat_reduce_finish_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_deref_instr *src0_deref = nir_src_as_deref(call->params[1]); + struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src0_deref->type); + nir_function *fnptr = call->callee; + nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call); + nir_def *src0 = load_cmat_src(b, call->params[1]); + nir_def *src1 = load_cmat_src(b, call->params[2]); + + assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {}; + + if (reduce & NIR_CMAT_REDUCE_COLUMN) { + nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp"); + /* All of the rows contains the same data, so just reduce both first rows. */ + nir_def *row_accum0 = nir_channel(b, src0, 0); + nir_def *row_accum1 = nir_channel(b, src1, 0); + + nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp); + + nir_call(b, fnptr, &col_tmp_deref->def, row_accum0, row_accum1); + + nir_def *first_col = nir_load_deref(b, col_tmp_deref); + + for (unsigned i = 0; i < CMAT_LEN; i++) + comps[i] = first_col; + } else if (reduce & NIR_CMAT_REDUCE_ROW) { + for (unsigned i = 0; i < CMAT_LEN; ++i) { + nir_def *row0_accum = nir_channel(b, src0, i); + nir_def *row1_accum = nir_channel(b, src1, i); + + nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp"); + nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp); + + nir_call(b, fnptr, &row_tmp_deref->def, row0_accum, row1_accum); + + nir_def *row = nir_load_deref(b, row_tmp_deref); + comps[i] = row; + } + } + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components)); + nir_instr_remove(&call->instr); + return true; +} + +static bool +lower_cmat_reduce_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]); + struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); + nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call); + nir_def *src = load_cmat_src(b, call->params[1]); + nir_function *fnptr = call->callee; + nir_def *lane_id = nir_load_subgroup_invocation(b); + + assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {}; + for (unsigned i = 0; i < CMAT_LEN; ++i) { + comps[i] = nir_channel(b, src, i); + } + + if (reduce & NIR_CMAT_REDUCE_COLUMN) { + nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp"); + + nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp); + nir_store_deref(b, col_tmp_deref, comps[0], 1); + + for (unsigned i = 1; i < CMAT_LEN; i++) { + nir_def *col_accum_val = nir_load_deref(b, col_tmp_deref); + nir_call(b, fnptr, &col_tmp_deref->def, col_accum_val, comps[i]); + } + + for (unsigned i = 0; i < CMAT_LEN; i++) + comps[i] = nir_load_deref(b, col_tmp_deref); + } + + if (reduce & NIR_CMAT_REDUCE_ROW) { + for (unsigned i = 0; i < CMAT_LEN; ++i) { + nir_def *row_accum = comps[i]; + nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp"); + nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp); + nir_store_deref(b, row_tmp_deref, row_accum, 1); + + for (unsigned j = 1; j < CMAT_LEN; j *= 2) { + nir_def *prev_row_accum_val = nir_load_deref(b, row_tmp_deref); + nir_def *this_row = nir_shuffle(b, prev_row_accum_val, nir_iadd(b, lane_id, nir_imm_int(b, j))); + + nir_call(b, fnptr, &row_tmp_deref->def, prev_row_accum_val, this_row); + } + row_tmp_deref = nir_build_deref_var(b, row_tmp); + comps[i] = nir_load_deref(b, row_tmp_deref); + } + } + + /* this should be lowered earlier */ + assert(!(reduce & NIR_CMAT_REDUCE_2X2)); + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components)); + nir_instr_remove(&call->instr); + return true; +} + +static bool +lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]); + struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); + nir_function *fnptr = call->callee; + nir_def *lane_id = nir_load_subgroup_invocation(b); + assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS]; + + nir_def *src_components[4][NIR_MAX_VEC_COMPONENTS]; + for (unsigned m = 0; m < 4; m++) { + nir_def *src = load_cmat_src(b, call->params[m + 1]); + for (unsigned i = 0; i < CMAT_LEN; i++) { + src_components[m][i] = nir_channel(b, src, i); + } + } + nir_variable *qd_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "qd_tmp"); + nir_deref_instr *qd_tmp_deref = nir_build_deref_var(b, qd_tmp); + + for (unsigned m = 0; m < 4; m++) { + for (unsigned i = 0; i < CMAT_LEN / 2; i++) { + nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]); + src_components[m][i] = nir_load_deref(b, qd_tmp_deref); + + nir_def *other_col = nir_shuffle_down(b, src_components[m][i], nir_imm_int(b, 1)); + nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i], other_col); + src_components[m][i] = nir_load_deref(b, qd_tmp_deref); + } + } + + nir_def *even = nir_inverse_ballot_imm(b, 0x5555555555555555, 32); + for (unsigned m = 0; m < 2; m++) { + for (unsigned i = 0; i < CMAT_LEN / 2; i++) { + nir_def *m0_comp = src_components[m * 2][i]; + nir_def *m1_comp = nir_shuffle_up(b, src_components[m * 2 + 1][i], nir_imm_int(b, 1)); + + nir_def *combined = nir_bcsel(b, even, m0_comp, m1_comp); + comps[m * (CMAT_LEN / 2) + i] = combined; + } + } + + nir_def *low_lane_id = nir_ilt_imm(b, lane_id, 4); + nir_def *new_lane_id_lo = nir_imul_imm(b, lane_id, 2); + nir_def *new_lane_id_hi = nir_iadd_imm(b, nir_imul_imm(b, nir_iadd_imm(b, lane_id, -4), 2), 1); + nir_def *new_lane_id = nir_bcsel(b, low_lane_id, new_lane_id_lo, new_lane_id_hi); + + for (unsigned m = 0; m < CMAT_LEN; m++) { + comps[m] = nir_shuffle(b, comps[m], new_lane_id); + } + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components)); + nir_instr_remove(&call->instr); + return true; +} + +static bool +lower_cmat_per_element_op_call(nir_builder *b, nir_cmat_call_instr *call) +{ + nir_def *src = load_cmat_src(b, call->params[3]); + nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]); + nir_function *fnptr = call->callee; + nir_def *lane_id = nir_load_subgroup_invocation(b); + + struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); + + nir_variable *elem_tmp = nir_local_variable_create(b->impl, glsl_get_cmat_element(dst_deref->type), "elemtmp"); + nir_deref_instr *elem_deref = nir_build_deref_var(b, elem_tmp); + + nir_def *comps[NIR_MAX_VEC_COMPONENTS]; + + for (unsigned i = 0; i < CMAT_LEN; i++) { + nir_def *src_elem = nir_channel(b, src, i); + nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr); + + nir_def *row_val = nir_imm_int(b, i); + nir_def *col_val = lane_id; + + if (desc.use == GLSL_CMAT_USE_B) + SWAP(col_val, row_val); + + row_val = nir_iadd(b, call->params[1].ssa, row_val); + col_val = nir_iadd(b, call->params[2].ssa, col_val); + + new_call->params[0] = nir_src_for_ssa(&elem_deref->def); + new_call->params[1] = nir_src_for_ssa(row_val); + new_call->params[2] = nir_src_for_ssa(col_val); + new_call->params[3] = nir_src_for_ssa(src_elem); + + for (unsigned p = 4; p < call->num_params; p++) { + nir_deref_instr *deref = nir_src_as_deref(call->params[p]); + nir_def *def = call->params[p].ssa; + if (deref) { + if (glsl_type_is_cmat(deref->type)) { + def = nir_build_load_deref(b, get_cmat_length(desc), + glsl_base_type_bit_size(desc.element_type), def); + def = nir_channel(b, def, i); + } + } + new_call->params[p] = nir_src_for_ssa(def); + } + nir_builder_instr_insert(b, &new_call->instr); + comps[i] = nir_build_load_deref(b, 1, glsl_base_type_bit_size(desc.element_type), &elem_deref->def, 0); + } + + nir_def *mat = nir_vec(b, comps, CMAT_LEN); + nir_store_deref(b, dst_deref, mat, nir_component_mask(src->num_components)); + + nir_instr_remove(&call->instr); + return true; +} + +static bool +lower_impl(nir_function_impl *impl, + struct hash_table *type_mapping) +{ + bool progress = false; + /* Remap all cmat temp var to array of scalars */ + nir_foreach_function_temp_variable(var, impl) { + const struct glsl_type *new_type = + remap_matrix_type(type_mapping, var->type); + if (new_type != var->type) { + var->type = new_type; + progress = true; + } + } + + /* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */ + nir_builder b = nir_builder_create(impl); + nir_foreach_block_reverse_safe (block, impl) { + nir_foreach_instr_reverse_safe (instr, block) { + b.cursor = nir_before_instr(instr); + + switch (instr->type) { + case nir_instr_type_intrinsic: { + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + switch (intr->intrinsic) { + case nir_intrinsic_cmat_length: + progress |= lower_cmat_length(&b, intr); + break; + case nir_intrinsic_cmat_construct: + progress |= lower_cmat_construct(&b, intr); + break; + case nir_intrinsic_cmat_extract: + progress |= lower_cmat_extract(&b, intr); + break; + case nir_intrinsic_cmat_insert: + progress |= lower_cmat_insert(&b, intr); + break; + case nir_intrinsic_cmat_load: + case nir_intrinsic_cmat_store: + progress |= lower_cmat_load_store(&b, type_mapping, intr); + break; + case nir_intrinsic_cmat_binary_op: + progress |= lower_cmat_binary_op(&b, intr); + break; + case nir_intrinsic_cmat_unary_op: + progress |= lower_cmat_unary_op(&b, intr); + break; + case nir_intrinsic_cmat_scalar_op: + progress |= lower_cmat_scalar_op(&b, intr); + break; + case nir_intrinsic_cmat_muladd: + progress |= lower_cmat_muladd(&b, intr); + break; + case nir_intrinsic_cmat_copy: + progress |= lower_cmat_copy(&b, intr); + break; + case nir_intrinsic_cmat_convert: + case nir_intrinsic_cmat_transpose: + progress |= lower_cmat_convert(&b, intr); + break; + default: + break; + } + break; + } + case nir_instr_type_deref: { + nir_deref_instr *deref = nir_instr_as_deref(instr); + const struct glsl_type *new_type = + remap_matrix_type(type_mapping, deref->type); + + if (new_type != deref->type) { + deref->type = new_type; + progress = true; + } + break; + } + case nir_instr_type_cmat_call: { + nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr); + switch (call->op) { + case nir_cmat_call_op_reduce: + progress |= lower_cmat_reduce_call(&b, call); + break; + case nir_cmat_call_op_reduce_finish: + progress |= lower_cmat_reduce_finish_call(&b, call); + break; + case nir_cmat_call_op_reduce_2x2: + progress |= lower_cmat_reduce_2x2_call(&b, call); + break; + case nir_cmat_call_op_per_element_op: + progress |= lower_cmat_per_element_op_call(&b, call); + break; + default: + break; + } + break; + } + default: + break; + } + } + } + return nir_progress(progress, impl, nir_metadata_none); +} + +bool +lvp_nir_lower_cooperative_matrix(nir_shader *shader) +{ + bool progress = false; + + if (!shader->info.cs.has_cooperative_matrix) + return false; + + struct hash_table *type_mapping = _mesa_pointer_hash_table_create(NULL); + /* Remap all cmat shader temp var to array of vectors */ + nir_foreach_variable_with_modes(var, shader, nir_var_shader_temp) { + const struct glsl_type *new_type = + remap_matrix_type(type_mapping, var->type); + + if (new_type != var->type) { + var->type = new_type; + progress = true; + } + } + + progress |= lower_impl(nir_shader_get_entrypoint(shader), type_mapping); + + _mesa_hash_table_destroy(type_mapping, NULL); + + nir_foreach_function_impl(fnim, shader) + nir_progress(progress, fnim, 0); + return progress; +} diff --git a/src/vulkan/runtime/vk_nir.c b/src/vulkan/runtime/vk_nir.c index b9cec0878e9..391b07616e7 100644 --- a/src/vulkan/runtime/vk_nir.c +++ b/src/vulkan/runtime/vk_nir.c @@ -168,7 +168,7 @@ vk_spirv_to_nir(struct vk_device *device, NIR_PASS(_, nir, nir_opt_deref); /* Pick off the single entrypoint that we want */ - nir_remove_non_entrypoints(nir); + nir_remove_non_cmat_call_entrypoints(nir); /* Now that we've deleted all but the main function, we can go ahead and * lower the rest of the constant initializers. We do this here so that