diff --git a/src/gallium/drivers/radeonsi/si_compute.c b/src/gallium/drivers/radeonsi/si_compute.c index c70112b843a..f69f2b4516f 100644 --- a/src/gallium/drivers/radeonsi/si_compute.c +++ b/src/gallium/drivers/radeonsi/si_compute.c @@ -380,7 +380,7 @@ static bool si_switch_compute_shader(struct si_context *sctx, struct si_compute simple_mtx_lock(&shader->selector->mutex); /* Update max_seen_compute_scratch_bytes_per_wave and compute_tmpring_size. */ - si_get_scratch_tmpring_size(sctx, config->scratch_bytes_per_wave, + si_get_scratch_tmpring_size(sctx, config->scratch_bytes_per_wave, true, &sctx->compute_tmpring_size); if (!si_setup_compute_scratch_buffer(sctx, shader)) diff --git a/src/gallium/drivers/radeonsi/si_pipe.c b/src/gallium/drivers/radeonsi/si_pipe.c index 882b8497b71..adafd57b03f 100644 --- a/src/gallium/drivers/radeonsi/si_pipe.c +++ b/src/gallium/drivers/radeonsi/si_pipe.c @@ -889,7 +889,7 @@ static struct pipe_context *si_create_context(struct pipe_screen *screen, unsign goto fail; /* Initialize compute_tmpring_size. */ - si_get_scratch_tmpring_size(sctx, 0, &sctx->compute_tmpring_size); + si_get_scratch_tmpring_size(sctx, 0, true, &sctx->compute_tmpring_size); return &sctx->b; fail: @@ -900,16 +900,22 @@ fail: void si_get_scratch_tmpring_size(struct si_context *sctx, unsigned bytes_per_wave, - unsigned *spi_tmpring_size) + bool is_compute, unsigned *spi_tmpring_size) { bytes_per_wave = ac_compute_scratch_wavesize(&sctx->screen->info, bytes_per_wave); - sctx->max_seen_scratch_bytes_per_wave = - MAX2(sctx->max_seen_scratch_bytes_per_wave, bytes_per_wave); + if (is_compute) { + sctx->max_seen_compute_scratch_bytes_per_wave = + MAX2(sctx->max_seen_compute_scratch_bytes_per_wave, bytes_per_wave); + } else { + sctx->max_seen_scratch_bytes_per_wave = + MAX2(sctx->max_seen_scratch_bytes_per_wave, bytes_per_wave); + } /* TODO: We could decrease WAVES to make the whole buffer fit into the infinity cache. */ ac_get_scratch_tmpring_size(&sctx->screen->info, sctx->screen->info.max_scratch_waves, - sctx->max_seen_scratch_bytes_per_wave, + is_compute ? sctx->max_seen_compute_scratch_bytes_per_wave + : sctx->max_seen_scratch_bytes_per_wave, spi_tmpring_size); } diff --git a/src/gallium/drivers/radeonsi/si_pipe.h b/src/gallium/drivers/radeonsi/si_pipe.h index a4501bda8b2..ea6a30b0bec 100644 --- a/src/gallium/drivers/radeonsi/si_pipe.h +++ b/src/gallium/drivers/radeonsi/si_pipe.h @@ -1623,7 +1623,7 @@ void si_init_aux_async_compute_ctx(struct si_screen *sscreen); struct si_context *si_get_aux_context(struct si_aux_context *ctx); void si_put_aux_context_flush(struct si_aux_context *ctx); void si_get_scratch_tmpring_size(struct si_context *sctx, unsigned bytes_per_wave, - unsigned *spi_tmpring_size); + bool is_compute, unsigned *spi_tmpring_size); void si_destroy_screen(struct pipe_screen *pscreen); /* si_perfcounters.c */ diff --git a/src/gallium/drivers/radeonsi/si_state_shaders.cpp b/src/gallium/drivers/radeonsi/si_state_shaders.cpp index 6459de467ea..50d66d364aa 100644 --- a/src/gallium/drivers/radeonsi/si_state_shaders.cpp +++ b/src/gallium/drivers/radeonsi/si_state_shaders.cpp @@ -4493,7 +4493,7 @@ static bool si_update_scratch_relocs(struct si_context *sctx) bool si_update_spi_tmpring_size(struct si_context *sctx, unsigned bytes) { unsigned spi_tmpring_size; - si_get_scratch_tmpring_size(sctx, bytes, &spi_tmpring_size); + si_get_scratch_tmpring_size(sctx, bytes, false, &spi_tmpring_size); unsigned scratch_needed_size = sctx->max_seen_scratch_bytes_per_wave * sctx->screen->info.max_scratch_waves;