diff --git a/src/util/os_memory_fd.c b/src/util/os_memory_fd.c index 86c80c9e8a1..eaf31bd45c4 100644 --- a/src/util/os_memory_fd.c +++ b/src/util/os_memory_fd.c @@ -66,6 +66,23 @@ get_driver_id_sha1_hash(uint8_t sha1[SHA1_DIGEST_LENGTH], const char *driver_id) _mesa_sha1_final(&sha1_ctx, sha1); } +static bool +get_fd_header(int fd, struct memory_header *header, char const *driver_id) +{ + lseek(fd, 0, SEEK_SET); + const int bytes_read = read(fd, header, sizeof(*header)); + if (bytes_read != sizeof(*header)) + return false; + + // Check the uuid we put after the sizes in order to verify that the fd + // is a memfd that we created and not some random fd. + uint8_t sha1[SHA1_DIGEST_LENGTH]; + get_driver_id_sha1_hash(sha1, driver_id); + + assert(SHA1_DIGEST_LENGTH >= UUID_SIZE); + return memcmp(header->uuid, sha1, UUID_SIZE) == 0; +} + /** * Imports memory from a file descriptor */ @@ -75,21 +92,9 @@ os_import_memory_fd(int fd, void **ptr, uint64_t *size, char const *driver_id) void *mapped_ptr; struct memory_header header; - lseek(fd, 0, SEEK_SET); - int bytes_read = read(fd, &header, sizeof(header)); - if(bytes_read != sizeof(header)) + if (!get_fd_header(fd, &header, driver_id)) return false; - // Check the uuid we put after the sizes in order to verify that the fd - // is a memfd that we created and not some random fd. - uint8_t sha1[SHA1_DIGEST_LENGTH]; - get_driver_id_sha1_hash(sha1, driver_id); - - assert(SHA1_DIGEST_LENGTH >= UUID_SIZE); - if (memcmp(header.uuid, sha1, UUID_SIZE)) { - return false; - } - mapped_ptr = mmap(NULL, header.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); if (mapped_ptr == MAP_FAILED) { return false;