From 949f8572ec3959be0ac21aa8857515452c180500 Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Wed, 11 Nov 2020 14:24:02 -0800 Subject: [PATCH] vtn/opencl: Fix alignment for half vload/vstore Reviewed-by: Jason Ekstrand Part-of: --- src/compiler/spirv/vtn_opencl.c | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index 127fc9e9a85..9a1b052cd92 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -632,17 +632,6 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]); struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer); - enum glsl_base_type ptr_base_type = - glsl_get_base_type(p->pointer->type->type); - if (base_type != ptr_base_type) { - vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 || - (base_type != GLSL_TYPE_FLOAT && - base_type != GLSL_TYPE_DOUBLE), - "vload/vstore cannot do type conversion. " - "vload/vstore_half can only convert from half to other " - "floating-point types."); - } - struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS]; nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS]; @@ -652,6 +641,20 @@ _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) : glsl_get_bit_size(type->type) / 8; + enum glsl_base_type ptr_base_type = + glsl_get_base_type(p->pointer->type->type); + if (base_type != ptr_base_type) { + vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 || + (base_type != GLSL_TYPE_FLOAT && + base_type != GLSL_TYPE_DOUBLE), + "vload/vstore cannot do type conversion. " + "vload/vstore_half can only convert from half to other " + "floating-point types."); + + /* Above-computed alignment was for floats/doubles, not halves */ + alignment /= glsl_get_bit_size(type->type) / glsl_base_type_get_bit_size(ptr_base_type); + } + deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0); for (int i = 0; i < components; i++) {