etnaviv/ml: Add support for Logistic

Add a TP job that makes use of a look up table to implement a piecewise
linear approximation of the logistic function.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34629>
This commit is contained in:
Tomeu Vizoso 2025-04-11 09:29:37 +02:00 committed by Marge Bot
parent 9c6cab0458
commit a8a2ce1d74
4 changed files with 204 additions and 3 deletions

View file

@ -480,6 +480,13 @@ lower_operations(struct etna_ml_subgraph *subgraph,
list_addtail(&operation->link, etna_operations);
break;
}
case PIPE_ML_OPERATION_TYPE_LOGISTIC: {
etna_ml_lower_logistic(subgraph, poperation, operation);
operation->input_tensors[0] = input_tensors[0];
operation->output_tensors[0] = poperation->output_tensors[0]->index;
list_addtail(&operation->link, etna_operations);
break;
}
default:
unreachable("Unsupported ML operation type");
}
@ -600,6 +607,7 @@ count_tensors(const struct pipe_ml_operation *poperations,
case PIPE_ML_OPERATION_TYPE_RESHAPE:
case PIPE_ML_OPERATION_TYPE_RELU:
case PIPE_ML_OPERATION_TYPE_ABSOLUTE:
case PIPE_ML_OPERATION_TYPE_LOGISTIC:
break;
default:
unreachable("Unsupported ML operation type");
@ -687,6 +695,7 @@ etna_ml_operation_supported(struct pipe_context *pcontext,
case PIPE_ML_OPERATION_TYPE_RESHAPE:
case PIPE_ML_OPERATION_TYPE_RELU:
case PIPE_ML_OPERATION_TYPE_ABSOLUTE:
case PIPE_ML_OPERATION_TYPE_LOGISTIC:
supported = true;
break;
default:

View file

@ -28,6 +28,7 @@ enum etna_ml_tp_type {
ETNA_ML_TP_PAD,
ETNA_ML_TP_RELU,
ETNA_ML_TP_ABSOLUTE,
ETNA_ML_TP_LOGISTIC,
};
enum etna_ml_tensor_layout {

View file

@ -726,7 +726,7 @@ split_pwl_lut(struct etna_ml_subgraph *subgraph, const struct etna_operation *op
unsigned tp_core, unsigned tp_cores_used, unsigned *in_dims, unsigned *out_dims)
{
unsigned remaining_in_size;
unsigned dim_to_split = 2;
const unsigned dim_to_split = 2;
remaining_in_size = in_dims[dim_to_split];
@ -854,7 +854,10 @@ create_pwl_lut_config(struct etna_ml_subgraph *subgraph, const struct etna_opera
uint32_t scale;
scale = fui(operation->input_scale / operation->output_scale);
if (operation->tp_type == ETNA_ML_TP_LOGISTIC)
scale = fui(1.0 / operation->output_scale);
else
scale = fui(operation->input_scale / operation->output_scale);
/* This should compensate for some loss of precision */
if ((scale >> 7 & 1) != 0 && (scale & 0x17f) != 0) {
@ -1084,13 +1087,43 @@ etna_ml_lower_absolute(struct etna_ml_subgraph *subgraph,
operation->output_channels;
}
void
etna_ml_lower_logistic(struct etna_ml_subgraph *subgraph,
const struct pipe_ml_operation *abs,
struct etna_operation *operation)
{
operation->type = ETNA_JOB_TYPE_TP;
operation->tp_type = ETNA_ML_TP_LOGISTIC;
operation->stride = 1;
operation->input_count = 1;
operation->input_width = abs->input_tensors[0]->dims[1];
operation->input_height = abs->input_tensors[0]->dims[2];
operation->input_channels = abs->input_tensors[0]->dims[3];
operation->input_tensor_sizes[0] = operation->input_width *
operation->input_height *
operation->input_channels;
operation->input_zero_point = etna_tensor_zero_point(abs->input_tensors[0]);
operation->input_scale = abs->input_tensors[0]->scale;
operation->output_count = 1;
operation->output_width = abs->output_tensors[0]->dims[1];
operation->output_height = abs->output_tensors[0]->dims[2];
operation->output_channels = abs->output_tensors[0]->dims[3];
operation->output_zero_point = etna_tensor_zero_point(abs->output_tensors[0]);
operation->output_scale = abs->output_tensors[0]->scale;
operation->output_tensor_sizes[0] = operation->output_width *
operation->output_height *
operation->output_channels;
}
static struct etna_bo *
create_relu_lut_bo(struct etna_ml_subgraph *subgraph,
const struct etna_operation *operation)
{
struct pipe_context *context = subgraph->base.context;
struct etna_context *ctx = etna_context(context);
unsigned lut_length = 1024;
const unsigned lut_length = 1024;
struct etna_bo *pwl_lut = etna_bo_new(ctx->screen->dev,
lut_length * sizeof(uint32_t),
DRM_ETNA_GEM_CACHE_WC);
@ -1150,6 +1183,146 @@ create_abs_lut_bo(struct etna_ml_subgraph *subgraph,
return pwl_lut;
}
/* Based on code from src/util/half_float.c in Mesa itself. */
static uint32_t
fp21(float val)
{
const fi_type fi = {val};
const int flt_m = fi.i & 0x7fffff;
const int flt_e = (fi.i >> 23) & 0xff;
const int flt_s = (fi.i >> 31) & 0x1;
int s, e, m = 0;
/* sign bit */
s = flt_s;
/* handle special cases */
if ((flt_e == 0) && (flt_m == 0)) {
/* zero */
/* m = 0; - already set */
e = 0;
}
else if ((flt_e == 0) && (flt_m != 0)) {
/* denorm -- denorm float maps to 0 fp21 */
/* m = 0; - already set */
e = 0;
}
else if ((flt_e == 0xff) && (flt_m == 0)) {
/* infinity */
/* m = 0; - already set */
e = 15;
}
else if ((flt_e == 0xff) && (flt_m != 0)) {
/* Retain the top bits of a NaN to make sure that the quiet/signaling
* status stays the same.
*/
m = flt_m >> 13;
if (!m)
m = 1;
e = 15;
}
else {
/* regular number */
const int new_exp = flt_e - 127;
if (new_exp < -14) {
/* The float32 lies in the range (0.0, min_normal16) and is rounded
* to a nearby float16 value. The result will be either zero, subnormal,
* or normal.
*/
e = 0;
m = _mesa_lroundevenf((1 << 24) * fabsf(fi.f));
}
else if (new_exp > 15) {
/* map this value to infinity */
e = 0x1e;
m = 0x7fff;
}
else {
/* The float32 lies in the range
* [min_normal16, max_normal16 + max_step16)
* and is rounded to a nearby float16 value. The result will be
* either normal or infinite.
*/
e = new_exp + 15;
m = _mesa_lroundevenf(flt_m / (float) (1 << 8));
}
}
return (s << 20) | (e << 15) | m;
}
/* Based on code from src/util/half_float.c in Mesa itself. */
static float
fp32(uint32_t val)
{
union fi f32;
/* Exponent / Mantissa */
f32.ui = (val & 0xfffff) << 8;
/* Sign */
f32.ui |= (uint32_t)(val & 0x100000) << 11;
f32.ui += 0x38000000;
return f32.f;
}
static struct etna_bo *
create_log_lut_bo(struct etna_ml_subgraph *subgraph,
const struct etna_operation *operation)
{
struct pipe_context *context = subgraph->base.context;
struct etna_context *ctx = etna_context(context);
unsigned lut_table_len = 1024;
struct etna_bo *pwl_lut = etna_bo_new(ctx->screen->dev,
lut_table_len * sizeof(uint32_t),
DRM_ETNA_GEM_CACHE_WC);
etna_bo_cpu_prep(pwl_lut, DRM_ETNA_PREP_WRITE);
uint32_t *map = etna_bo_map(pwl_lut);
unsigned pos = 0;
uint32_t initial_val = fp21(0.5);
for (int i = 0; i < 16; i++)
map[pos++] = initial_val;
for (int cur_val = 0x8000; cur_val < 0xf8000; cur_val += 0x800) {
float val = fp32(cur_val);
val *= operation->input_scale;
val *= -1.0;
val = expf(val);
val = 1.0f / (val + 1.0f);
map[pos++] = fp21(val);
}
uint32_t middle_val = fp21(1.0);
for (int i = 0; i < 16; i++)
map[pos++] = middle_val;
uint32_t middle_2_val = fp21(0.5);
for (int i = 0; i < 17; i++)
map[pos++] = middle_2_val;
for (int cur_val = 0x108800; cur_val < 0x1f8000; cur_val += 0x800) {
float val = fp32(cur_val);
val *= operation->input_scale;
val *= -1.0;
val = expf(val);
val = 1.0f / (val + 1.0f);
map[pos++] = fp21(val);
}
uint32_t final_val = fp21(0.0);
for (int i = 0; i < 16; i++)
map[pos++] = final_val;
etna_bo_cpu_fini(pwl_lut);
return pwl_lut;
}
void
etna_ml_compile_operation_tp(struct etna_ml_subgraph *subgraph,
const struct etna_operation *operation,
@ -1226,6 +1399,19 @@ etna_ml_compile_operation_tp(struct etna_ml_subgraph *subgraph,
}
break;
}
case ETNA_ML_TP_LOGISTIC: {
unsigned tp_cores_used = etna_ml_get_core_info(ctx)->tp_core_count;
if (operation->input_width < 6)
tp_cores_used = 1;
ML_DBG("logistic: input_width %d tp_cores_used %d\n", operation->input_width, tp_cores_used);
instruction->pwl_lut = create_log_lut_bo(subgraph, operation);
for (unsigned i = 0; i < tp_cores_used; i++) {
instruction->configs[i] = create_pwl_lut_config(subgraph, operation, i, tp_cores_used, instruction->pwl_lut);
}
break;
}
}
instruction->type = ETNA_JOB_TYPE_TP;
instruction->tp_type = operation->tp_type;

View file

@ -36,6 +36,11 @@ etna_ml_lower_absolute(struct etna_ml_subgraph *subgraph,
const struct pipe_ml_operation *pad,
struct etna_operation *operation);
void
etna_ml_lower_logistic(struct etna_ml_subgraph *subgraph,
const struct pipe_ml_operation *pad,
struct etna_operation *operation);
void
etna_ml_compile_operation_tp(struct etna_ml_subgraph *subgraph,
const struct etna_operation *operation,