diff --git a/src/amd/common/ac_shader_util.c b/src/amd/common/ac_shader_util.c index cba42f176e6..fc95f4ac008 100644 --- a/src/amd/common/ac_shader_util.c +++ b/src/amd/common/ac_shader_util.c @@ -916,7 +916,7 @@ void ac_set_reg_cu_en(void *cs, unsigned reg_offset, uint32_t value, uint32_t cl } /* Return the register value and tune bytes_per_wave to increase scratch performance. */ -void ac_get_scratch_tmpring_size(const struct radeon_info *info, bool compute, +void ac_get_scratch_tmpring_size(const struct radeon_info *info, unsigned bytes_per_wave, unsigned *max_seen_bytes_per_wave, uint32_t *tmpring_size) { @@ -949,8 +949,8 @@ void ac_get_scratch_tmpring_size(const struct radeon_info *info, bool compute, *max_seen_bytes_per_wave = MAX2(*max_seen_bytes_per_wave, bytes_per_wave); unsigned max_scratch_waves = info->max_scratch_waves; - if (info->gfx_level >= GFX11 && !compute) - max_scratch_waves /= info->num_se; /* WAVES is per SE for SPI_TMPRING_SIZE. */ + if (info->gfx_level >= GFX11) + max_scratch_waves /= info->num_se; /* WAVES is per SE */ /* TODO: We could decrease WAVES to make the whole buffer fit into the infinity cache. */ *tmpring_size = S_0286E8_WAVES(max_scratch_waves) | diff --git a/src/amd/common/ac_shader_util.h b/src/amd/common/ac_shader_util.h index 87996654d93..6552bb8fa14 100644 --- a/src/amd/common/ac_shader_util.h +++ b/src/amd/common/ac_shader_util.h @@ -166,7 +166,7 @@ void ac_set_reg_cu_en(void *cs, unsigned reg_offset, uint32_t value, uint32_t cl unsigned value_shift, const struct radeon_info *info, void set_sh_reg(void*, unsigned, uint32_t)); -void ac_get_scratch_tmpring_size(const struct radeon_info *info, bool compute, +void ac_get_scratch_tmpring_size(const struct radeon_info *info, unsigned bytes_per_wave, unsigned *max_seen_bytes_per_wave, uint32_t *tmpring_size); diff --git a/src/gallium/drivers/radeonsi/si_compute.c b/src/gallium/drivers/radeonsi/si_compute.c index 5b0b5a721f8..4263ad624ab 100644 --- a/src/gallium/drivers/radeonsi/si_compute.c +++ b/src/gallium/drivers/radeonsi/si_compute.c @@ -566,7 +566,7 @@ static bool si_switch_compute_shader(struct si_context *sctx, struct si_compute } unsigned tmpring_size; - ac_get_scratch_tmpring_size(&sctx->screen->info, true, + ac_get_scratch_tmpring_size(&sctx->screen->info, config->scratch_bytes_per_wave, &sctx->max_seen_compute_scratch_bytes_per_wave, &tmpring_size); diff --git a/src/gallium/drivers/radeonsi/si_state_shaders.cpp b/src/gallium/drivers/radeonsi/si_state_shaders.cpp index c2c09185f8a..16012344abd 100644 --- a/src/gallium/drivers/radeonsi/si_state_shaders.cpp +++ b/src/gallium/drivers/radeonsi/si_state_shaders.cpp @@ -4054,7 +4054,7 @@ static bool si_update_scratch_relocs(struct si_context *sctx) bool si_update_spi_tmpring_size(struct si_context *sctx, unsigned bytes) { unsigned spi_tmpring_size; - ac_get_scratch_tmpring_size(&sctx->screen->info, false, bytes, + ac_get_scratch_tmpring_size(&sctx->screen->info, bytes, &sctx->max_seen_scratch_bytes_per_wave, &spi_tmpring_size); unsigned scratch_needed_size = sctx->max_seen_scratch_bytes_per_wave *