mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-20 05:10:11 +01:00
Merge branch 'lvp-coop-mat2-hacks' into 'main'
Draft: lavapipe: add support for lots of parts of NV_cooperative_matrix2 See merge request mesa/mesa!38964
This commit is contained in:
commit
686cf1a63f
11 changed files with 971 additions and 33 deletions
|
|
@ -126,6 +126,7 @@ propagate_invariant_instr(nir_instr *instr, struct set *invariants)
|
|||
case nir_instr_type_jump:
|
||||
case nir_instr_type_undef:
|
||||
case nir_instr_type_load_const:
|
||||
case nir_instr_type_cmat_call:
|
||||
break; /* Nothing to do */
|
||||
|
||||
case nir_instr_type_phi: {
|
||||
|
|
|
|||
|
|
@ -2324,7 +2324,13 @@ static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
|
|||
/* can't use llvm reduction intrinsics because of exec_mask */
|
||||
LLVMValueRef exec_mask = group_op_mask_vec(bld);
|
||||
nir_op reduction_op = nir_intrinsic_reduction_op(instr);
|
||||
|
||||
bool is_flt = reduction_op == nir_op_fadd ||
|
||||
reduction_op == nir_op_fmul ||
|
||||
reduction_op == nir_op_fmin ||
|
||||
reduction_op == nir_op_fmax;
|
||||
bool is_unsigned = reduction_op == nir_op_umin ||
|
||||
reduction_op == nir_op_umax;
|
||||
struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true);
|
||||
uint32_t cluster_size = 0;
|
||||
|
||||
if (instr->intrinsic == nir_intrinsic_reduce)
|
||||
|
|
@ -2338,20 +2344,57 @@ static void emit_reduce(struct lp_build_nir_soa_context *bld, LLVMValueRef src,
|
|||
src = LLVMBuildZExt(builder, src, bld->uint8_bld.vec_type, "");
|
||||
}
|
||||
|
||||
/* reduction addition optimisation passes - for coopmat */
|
||||
/* for i/fadd instead of doing it manually just zero out inactive lanes */
|
||||
if ((reduction_op == nir_op_iadd ||
|
||||
reduction_op == nir_op_fadd) &&
|
||||
cluster_size == bld->int_bld.type.length) {
|
||||
struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) :
|
||||
get_int_bld(bld, is_unsigned, bit_size, true);
|
||||
char intrinsic[64];
|
||||
uint32_t length = vec_bld->type.length;
|
||||
uint32_t src_width = bit_size;
|
||||
|
||||
src = LLVMBuildBitCast(builder, src, int_bld->vec_type, "");
|
||||
if (bit_size < 32)
|
||||
exec_mask = LLVMBuildTrunc(builder, exec_mask, int_bld->vec_type, "");
|
||||
LLVMValueRef masked_val = lp_build_and(int_bld, src, exec_mask);
|
||||
snprintf(intrinsic, sizeof intrinsic, "llvm.vector.reduce.%sadd.v%u%s%u",
|
||||
is_flt ? "f" : "",
|
||||
length, is_flt ? "f" : is_unsigned ? "u" : "i" , src_width);
|
||||
|
||||
if (is_flt)
|
||||
masked_val = LLVMBuildBitCast(builder, masked_val, vec_bld->vec_type, "");
|
||||
|
||||
LLVMValueRef args[2];
|
||||
int num_args = is_flt ? 2 : 1;
|
||||
if (is_flt) {
|
||||
args[0] = lp_build_const_elem(gallivm, vec_bld->type, 0);
|
||||
args[1] = masked_val;
|
||||
} else {
|
||||
args[0] = masked_val;
|
||||
}
|
||||
LLVMValueRef res = lp_build_intrinsic(builder, intrinsic, vec_bld->elem_type, args, num_args, 0);
|
||||
|
||||
res = lp_build_broadcast(gallivm, vec_bld->vec_type, res);
|
||||
LLVMValueRef swizzle[LP_MAX_VECTOR_LENGTH];
|
||||
for (uint32_t i = 0; i < vec_bld->type.length; i++)
|
||||
swizzle[i] = lp_build_const_int32(gallivm, i);
|
||||
|
||||
LLVMValueRef undef = LLVMGetUndef(vec_bld->vec_type);
|
||||
result[0] = LLVMBuildShuffleVector(
|
||||
builder, res, undef, LLVMConstVector(swizzle, vec_bld->type.length), "");
|
||||
return;
|
||||
}
|
||||
|
||||
LLVMValueRef res_store = NULL;
|
||||
LLVMValueRef scan_store;
|
||||
struct lp_build_context *int_bld = get_int_bld(bld, true, bit_size, true);
|
||||
|
||||
|
||||
res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
|
||||
scan_store = lp_build_alloca(gallivm, int_bld->elem_type, "");
|
||||
|
||||
struct lp_build_context elem_bld;
|
||||
bool is_flt = reduction_op == nir_op_fadd ||
|
||||
reduction_op == nir_op_fmul ||
|
||||
reduction_op == nir_op_fmin ||
|
||||
reduction_op == nir_op_fmax;
|
||||
bool is_unsigned = reduction_op == nir_op_umin ||
|
||||
reduction_op == nir_op_umax;
|
||||
|
||||
struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld, bit_size, true) :
|
||||
get_int_bld(bld, is_unsigned, bit_size, true);
|
||||
|
|
|
|||
|
|
@ -74,6 +74,13 @@
|
|||
#define LVP_SAMPLE_COUNTS (VK_SAMPLE_COUNT_1_BIT | VK_SAMPLE_COUNT_4_BIT | \
|
||||
VK_SAMPLE_COUNT_8_BIT)
|
||||
|
||||
extern unsigned lp_native_vector_width;
|
||||
|
||||
static bool has_cooperative_matrix(void) {
|
||||
/* only support coopmat if we have 8 wide */
|
||||
return (lp_native_vector_width / 32) >= 8;
|
||||
}
|
||||
|
||||
VKAPI_ATTR VkResult VKAPI_CALL lvp_EnumerateInstanceVersion(uint32_t* pApiVersion)
|
||||
{
|
||||
*pApiVersion = LVP_API_VERSION;
|
||||
|
|
@ -124,6 +131,7 @@ static const struct vk_device_extension_table lvp_device_extensions_supported =
|
|||
.KHR_buffer_device_address = true,
|
||||
.KHR_create_renderpass2 = true,
|
||||
.KHR_compute_shader_derivatives = true,
|
||||
.KHR_cooperative_matrix = true,
|
||||
.KHR_copy_commands2 = true,
|
||||
.KHR_copy_memory_indirect = true,
|
||||
.KHR_dedicated_allocation = true,
|
||||
|
|
@ -300,6 +308,7 @@ static const struct vk_device_extension_table lvp_device_extensions_supported =
|
|||
.GOOGLE_decorate_string = true,
|
||||
.GOOGLE_hlsl_functionality1 = true,
|
||||
.GOOGLE_user_type = true,
|
||||
.NV_cooperative_matrix2 = true,
|
||||
};
|
||||
|
||||
static bool
|
||||
|
|
@ -857,11 +866,18 @@ lvp_get_features(const struct lvp_physical_device *pdevice,
|
|||
/* VK_KHR_unified_image_layouts */
|
||||
.unifiedImageLayouts = true,
|
||||
.unifiedImageLayoutsVideo = true,
|
||||
|
||||
/* VK_KHR_cooperative_matrix */
|
||||
.cooperativeMatrix = has_cooperative_matrix(),
|
||||
.cooperativeMatrixRobustBufferAccess = has_cooperative_matrix(),
|
||||
|
||||
.cooperativeMatrixFlexibleDimensions = true,
|
||||
.cooperativeMatrixConversions = true,
|
||||
.cooperativeMatrixReductions = true,
|
||||
.cooperativeMatrixPerElementOperations = true,
|
||||
};
|
||||
}
|
||||
|
||||
extern unsigned lp_native_vector_width;
|
||||
|
||||
static VkImageLayout lvp_host_copy_image_layouts[] = {
|
||||
VK_IMAGE_LAYOUT_GENERAL,
|
||||
VK_IMAGE_LAYOUT_COLOR_ATTACHMENT_OPTIMAL,
|
||||
|
|
@ -1304,6 +1320,9 @@ lvp_get_properties(const struct lvp_physical_device *device, struct vk_propertie
|
|||
|
||||
/* VK_KHR_compute_shader_derivatives */
|
||||
.meshAndTaskShaderDerivatives = true,
|
||||
|
||||
/* VK_NV_cooperative_matrix2 */
|
||||
.cooperativeMatrixFlexibleDimensionsMaxDimension = 1024,
|
||||
};
|
||||
|
||||
/* Vulkan 1.0 */
|
||||
|
|
@ -1369,6 +1388,10 @@ lvp_get_properties(const struct lvp_physical_device *device, struct vk_propertie
|
|||
/* VK_EXT_mesh_shader */
|
||||
p->maxMeshPayloadAndSharedMemorySize = p->maxTaskPayloadSize + p->maxMeshSharedMemorySize; /* 28K min required */
|
||||
p->maxMeshPayloadAndOutputMemorySize = p->maxTaskPayloadSize + p->maxMeshOutputMemorySize; /* 47K min required */
|
||||
|
||||
/* VK_KHR_cooperative_matrix */
|
||||
p->cooperativeMatrixSupportedStages = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
|
||||
}
|
||||
|
||||
static VkResult VKAPI_CALL
|
||||
|
|
@ -1495,8 +1518,6 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_CreateInstance(
|
|||
return vk_error(NULL, result);
|
||||
}
|
||||
|
||||
instance->apiVersion = LVP_API_VERSION;
|
||||
|
||||
instance->vk.physical_devices.enumerate = lvp_enumerate_physical_devices;
|
||||
instance->vk.physical_devices.destroy = lvp_destroy_physical_device;
|
||||
|
||||
|
|
@ -2394,9 +2415,7 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_BindBufferMemory2(VkDevice _device,
|
|||
VK_FROM_HANDLE(lvp_buffer, buffer, pBindInfos[i].buffer);
|
||||
VkBindMemoryStatusKHR *status = (void*)vk_find_struct_const(&pBindInfos[i], BIND_MEMORY_STATUS_KHR);
|
||||
|
||||
buffer->mem = mem;
|
||||
buffer->map = (char*)mem->map + pBindInfos[i].memoryOffset;
|
||||
buffer->offset = pBindInfos[i].memoryOffset;
|
||||
device->pscreen->resource_bind_backing(device->pscreen,
|
||||
buffer->bo,
|
||||
mem->pmem,
|
||||
|
|
@ -2867,3 +2886,85 @@ VKAPI_ATTR void VKAPI_CALL lvp_GetRenderingAreaGranularityKHR(
|
|||
VkExtent2D tile_size = {64, 64};
|
||||
*pGranularity = tile_size;
|
||||
}
|
||||
|
||||
VKAPI_ATTR VkResult VKAPI_CALL lvp_GetPhysicalDeviceCooperativeMatrixPropertiesKHR(
|
||||
VkPhysicalDevice physicalDevice,
|
||||
uint32_t *pPropertyCount,
|
||||
VkCooperativeMatrixPropertiesKHR *pProperties)
|
||||
{
|
||||
VK_OUTARRAY_MAKE_TYPED(VkCooperativeMatrixPropertiesKHR, out, pProperties, pPropertyCount);
|
||||
|
||||
vk_outarray_append_typed(VkCooperativeMatrixPropertiesKHR, &out, p)
|
||||
{
|
||||
*p = (struct VkCooperativeMatrixPropertiesKHR){
|
||||
.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR,
|
||||
.MSize = 8,
|
||||
.NSize = 8,
|
||||
.KSize = 8,
|
||||
.AType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.BType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.CType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.ResultType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.saturatingAccumulation = false,
|
||||
.scope = VK_SCOPE_SUBGROUP_KHR
|
||||
};
|
||||
}
|
||||
|
||||
vk_outarray_append_typed(VkCooperativeMatrixPropertiesKHR, &out, p)
|
||||
{
|
||||
*p = (struct VkCooperativeMatrixPropertiesKHR){
|
||||
.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR,
|
||||
.MSize = 8,
|
||||
.NSize = 8,
|
||||
.KSize = 8,
|
||||
.AType = VK_COMPONENT_TYPE_UINT8_KHR,
|
||||
.BType = VK_COMPONENT_TYPE_UINT8_KHR,
|
||||
.CType = VK_COMPONENT_TYPE_UINT32_KHR,
|
||||
.ResultType = VK_COMPONENT_TYPE_UINT32_KHR,
|
||||
.saturatingAccumulation = false,
|
||||
.scope = VK_SCOPE_SUBGROUP_KHR
|
||||
};
|
||||
}
|
||||
return vk_outarray_status(&out);
|
||||
}
|
||||
|
||||
VKAPI_ATTR VkResult VKAPI_CALL
|
||||
lvp_GetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(
|
||||
VkPhysicalDevice physicalDevice, uint32_t *pPropertyCount,
|
||||
VkCooperativeMatrixFlexibleDimensionsPropertiesNV *pProperties)
|
||||
{
|
||||
VK_OUTARRAY_MAKE_TYPED(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, out, pProperties, pPropertyCount);
|
||||
|
||||
vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p)
|
||||
{
|
||||
*p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){
|
||||
.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV,
|
||||
.MGranularity = 8,
|
||||
.NGranularity = 8,
|
||||
.KGranularity = 8,
|
||||
.AType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.BType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.CType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.ResultType = VK_COMPONENT_TYPE_FLOAT16_KHR,
|
||||
.saturatingAccumulation = false,
|
||||
.scope = VK_SCOPE_SUBGROUP_KHR
|
||||
};
|
||||
}
|
||||
|
||||
vk_outarray_append_typed(VkCooperativeMatrixFlexibleDimensionsPropertiesNV, &out, p)
|
||||
{
|
||||
*p = (struct VkCooperativeMatrixFlexibleDimensionsPropertiesNV){
|
||||
.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV,
|
||||
.MGranularity = 8,
|
||||
.NGranularity = 8,
|
||||
.KGranularity = 8,
|
||||
.AType = VK_COMPONENT_TYPE_UINT8_KHR,
|
||||
.BType = VK_COMPONENT_TYPE_UINT8_KHR,
|
||||
.CType = VK_COMPONENT_TYPE_UINT32_KHR,
|
||||
.ResultType = VK_COMPONENT_TYPE_UINT32_KHR,
|
||||
.saturatingAccumulation = false,
|
||||
.scope = VK_SCOPE_SUBGROUP_KHR
|
||||
};
|
||||
}
|
||||
return vk_outarray_status(&out);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -73,11 +73,11 @@ struct descriptor_buffer_offset {
|
|||
struct lvp_render_attachment {
|
||||
struct lvp_image_view *imgv;
|
||||
VkResolveModeFlags resolve_mode;
|
||||
bool read_only;
|
||||
struct lvp_image_view *resolve_imgv;
|
||||
VkAttachmentLoadOp load_op;
|
||||
VkAttachmentStoreOp store_op;
|
||||
VkClearValue clear_value;
|
||||
bool read_only;
|
||||
};
|
||||
|
||||
struct lvp_conditional_rendering_state {
|
||||
|
|
@ -2961,21 +2961,22 @@ static void handle_copy_query_pool_results(struct vk_cmd_queue_entry *cmd,
|
|||
unsigned result_size = copycmd->flags & VK_QUERY_RESULT_64_BIT ? 8 : 4;
|
||||
for (unsigned i = copycmd->first_query; i < copycmd->first_query + copycmd->query_count; i++) {
|
||||
unsigned offset = copycmd->dst_offset + (copycmd->stride * (i - copycmd->first_query));
|
||||
|
||||
if (pool->base_type >= PIPE_QUERY_TYPES) {
|
||||
struct pipe_transfer *transfer;
|
||||
uint8_t *map = pipe_buffer_map(state->pctx, lvp_buffer_from_handle(copycmd->dst_buffer)->bo, PIPE_MAP_WRITE, &transfer);
|
||||
map += offset;
|
||||
|
||||
void *data = &pool->queries;
|
||||
|
||||
if (copycmd->flags & VK_QUERY_RESULT_64_BIT) {
|
||||
uint64_t *dst = (uint64_t *)map;
|
||||
uint64_t *src = (uint64_t *)pool->data;
|
||||
uint64_t *src = (uint64_t *)data;
|
||||
*dst = src[i];
|
||||
if (copycmd->flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT)
|
||||
*(dst + 1) = 1;
|
||||
} else {
|
||||
uint32_t *dst = (uint32_t *)map;
|
||||
uint64_t *src = (uint64_t *)pool->data;
|
||||
uint64_t *src = (uint64_t *)data;
|
||||
*dst = (uint32_t) (src[i] & UINT32_MAX);
|
||||
if (copycmd->flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT)
|
||||
*(dst + 1) = 1;
|
||||
|
|
@ -4542,8 +4543,8 @@ handle_write_acceleration_structures_properties(struct vk_cmd_queue_entry *cmd,
|
|||
struct vk_cmd_write_acceleration_structures_properties_khr *write = &cmd->u.write_acceleration_structures_properties_khr;
|
||||
|
||||
VK_FROM_HANDLE(lvp_query_pool, pool, write->query_pool);
|
||||
|
||||
uint64_t *dst = pool->data;
|
||||
void *data = &pool->queries;
|
||||
uint64_t *dst = data;
|
||||
dst += write->first_query;
|
||||
|
||||
for (uint32_t i = 0; i < write->acceleration_structure_count; i++) {
|
||||
|
|
|
|||
|
|
@ -38,10 +38,6 @@
|
|||
|
||||
#include "gallivm/lp_bld_debug.h"
|
||||
|
||||
#define SPIR_V_MAGIC_NUMBER 0x07230203
|
||||
|
||||
#define MAX_DYNAMIC_STATES 72
|
||||
|
||||
typedef void (*cso_destroy_func)(struct pipe_context*, void*);
|
||||
|
||||
static void
|
||||
|
|
@ -351,6 +347,26 @@ lvp_shader_lower(struct lvp_device *pdevice, nir_shader *nir, struct lvp_pipelin
|
|||
NIR_PASS(_, nir, nir_lower_system_values);
|
||||
NIR_PASS(_, nir, nir_lower_is_helper_invocation);
|
||||
|
||||
bool progress;
|
||||
NIR_PASS(progress, nir, nir_lower_cooperative_matrix_flexible_dimensions, 8, 8, 8);
|
||||
if (progress) {
|
||||
NIR_PASS(_, nir, nir_opt_deref);
|
||||
NIR_PASS(_, nir, nir_opt_dce);
|
||||
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
|
||||
}
|
||||
NIR_PASS(progress, nir, lvp_nir_lower_cooperative_matrix);
|
||||
if (progress) {
|
||||
NIR_PASS(_, nir, nir_opt_dce);
|
||||
NIR_PASS(progress, nir, nir_inline_functions);
|
||||
nir_remove_non_entrypoints(nir); /* remove the late inlined functions */
|
||||
if (progress) {
|
||||
NIR_PASS(_, nir, nir_opt_copy_prop_vars);
|
||||
NIR_PASS(_, nir, nir_opt_copy_prop);
|
||||
}
|
||||
NIR_PASS(_, nir, nir_opt_deref);
|
||||
NIR_PASS(_, nir, nir_opt_dce);
|
||||
}
|
||||
|
||||
const struct nir_lower_compute_system_values_options compute_system_values = {0};
|
||||
NIR_PASS(_, nir, nir_lower_compute_system_values, &compute_system_values);
|
||||
|
||||
|
|
|
|||
|
|
@ -168,8 +168,6 @@ struct lvp_physical_device {
|
|||
struct lvp_instance {
|
||||
struct vk_instance vk;
|
||||
|
||||
uint32_t apiVersion;
|
||||
|
||||
uint64_t debug_flags;
|
||||
|
||||
struct pipe_loader_device *devs;
|
||||
|
|
@ -579,10 +577,8 @@ struct lvp_event {
|
|||
struct lvp_buffer {
|
||||
struct vk_buffer vk;
|
||||
|
||||
struct lvp_device_memory *mem;
|
||||
struct pipe_resource *bo;
|
||||
uint64_t total_size;
|
||||
uint64_t offset;
|
||||
void *map;
|
||||
struct pipe_transfer *transfer;
|
||||
};
|
||||
|
|
@ -605,7 +601,6 @@ struct lvp_buffer_view {
|
|||
struct lvp_query_pool {
|
||||
struct vk_query_pool vk;
|
||||
enum pipe_query_type base_type;
|
||||
void *data; /* Used by queries that are not implemented by pipe_query */
|
||||
struct pipe_query *queries[0];
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_CreateQueryPool(
|
|||
return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
|
||||
|
||||
pool->base_type = pipeq;
|
||||
pool->data = &pool->queries;
|
||||
|
||||
*pQueryPool = lvp_query_pool_to_handle(pool);
|
||||
return VK_SUCCESS;
|
||||
|
|
@ -122,6 +121,7 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_GetQueryPoolResults(
|
|||
|
||||
device->vk.dispatch_table.DeviceWaitIdle(_device);
|
||||
|
||||
void *data = &pool->queries;
|
||||
for (unsigned i = firstQuery; i < firstQuery + queryCount; i++) {
|
||||
uint8_t *dest = (uint8_t *)((char *)pData + (stride * (i - firstQuery)));
|
||||
union pipe_query_result result;
|
||||
|
|
@ -130,13 +130,13 @@ VKAPI_ATTR VkResult VKAPI_CALL lvp_GetQueryPoolResults(
|
|||
if (pool->base_type >= PIPE_QUERY_TYPES) {
|
||||
if (flags & VK_QUERY_RESULT_64_BIT) {
|
||||
uint64_t *dst = (uint64_t *)dest;
|
||||
uint64_t *src = (uint64_t *)pool->data;
|
||||
uint64_t *src = (uint64_t *)data;
|
||||
*dst = src[i];
|
||||
if (flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT)
|
||||
*(dst + 1) = 1;
|
||||
} else {
|
||||
uint32_t *dst = (uint32_t *)dest;
|
||||
uint64_t *src = (uint64_t *)pool->data;
|
||||
uint64_t *src = (uint64_t *)data;
|
||||
*dst = src[i];
|
||||
if (flags & VK_QUERY_RESULT_WITH_AVAILABILITY_BIT)
|
||||
*(dst + 1) = 1;
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ lvp_entrypoints = custom_target(
|
|||
)
|
||||
|
||||
liblvp_files = files(
|
||||
'nir/lvp_nir_lower_cooperative_matrix.c',
|
||||
'nir/lvp_nir_lower_exec_graph.c',
|
||||
'nir/lvp_nir_lower_input_attachments.c',
|
||||
'nir/lvp_nir_lower_pipeline_layout.c',
|
||||
|
|
|
|||
|
|
@ -121,4 +121,6 @@ bool lvp_nir_lower_sparse_residency(struct nir_shader *shader);
|
|||
bool lvp_nir_opt_robustness(struct nir_shader *shader, struct lvp_device *device,
|
||||
struct vk_pipeline_robustness_state *robustness);
|
||||
|
||||
bool lvp_nir_lower_cooperative_matrix(nir_shader *shader);
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,778 @@
|
|||
/*
|
||||
* Copyright © 2025 Red Hat
|
||||
*
|
||||
* SPDX-License-Identifier: MIT
|
||||
*/
|
||||
#include "lvp_nir.h"
|
||||
|
||||
extern unsigned lp_native_vector_width;
|
||||
|
||||
#define MAX_CMAT_LEN 16
|
||||
#define CMAT_LEN (lp_native_vector_width / 32)
|
||||
|
||||
/* This pass lowers cooperative matrix.
|
||||
*
|
||||
* for lavapipe we advertise 8x8 matrix.
|
||||
* This means we can store vec8[8] and get the backend to do the right thing.
|
||||
*/
|
||||
static unsigned
|
||||
get_cmat_size(struct glsl_cmat_description matrix_desc)
|
||||
{
|
||||
return matrix_desc.cols * matrix_desc.rows;
|
||||
}
|
||||
|
||||
static unsigned
|
||||
get_cmat_length(struct glsl_cmat_description matrix_desc)
|
||||
{
|
||||
return get_cmat_size(matrix_desc) / CMAT_LEN;
|
||||
}
|
||||
|
||||
static const struct glsl_type *
|
||||
remap_matrix_type(struct hash_table *mapping, const struct glsl_type *orig)
|
||||
{
|
||||
struct hash_entry *entry = _mesa_hash_table_search(mapping, orig);
|
||||
|
||||
if (entry)
|
||||
return entry->data;
|
||||
|
||||
const struct glsl_type *new_type = orig;
|
||||
|
||||
if (glsl_type_is_cmat(orig)) {
|
||||
struct glsl_cmat_description matrix_desc =
|
||||
*glsl_get_cmat_description(orig);
|
||||
|
||||
new_type = glsl_vector_type(matrix_desc.element_type, get_cmat_length(matrix_desc));
|
||||
} else if (glsl_type_is_array(orig)) {
|
||||
const struct glsl_type *elem_type = glsl_get_array_element(orig);
|
||||
const struct glsl_type *new_elem_type =
|
||||
remap_matrix_type(mapping, elem_type);
|
||||
|
||||
if (elem_type != new_elem_type) {
|
||||
new_type = glsl_array_type(new_elem_type, glsl_get_length(orig),
|
||||
glsl_get_explicit_stride(orig));
|
||||
}
|
||||
}
|
||||
_mesa_hash_table_insert(mapping, orig, (void *)new_type);
|
||||
return new_type;
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
load_cmat_deref(nir_builder *b, nir_deref_instr *src)
|
||||
{
|
||||
struct glsl_cmat_description matrix_desc =
|
||||
*glsl_get_cmat_description(src->type);
|
||||
|
||||
return nir_build_load_deref(
|
||||
b, get_cmat_length(matrix_desc),
|
||||
glsl_base_type_bit_size(matrix_desc.element_type), &src->def, 0);
|
||||
}
|
||||
|
||||
static ALWAYS_INLINE nir_def *
|
||||
load_cmat_src(nir_builder *b, nir_src src)
|
||||
{
|
||||
return load_cmat_deref(b, nir_src_as_deref(src));
|
||||
}
|
||||
|
||||
static ALWAYS_INLINE struct glsl_cmat_description
|
||||
cmat_src_desc(nir_src src)
|
||||
{
|
||||
nir_deref_instr *deref = nir_src_as_deref(src);
|
||||
return *glsl_get_cmat_description(deref->type);
|
||||
}
|
||||
|
||||
static void
|
||||
store_cmat_deref(nir_builder *b, nir_deref_instr *dst, nir_def *val)
|
||||
{
|
||||
ASSERTED struct glsl_cmat_description matrix_desc =
|
||||
*glsl_get_cmat_description(dst->type);
|
||||
|
||||
assert(val->bit_size == glsl_base_type_bit_size(matrix_desc.element_type));
|
||||
assert(val->num_components == get_cmat_length(matrix_desc));
|
||||
|
||||
nir_store_deref(b, dst, val, ~0);
|
||||
}
|
||||
|
||||
static ALWAYS_INLINE void
|
||||
store_cmat_src(nir_builder *b, nir_src dst_src, nir_def *val)
|
||||
{
|
||||
store_cmat_deref(b, nir_src_as_deref(dst_src), val);
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_copy(nir_builder *b, nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_build_copy_deref(b, intr->src[0].ssa, intr->src[1].ssa);
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use,
|
||||
enum glsl_cmat_use dst_use)
|
||||
{
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
|
||||
nir_def *out_comps[NIR_MAX_VEC_COMPONENTS] = {};
|
||||
unsigned num_comps = src->num_components;
|
||||
for (unsigned i = 0; i < num_comps; i++) {
|
||||
comps[i] = nir_channel(b, src, i);
|
||||
out_comps[i] = nir_imm_zero(b, 1, comps[i]->bit_size);
|
||||
}
|
||||
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
|
||||
/* construct the outer row */
|
||||
for (unsigned i = 0; i < num_comps; i++) {
|
||||
|
||||
for (unsigned j = 0; j < CMAT_LEN; j++) {
|
||||
nir_def *else_val = out_comps[i];
|
||||
nir_def *active_lane = nir_ieq(b, lane_id, nir_imm_int(b, j));
|
||||
|
||||
out_comps[i] = nir_read_invocation(b, comps[j], nir_imm_int(b, i));
|
||||
|
||||
out_comps[i] = nir_bcsel(b, active_lane, out_comps[i], else_val);
|
||||
}
|
||||
}
|
||||
return nir_vec(b, out_comps, num_comps);
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type)
|
||||
{
|
||||
if (dst_type == src_type)
|
||||
return src;
|
||||
|
||||
nir_op op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_type),
|
||||
nir_get_nir_type_for_glsl_base_type(dst_type), nir_rounding_mode_undef);
|
||||
|
||||
return nir_build_alu1(b, op, src);
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_convert(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
const bool transpose = intr->intrinsic == nir_intrinsic_cmat_transpose;
|
||||
struct glsl_cmat_description dst_desc = cmat_src_desc(intr->src[0]);
|
||||
struct glsl_cmat_description src_desc = cmat_src_desc(intr->src[1]);
|
||||
|
||||
enum glsl_base_type dst_element_type = dst_desc.element_type;
|
||||
enum glsl_base_type src_element_type = src_desc.element_type;
|
||||
|
||||
enum glsl_cmat_use dst_use = dst_desc.use;
|
||||
enum glsl_cmat_use src_use = src_desc.use;
|
||||
|
||||
nir_def *cmat = load_cmat_src(b, intr->src[1]);
|
||||
|
||||
if (dst_use == GLSL_CMAT_USE_ACCUMULATOR)
|
||||
dst_use = GLSL_CMAT_USE_A;
|
||||
if (src_use == GLSL_CMAT_USE_ACCUMULATOR)
|
||||
src_use = GLSL_CMAT_USE_A;
|
||||
|
||||
if (transpose) {
|
||||
if (src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_B)
|
||||
src_use = dst_use;
|
||||
if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_A)
|
||||
src_use = dst_use;
|
||||
}
|
||||
|
||||
nir_def *ret = cmat;
|
||||
if (dst_use != src_use) {
|
||||
ret = convert_use(b, cmat, src_use, dst_use);
|
||||
}
|
||||
ret = convert_base_type(b, ret, src_element_type, dst_element_type);
|
||||
store_cmat_src(b, intr->src[0], ret);
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_load_store(nir_builder *b,
|
||||
struct hash_table *type_mapping,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load;
|
||||
const struct glsl_cmat_description desc = cmat_src_desc(intr->src[!is_load]);
|
||||
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
|
||||
nir_deref_instr *cmat_deref = nir_src_as_deref(intr->src[!is_load]);
|
||||
nir_deref_instr *deref = nir_src_as_deref(intr->src[is_load]);
|
||||
nir_def *stride = intr->src[2].ssa;
|
||||
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
unsigned type_size_B = glsl_base_type_bit_size(desc.element_type) / 8;
|
||||
const uint32_t ptr_stride = glsl_get_bit_size(deref->type) / 8 * glsl_get_vector_elements(deref->type);
|
||||
deref = nir_build_deref_cast(b, &deref->def, deref->modes, deref->type, ptr_stride);
|
||||
const struct glsl_type *cmat_type = remap_matrix_type(type_mapping, cmat_deref->type);
|
||||
cmat_deref = nir_build_deref_cast(b, &cmat_deref->def, cmat_deref->modes,
|
||||
cmat_type, 0);
|
||||
|
||||
/* store B matrix transposed */
|
||||
if (desc.use == GLSL_CMAT_USE_B)
|
||||
layout =
|
||||
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
|
||||
unsigned idx_bits = deref->def.bit_size;
|
||||
nir_def *vars[MAX_CMAT_LEN];
|
||||
|
||||
if (is_load) {
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
vars[i] = nir_undef(b, 1, 16);
|
||||
}
|
||||
} else {
|
||||
nir_def *src = load_cmat_src(b, intr->src[!is_load]);
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
vars[i] = nir_channel(b, src, i);
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
nir_def *col_offset = lane_id;
|
||||
nir_def *row_offset = nir_imm_int(b, i);
|
||||
|
||||
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
|
||||
SWAP(col_offset, row_offset);
|
||||
}
|
||||
|
||||
col_offset = nir_imul(b, col_offset, stride);
|
||||
col_offset = nir_u2uN(b, col_offset, idx_bits);
|
||||
row_offset = nir_u2uN(b, row_offset, idx_bits);
|
||||
|
||||
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(b, deref, col_offset);
|
||||
|
||||
iter_deref = nir_build_deref_cast(b, &iter_deref->def,
|
||||
deref->modes,
|
||||
glsl_scalar_type(desc.element_type),
|
||||
type_size_B);
|
||||
iter_deref = nir_build_deref_ptr_as_array(b, iter_deref, row_offset);
|
||||
|
||||
if (is_load) {
|
||||
vars[i] = nir_load_deref(b, iter_deref);
|
||||
} else {
|
||||
nir_store_deref(b, iter_deref, vars[i], ~0);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_load) {
|
||||
nir_def *mat = nir_vec(b, vars, CMAT_LEN);
|
||||
nir_store_deref(b, cmat_deref, mat, nir_component_mask(mat->num_components));
|
||||
}
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_construct(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
nir_def *elem = intr->src[1].ssa;
|
||||
|
||||
nir_def *r = nir_replicate(b, elem, get_cmat_length(desc));
|
||||
|
||||
nir_store_deref(b, dst_deref, r, nir_component_mask(r->num_components));
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_extract(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_def *mat = load_cmat_src(b, intr->src[0]);
|
||||
nir_def *index = intr->src[1].ssa;
|
||||
nir_def *elem = nir_vector_extract(b, mat, index);
|
||||
nir_def_replace(&intr->def, elem);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_insert(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_def *elem = intr->src[1].ssa;
|
||||
nir_def *mat = load_cmat_src(b, intr->src[2]);
|
||||
nir_def *index = intr->src[3].ssa;
|
||||
|
||||
nir_def *r = nir_vector_insert(b, mat, elem, index);
|
||||
store_cmat_src(b, intr->src[0], r);
|
||||
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_binary_op(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_def *src_a = load_cmat_src(b, intr->src[1]);
|
||||
nir_def *src_b = load_cmat_src(b, intr->src[2]);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
||||
nir_def *ret = nir_build_alu2(b, op, src_a, src_b);
|
||||
store_cmat_src(b, intr->src[0], ret);
|
||||
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_unary_op(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_def *src = load_cmat_src(b, intr->src[1]);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
||||
nir_def *ret = nir_build_alu1(b, op, src);
|
||||
store_cmat_src(b, intr->src[0], ret);
|
||||
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_scalar_op(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_def *src_a = load_cmat_src(b, intr->src[1]);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
||||
nir_def *ret = nir_build_alu2(b, op, src_a, intr->src[2].ssa);
|
||||
store_cmat_src(b, intr->src[0], ret);
|
||||
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_length(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
nir_def_replace(&intr->def, nir_imm_int(b, CMAT_LEN));
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_muladd(nir_builder *b,
|
||||
nir_intrinsic_instr *intr)
|
||||
{
|
||||
const struct glsl_cmat_description a_desc = cmat_src_desc(intr->src[1]);
|
||||
const struct glsl_cmat_description b_desc = cmat_src_desc(intr->src[2]);
|
||||
const struct glsl_cmat_description c_desc = cmat_src_desc(intr->src[3]);
|
||||
nir_def *cmat_a = load_cmat_src(b, intr->src[1]);
|
||||
nir_def *cmat_b = load_cmat_src(b, intr->src[2]);
|
||||
nir_def *cmat_c = load_cmat_src(b, intr->src[3]);
|
||||
|
||||
unsigned a_length = get_cmat_length(a_desc);
|
||||
unsigned b_length = get_cmat_length(b_desc);
|
||||
unsigned c_length = get_cmat_length(c_desc);
|
||||
nir_def *a_comps[NIR_MAX_VEC_COMPONENTS];
|
||||
nir_def *b_comps[NIR_MAX_VEC_COMPONENTS];
|
||||
nir_def *c_comps[NIR_MAX_VEC_COMPONENTS];
|
||||
nir_def *d_comps[NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
for (unsigned i = 0; i < a_length; i++)
|
||||
a_comps[i] = nir_channel(b, cmat_a, i);
|
||||
|
||||
for (unsigned i = 0; i < b_length; i++)
|
||||
b_comps[i] = nir_channel(b, cmat_b, i);
|
||||
|
||||
for (unsigned i = 0; i < c_length; i++)
|
||||
c_comps[i] = nir_channel(b, cmat_c, i);
|
||||
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
int accum_bit_size = glsl_base_type_bit_size(c_desc.element_type);
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
nir_def *ref = nir_imm_zero(b, 1, glsl_base_type_bit_size(c_desc.element_type));
|
||||
for (unsigned j = 0; j < CMAT_LEN; j++) {
|
||||
nir_def *outer_else_val = ref;
|
||||
ref = nir_imm_zero(b, 1, glsl_base_type_bit_size(c_desc.element_type));
|
||||
|
||||
nir_def *a_i = a_comps[i];
|
||||
nir_def *b_j = b_comps[j]; /* B is stored transposed */
|
||||
nir_def *val;
|
||||
if (glsl_base_type_is_integer(c_desc.element_type)) {
|
||||
a_i = nir_u2uN(b, a_i, accum_bit_size);
|
||||
b_j = nir_u2uN(b, b_j, accum_bit_size);
|
||||
val = nir_imul(b, a_i, b_j);
|
||||
ref = nir_iadd(b, ref,val);
|
||||
} else {
|
||||
val = nir_fmul(b, a_i, b_j);
|
||||
ref = nir_fadd(b, ref, val);
|
||||
}
|
||||
|
||||
if (glsl_base_type_is_integer(c_desc.element_type)) {
|
||||
ref = nir_reduce(b, ref, .reduction_op = nir_op_iadd);
|
||||
} else {
|
||||
ref = nir_reduce(b, ref, .reduction_op = nir_op_fadd);
|
||||
}
|
||||
|
||||
nir_def *lane = nir_ieq_imm(b, lane_id, j);
|
||||
ref = nir_bcsel(b, lane, ref, outer_else_val);
|
||||
}
|
||||
|
||||
if (glsl_base_type_is_integer(c_desc.element_type)) {
|
||||
ref = nir_iadd(b, ref, c_comps[i]);
|
||||
} else {
|
||||
ref = nir_fadd(b, ref, c_comps[i]);
|
||||
}
|
||||
d_comps[i] = ref;
|
||||
}
|
||||
nir_def *ret = nir_vec(b, d_comps, CMAT_LEN);
|
||||
store_cmat_src(b, intr->src[0], ret);
|
||||
nir_instr_remove(&intr->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_reduce_finish_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src0_deref = nir_src_as_deref(call->params[1]);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src0_deref->type);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
|
||||
nir_def *src0 = load_cmat_src(b, call->params[1]);
|
||||
nir_def *src1 = load_cmat_src(b, call->params[2]);
|
||||
|
||||
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
|
||||
|
||||
if (reduce & NIR_CMAT_REDUCE_COLUMN) {
|
||||
nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp");
|
||||
/* All of the rows contains the same data, so just reduce both first rows. */
|
||||
nir_def *row_accum0 = nir_channel(b, src0, 0);
|
||||
nir_def *row_accum1 = nir_channel(b, src1, 0);
|
||||
|
||||
nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp);
|
||||
|
||||
nir_call(b, fnptr, &col_tmp_deref->def, row_accum0, row_accum1);
|
||||
|
||||
nir_def *first_col = nir_load_deref(b, col_tmp_deref);
|
||||
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++)
|
||||
comps[i] = first_col;
|
||||
} else if (reduce & NIR_CMAT_REDUCE_ROW) {
|
||||
for (unsigned i = 0; i < CMAT_LEN; ++i) {
|
||||
nir_def *row0_accum = nir_channel(b, src0, i);
|
||||
nir_def *row1_accum = nir_channel(b, src1, i);
|
||||
|
||||
nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp");
|
||||
nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp);
|
||||
|
||||
nir_call(b, fnptr, &row_tmp_deref->def, row0_accum, row1_accum);
|
||||
|
||||
nir_def *row = nir_load_deref(b, row_tmp_deref);
|
||||
comps[i] = row;
|
||||
}
|
||||
}
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_reduce_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_cmat_reduce reduce = nir_cmat_call_reduce_flags(call);
|
||||
nir_def *src = load_cmat_src(b, call->params[1]);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
|
||||
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS] = {};
|
||||
for (unsigned i = 0; i < CMAT_LEN; ++i) {
|
||||
comps[i] = nir_channel(b, src, i);
|
||||
}
|
||||
|
||||
if (reduce & NIR_CMAT_REDUCE_COLUMN) {
|
||||
nir_variable *col_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "col_tmp");
|
||||
|
||||
nir_deref_instr *col_tmp_deref = nir_build_deref_var(b, col_tmp);
|
||||
nir_store_deref(b, col_tmp_deref, comps[0], 1);
|
||||
|
||||
for (unsigned i = 1; i < CMAT_LEN; i++) {
|
||||
nir_def *col_accum_val = nir_load_deref(b, col_tmp_deref);
|
||||
nir_call(b, fnptr, &col_tmp_deref->def, col_accum_val, comps[i]);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++)
|
||||
comps[i] = nir_load_deref(b, col_tmp_deref);
|
||||
}
|
||||
|
||||
if (reduce & NIR_CMAT_REDUCE_ROW) {
|
||||
for (unsigned i = 0; i < CMAT_LEN; ++i) {
|
||||
nir_def *row_accum = comps[i];
|
||||
nir_variable *row_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "row_tmp");
|
||||
nir_deref_instr *row_tmp_deref = nir_build_deref_var(b, row_tmp);
|
||||
nir_store_deref(b, row_tmp_deref, row_accum, 1);
|
||||
|
||||
for (unsigned j = 1; j < CMAT_LEN; j *= 2) {
|
||||
nir_def *prev_row_accum_val = nir_load_deref(b, row_tmp_deref);
|
||||
nir_def *this_row = nir_shuffle(b, prev_row_accum_val, nir_iadd(b, lane_id, nir_imm_int(b, j)));
|
||||
|
||||
nir_call(b, fnptr, &row_tmp_deref->def, prev_row_accum_val, this_row);
|
||||
}
|
||||
row_tmp_deref = nir_build_deref_var(b, row_tmp);
|
||||
comps[i] = nir_load_deref(b, row_tmp_deref);
|
||||
}
|
||||
}
|
||||
|
||||
/* this should be lowered earlier */
|
||||
assert(!(reduce & NIR_CMAT_REDUCE_2X2));
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_deref_instr *src_deref = nir_src_as_deref(call->params[1]);
|
||||
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
assert(src_desc.use == GLSL_CMAT_USE_ACCUMULATOR);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
nir_def *src_components[4][NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned m = 0; m < 4; m++) {
|
||||
nir_def *src = load_cmat_src(b, call->params[m + 1]);
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
src_components[m][i] = nir_channel(b, src, i);
|
||||
}
|
||||
}
|
||||
nir_variable *qd_tmp = nir_local_variable_create(b->impl, glsl_get_bare_type(fnptr->params[0].type), "qd_tmp");
|
||||
nir_deref_instr *qd_tmp_deref = nir_build_deref_var(b, qd_tmp);
|
||||
|
||||
for (unsigned m = 0; m < 4; m++) {
|
||||
for (unsigned i = 0; i < CMAT_LEN / 2; i++) {
|
||||
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]);
|
||||
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
|
||||
|
||||
nir_def *other_col = nir_shuffle_down(b, src_components[m][i], nir_imm_int(b, 1));
|
||||
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i], other_col);
|
||||
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
|
||||
}
|
||||
}
|
||||
|
||||
nir_def *even = nir_inverse_ballot_imm(b, 0x5555555555555555, 32);
|
||||
for (unsigned m = 0; m < 2; m++) {
|
||||
for (unsigned i = 0; i < CMAT_LEN / 2; i++) {
|
||||
nir_def *m0_comp = src_components[m * 2][i];
|
||||
nir_def *m1_comp = nir_shuffle_up(b, src_components[m * 2 + 1][i], nir_imm_int(b, 1));
|
||||
|
||||
nir_def *combined = nir_bcsel(b, even, m0_comp, m1_comp);
|
||||
comps[m * (CMAT_LEN / 2) + i] = combined;
|
||||
}
|
||||
}
|
||||
|
||||
nir_def *low_lane_id = nir_ilt_imm(b, lane_id, 4);
|
||||
nir_def *new_lane_id_lo = nir_imul_imm(b, lane_id, 2);
|
||||
nir_def *new_lane_id_hi = nir_iadd_imm(b, nir_imul_imm(b, nir_iadd_imm(b, lane_id, -4), 2), 1);
|
||||
nir_def *new_lane_id = nir_bcsel(b, low_lane_id, new_lane_id_lo, new_lane_id_hi);
|
||||
|
||||
for (unsigned m = 0; m < CMAT_LEN; m++) {
|
||||
comps[m] = nir_shuffle(b, comps[m], new_lane_id);
|
||||
}
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(mat->num_components));
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_cmat_per_element_op_call(nir_builder *b, nir_cmat_call_instr *call)
|
||||
{
|
||||
nir_def *src = load_cmat_src(b, call->params[3]);
|
||||
nir_deref_instr *dst_deref = nir_src_as_deref(call->params[0]);
|
||||
nir_function *fnptr = call->callee;
|
||||
nir_def *lane_id = nir_load_subgroup_invocation(b);
|
||||
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
|
||||
nir_variable *elem_tmp = nir_local_variable_create(b->impl, glsl_get_cmat_element(dst_deref->type), "elemtmp");
|
||||
nir_deref_instr *elem_deref = nir_build_deref_var(b, elem_tmp);
|
||||
|
||||
nir_def *comps[NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
for (unsigned i = 0; i < CMAT_LEN; i++) {
|
||||
nir_def *src_elem = nir_channel(b, src, i);
|
||||
nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr);
|
||||
|
||||
nir_def *row_val = nir_imm_int(b, i);
|
||||
nir_def *col_val = lane_id;
|
||||
|
||||
if (desc.use == GLSL_CMAT_USE_B)
|
||||
SWAP(col_val, row_val);
|
||||
|
||||
row_val = nir_iadd(b, call->params[1].ssa, row_val);
|
||||
col_val = nir_iadd(b, call->params[2].ssa, col_val);
|
||||
|
||||
new_call->params[0] = nir_src_for_ssa(&elem_deref->def);
|
||||
new_call->params[1] = nir_src_for_ssa(row_val);
|
||||
new_call->params[2] = nir_src_for_ssa(col_val);
|
||||
new_call->params[3] = nir_src_for_ssa(src_elem);
|
||||
|
||||
for (unsigned p = 4; p < call->num_params; p++) {
|
||||
nir_deref_instr *deref = nir_src_as_deref(call->params[p]);
|
||||
nir_def *def = call->params[p].ssa;
|
||||
if (deref) {
|
||||
if (glsl_type_is_cmat(deref->type)) {
|
||||
def = nir_build_load_deref(b, get_cmat_length(desc),
|
||||
glsl_base_type_bit_size(desc.element_type), def);
|
||||
def = nir_channel(b, def, i);
|
||||
}
|
||||
}
|
||||
new_call->params[p] = nir_src_for_ssa(def);
|
||||
}
|
||||
nir_builder_instr_insert(b, &new_call->instr);
|
||||
comps[i] = nir_build_load_deref(b, 1, glsl_base_type_bit_size(desc.element_type), &elem_deref->def, 0);
|
||||
}
|
||||
|
||||
nir_def *mat = nir_vec(b, comps, CMAT_LEN);
|
||||
nir_store_deref(b, dst_deref, mat, nir_component_mask(src->num_components));
|
||||
|
||||
nir_instr_remove(&call->instr);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
lower_impl(nir_function_impl *impl,
|
||||
struct hash_table *type_mapping)
|
||||
{
|
||||
bool progress = false;
|
||||
/* Remap all cmat temp var to array of scalars */
|
||||
nir_foreach_function_temp_variable(var, impl) {
|
||||
const struct glsl_type *new_type =
|
||||
remap_matrix_type(type_mapping, var->type);
|
||||
if (new_type != var->type) {
|
||||
var->type = new_type;
|
||||
progress = true;
|
||||
}
|
||||
}
|
||||
|
||||
/* Iterate in reverse order so that lowering can still use the matrix types from the derefs before we change it. */
|
||||
nir_builder b = nir_builder_create(impl);
|
||||
nir_foreach_block_reverse_safe (block, impl) {
|
||||
nir_foreach_instr_reverse_safe (instr, block) {
|
||||
b.cursor = nir_before_instr(instr);
|
||||
|
||||
switch (instr->type) {
|
||||
case nir_instr_type_intrinsic: {
|
||||
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
|
||||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_cmat_length:
|
||||
progress |= lower_cmat_length(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_construct:
|
||||
progress |= lower_cmat_construct(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_extract:
|
||||
progress |= lower_cmat_extract(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_insert:
|
||||
progress |= lower_cmat_insert(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_load:
|
||||
case nir_intrinsic_cmat_store:
|
||||
progress |= lower_cmat_load_store(&b, type_mapping, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_binary_op:
|
||||
progress |= lower_cmat_binary_op(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_unary_op:
|
||||
progress |= lower_cmat_unary_op(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_scalar_op:
|
||||
progress |= lower_cmat_scalar_op(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_muladd:
|
||||
progress |= lower_cmat_muladd(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_copy:
|
||||
progress |= lower_cmat_copy(&b, intr);
|
||||
break;
|
||||
case nir_intrinsic_cmat_convert:
|
||||
case nir_intrinsic_cmat_transpose:
|
||||
progress |= lower_cmat_convert(&b, intr);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_instr_type_deref: {
|
||||
nir_deref_instr *deref = nir_instr_as_deref(instr);
|
||||
const struct glsl_type *new_type =
|
||||
remap_matrix_type(type_mapping, deref->type);
|
||||
|
||||
if (new_type != deref->type) {
|
||||
deref->type = new_type;
|
||||
progress = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case nir_instr_type_cmat_call: {
|
||||
nir_cmat_call_instr *call = nir_instr_as_cmat_call(instr);
|
||||
switch (call->op) {
|
||||
case nir_cmat_call_op_reduce:
|
||||
progress |= lower_cmat_reduce_call(&b, call);
|
||||
break;
|
||||
case nir_cmat_call_op_reduce_finish:
|
||||
progress |= lower_cmat_reduce_finish_call(&b, call);
|
||||
break;
|
||||
case nir_cmat_call_op_reduce_2x2:
|
||||
progress |= lower_cmat_reduce_2x2_call(&b, call);
|
||||
break;
|
||||
case nir_cmat_call_op_per_element_op:
|
||||
progress |= lower_cmat_per_element_op_call(&b, call);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return nir_progress(progress, impl, nir_metadata_none);
|
||||
}
|
||||
|
||||
bool
|
||||
lvp_nir_lower_cooperative_matrix(nir_shader *shader)
|
||||
{
|
||||
bool progress = false;
|
||||
|
||||
if (!shader->info.cs.has_cooperative_matrix)
|
||||
return false;
|
||||
|
||||
struct hash_table *type_mapping = _mesa_pointer_hash_table_create(NULL);
|
||||
/* Remap all cmat shader temp var to array of vectors */
|
||||
nir_foreach_variable_with_modes(var, shader, nir_var_shader_temp) {
|
||||
const struct glsl_type *new_type =
|
||||
remap_matrix_type(type_mapping, var->type);
|
||||
|
||||
if (new_type != var->type) {
|
||||
var->type = new_type;
|
||||
progress = true;
|
||||
}
|
||||
}
|
||||
|
||||
progress |= lower_impl(nir_shader_get_entrypoint(shader), type_mapping);
|
||||
|
||||
_mesa_hash_table_destroy(type_mapping, NULL);
|
||||
|
||||
nir_foreach_function_impl(fnim, shader)
|
||||
nir_progress(progress, fnim, 0);
|
||||
return progress;
|
||||
}
|
||||
|
|
@ -168,7 +168,7 @@ vk_spirv_to_nir(struct vk_device *device,
|
|||
NIR_PASS(_, nir, nir_opt_deref);
|
||||
|
||||
/* Pick off the single entrypoint that we want */
|
||||
nir_remove_non_entrypoints(nir);
|
||||
nir_remove_non_cmat_call_entrypoints(nir);
|
||||
|
||||
/* Now that we've deleted all but the main function, we can go ahead and
|
||||
* lower the rest of the constant initializers. We do this here so that
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue