diff --git a/src/virtio/vulkan/vn_device.c b/src/virtio/vulkan/vn_device.c index 74356c29c4d..d90ba19ed49 100644 --- a/src/virtio/vulkan/vn_device.c +++ b/src/virtio/vulkan/vn_device.c @@ -447,6 +447,7 @@ vn_device_init(struct vn_device *dev, dev->instance = instance; dev->physical_device = physical_dev; + dev->device_mask = 1; dev->renderer = instance->renderer; dev->primary_ring = instance->ring.ring; @@ -455,6 +456,11 @@ vn_device_init(struct vn_device *dev, if (!create_info) return VK_ERROR_OUT_OF_HOST_MEMORY; + const VkDeviceGroupDeviceCreateInfo *group = vk_find_struct_const( + create_info->pNext, DEVICE_GROUP_DEVICE_CREATE_INFO); + if (group && group->physicalDeviceCount) + dev->device_mask = (1 << group->physicalDeviceCount) - 1; + result = vn_call_vkCreateDevice(dev->primary_ring, physical_dev_handle, create_info, NULL, &dev_handle); diff --git a/src/virtio/vulkan/vn_device.h b/src/virtio/vulkan/vn_device.h index 0f1b7e128ee..d80aa0c5a51 100644 --- a/src/virtio/vulkan/vn_device.h +++ b/src/virtio/vulkan/vn_device.h @@ -28,6 +28,7 @@ struct vn_device { struct vn_instance *instance; struct vn_physical_device *physical_device; + uint32_t device_mask; struct vn_renderer *renderer; struct vn_ring *primary_ring; diff --git a/src/virtio/vulkan/vn_queue.c b/src/virtio/vulkan/vn_queue.c index 60a92d81b78..81933fec6eb 100644 --- a/src/virtio/vulkan/vn_queue.c +++ b/src/virtio/vulkan/vn_queue.c @@ -319,6 +319,8 @@ static void vn_fix_device_group_cmd_count(struct vn_queue_submission *submit, uint32_t batch_index) { + struct vk_queue *queue_vk = vk_queue_from_handle(submit->queue_handle); + struct vn_device *dev = (void *)queue_vk->base.device; const VkSubmitInfo *src_batch = &submit->submit_batches[batch_index]; struct vn_submit_info_pnext_fix *pnext_fix = submit->temp.pnexts; VkBaseOutStructure *dst = @@ -343,9 +345,9 @@ vn_fix_device_group_cmd_count(struct vn_queue_submission *submit, sizeof(uint32_t) * orig_cmd_count); } - /* Set feedback cmd device masks to 0 */ + /* Set the group device mask. Unlike sync2, zero means skip. */ for (uint32_t i = orig_cmd_count; i < new_cmd_count; i++) { - submit->temp.dev_masks[i] = 0; + submit->temp.dev_masks[i] = dev->device_mask; } pnext_fix->group.commandBufferCount = new_cmd_count;