util: Generate correct format conversions for half floats.

This commit is contained in:
Michal Krol 2010-04-01 13:56:03 +02:00
parent b7bca4b28c
commit 943408533d

View file

@ -231,66 +231,86 @@ def conversion_expr(src_channel, dst_channel, dst_native_type, value, clamp=True
if src_channel == dst_channel:
return value
if src_channel.type == FLOAT and dst_channel.type == FLOAT:
if src_channel.size == 64:
value = '(float)%s' % (value)
elif src_channel.size == 16:
value = 'util_half_to_float(%s)' % (value)
src_type = src_channel.type
src_size = src_channel.size
src_norm = src_channel.norm
if dst_channel.size == 16:
value = 'util_float_to_half(%s)' % (value)
elif dst_channel.size == 64:
value = '(double)%s' % (value)
return value
# Promote half to float
if src_type == FLOAT and src_size == 16:
value = 'util_half_to_float(%s)' % value
src_size = 32
if clamp:
value = clamp_expr(src_channel, dst_channel, dst_native_type, value)
if dst_channel.type != FLOAT or src_type != FLOAT:
value = clamp_expr(src_channel, dst_channel, dst_native_type, value)
if dst_channel.type == FLOAT:
if src_channel.norm:
one = get_one(src_channel)
if src_channel.size <= 23:
scale = '(1.0f/0x%x)' % one
else:
# bigger than single precision mantissa, use double
scale = '(1.0/0x%x)' % one
value = '(%s * %s)' % (value, scale)
return '(%s)%s' % (dst_native_type, value)
if src_channel.type == FLOAT:
if dst_channel.norm:
dst_one = get_one(dst_channel)
if dst_channel.size <= 23:
scale = '0x%x' % dst_one
else:
# bigger than single precision mantissa, use double
scale = '(double)0x%x' % dst_one
value = '(%s * %s)' % (value, scale)
return '(%s)%s' % (dst_native_type, value)
if src_channel.type in (SIGNED, UNSIGNED) and dst_channel.type in (SIGNED, UNSIGNED):
if not src_channel.norm and not dst_channel.norm:
if src_type in (SIGNED, UNSIGNED) and dst_channel.type in (SIGNED, UNSIGNED):
if not src_norm and not dst_channel.norm:
# neither is normalized -- just cast
return '(%s)%s' % (dst_native_type, value)
src_one = get_one(src_channel)
dst_one = get_one(dst_channel)
if src_one > dst_one and src_channel.norm and dst_channel.norm:
if src_one > dst_one and src_norm and dst_channel.norm:
# We can just bitshift
src_shift = get_one_shift(src_channel)
dst_shift = get_one_shift(dst_channel)
value = '(%s >> %s)' % (value, src_shift - dst_shift)
else:
# We need to rescale using an intermediate type big enough to hold the multiplication of both
tmp_native_type = intermediate_native_type(src_channel.size + dst_channel.size, src_channel.sign and dst_channel.sign)
tmp_native_type = intermediate_native_type(src_size + dst_channel.size, src_channel.sign and dst_channel.sign)
value = '((%s)%s)' % (tmp_native_type, value)
value = '(%s * 0x%x / 0x%x)' % (value, dst_one, src_one)
value = '(%s)%s' % (dst_native_type, value)
return value
assert False
# Promote to either float or double
if src_type != FLOAT:
if src_norm:
one = get_one(src_channel)
if src_size <= 23:
value = '(%s * (1.0f/0x%x))' % (value, one)
if dst_channel.size <= 32:
value = '(float)%s' % value
src_size = 32
else:
# bigger than single precision mantissa, use double
value = '(%s * (1.0/0x%x))' % (value, one)
src_size = 64
src_norm = False
else:
if src_size <= 23 or dst_channel.size <= 32:
value = '(float)%s' % value
src_size = 32
else:
# bigger than single precision mantissa, use double
value = '(double)%s' % value
src_size = 64
src_type = FLOAT
# Convert double or float to non-float
if dst_channel.type != FLOAT:
if dst_channel.norm:
dst_one = get_one(dst_channel)
if dst_channel.size <= 23:
value = '(%s * 0x%x)' % (value, dst_one)
else:
# bigger than single precision mantissa, use double
value = '(%s * (double)0x%x)' % (value, dst_one)
value = '(%s)%s' % (dst_native_type, value)
else:
# Cast double to float when converting to either half or float
if dst_channel.size <= 32 and src_size > 32:
value = '(float)%s' % value
src_size = 32
if dst_channel.size == 16:
value = 'util_float_to_half(%s)' % value
elif dst_channel.size == 64 and src_size < 64:
value = '(double)%s' % value
return value
def generate_unpack_kernel(format, dst_channel, dst_native_type):