diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index f6bc3796333..0a2f49c4103 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -56,6 +56,7 @@ static void vtn_opencl_mangle(const char *in_name, uint32_t const_mask, int ntypes, struct vtn_type **src_types, + bool upcast_fp16, char **outstring) { char local_name[256] = ""; @@ -101,6 +102,11 @@ vtn_opencl_mangle(const char *in_name, } const char *suffix = NULL; + + enum glsl_base_type glsl_base_type = glsl_get_base_type(type); + if (glsl_base_type == GLSL_TYPE_FLOAT16 && upcast_fp16) + glsl_base_type = GLSL_TYPE_FLOAT; + switch (base_type) { case vtn_base_type_sampler: suffix = "11ocl_sampler"; break; case vtn_base_type_event: suffix = "9ocl_event"; break; @@ -120,7 +126,6 @@ vtn_opencl_mangle(const char *in_name, [GLSL_TYPE_BOOL] = "b", [GLSL_TYPE_ERROR] = NULL, }; - enum glsl_base_type glsl_base_type = glsl_get_base_type(type); assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]); suffix = primitives[glsl_base_type]; break; @@ -134,13 +139,15 @@ vtn_opencl_mangle(const char *in_name, static nir_function *mangle_and_find(struct vtn_builder *b, const char *name, + uint8_t try_fp16_lowering, uint32_t const_mask, uint32_t num_srcs, struct vtn_type **src_types) { char *mname; + char *fp16_name = NULL; - vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname); + vtn_opencl_mangle(name, const_mask, num_srcs, src_types, false, &mname); /* try and find in current shader first. */ nir_function *found = nir_shader_get_function_for_name(b->shader, mname); @@ -148,6 +155,14 @@ static nir_function *mangle_and_find(struct vtn_builder *b, /* if not found here find in clc shader and create a decl mirroring it */ if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) { found = nir_shader_get_function_for_name(b->options->clc_shader, mname); + + /* try upcasting fp16 */ + if (!found && try_fp16_lowering) { + fp16_name = mname; + vtn_opencl_mangle(name, const_mask, num_srcs, src_types, true, &mname); + found = nir_shader_get_function_for_name(b->options->clc_shader, mname); + } + if (found) { nir_function *decl = nir_function_create(b->shader, mname); decl->num_params = found->num_params; @@ -157,16 +172,59 @@ static nir_function *mangle_and_find(struct vtn_builder *b, decl->params[i].name = ralloc_strdup(b->shader, found->params[i].name); } found = decl; + + if (fp16_name) { + nir_function *fp16_decl = nir_function_create(b->shader, fp16_name); + found = fp16_decl; + + fp16_decl->num_params = decl->num_params; + fp16_decl->params = ralloc_array(b->shader, nir_parameter, fp16_decl->num_params); + for (unsigned i = 0; i < fp16_decl->num_params; i++) { + fp16_decl->params[i] = decl->params[i]; + if (try_fp16_lowering & (1 << i)) { + fp16_decl->params[i].type = glsl_f16vec_type(glsl_get_vector_elements(fp16_decl->params[i].type)); + if (!fp16_decl->params[i].is_return) { + assert(fp16_decl->params[i].bit_size == 32); + fp16_decl->params[i].bit_size = 16; + } + } + } + fp16_decl->impl = nir_function_impl_create(fp16_decl); + + nir_builder nb_saved = b->nb; + b->nb = nir_builder_at(nir_before_impl(fp16_decl->impl)); + + nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl, glsl_get_bare_type(decl->params[0].type), "return_tmp"); + nir_deref_instr *ret_deref = nir_build_deref_var(&b->nb, ret_tmp); + nir_call_instr *call = nir_call_instr_create(b->nb.shader, decl); + + call->params[0] = nir_src_for_ssa(&ret_deref->def); + for (unsigned i = 1; i < fp16_decl->num_params; i++) { + nir_def *param = nir_load_param(&b->nb, i); + if (try_fp16_lowering & (1 << i)) + param = nir_f2f32(&b->nb, param); + call->params[i] = nir_src_for_ssa(param); + } + nir_builder_instr_insert(&b->nb, &call->instr); + + nir_def *res_val = nir_f2f16(&b->nb, nir_load_deref(&b->nb, ret_deref)); + nir_def *ret = nir_load_param(&b->nb, 0); + ret_deref = nir_build_deref_cast(&b->nb, ret, nir_var_function_temp, fp16_decl->params[1].type, 0); + nir_store_deref(&b->nb, ret_deref, res_val, -1); + b->nb = nb_saved; + } } } if (!found) vtn_fail("Can't find clc function %s\n", mname); free(mname); + free(fp16_name); return found; } static bool call_mangled_function(struct vtn_builder *b, const char *name, + uint8_t try_fp16_lowering, uint32_t const_mask, uint32_t num_srcs, struct vtn_type **src_types, @@ -174,7 +232,7 @@ static bool call_mangled_function(struct vtn_builder *b, nir_def **srcs, nir_deref_instr **ret_deref_ptr) { - nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types); + nir_function *found = mangle_and_find(b, name, try_fp16_lowering, const_mask, num_srcs, src_types); if (!found) return false; @@ -422,6 +480,77 @@ get_signed_type(struct vtn_builder *b, struct vtn_type *t) glsl_get_vector_elements(t->type))); } +static uint8_t fp16_lowering_supported(enum OpenCLstd_Entrypoints opcode) +{ + /* libclc has very limited fp16 compatibility */ + switch (opcode) { + case OpenCLstd_Acos: + case OpenCLstd_Acosh: + case OpenCLstd_Acospi: + case OpenCLstd_Asin: + case OpenCLstd_Asinh: + case OpenCLstd_Asinpi: + case OpenCLstd_Atan: + case OpenCLstd_Atan2: + case OpenCLstd_Atanh: + case OpenCLstd_Atanpi: + case OpenCLstd_Atan2pi: + case OpenCLstd_Cbrt: + case OpenCLstd_Cos: + case OpenCLstd_Cosh: + case OpenCLstd_Cospi: + case OpenCLstd_Degrees: + case OpenCLstd_Distance: + case OpenCLstd_Erf: + case OpenCLstd_Erfc: + case OpenCLstd_Exp: + case OpenCLstd_Exp2: + case OpenCLstd_Exp10: + case OpenCLstd_Expm1: + case OpenCLstd_Fma: + case OpenCLstd_Fmod: + case OpenCLstd_Fract: + case OpenCLstd_Hypot: + case OpenCLstd_Ilogb: + case OpenCLstd_Length: + case OpenCLstd_Lgamma: + case OpenCLstd_Log: + case OpenCLstd_Log2: + case OpenCLstd_Log10: + case OpenCLstd_Log1p: + case OpenCLstd_Logb: + case OpenCLstd_Modf: + case OpenCLstd_Pow: + case OpenCLstd_Powr: + case OpenCLstd_Radians: + case OpenCLstd_Remainder: + case OpenCLstd_Smoothstep: + case OpenCLstd_Step: + case OpenCLstd_Sin: + case OpenCLstd_Sinh: + case OpenCLstd_Sinpi: + case OpenCLstd_Tan: + case OpenCLstd_Tanh: + case OpenCLstd_Tanpi: + case OpenCLstd_Tgamma: + return 0xff; + case OpenCLstd_Frexp: + case OpenCLstd_Ldexp: + case OpenCLstd_Lgamma_r: + case OpenCLstd_Pown: + case OpenCLstd_Remquo: + case OpenCLstd_Rootn: + /* second argument shouldn't be touched at all */ + return 0xff ^ (1 << 2); + /* the second argument is a pointer to a float + * a new enough libclc supports it though + */ + case OpenCLstd_Sincos: + default: + return 0; + } +} + static nir_def * handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, int num_srcs, @@ -463,7 +592,8 @@ handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, nir_deref_instr *ret_deref = NULL; - if (!call_mangled_function(b, name, 0, num_srcs, src_types, + uint8_t try_fp16_lowering = fp16_lowering_supported(opcode); + if (!call_mangled_function(b, name, try_fp16_lowering, 0, num_srcs, src_types, dest_type, srcs, &ret_deref)) return NULL; @@ -567,7 +697,8 @@ handle_special(struct vtn_builder *b, uint32_t opcode, return nir_ldexp(nb, srcs[0], srcs[1]); case OpenCLstd_Fma: /* FIXME: the software implementation only supports fp32 for now. */ - if (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) + if ((nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) || + (nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16)) break; return nir_ffma(nb, srcs[0], srcs[1], srcs[2]); case OpenCLstd_Rotate: @@ -608,7 +739,7 @@ handle_core(struct vtn_builder *b, uint32_t opcode, src_types[i]->storage_class); } } - if (!call_mangled_function(b, "async_work_group_strided_copy", (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref)) + if (!call_mangled_function(b, "async_work_group_strided_copy", false, (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref)) return NULL; break; }