diff --git a/src/intel/executor/examples/matmul.lua b/src/intel/executor/examples/matmul.lua new file mode 100644 index 00000000000..92910630bfa --- /dev/null +++ b/src/intel/executor/examples/matmul.lua @@ -0,0 +1,337 @@ +-- Copyright © 2025 Intel Corporation +-- SPDX-License-Identifier: MIT + +local HELP_MESSAGE = [[ +Matrix Multiplication using DPAS + +Usage: executor matmul.lua FORMAT A_FILE B_FILE [C_FILE] + +Perform matrix multiplication D = A * B + C using DPAS instruction. +If C is not provided, will be equivalent to all zeros. + +Input files have values separated by spaces, with the rows +separated by newlines. Values are expected to be valid for +the format, e.g. if a matrix should contain UB (unsigned byte), +values should be between 0-255. Float-pointing data is +expected to be in "raw" form as hexadecimal values. + +Matrices smaller than the maximum dimensions will be automatically +zero-padded to the required size for DPAS computation. + +Maximum dimensions are limited by data format and hardware version. + + Gfx20+ + - HF/F: A max 8x16, B max 16x16, C/D max 8x16 + - BF/F: A max 8x16, B max 16x16, C/D max 8x16 + - UB/UD: A max 8x32, B max 32x16, C/D max 8x16 + + Gfx125 + - HF/F: A max 8x16, B max 16x8, C/D max 8x8 + - BF/F: A max 8x16, B max 16x8, C/D max 8x8 + - UB/UD: A max 8x32, B max 32x8, C/D max 8x8 + +If `octave` is installed, it will be used to verify the results. +]] + +-- TODO: Change this program to load matrix data from memory +-- instead of setting data as immediates in the shader code. + +if not devinfo.has_dpas then + print("DPAS not supported on this platform.") + os.exit(1) +end + +local verify_results = false +do + local handle = io.popen("which octave 2>/dev/null") + local octave_path = handle:read("*a") + handle:close() + verify_results = octave_path and #octave_path > 0 +end + +local gen = require("mod/gen") +local matrix = require("mod/matrix") +local fp = require("mod/fp") + +local format_ab, format_cd, a_file, b_file, c_file + +for i = 1, #arg do + if arg[i] == "--help" or arg[i] == "-h" then + print(HELP_MESSAGE) + os.exit(0) + elseif not format_ab then + local format = arg[i]:upper() + format_ab = format:sub(1, format:find("/") - 1) + format_cd = format:sub(format:find("/") + 1) + elseif not a_file then a_file = arg[i] + elseif not b_file then b_file = arg[i] + elseif not c_file then c_file = arg[i] + end +end + +if not format_ab or not format_cd or not a_file or not b_file then + print("Usage: executor matmul.lua FORMAT A_FILE B_FILE [C_FILE]") + print("Use --help for more information") + os.exit(1) +end + +if not ((format_ab == "HF" and format_cd == "F") or + (format_ab == "BF" and format_cd == "F") or + (format_ab == "UB" and format_cd == "UD")) then + print("Error: format must be 'HF/F', 'BF/F', or 'UB/UD', got '" .. format_ab .. "/" .. format_cd .. "'") + print("Use --help for more information") + os.exit(1) +end + +local function read_matrix(m, filename, max_rows, max_cols) + local file = io.open(filename, "r") + if not file then + error("Failed to open file: " .. filename) + end + + -- Read entire file content first + local raw_content = file:read("*all") + file:close() + + local rows_data = {} + local cols = nil + + -- Parse from the raw content string + for line in raw_content:gmatch("[^\r\n]+") do + local row = {} + for val in line:gmatch("%S+") do + local num = tonumber(val) + if not num then + error(string.format( + "Error reading matrix from '%s': invalid number in row %d: " .. val, + filename, #rows_data + 1)) + end + + table.insert(row, num) + end + + if #row > 0 then + if cols == nil then + cols = #row + elseif #row ~= cols then + error(string.format( + "Error reading matrix from %s: inconsistent number of columns (%d) in row %d", + filename, #row, #rows_data + 1)) + end + table.insert(rows_data, row) + end + end + + local rows = #rows_data + + if rows > max_rows then + error(string.format( + "Error reading matrix from %s: too many rows (%d), maximum is %d", + filename, rows, max_rows)) + end + + if cols > max_cols then + error(string.format( + "Error reading matrix from %s: too many columns (%d), maximum is %d", + filename, cols, max_cols)) + end + + for i = 1, rows do + for j = 1, cols do + -- Matrix indices are zero indexed. + m:set(i-1, j-1, rows_data[i][j]) + end + end + + return rows, cols, raw_content +end + +local exec_size = devinfo.ver >= 20 and 16 or 8 +local packing_factor + +if format_ab == "HF" or format_ab == "BF" then + packing_factor = 2 +elseif format_ab == "UB" then + packing_factor = 4 +end + +local max_a = { rows = 8, cols = packing_factor * 8 } +local max_b = { rows = packing_factor * 8, cols = exec_size } +local max_c = { rows = 8, cols = exec_size } + +local a = matrix.new(max_a.rows, max_a.cols, 0) +local b = matrix.new(max_b.rows, max_b.cols, 0) +local c = matrix.new(max_c.rows, max_c.cols, 0) + +local actual_a_rows, actual_a_cols, a_raw_content = read_matrix(a, a_file, max_a.rows, max_a.cols) +local actual_b_rows, actual_b_cols, b_raw_content = read_matrix(b, b_file, max_b.rows, max_b.cols) + +if actual_a_cols ~= actual_b_rows then + error(string.format( + "Matrix dimension mismatch: A is %dx%d, B is %dx%d. A columns (%d) must equal B rows (%d)", + actual_a_rows, actual_a_cols, actual_b_rows, actual_b_cols, actual_a_cols, actual_b_rows)) +end + +local actual_c_rows, actual_c_cols, c_raw_content +if c_file then + actual_c_rows, actual_c_cols, c_raw_content = read_matrix(c, c_file, max_c.rows, max_c.cols) + if actual_a_rows ~= actual_c_rows or actual_b_cols ~= actual_c_cols then + error(string.format( + "Matrix dimension mismatch: A*B would be %dx%d, but C is %dx%d", + actual_a_rows, actual_b_cols, actual_c_rows, actual_c_cols)) + end +else + -- C defaults to zeros with dimensions matching A*B result + actual_c_rows, actual_c_cols = actual_a_rows, actual_b_cols + c_raw_content = nil +end + +local exec_size = c.cols + +local encode = function(m, fmt) + local f = nil + if fmt == "HF" then f = fp.encode_f16 + elseif fmt == "BF" then f = fp.encode_bf16 + elseif fmt == "F" then f = fp.encode_f32 + end + + if f then + m:apply(f) + end +end + +encode(a, format_ab) +encode(b, format_ab) +encode(c, format_cd) + +local buf = execute { + src = + [[]] + .. gen.mov_grf(format_ab, 10, a:to_row_major()) + .. gen.mov_grf(format_ab, 20, b:to_interleaved_row_major(packing_factor)) + .. gen.mov_grf(format_cd, 30, c:to_row_major()) + .. string.format([[ + + dpas.8x8(%d) r40<1>%s r30<1>%s r20<1>%s r10<1>%s {A@1 $1}; + @syncnop + + ]], exec_size, format_cd, format_cd, format_ab, format_ab) + .. gen.write_grfs(40, 8) + .. [[ + @eot + ]], +} + +local d = matrix.from_row_major_buffer(8, exec_size, buf) + +local d_print_fmt = nil +if string.find(format_cd, "F") then + d_print_fmt = "%.6f" + + local f = nil + if format_cd == "HF" then f = fp.decode_f16 + elseif format_cd == "BF" then f = fp.decode_bf16 + elseif format_cd == "F" then f = fp.decode_f32 + else + error("unsupported float format") + end + + d:apply(f) +end + +-- Just consider the actual rows, same as C matrix. +d:print_submatrix(actual_c_rows, actual_c_cols, d_print_fmt) + +-- +-- VERIFICATION USING OCTAVE. +-- + +if verify_results then + local function save_matrix_to_temp(m, rows, cols) + local filename = os.tmpname() + local file = io.open(filename, "w") + for i = 0, rows - 1 do + local row = {} + for j = 0, cols - 1 do + table.insert(row, tostring(m:get(i, j))) + end + file:write(table.concat(row, " ") .. "\n") + end + file:close() + return filename + end + + local function save_raw_content_to_temp(raw_content) + local filename = os.tmpname() + local file = io.open(filename, "w") + file:write(raw_content) + file:close() + return filename + end + + -- Save A and B raw contents to temp files for Octave + local a_for_octave = save_raw_content_to_temp(a_raw_content) + local b_for_octave = save_raw_content_to_temp(b_raw_content) + local d_for_octave = save_matrix_to_temp(d, actual_c_rows, actual_c_cols) + + local c_load, c_for_octave + if c_raw_content then + c_for_octave = save_raw_content_to_temp(c_raw_content) + c_load = string.format("C = single(dlmread('%s'));", c_for_octave) + else + c_load = string.format("C = single(zeros(%d, %d));", actual_c_rows, actual_c_cols) + end + + local tolerance = (format_cd == "F") and "1e-3" or "1e-6" + + -- TODO: Currently values are rounded to what fits in F (32-bit), but for + -- better results we should have a way to round them to precision based on + -- format, and handle HF and BF16. Octave doesn't support those types + -- natively. See https://github.com/higham/chop for Matlab version of this. + + local octave_script = string.format([[ +A = single(dlmread('%s')); +B = single(dlmread('%s')); +%s +D = single(dlmread('%s')); + +D_expected = A * B + C; + +%% Use relative tolerance for better comparison across different magnitudes. +max_val = max(abs(D_expected(:))); +tol = max_val * %s; + +if all(all(abs(D_expected - D) < tol)) + exit(0); +else + disp('MISMATCH!'); + disp('Octave result:'); + disp(D_expected); + disp('DPAS result:'); + disp(D); + exit(1); +endif +]], a_for_octave, b_for_octave, c_load, d_for_octave, tolerance) + + local exit_code = os.execute( + [[octave --quiet --no-gui --eval "]] + .. octave_script .. + [[" 2>&1]]) + + -- Clean up temporary files + os.remove(a_for_octave) + os.remove(b_for_octave) + if c_for_octave then + os.remove(c_for_octave) + end + os.remove(d_for_octave) + + if exit_code then + print("\nMatches Octave.") + else + print("\nMISMATCH with Octave!") + os.exit(1) + end +else + print("\nNOTE: Install `octave` to verify the results.") +end diff --git a/src/intel/executor/examples/mod/fp.lua b/src/intel/executor/examples/mod/fp.lua new file mode 100644 index 00000000000..e3662bb4d3a --- /dev/null +++ b/src/intel/executor/examples/mod/fp.lua @@ -0,0 +1,128 @@ +-- Copyright © 2025 Intel Corporation +-- SPDX-License-Identifier: MIT +-- +-- Encode and decode floating point types. +-- +-- Just enough to get basic usage of the examples. If this get serious might +-- be a good idea to implement it using the existing Mesa routines and exposing +-- as part of executor API. + +local M = {} + +M.decode_float = function(bits, mantissa_bits, exponent_bits, bias) + local total_bits = 1 + exponent_bits + mantissa_bits + + local sign_mask = 1 << (total_bits - 1) + local exponent_mask = ((1 << exponent_bits) - 1) << mantissa_bits + local mantissa_mask = (1 << mantissa_bits) - 1 + + local sign = (bits & sign_mask) ~= 0 + local exponent = (bits & exponent_mask) >> mantissa_bits + local mantissa = bits & mantissa_mask + + if exponent == 0 then + if mantissa == 0 then + return sign and "-0.0" or "0.0" + else + -- Subnormal number. They don't have implicit leading 1, so + -- the number corresponds to "0.mantissa * 2^(1-bias)". + local value = mantissa / (1 << mantissa_bits) + value = value * (2 ^ (1 - bias)) + if sign then value = -value end + return string.format("%.17g", value) + end + elseif exponent == (1 << exponent_bits) - 1 then + if mantissa == 0 then + return sign and "-inf" or "inf" + else + return "nan" + end + else + -- Normal numbers have implicit leading 1, so the + -- number corresponds to "1.mantissa * 2^(exponent-bias)". + local value = 1.0 + (mantissa / (1 << mantissa_bits)) + value = value * (2 ^ (exponent - bias)) + if sign then value = -value end + return string.format("%.17g", value) + end +end + +M.decode_f32 = function(bits) + return M.decode_float(bits, 23, 8, 127) +end + +M.decode_f16 = function(bits) + return M.decode_float(bits, 10, 5, 15) +end + +M.decode_bf16 = function(bits) + return M.decode_float(bits, 7, 8, 127) +end + +M.encode_float = function(value_str, mantissa_bits, exponent_bits, bias) + local value = tonumber(value_str) + if not value then + return nil + end + + local total_bits = 1 + exponent_bits + mantissa_bits + local max_exponent = (1 << exponent_bits) - 1 + + local sign_bit = value < 0 and (1 << (total_bits - 1)) or 0 + local signed_inf = sign_bit | max_exponent << mantissa_bits + + -- Handle various special cases first: signed zero, NaN, Inf/-Inf. + if value == 0.0 then + return sign_bit + elseif value ~= value then + return (1 << (total_bits - 1)) | (max_exponent << mantissa_bits) | 1 + elseif value == math.huge or value == -math.huge then + return signed_inf + end + + -- Do math with the absolute value from now on. + if sign_bit ~= 0 then value = -value end + + local exponent = math.floor(math.log(value) / math.log(2)) + + local mantissa_value = value / (2 ^ exponent) - 1.0 + + exponent = exponent + bias + + if exponent <= 0 then + -- Subnormal: no implicit leading 1, use minimum exponent + mantissa_value = value / (2 ^ (1 - bias)) + exponent = 0 + elseif exponent >= max_exponent then + -- Value too large to represent. + return signed_inf + end + + local mantissa = math.floor(mantissa_value * (1 << mantissa_bits) + 0.5) + + if mantissa >= (1 << mantissa_bits) then + -- Rounding caused mantissa to overflow, increment exponent. + mantissa = 0 + exponent = exponent + 1 + if exponent >= max_exponent then + -- Value too large to represent. + return signed_inf + end + end + + return sign_bit | (exponent << mantissa_bits) | mantissa +end + +M.encode_f32 = function(value_str) + return M.encode_float(value_str, 23, 8, 127) +end + +M.encode_f16 = function(value_str) + return M.encode_float(value_str, 10, 5, 15) +end + +M.encode_bf16 = function(value_str) + return M.encode_float(value_str, 7, 8, 127) +end + +return M diff --git a/src/intel/executor/examples/mod/matrix.lua b/src/intel/executor/examples/mod/matrix.lua index 382d2aa4117..19b3594211f 100644 --- a/src/intel/executor/examples/mod/matrix.lua +++ b/src/intel/executor/examples/mod/matrix.lua @@ -22,20 +22,24 @@ Matrix.set = function(self, i, j, value) self.data[i][j] = value end -Matrix.print = function(self, fmt) +Matrix.print_submatrix = function(self, rows, cols, fmt) local fmt = fmt or "%4u" - print(string.format("# %dx%d matrix", self.rows, self.cols)) + print(string.format("# %dx%d matrix", rows, cols)) io.write("[\n") - for i = 0, self.rows - 1 do - for j = 0, self.cols - 1 do + for i = 0, rows - 1 do + for j = 0, cols - 1 do io.write(string.format(fmt, self.data[i][j])) - if j < self.cols - 1 then io.write(" ") end + if j < cols - 1 then io.write(" ") end end io.write("\n") end io.write("]\n") end +Matrix.print = function(self, fmt) + self:print_submatrix(self.rows, self.cols, fmt) +end + -- "Interleaved" row major is like row major except that -- elements from `packing_factor` rows are packed together. -- @@ -136,4 +140,12 @@ M.from_row_major_buffer = function(rows, cols, data) return self end +Matrix.apply = function(self, func) + for i = 0, self.rows - 1 do + for j = 0, self.cols - 1 do + self.data[i][j] = func(self.data[i][j]) + end + end +end + return M diff --git a/src/intel/executor/examples/test_matmul.sh b/src/intel/executor/examples/test_matmul.sh new file mode 100755 index 00000000000..af482013f5e --- /dev/null +++ b/src/intel/executor/examples/test_matmul.sh @@ -0,0 +1,179 @@ +#!/bin/bash + +cd "$(dirname "${BASH_SOURCE[0]}")" + +if ! command -v executor &> /dev/null; then + echo "ERROR: executor command not found." >&2 + exit 1 +fi + +set -e + +executor matmul.lua UB/UD \ + <(cat <