mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-02-15 19:40:30 +01:00
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37805>
337 lines
9.3 KiB
Lua
337 lines
9.3 KiB
Lua
-- Copyright © 2025 Intel Corporation
|
|
-- SPDX-License-Identifier: MIT
|
|
|
|
local HELP_MESSAGE = [[
|
|
Matrix Multiplication using DPAS
|
|
|
|
Usage: executor matmul.lua FORMAT A_FILE B_FILE [C_FILE]
|
|
|
|
Perform matrix multiplication D = A * B + C using DPAS instruction.
|
|
If C is not provided, will be equivalent to all zeros.
|
|
|
|
Input files have values separated by spaces, with the rows
|
|
separated by newlines. Values are expected to be valid for
|
|
the format, e.g. if a matrix should contain UB (unsigned byte),
|
|
values should be between 0-255. Float-pointing data is
|
|
expected to be in "raw" form as hexadecimal values.
|
|
|
|
Matrices smaller than the maximum dimensions will be automatically
|
|
zero-padded to the required size for DPAS computation.
|
|
|
|
Maximum dimensions are limited by data format and hardware version.
|
|
|
|
Gfx20+
|
|
- HF/F: A max 8x16, B max 16x16, C/D max 8x16
|
|
- BF/F: A max 8x16, B max 16x16, C/D max 8x16
|
|
- UB/UD: A max 8x32, B max 32x16, C/D max 8x16
|
|
|
|
Gfx125
|
|
- HF/F: A max 8x16, B max 16x8, C/D max 8x8
|
|
- BF/F: A max 8x16, B max 16x8, C/D max 8x8
|
|
- UB/UD: A max 8x32, B max 32x8, C/D max 8x8
|
|
|
|
If `octave` is installed, it will be used to verify the results.
|
|
]]
|
|
|
|
-- TODO: Change this program to load matrix data from memory
|
|
-- instead of setting data as immediates in the shader code.
|
|
|
|
if not devinfo.has_dpas then
|
|
print("DPAS not supported on this platform.")
|
|
os.exit(1)
|
|
end
|
|
|
|
local verify_results = false
|
|
do
|
|
local handle = io.popen("which octave 2>/dev/null")
|
|
local octave_path = handle:read("*a")
|
|
handle:close()
|
|
verify_results = octave_path and #octave_path > 0
|
|
end
|
|
|
|
local gen = require("mod/gen")
|
|
local matrix = require("mod/matrix")
|
|
local fp = require("mod/fp")
|
|
|
|
local format_ab, format_cd, a_file, b_file, c_file
|
|
|
|
for i = 1, #arg do
|
|
if arg[i] == "--help" or arg[i] == "-h" then
|
|
print(HELP_MESSAGE)
|
|
os.exit(0)
|
|
elseif not format_ab then
|
|
local format = arg[i]:upper()
|
|
format_ab = format:sub(1, format:find("/") - 1)
|
|
format_cd = format:sub(format:find("/") + 1)
|
|
elseif not a_file then a_file = arg[i]
|
|
elseif not b_file then b_file = arg[i]
|
|
elseif not c_file then c_file = arg[i]
|
|
end
|
|
end
|
|
|
|
if not format_ab or not format_cd or not a_file or not b_file then
|
|
print("Usage: executor matmul.lua FORMAT A_FILE B_FILE [C_FILE]")
|
|
print("Use --help for more information")
|
|
os.exit(1)
|
|
end
|
|
|
|
if not ((format_ab == "HF" and format_cd == "F") or
|
|
(format_ab == "BF" and format_cd == "F") or
|
|
(format_ab == "UB" and format_cd == "UD")) then
|
|
print("Error: format must be 'HF/F', 'BF/F', or 'UB/UD', got '" .. format_ab .. "/" .. format_cd .. "'")
|
|
print("Use --help for more information")
|
|
os.exit(1)
|
|
end
|
|
|
|
local function read_matrix(m, filename, max_rows, max_cols)
|
|
local file = io.open(filename, "r")
|
|
if not file then
|
|
error("Failed to open file: " .. filename)
|
|
end
|
|
|
|
-- Read entire file content first
|
|
local raw_content = file:read("*all")
|
|
file:close()
|
|
|
|
local rows_data = {}
|
|
local cols = nil
|
|
|
|
-- Parse from the raw content string
|
|
for line in raw_content:gmatch("[^\r\n]+") do
|
|
local row = {}
|
|
for val in line:gmatch("%S+") do
|
|
local num = tonumber(val)
|
|
if not num then
|
|
error(string.format(
|
|
"Error reading matrix from '%s': invalid number in row %d: " .. val,
|
|
filename, #rows_data + 1))
|
|
end
|
|
|
|
table.insert(row, num)
|
|
end
|
|
|
|
if #row > 0 then
|
|
if cols == nil then
|
|
cols = #row
|
|
elseif #row ~= cols then
|
|
error(string.format(
|
|
"Error reading matrix from %s: inconsistent number of columns (%d) in row %d",
|
|
filename, #row, #rows_data + 1))
|
|
end
|
|
table.insert(rows_data, row)
|
|
end
|
|
end
|
|
|
|
local rows = #rows_data
|
|
|
|
if rows > max_rows then
|
|
error(string.format(
|
|
"Error reading matrix from %s: too many rows (%d), maximum is %d",
|
|
filename, rows, max_rows))
|
|
end
|
|
|
|
if cols > max_cols then
|
|
error(string.format(
|
|
"Error reading matrix from %s: too many columns (%d), maximum is %d",
|
|
filename, cols, max_cols))
|
|
end
|
|
|
|
for i = 1, rows do
|
|
for j = 1, cols do
|
|
-- Matrix indices are zero indexed.
|
|
m:set(i-1, j-1, rows_data[i][j])
|
|
end
|
|
end
|
|
|
|
return rows, cols, raw_content
|
|
end
|
|
|
|
local exec_size = devinfo.ver >= 20 and 16 or 8
|
|
local packing_factor
|
|
|
|
if format_ab == "HF" or format_ab == "BF" then
|
|
packing_factor = 2
|
|
elseif format_ab == "UB" then
|
|
packing_factor = 4
|
|
end
|
|
|
|
local max_a = { rows = 8, cols = packing_factor * 8 }
|
|
local max_b = { rows = packing_factor * 8, cols = exec_size }
|
|
local max_c = { rows = 8, cols = exec_size }
|
|
|
|
local a = matrix.new(max_a.rows, max_a.cols, 0)
|
|
local b = matrix.new(max_b.rows, max_b.cols, 0)
|
|
local c = matrix.new(max_c.rows, max_c.cols, 0)
|
|
|
|
local actual_a_rows, actual_a_cols, a_raw_content = read_matrix(a, a_file, max_a.rows, max_a.cols)
|
|
local actual_b_rows, actual_b_cols, b_raw_content = read_matrix(b, b_file, max_b.rows, max_b.cols)
|
|
|
|
if actual_a_cols ~= actual_b_rows then
|
|
error(string.format(
|
|
"Matrix dimension mismatch: A is %dx%d, B is %dx%d. A columns (%d) must equal B rows (%d)",
|
|
actual_a_rows, actual_a_cols, actual_b_rows, actual_b_cols, actual_a_cols, actual_b_rows))
|
|
end
|
|
|
|
local actual_c_rows, actual_c_cols, c_raw_content
|
|
if c_file then
|
|
actual_c_rows, actual_c_cols, c_raw_content = read_matrix(c, c_file, max_c.rows, max_c.cols)
|
|
if actual_a_rows ~= actual_c_rows or actual_b_cols ~= actual_c_cols then
|
|
error(string.format(
|
|
"Matrix dimension mismatch: A*B would be %dx%d, but C is %dx%d",
|
|
actual_a_rows, actual_b_cols, actual_c_rows, actual_c_cols))
|
|
end
|
|
else
|
|
-- C defaults to zeros with dimensions matching A*B result
|
|
actual_c_rows, actual_c_cols = actual_a_rows, actual_b_cols
|
|
c_raw_content = nil
|
|
end
|
|
|
|
local exec_size = c.cols
|
|
|
|
local encode = function(m, fmt)
|
|
local f = nil
|
|
if fmt == "HF" then f = fp.encode_f16
|
|
elseif fmt == "BF" then f = fp.encode_bf16
|
|
elseif fmt == "F" then f = fp.encode_f32
|
|
end
|
|
|
|
if f then
|
|
m:apply(f)
|
|
end
|
|
end
|
|
|
|
encode(a, format_ab)
|
|
encode(b, format_ab)
|
|
encode(c, format_cd)
|
|
|
|
local buf = execute {
|
|
src =
|
|
[[]]
|
|
.. gen.mov_grf(format_ab, 10, a:to_row_major())
|
|
.. gen.mov_grf(format_ab, 20, b:to_interleaved_row_major(packing_factor))
|
|
.. gen.mov_grf(format_cd, 30, c:to_row_major())
|
|
.. string.format([[
|
|
|
|
dpas.8x8(%d) r40<1>%s r30<1>%s r20<1>%s r10<1>%s {A@1 $1};
|
|
@syncnop
|
|
|
|
]], exec_size, format_cd, format_cd, format_ab, format_ab)
|
|
.. gen.write_grfs(40, 8)
|
|
.. [[
|
|
@eot
|
|
]],
|
|
}
|
|
|
|
local d = matrix.from_row_major_buffer(8, exec_size, buf)
|
|
|
|
local d_print_fmt = nil
|
|
if string.find(format_cd, "F") then
|
|
d_print_fmt = "%.6f"
|
|
|
|
local f = nil
|
|
if format_cd == "HF" then f = fp.decode_f16
|
|
elseif format_cd == "BF" then f = fp.decode_bf16
|
|
elseif format_cd == "F" then f = fp.decode_f32
|
|
else
|
|
error("unsupported float format")
|
|
end
|
|
|
|
d:apply(f)
|
|
end
|
|
|
|
-- Just consider the actual rows, same as C matrix.
|
|
d:print_submatrix(actual_c_rows, actual_c_cols, d_print_fmt)
|
|
|
|
--
|
|
-- VERIFICATION USING OCTAVE.
|
|
--
|
|
|
|
if verify_results then
|
|
local function save_matrix_to_temp(m, rows, cols)
|
|
local filename = os.tmpname()
|
|
local file = io.open(filename, "w")
|
|
for i = 0, rows - 1 do
|
|
local row = {}
|
|
for j = 0, cols - 1 do
|
|
table.insert(row, tostring(m:get(i, j)))
|
|
end
|
|
file:write(table.concat(row, " ") .. "\n")
|
|
end
|
|
file:close()
|
|
return filename
|
|
end
|
|
|
|
local function save_raw_content_to_temp(raw_content)
|
|
local filename = os.tmpname()
|
|
local file = io.open(filename, "w")
|
|
file:write(raw_content)
|
|
file:close()
|
|
return filename
|
|
end
|
|
|
|
-- Save A and B raw contents to temp files for Octave
|
|
local a_for_octave = save_raw_content_to_temp(a_raw_content)
|
|
local b_for_octave = save_raw_content_to_temp(b_raw_content)
|
|
local d_for_octave = save_matrix_to_temp(d, actual_c_rows, actual_c_cols)
|
|
|
|
local c_load, c_for_octave
|
|
if c_raw_content then
|
|
c_for_octave = save_raw_content_to_temp(c_raw_content)
|
|
c_load = string.format("C = single(dlmread('%s'));", c_for_octave)
|
|
else
|
|
c_load = string.format("C = single(zeros(%d, %d));", actual_c_rows, actual_c_cols)
|
|
end
|
|
|
|
local tolerance = (format_cd == "F") and "1e-3" or "1e-6"
|
|
|
|
-- TODO: Currently values are rounded to what fits in F (32-bit), but for
|
|
-- better results we should have a way to round them to precision based on
|
|
-- format, and handle HF and BF16. Octave doesn't support those types
|
|
-- natively. See https://github.com/higham/chop for Matlab version of this.
|
|
|
|
local octave_script = string.format([[
|
|
A = single(dlmread('%s'));
|
|
B = single(dlmread('%s'));
|
|
%s
|
|
D = single(dlmread('%s'));
|
|
|
|
D_expected = A * B + C;
|
|
|
|
%% Use relative tolerance for better comparison across different magnitudes.
|
|
max_val = max(abs(D_expected(:)));
|
|
tol = max_val * %s;
|
|
|
|
if all(all(abs(D_expected - D) < tol))
|
|
exit(0);
|
|
else
|
|
disp('MISMATCH!');
|
|
disp('Octave result:');
|
|
disp(D_expected);
|
|
disp('DPAS result:');
|
|
disp(D);
|
|
exit(1);
|
|
endif
|
|
]], a_for_octave, b_for_octave, c_load, d_for_octave, tolerance)
|
|
|
|
local exit_code = os.execute(
|
|
[[octave --quiet --no-gui --eval "]]
|
|
.. octave_script ..
|
|
[[" 2>&1]])
|
|
|
|
-- Clean up temporary files
|
|
os.remove(a_for_octave)
|
|
os.remove(b_for_octave)
|
|
if c_for_octave then
|
|
os.remove(c_for_octave)
|
|
end
|
|
os.remove(d_for_octave)
|
|
|
|
if exit_code then
|
|
print("\nMatches Octave.")
|
|
else
|
|
print("\nMISMATCH with Octave!")
|
|
os.exit(1)
|
|
end
|
|
else
|
|
print("\nNOTE: Install `octave` to verify the results.")
|
|
end
|