gallivm: add initial support for 16-bit float builder.

This is an initial patch that is needed for OpenCL and Vulkan
support for proper 16-bit floats.

This doesn't enable the cap bit yet

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/11816>
This commit is contained in:
Dave Airlie 2021-07-12 07:24:54 +10:00 committed by Marge Bot
parent 0418b98569
commit c396067366
9 changed files with 53 additions and 10 deletions

View file

@ -42,7 +42,7 @@
#include "lp_bld_type.h"
#include "lp_bld_const.h"
#include "lp_bld_init.h"
#include "lp_bld_limits.h"
unsigned
lp_mantissa(struct lp_type type)
@ -256,7 +256,7 @@ lp_build_one(struct gallivm_state *gallivm, struct lp_type type)
elem_type = lp_build_elem_type(gallivm, type);
if(type.floating && type.width == 16)
if(!lp_has_fp16() && type.floating && type.width == 16)
elems[0] = LLVMConstInt(elem_type, _mesa_float_to_half(1.0f), 0);
else if(type.floating)
elems[0] = LLVMConstReal(elem_type, 1.0);
@ -303,7 +303,7 @@ lp_build_const_elem(struct gallivm_state *gallivm,
LLVMTypeRef elem_type = lp_build_elem_type(gallivm, type);
LLVMValueRef elem;
if(type.floating && type.width == 16) {
if (!lp_has_fp16() && type.floating && type.width == 16) {
elem = LLVMConstInt(elem_type, _mesa_float_to_half((float)val), 0);
} else if(type.floating) {
elem = LLVMConstReal(elem_type, val);

View file

@ -120,6 +120,8 @@ lp_build_half_to_float(struct gallivm_state *gallivm,
else {
intrinsic = "llvm.x86.vcvtph2ps.256";
}
src = LLVMBuildBitCast(builder, src,
LLVMVectorType(LLVMInt16TypeInContext(gallivm->context), 8), "");
return lp_build_intrinsic_unary(builder, intrinsic,
lp_build_vec_type(gallivm, f32_type), src);
} else {
@ -193,6 +195,7 @@ lp_build_float_to_half(struct gallivm_state *gallivm,
if (length == 4) {
result = lp_build_extract_range(gallivm, result, 0, 4);
}
result = LLVMBuildBitCast(builder, result, lp_build_vec_type(gallivm, lp_type_float_vec(16, 16 * length)), "");
}
else {

View file

@ -945,6 +945,8 @@ lp_build_insert_soa_chan(struct lp_build_context *bld,
if (type.floating) {
if (chan_desc.size == 16) {
chan = lp_build_float_to_half(gallivm, rgba);
chan = LLVMBuildBitCast(builder, chan,
lp_build_vec_type(gallivm, lp_type_int_vec(16, 16 * type.length)), "");
chan = LLVMBuildZExt(builder, chan, bld->int_vec_type, "");
if (start)
chan = LLVMBuildShl(builder, chan,

View file

@ -86,6 +86,10 @@ lp_format_intrinsic(char *name,
c = 'f';
width = 64;
break;
case LLVMHalfTypeKind:
c = 'f';
width = 16;
break;
default:
unreachable("unexpected LLVMTypeKind");
}

View file

@ -34,7 +34,7 @@
#include "pipe/p_state.h"
#include "pipe/p_defines.h"
#include "util/u_cpu_detect.h"
/*
* TGSI translation limits.
@ -85,6 +85,11 @@
*/
#define LP_MAX_TGSI_LOOP_ITERATIONS 65535
static inline bool
lp_has_fp16(void)
{
return util_get_cpu_caps()->has_f16c;
}
/**
* Some of these limits are actually infinite (i.e., only limited by available

View file

@ -48,7 +48,7 @@ static LLVMValueRef cast_type(struct lp_build_nir_context *bld_base, LLVMValueRe
case nir_type_float:
switch (bit_size) {
case 16:
return LLVMBuildBitCast(builder, val, LLVMVectorType(LLVMHalfTypeInContext(bld_base->base.gallivm->context), bld_base->base.type.length), "");
return LLVMBuildBitCast(builder, val, bld_base->half_bld.vec_type, "");
case 32:
return LLVMBuildBitCast(builder, val, bld_base->base.vec_type, "");
case 64:
@ -241,6 +241,8 @@ static LLVMValueRef fcmp32(struct lp_build_nir_context *bld_base,
result = lp_build_cmp(flt_bld, compare, src[0], src[1]);
if (src_bit_size == 64)
result = LLVMBuildTrunc(builder, result, bld_base->int_bld.vec_type, "");
else if (src_bit_size == 16)
result = LLVMBuildSExt(builder, result, bld_base->int_bld.vec_type, "");
return result;
}
@ -307,6 +309,9 @@ static LLVMValueRef emit_b2f(struct lp_build_nir_context *bld_base,
"");
result = LLVMBuildBitCast(builder, result, bld_base->base.vec_type, "");
switch (bitsize) {
case 16:
result = LLVMBuildFPTrunc(builder, result, bld_base->half_bld.vec_type, "");
break;
case 32:
break;
case 64:
@ -545,7 +550,7 @@ do_quantize_to_f16(struct lp_build_nir_context *bld_base,
LLVMBuilderRef builder = gallivm->builder;
LLVMValueRef result, cond, cond2, temp;
result = LLVMBuildFPTrunc(builder, src, LLVMVectorType(LLVMHalfTypeInContext(gallivm->context), bld_base->base.type.length), "");
result = LLVMBuildFPTrunc(builder, src, bld_base->half_bld.vec_type, "");
result = LLVMBuildFPExt(builder, result, bld_base->base.vec_type, "");
temp = lp_build_abs(get_flt_bld(bld_base, 32), result);
@ -568,6 +573,9 @@ static LLVMValueRef do_alu_action(struct lp_build_nir_context *bld_base,
LLVMValueRef result;
switch (instr->op) {
case nir_op_b2f16:
result = emit_b2f(bld_base, src[0], 16);
break;
case nir_op_b2f32:
result = emit_b2f(bld_base, src[0], 32);
break;
@ -610,7 +618,7 @@ static LLVMValueRef do_alu_action(struct lp_build_nir_context *bld_base,
src[0] = LLVMBuildFPTrunc(builder, src[0],
bld_base->base.vec_type, "");
result = LLVMBuildFPTrunc(builder, src[0],
LLVMVectorType(LLVMHalfTypeInContext(gallivm->context), bld_base->base.type.length), "");
bld_base->half_bld.vec_type, "");
break;
case nir_op_f2f32:
if (src_bit_size[0] < 32)
@ -810,6 +818,10 @@ static LLVMValueRef do_alu_action(struct lp_build_nir_context *bld_base,
case nir_op_i2b32:
result = int_to_bool32(bld_base, src_bit_size[0], false, src[0]);
break;
case nir_op_i2f16:
result = LLVMBuildSIToFP(builder, src[0],
bld_base->half_bld.vec_type, "");
break;
case nir_op_i2f32:
result = lp_build_int_to_float(&bld_base->base, src[0]);
break;
@ -950,6 +962,10 @@ static LLVMValueRef do_alu_action(struct lp_build_nir_context *bld_base,
result = LLVMBuildBitCast(builder, tmp, bld_base->uint64_bld.vec_type, "");
break;
}
case nir_op_u2f16:
result = LLVMBuildUIToFP(builder, src[0],
bld_base->half_bld.vec_type, "");
break;
case nir_op_u2f32:
result = LLVMBuildUIToFP(builder, src[0], bld_base->base.vec_type, "");
break;

View file

@ -49,6 +49,7 @@ struct lp_build_nir_context
struct lp_build_context int8_bld;
struct lp_build_context uint16_bld;
struct lp_build_context int16_bld;
struct lp_build_context half_bld;
struct lp_build_context dbl_bld;
struct lp_build_context uint64_bld;
struct lp_build_context int64_bld;
@ -289,6 +290,8 @@ static inline struct lp_build_context *get_flt_bld(struct lp_build_nir_context *
switch (op_bit_size) {
case 64:
return &bld_base->dbl_bld;
case 16:
return &bld_base->half_bld;
default:
case 32:
return &bld_base->base;

View file

@ -2355,6 +2355,12 @@ void lp_build_nir_soa(struct gallivm_state *gallivm,
dbl_type.width *= 2;
lp_build_context_init(&bld.bld_base.dbl_bld, gallivm, dbl_type);
}
{
struct lp_type half_type;
half_type = type;
half_type.width /= 2;
lp_build_context_init(&bld.bld_base.half_bld, gallivm, half_type);
}
{
struct lp_type uint64_type;
uint64_type = lp_uint_type(type);

View file

@ -31,7 +31,7 @@
#include "lp_bld_type.h"
#include "lp_bld_const.h"
#include "lp_bld_init.h"
#include "lp_bld_limits.h"
LLVMTypeRef
lp_build_elem_type(struct gallivm_state *gallivm, struct lp_type type)
@ -39,7 +39,7 @@ lp_build_elem_type(struct gallivm_state *gallivm, struct lp_type type)
if (type.floating) {
switch(type.width) {
case 16:
return LLVMIntTypeInContext(gallivm->context, 16);
return lp_has_fp16() ? LLVMHalfTypeInContext(gallivm->context) : LLVMInt16TypeInContext(gallivm->context);
break;
case 32:
return LLVMFloatTypeInContext(gallivm->context);
@ -89,7 +89,7 @@ lp_check_elem_type(struct lp_type type, LLVMTypeRef elem_type)
if (type.floating) {
switch(type.width) {
case 16:
if(elem_kind != LLVMIntegerTypeKind)
if(elem_kind != (lp_has_fp16() ? LLVMHalfTypeKind : LLVMIntegerTypeKind))
return FALSE;
break;
case 32:
@ -259,6 +259,8 @@ lp_sizeof_llvm_type(LLVMTypeRef t)
return 8 * sizeof(float);
case LLVMDoubleTypeKind:
return 8 * sizeof(double);
case LLVMHalfTypeKind:
return 8 * sizeof(uint16_t);
case LLVMVectorTypeKind:
{
LLVMTypeRef elem = LLVMGetElementType(t);
@ -291,6 +293,8 @@ lp_typekind_name(LLVMTypeKind t)
return "LLVMVoidTypeKind";
case LLVMFloatTypeKind:
return "LLVMFloatTypeKind";
case LLVMHalfTypeKind:
return "LLVMHalfTypeKind";
case LLVMDoubleTypeKind:
return "LLVMDoubleTypeKind";
case LLVMX86_FP80TypeKind: