mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-01-26 10:00:22 +01:00
vtn/opencl: support fp16 builtins
If we can't find an appropiate builtin in the libclc library, we add our own wrapper at runtime executing the op in fp32 space. Libclc has variying support for fp16 opcodes and with a libclc prior llvm-19 it does not work as good as with the newer one. Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io> Reviewed-by: Adam Jackson <ajax@redhat.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34053>
This commit is contained in:
parent
ca01635075
commit
aa5a981b83
1 changed files with 137 additions and 6 deletions
|
|
@ -56,6 +56,7 @@ static void
|
|||
vtn_opencl_mangle(const char *in_name,
|
||||
uint32_t const_mask,
|
||||
int ntypes, struct vtn_type **src_types,
|
||||
bool upcast_fp16,
|
||||
char **outstring)
|
||||
{
|
||||
char local_name[256] = "";
|
||||
|
|
@ -101,6 +102,11 @@ vtn_opencl_mangle(const char *in_name,
|
|||
}
|
||||
|
||||
const char *suffix = NULL;
|
||||
|
||||
enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
|
||||
if (glsl_base_type == GLSL_TYPE_FLOAT16 && upcast_fp16)
|
||||
glsl_base_type = GLSL_TYPE_FLOAT;
|
||||
|
||||
switch (base_type) {
|
||||
case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
|
||||
case vtn_base_type_event: suffix = "9ocl_event"; break;
|
||||
|
|
@ -120,7 +126,6 @@ vtn_opencl_mangle(const char *in_name,
|
|||
[GLSL_TYPE_BOOL] = "b",
|
||||
[GLSL_TYPE_ERROR] = NULL,
|
||||
};
|
||||
enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
|
||||
assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
|
||||
suffix = primitives[glsl_base_type];
|
||||
break;
|
||||
|
|
@ -134,13 +139,15 @@ vtn_opencl_mangle(const char *in_name,
|
|||
|
||||
static nir_function *mangle_and_find(struct vtn_builder *b,
|
||||
const char *name,
|
||||
uint8_t try_fp16_lowering,
|
||||
uint32_t const_mask,
|
||||
uint32_t num_srcs,
|
||||
struct vtn_type **src_types)
|
||||
{
|
||||
char *mname;
|
||||
char *fp16_name = NULL;
|
||||
|
||||
vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
|
||||
vtn_opencl_mangle(name, const_mask, num_srcs, src_types, false, &mname);
|
||||
|
||||
/* try and find in current shader first. */
|
||||
nir_function *found = nir_shader_get_function_for_name(b->shader, mname);
|
||||
|
|
@ -148,6 +155,14 @@ static nir_function *mangle_and_find(struct vtn_builder *b,
|
|||
/* if not found here find in clc shader and create a decl mirroring it */
|
||||
if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
|
||||
found = nir_shader_get_function_for_name(b->options->clc_shader, mname);
|
||||
|
||||
/* try upcasting fp16 */
|
||||
if (!found && try_fp16_lowering) {
|
||||
fp16_name = mname;
|
||||
vtn_opencl_mangle(name, const_mask, num_srcs, src_types, true, &mname);
|
||||
found = nir_shader_get_function_for_name(b->options->clc_shader, mname);
|
||||
}
|
||||
|
||||
if (found) {
|
||||
nir_function *decl = nir_function_create(b->shader, mname);
|
||||
decl->num_params = found->num_params;
|
||||
|
|
@ -157,16 +172,59 @@ static nir_function *mangle_and_find(struct vtn_builder *b,
|
|||
decl->params[i].name = ralloc_strdup(b->shader, found->params[i].name);
|
||||
}
|
||||
found = decl;
|
||||
|
||||
if (fp16_name) {
|
||||
nir_function *fp16_decl = nir_function_create(b->shader, fp16_name);
|
||||
found = fp16_decl;
|
||||
|
||||
fp16_decl->num_params = decl->num_params;
|
||||
fp16_decl->params = ralloc_array(b->shader, nir_parameter, fp16_decl->num_params);
|
||||
for (unsigned i = 0; i < fp16_decl->num_params; i++) {
|
||||
fp16_decl->params[i] = decl->params[i];
|
||||
if (try_fp16_lowering & (1 << i)) {
|
||||
fp16_decl->params[i].type = glsl_f16vec_type(glsl_get_vector_elements(fp16_decl->params[i].type));
|
||||
if (!fp16_decl->params[i].is_return) {
|
||||
assert(fp16_decl->params[i].bit_size == 32);
|
||||
fp16_decl->params[i].bit_size = 16;
|
||||
}
|
||||
}
|
||||
}
|
||||
fp16_decl->impl = nir_function_impl_create(fp16_decl);
|
||||
|
||||
nir_builder nb_saved = b->nb;
|
||||
b->nb = nir_builder_at(nir_before_impl(fp16_decl->impl));
|
||||
|
||||
nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl, glsl_get_bare_type(decl->params[0].type), "return_tmp");
|
||||
nir_deref_instr *ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
|
||||
nir_call_instr *call = nir_call_instr_create(b->nb.shader, decl);
|
||||
|
||||
call->params[0] = nir_src_for_ssa(&ret_deref->def);
|
||||
for (unsigned i = 1; i < fp16_decl->num_params; i++) {
|
||||
nir_def *param = nir_load_param(&b->nb, i);
|
||||
if (try_fp16_lowering & (1 << i))
|
||||
param = nir_f2f32(&b->nb, param);
|
||||
call->params[i] = nir_src_for_ssa(param);
|
||||
}
|
||||
nir_builder_instr_insert(&b->nb, &call->instr);
|
||||
|
||||
nir_def *res_val = nir_f2f16(&b->nb, nir_load_deref(&b->nb, ret_deref));
|
||||
nir_def *ret = nir_load_param(&b->nb, 0);
|
||||
ret_deref = nir_build_deref_cast(&b->nb, ret, nir_var_function_temp, fp16_decl->params[1].type, 0);
|
||||
nir_store_deref(&b->nb, ret_deref, res_val, -1);
|
||||
b->nb = nb_saved;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!found)
|
||||
vtn_fail("Can't find clc function %s\n", mname);
|
||||
free(mname);
|
||||
free(fp16_name);
|
||||
return found;
|
||||
}
|
||||
|
||||
static bool call_mangled_function(struct vtn_builder *b,
|
||||
const char *name,
|
||||
uint8_t try_fp16_lowering,
|
||||
uint32_t const_mask,
|
||||
uint32_t num_srcs,
|
||||
struct vtn_type **src_types,
|
||||
|
|
@ -174,7 +232,7 @@ static bool call_mangled_function(struct vtn_builder *b,
|
|||
nir_def **srcs,
|
||||
nir_deref_instr **ret_deref_ptr)
|
||||
{
|
||||
nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
|
||||
nir_function *found = mangle_and_find(b, name, try_fp16_lowering, const_mask, num_srcs, src_types);
|
||||
if (!found)
|
||||
return false;
|
||||
|
||||
|
|
@ -422,6 +480,77 @@ get_signed_type(struct vtn_builder *b, struct vtn_type *t)
|
|||
glsl_get_vector_elements(t->type)));
|
||||
}
|
||||
|
||||
static uint8_t fp16_lowering_supported(enum OpenCLstd_Entrypoints opcode)
|
||||
{
|
||||
/* libclc has very limited fp16 compatibility */
|
||||
switch (opcode) {
|
||||
case OpenCLstd_Acos:
|
||||
case OpenCLstd_Acosh:
|
||||
case OpenCLstd_Acospi:
|
||||
case OpenCLstd_Asin:
|
||||
case OpenCLstd_Asinh:
|
||||
case OpenCLstd_Asinpi:
|
||||
case OpenCLstd_Atan:
|
||||
case OpenCLstd_Atan2:
|
||||
case OpenCLstd_Atanh:
|
||||
case OpenCLstd_Atanpi:
|
||||
case OpenCLstd_Atan2pi:
|
||||
case OpenCLstd_Cbrt:
|
||||
case OpenCLstd_Cos:
|
||||
case OpenCLstd_Cosh:
|
||||
case OpenCLstd_Cospi:
|
||||
case OpenCLstd_Degrees:
|
||||
case OpenCLstd_Distance:
|
||||
case OpenCLstd_Erf:
|
||||
case OpenCLstd_Erfc:
|
||||
case OpenCLstd_Exp:
|
||||
case OpenCLstd_Exp2:
|
||||
case OpenCLstd_Exp10:
|
||||
case OpenCLstd_Expm1:
|
||||
case OpenCLstd_Fma:
|
||||
case OpenCLstd_Fmod:
|
||||
case OpenCLstd_Fract:
|
||||
case OpenCLstd_Hypot:
|
||||
case OpenCLstd_Ilogb:
|
||||
case OpenCLstd_Length:
|
||||
case OpenCLstd_Lgamma:
|
||||
case OpenCLstd_Log:
|
||||
case OpenCLstd_Log2:
|
||||
case OpenCLstd_Log10:
|
||||
case OpenCLstd_Log1p:
|
||||
case OpenCLstd_Logb:
|
||||
case OpenCLstd_Modf:
|
||||
case OpenCLstd_Pow:
|
||||
case OpenCLstd_Powr:
|
||||
case OpenCLstd_Radians:
|
||||
case OpenCLstd_Remainder:
|
||||
case OpenCLstd_Smoothstep:
|
||||
case OpenCLstd_Step:
|
||||
case OpenCLstd_Sin:
|
||||
case OpenCLstd_Sinh:
|
||||
case OpenCLstd_Sinpi:
|
||||
case OpenCLstd_Tan:
|
||||
case OpenCLstd_Tanh:
|
||||
case OpenCLstd_Tanpi:
|
||||
case OpenCLstd_Tgamma:
|
||||
return 0xff;
|
||||
case OpenCLstd_Frexp:
|
||||
case OpenCLstd_Ldexp:
|
||||
case OpenCLstd_Lgamma_r:
|
||||
case OpenCLstd_Pown:
|
||||
case OpenCLstd_Remquo:
|
||||
case OpenCLstd_Rootn:
|
||||
/* second argument shouldn't be touched at all */
|
||||
return 0xff ^ (1 << 2);
|
||||
/* the second argument is a pointer to a float
|
||||
* a new enough libclc supports it though
|
||||
*/
|
||||
case OpenCLstd_Sincos:
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
||||
int num_srcs,
|
||||
|
|
@ -463,7 +592,8 @@ handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
|
|||
|
||||
nir_deref_instr *ret_deref = NULL;
|
||||
|
||||
if (!call_mangled_function(b, name, 0, num_srcs, src_types,
|
||||
uint8_t try_fp16_lowering = fp16_lowering_supported(opcode);
|
||||
if (!call_mangled_function(b, name, try_fp16_lowering, 0, num_srcs, src_types,
|
||||
dest_type, srcs, &ret_deref))
|
||||
return NULL;
|
||||
|
||||
|
|
@ -567,7 +697,8 @@ handle_special(struct vtn_builder *b, uint32_t opcode,
|
|||
return nir_ldexp(nb, srcs[0], srcs[1]);
|
||||
case OpenCLstd_Fma:
|
||||
/* FIXME: the software implementation only supports fp32 for now. */
|
||||
if (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32)
|
||||
if ((nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) ||
|
||||
(nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16))
|
||||
break;
|
||||
return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
|
||||
case OpenCLstd_Rotate:
|
||||
|
|
@ -608,7 +739,7 @@ handle_core(struct vtn_builder *b, uint32_t opcode,
|
|||
src_types[i]->storage_class);
|
||||
}
|
||||
}
|
||||
if (!call_mangled_function(b, "async_work_group_strided_copy", (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref))
|
||||
if (!call_mangled_function(b, "async_work_group_strided_copy", false, (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref))
|
||||
return NULL;
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue