mesa/src/intel/executor/examples/matmul.lua
Caio Oliveira 74859c19fb
Some checks are pending
macOS-CI / macOS-CI (dri) (push) Waiting to run
macOS-CI / macOS-CI (xlib) (push) Waiting to run
intel/executor: Add a matrix multiplication example
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37805>
2025-10-11 01:02:45 +00:00

337 lines
9.3 KiB
Lua

-- Copyright © 2025 Intel Corporation
-- SPDX-License-Identifier: MIT
local HELP_MESSAGE = [[
Matrix Multiplication using DPAS
Usage: executor matmul.lua FORMAT A_FILE B_FILE [C_FILE]
Perform matrix multiplication D = A * B + C using DPAS instruction.
If C is not provided, will be equivalent to all zeros.
Input files have values separated by spaces, with the rows
separated by newlines. Values are expected to be valid for
the format, e.g. if a matrix should contain UB (unsigned byte),
values should be between 0-255. Float-pointing data is
expected to be in "raw" form as hexadecimal values.
Matrices smaller than the maximum dimensions will be automatically
zero-padded to the required size for DPAS computation.
Maximum dimensions are limited by data format and hardware version.
Gfx20+
- HF/F: A max 8x16, B max 16x16, C/D max 8x16
- BF/F: A max 8x16, B max 16x16, C/D max 8x16
- UB/UD: A max 8x32, B max 32x16, C/D max 8x16
Gfx125
- HF/F: A max 8x16, B max 16x8, C/D max 8x8
- BF/F: A max 8x16, B max 16x8, C/D max 8x8
- UB/UD: A max 8x32, B max 32x8, C/D max 8x8
If `octave` is installed, it will be used to verify the results.
]]
-- TODO: Change this program to load matrix data from memory
-- instead of setting data as immediates in the shader code.
if not devinfo.has_dpas then
print("DPAS not supported on this platform.")
os.exit(1)
end
local verify_results = false
do
local handle = io.popen("which octave 2>/dev/null")
local octave_path = handle:read("*a")
handle:close()
verify_results = octave_path and #octave_path > 0
end
local gen = require("mod/gen")
local matrix = require("mod/matrix")
local fp = require("mod/fp")
local format_ab, format_cd, a_file, b_file, c_file
for i = 1, #arg do
if arg[i] == "--help" or arg[i] == "-h" then
print(HELP_MESSAGE)
os.exit(0)
elseif not format_ab then
local format = arg[i]:upper()
format_ab = format:sub(1, format:find("/") - 1)
format_cd = format:sub(format:find("/") + 1)
elseif not a_file then a_file = arg[i]
elseif not b_file then b_file = arg[i]
elseif not c_file then c_file = arg[i]
end
end
if not format_ab or not format_cd or not a_file or not b_file then
print("Usage: executor matmul.lua FORMAT A_FILE B_FILE [C_FILE]")
print("Use --help for more information")
os.exit(1)
end
if not ((format_ab == "HF" and format_cd == "F") or
(format_ab == "BF" and format_cd == "F") or
(format_ab == "UB" and format_cd == "UD")) then
print("Error: format must be 'HF/F', 'BF/F', or 'UB/UD', got '" .. format_ab .. "/" .. format_cd .. "'")
print("Use --help for more information")
os.exit(1)
end
local function read_matrix(m, filename, max_rows, max_cols)
local file = io.open(filename, "r")
if not file then
error("Failed to open file: " .. filename)
end
-- Read entire file content first
local raw_content = file:read("*all")
file:close()
local rows_data = {}
local cols = nil
-- Parse from the raw content string
for line in raw_content:gmatch("[^\r\n]+") do
local row = {}
for val in line:gmatch("%S+") do
local num = tonumber(val)
if not num then
error(string.format(
"Error reading matrix from '%s': invalid number in row %d: " .. val,
filename, #rows_data + 1))
end
table.insert(row, num)
end
if #row > 0 then
if cols == nil then
cols = #row
elseif #row ~= cols then
error(string.format(
"Error reading matrix from %s: inconsistent number of columns (%d) in row %d",
filename, #row, #rows_data + 1))
end
table.insert(rows_data, row)
end
end
local rows = #rows_data
if rows > max_rows then
error(string.format(
"Error reading matrix from %s: too many rows (%d), maximum is %d",
filename, rows, max_rows))
end
if cols > max_cols then
error(string.format(
"Error reading matrix from %s: too many columns (%d), maximum is %d",
filename, cols, max_cols))
end
for i = 1, rows do
for j = 1, cols do
-- Matrix indices are zero indexed.
m:set(i-1, j-1, rows_data[i][j])
end
end
return rows, cols, raw_content
end
local exec_size = devinfo.ver >= 20 and 16 or 8
local packing_factor
if format_ab == "HF" or format_ab == "BF" then
packing_factor = 2
elseif format_ab == "UB" then
packing_factor = 4
end
local max_a = { rows = 8, cols = packing_factor * 8 }
local max_b = { rows = packing_factor * 8, cols = exec_size }
local max_c = { rows = 8, cols = exec_size }
local a = matrix.new(max_a.rows, max_a.cols, 0)
local b = matrix.new(max_b.rows, max_b.cols, 0)
local c = matrix.new(max_c.rows, max_c.cols, 0)
local actual_a_rows, actual_a_cols, a_raw_content = read_matrix(a, a_file, max_a.rows, max_a.cols)
local actual_b_rows, actual_b_cols, b_raw_content = read_matrix(b, b_file, max_b.rows, max_b.cols)
if actual_a_cols ~= actual_b_rows then
error(string.format(
"Matrix dimension mismatch: A is %dx%d, B is %dx%d. A columns (%d) must equal B rows (%d)",
actual_a_rows, actual_a_cols, actual_b_rows, actual_b_cols, actual_a_cols, actual_b_rows))
end
local actual_c_rows, actual_c_cols, c_raw_content
if c_file then
actual_c_rows, actual_c_cols, c_raw_content = read_matrix(c, c_file, max_c.rows, max_c.cols)
if actual_a_rows ~= actual_c_rows or actual_b_cols ~= actual_c_cols then
error(string.format(
"Matrix dimension mismatch: A*B would be %dx%d, but C is %dx%d",
actual_a_rows, actual_b_cols, actual_c_rows, actual_c_cols))
end
else
-- C defaults to zeros with dimensions matching A*B result
actual_c_rows, actual_c_cols = actual_a_rows, actual_b_cols
c_raw_content = nil
end
local exec_size = c.cols
local encode = function(m, fmt)
local f = nil
if fmt == "HF" then f = fp.encode_f16
elseif fmt == "BF" then f = fp.encode_bf16
elseif fmt == "F" then f = fp.encode_f32
end
if f then
m:apply(f)
end
end
encode(a, format_ab)
encode(b, format_ab)
encode(c, format_cd)
local buf = execute {
src =
[[]]
.. gen.mov_grf(format_ab, 10, a:to_row_major())
.. gen.mov_grf(format_ab, 20, b:to_interleaved_row_major(packing_factor))
.. gen.mov_grf(format_cd, 30, c:to_row_major())
.. string.format([[
dpas.8x8(%d) r40<1>%s r30<1>%s r20<1>%s r10<1>%s {A@1 $1};
@syncnop
]], exec_size, format_cd, format_cd, format_ab, format_ab)
.. gen.write_grfs(40, 8)
.. [[
@eot
]],
}
local d = matrix.from_row_major_buffer(8, exec_size, buf)
local d_print_fmt = nil
if string.find(format_cd, "F") then
d_print_fmt = "%.6f"
local f = nil
if format_cd == "HF" then f = fp.decode_f16
elseif format_cd == "BF" then f = fp.decode_bf16
elseif format_cd == "F" then f = fp.decode_f32
else
error("unsupported float format")
end
d:apply(f)
end
-- Just consider the actual rows, same as C matrix.
d:print_submatrix(actual_c_rows, actual_c_cols, d_print_fmt)
--
-- VERIFICATION USING OCTAVE.
--
if verify_results then
local function save_matrix_to_temp(m, rows, cols)
local filename = os.tmpname()
local file = io.open(filename, "w")
for i = 0, rows - 1 do
local row = {}
for j = 0, cols - 1 do
table.insert(row, tostring(m:get(i, j)))
end
file:write(table.concat(row, " ") .. "\n")
end
file:close()
return filename
end
local function save_raw_content_to_temp(raw_content)
local filename = os.tmpname()
local file = io.open(filename, "w")
file:write(raw_content)
file:close()
return filename
end
-- Save A and B raw contents to temp files for Octave
local a_for_octave = save_raw_content_to_temp(a_raw_content)
local b_for_octave = save_raw_content_to_temp(b_raw_content)
local d_for_octave = save_matrix_to_temp(d, actual_c_rows, actual_c_cols)
local c_load, c_for_octave
if c_raw_content then
c_for_octave = save_raw_content_to_temp(c_raw_content)
c_load = string.format("C = single(dlmread('%s'));", c_for_octave)
else
c_load = string.format("C = single(zeros(%d, %d));", actual_c_rows, actual_c_cols)
end
local tolerance = (format_cd == "F") and "1e-3" or "1e-6"
-- TODO: Currently values are rounded to what fits in F (32-bit), but for
-- better results we should have a way to round them to precision based on
-- format, and handle HF and BF16. Octave doesn't support those types
-- natively. See https://github.com/higham/chop for Matlab version of this.
local octave_script = string.format([[
A = single(dlmread('%s'));
B = single(dlmread('%s'));
%s
D = single(dlmread('%s'));
D_expected = A * B + C;
%% Use relative tolerance for better comparison across different magnitudes.
max_val = max(abs(D_expected(:)));
tol = max_val * %s;
if all(all(abs(D_expected - D) < tol))
exit(0);
else
disp('MISMATCH!');
disp('Octave result:');
disp(D_expected);
disp('DPAS result:');
disp(D);
exit(1);
endif
]], a_for_octave, b_for_octave, c_load, d_for_octave, tolerance)
local exit_code = os.execute(
[[octave --quiet --no-gui --eval "]]
.. octave_script ..
[[" 2>&1]])
-- Clean up temporary files
os.remove(a_for_octave)
os.remove(b_for_octave)
if c_for_octave then
os.remove(c_for_octave)
end
os.remove(d_for_octave)
if exit_code then
print("\nMatches Octave.")
else
print("\nMISMATCH with Octave!")
os.exit(1)
end
else
print("\nNOTE: Install `octave` to verify the results.")
end