From 78102a47ada5624ccb25809aa9d0dd3485368b8d Mon Sep 17 00:00:00 2001 From: Jagrit Digani Date: Wed, 17 Jan 2024 12:42:39 -0800 Subject: [PATCH] Update GEMM (#424) * Organize and collect metal subroutine templates and elements in `metal/kernels/steel/` * Update gemm elements for better performance * Add split-K specialization for gemm * Add `addmm` primitive, op and bindings for fused matmul and bias addition * Update tests and benchmarks as needed --- benchmarks/python/blas/bench_gemm.py | 4 +- benchmarks/python/comparative/bench_mlx.py | 12 +- mlx/backend/accelerate/matmul.cpp | 51 +- mlx/backend/common/default_primitives.cpp | 52 +- mlx/backend/metal/conv.cpp | 6 +- mlx/backend/metal/kernels/CMakeLists.txt | 37 +- mlx/backend/metal/kernels/{gemm => }/conv.h | 0 mlx/backend/metal/kernels/conv.metal | 2 +- mlx/backend/metal/kernels/gemm/gemm.h | 538 ------------------ mlx/backend/metal/kernels/quantized.metal | 12 +- mlx/backend/metal/kernels/steel/gemm/gemm.h | 312 ++++++++++ .../gemm/kernels/steel_gemm.metal} | 44 +- .../steel/gemm/kernels/steel_gemm_addmm.metal | 260 +++++++++ .../gemm/kernels/steel_gemm_splitk.metal | 280 +++++++++ mlx/backend/metal/kernels/steel/gemm/loader.h | 160 ++++++ mlx/backend/metal/kernels/steel/gemm/mma.h | 264 +++++++++ mlx/backend/metal/kernels/steel/gemm/params.h | 79 +++ .../metal/kernels/steel/gemm/transforms.h | 63 ++ mlx/backend/metal/kernels/steel/host.h | 5 + mlx/backend/metal/kernels/steel/utils.h | 9 + mlx/backend/metal/matmul.cpp | 471 +++++++++++++-- mlx/backend/metal/matmul.h | 2 +- mlx/backend/no_metal/primitives.cpp | 1 + mlx/ops.cpp | 94 +++ mlx/ops.h | 8 + mlx/primitives.cpp | 46 ++ mlx/primitives.h | 23 + python/mlx/nn/layers/linear.py | 5 +- python/src/ops.cpp | 30 + python/tests/test_blas.py | 137 +++++ 30 files changed, 2361 insertions(+), 646 deletions(-) rename mlx/backend/metal/kernels/{gemm => }/conv.h (100%) delete mode 100644 mlx/backend/metal/kernels/gemm/gemm.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/gemm.h rename mlx/backend/metal/kernels/{gemm.metal => steel/gemm/kernels/steel_gemm.metal} (72%) create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_addmm.metal create mode 100644 mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal create mode 100644 mlx/backend/metal/kernels/steel/gemm/loader.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/mma.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/params.h create mode 100644 mlx/backend/metal/kernels/steel/gemm/transforms.h create mode 100644 mlx/backend/metal/kernels/steel/host.h create mode 100644 mlx/backend/metal/kernels/steel/utils.h diff --git a/benchmarks/python/blas/bench_gemm.py b/benchmarks/python/blas/bench_gemm.py index 3681cd777..4914c40ba 100644 --- a/benchmarks/python/blas/bench_gemm.py +++ b/benchmarks/python/blas/bench_gemm.py @@ -166,13 +166,13 @@ def get_gflop_count(B, M, N, K): dtypes = ("float32", "float16") transposes = ("nn", "nt", "tn") shapes = ( + (16, 234, 768, 3072), + (1, 64, 64, 25344), (16, 1024, 1024, 1024), (1, 1024, 1024, 2048), (4, 1024, 1024, 4096), (4, 1024, 4096, 1024), (1, 4096, 4096, 4096), - (15, 1023, 1023, 1023), - (17, 1025, 1025, 1025), ) for dtype in dtypes: diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py index 8b96840f7..ecbb65cb4 100644 --- a/benchmarks/python/comparative/bench_mlx.py +++ b/benchmarks/python/comparative/bench_mlx.py @@ -257,6 +257,13 @@ def linear(w, b, x): mx.eval(ys) +def linear_fused(w, b, x): + ys = [] + for i in range(10): + ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0)))) + mx.eval(ys) + + def rope(x): *_, N, D = x.shape ys = [] @@ -397,7 +404,10 @@ def selu(x): print(bench(quant_matmul[args.benchmark], *xs)) elif args.benchmark == "linear": - print(bench(linear, *xs)) + if args.fused: + print(bench(linear_fused, *xs)) + else: + print(bench(linear, *xs)) elif args.benchmark == "sum_axis": print(bench(reduction, "sum", axis, x)) diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp index 5a38e3123..254f3fc4c 100644 --- a/mlx/backend/accelerate/matmul.cpp +++ b/mlx/backend/accelerate/matmul.cpp @@ -29,12 +29,16 @@ std::tuple check_transpose(const array& arr) { } } -inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) { +inline void matmul_cblas_general( + const array& a_pre, + const array& b_pre, + array& out, + float alpha = 1.0f, + float beta = 0.0f) { if (out.dtype() != float32) { throw std::runtime_error( "[matmul_cblas] on CPU currently only supports float32"); } - out.set_data(allocator::malloc_or_wait(out.nbytes())); auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); @@ -50,21 +54,34 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) { M, N, K, - 1.0f, // alpha + alpha, // alpha a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), lda, b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), ldb, - 0.0f, // beta + beta, // beta out.data() + M * N * i, out.shape(-1) // ldc ); } } -inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { - // TODO: Update to utilize BNNS broadcasting +inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[matmul_cblas] on CPU currently only supports float32"); + } out.set_data(allocator::malloc_or_wait(out.nbytes())); + return matmul_cblas_general(a_pre, b_pre, out); +} + +inline void matmul_bnns_general( + const array& a_pre, + const array& b_pre, + array& out, + float alpha = 1.0f, + float beta = 0.0f) { + // TODO: Update to utilize BNNS broadcasting auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); @@ -75,8 +92,8 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); const BNNSLayerParametersBroadcastMatMul gemm_params{ - /* float alpha = */ 1.0, - /* float beta = */ 0.0, + /* float alpha = */ alpha, + /* float beta = */ beta, /* bool transA = */ a_transposed, /* bool transB = */ b_transposed, /* bool quadratic = */ false, @@ -157,6 +174,12 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { BNNSFilterDestroy(bnns_filter); } +inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { + // TODO: Update to utilize BNNS broadcasting + out.set_data(allocator::malloc_or_wait(out.nbytes())); + return matmul_bnns_general(a_pre, b_pre, out); +} + } // namespace void Matmul::eval_cpu(const std::vector& inputs, array& out) { @@ -166,4 +189,16 @@ void Matmul::eval_cpu(const std::vector& inputs, array& out) { return matmul_bnns(inputs[0], inputs[1], out); } +void AddMM::eval_cpu(const std::vector& inputs, array& out) { + // Fill output with C + auto& c = inputs[2]; + CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General; + copy(c, out, ctype); + + if (out.dtype() == float32) { + return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_); + } + return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_); +} + } // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index cecf64cee..1225fd1af 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -98,16 +98,14 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT_MULTI(DivMod) -void Matmul::eval_cpu(const std::vector& inputs, array& out) { - if (out.dtype() != float32) { - throw std::runtime_error( - "[Matmul::eval_cpu] Currently only supports float32."); - } - out.set_data(allocator::malloc_or_wait(out.nbytes())); - - auto& a_pre = inputs[0]; - auto& b_pre = inputs[1]; +namespace { +inline void matmul_common_general( + const array& a_pre, + const array& b_pre, + array& out, + float alpha = 1.0f, + float beta = 0.0f) { auto check_transpose = [](const array& arr) { auto stx = arr.strides()[arr.ndim() - 2]; auto sty = arr.strides()[arr.ndim() - 1]; @@ -125,9 +123,10 @@ void Matmul::eval_cpu(const std::vector& inputs, array& out) { auto [a_transposed, lda, a] = check_transpose(a_pre); auto [b_transposed, ldb, b] = check_transpose(b_pre); - int M = a.shape(-2); - int N = b.shape(-1); - int K = a.shape(-1); + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + for (int i = 0; i < (a.size() / (M * K)); ++i) { cblas_sgemm( CblasRowMajor, @@ -136,16 +135,41 @@ void Matmul::eval_cpu(const std::vector& inputs, array& out) { M, N, K, - 1.0f, // alpha + alpha, // alpha a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), lda, b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), ldb, - 0.0f, // beta + beta, // beta out.data() + M * N * i, out.shape(-1) // ldc ); } } +} // namespace + +void Matmul::eval_cpu(const std::vector& inputs, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[Matmul::eval_cpu] Currently only supports float32."); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + return matmul_common_general(inputs[0], inputs[1], out); +} + +void AddMM::eval_cpu(const std::vector& inputs, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[AddMM::eval_cpu] Currently only supports float32."); + } + + // Fill output with C + auto& c = inputs[2]; + CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General; + copy(c, out, ctype); + + return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_); +} + } // namespace mlx::core diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 3377939ba..d632556f0 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -70,7 +70,7 @@ void explicit_gemm_conv_1D_gpu( // Perform gemm std::vector copies = {in_padded, in_strided}; - mlx_matmul( + return steel_matmul( s, d, /*a = */ in_strided, @@ -262,7 +262,7 @@ void explicit_gemm_conv_2D_gpu( // Perform gemm std::vector copies = {in_padded, in_strided}; - mlx_matmul( + return steel_matmul( s, d, /*a = */ in_strided, @@ -411,7 +411,7 @@ void winograd_conv_2D_gpu( copies_w.push_back(out_wg); { std::vector empty_copies; - mlx_matmul( + steel_matmul( s, d, /*a = */ inp_wg, diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index bc3da3018..2d271abb4 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -18,7 +18,6 @@ set( "binary_two" "conv" "copy" - "gemm" "gemv" "quantized" "random" @@ -30,26 +29,27 @@ set( "indexing" ) -function(build_kernel KERNEL) - set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) - set(HEADERS_PADDED ${HEADERS}) - if(${KERNEL} STREQUAL "gemm") - set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h) - endif() - if(${KERNEL} STREQUAL "conv") - set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h) - endif() +function(build_kernel_base TARGET SRCFILE DEPS) add_custom_command( COMMAND xcrun -sdk macosx metal -Wall -Wextra -fno-fast-math -c ${SRCFILE} -I${PROJECT_SOURCE_DIR} - -o ${KERNEL}.air - DEPENDS ${SRCFILE} ${HEADERS_PADDED} - OUTPUT ${KERNEL}.air - COMMENT "Building ${KERNEL}.air" + -o ${TARGET}.air + DEPENDS ${SRCFILE} ${DEPS} + OUTPUT ${TARGET}.air + COMMENT "Building ${TARGET}.air" VERBATIM ) +endfunction(build_kernel_base) + +function(build_kernel KERNEL) + set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal) + set(HEADERS_PADDED ${HEADERS}) + if(${KERNEL} STREQUAL "conv") + set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h) + endif() + build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}") endfunction(build_kernel) foreach(KERNEL ${KERNELS}) @@ -57,6 +57,15 @@ foreach(KERNEL ${KERNELS}) set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR}) endforeach() +file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal) +file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h) + +foreach(KERNEL ${STEEL_KERNELS}) + cmake_path(GET KERNEL STEM TARGET) + build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}") + set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR}) +endforeach() + add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib diff --git a/mlx/backend/metal/kernels/gemm/conv.h b/mlx/backend/metal/kernels/conv.h similarity index 100% rename from mlx/backend/metal/kernels/gemm/conv.h rename to mlx/backend/metal/kernels/conv.h diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 6dbe7bc3a..77c72c48c 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -5,7 +5,7 @@ #include "mlx/backend/metal/kernels/conv_params.h" #include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/gemm/conv.h" +#include "mlx/backend/metal/kernels/conv.h" using namespace metal; diff --git a/mlx/backend/metal/kernels/gemm/gemm.h b/mlx/backend/metal/kernels/gemm/gemm.h deleted file mode 100644 index 95d2e6497..000000000 --- a/mlx/backend/metal/kernels/gemm/gemm.h +++ /dev/null @@ -1,538 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include -#include -#include - -#define MLX_MTL_CONST static constant constexpr const - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BROWS, - int BCOLS, - int BK, - int vec_size, - int tgp_size, - bool transpose, - bool ldK, - int tgp_padding = 0> -struct BlockLoader { - // Destination dimensions - MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS; - MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding; - MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size; - - // Stride along block row within the block - MLX_MTL_CONST int bstride = tgp_size / n_vecs; - - // Leading dimension for src - const int src_ld; - // Stride along reduction axis between blocks - const int tstride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - /* Constructor */ - METAL_FUNC BlockLoader( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tstride( - BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / n_vecs), - bj(vec_size * (thread_idx % n_vecs)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { -#pragma clang loop unroll(full) - for (short i = 0; i < dst_fd; i += bstride) { -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = src[i * src_ld + j]; - } - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy; - - // Iterate over rows of block -#pragma clang loop unroll(full) - for (short i = 0; i < dst_fd; i += bstride) { - // Row is in bounds, we check against column - if ((bi + i) < src_tile_dim.y) { - // Use fast thread memory for bound checks - short tmp_idx[vec_size]; - T tmp_val[vec_size]; - - // Make sure tmp_idx only contains valid indices -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0; - } - - // Read all valid indices into tmp_val -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[i * src_ld + tmp_idx[j]]; - } - - // Zero out unneeded values -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = tmp_val[j]; - } - } - - // Row is out of bounds, we just fill tgp memory with zeros - else { -#pragma clang loop unroll(full) - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tstride; - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Transforms -/////////////////////////////////////////////////////////////////////////////// - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; - -/////////////////////////////////////////////////////////////////////////////// -// MMA helper -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - int tgp_padding_a = 0, - int tgp_padding_b = 0, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct BlockMMA { - // Warp tile size along M - MLX_MTL_CONST int TM = BM / (WM * 8); - // Warp tile size along N - MLX_MTL_CONST int TN = BN / (WN * 8); - - // Warp tile simdgroup matrix strides along M - MLX_MTL_CONST int TM_stride = 8 * WM; - // Warp tile simdgroup matrix strides along M - MLX_MTL_CONST int TN_stride = 8 * WN; - - // Leading dimensions of threadgroup A, B blocks - MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a; - MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b; - - // Strides of A, B along reduction axis - MLX_MTL_CONST short simd_stride_a = - transpose_a ? TM_stride : TM_stride * lda_tgp; - MLX_MTL_CONST short simd_stride_b = - transpose_b ? TN_stride * ldb_tgp : TN_stride; - - // Jump between elements - MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1; - MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1; - - // Offsets within threadgroup - const int tm; - const int tn; - - // Simdgroup matrices - simdgroup_matrix Asimd[TM]; - simdgroup_matrix Bsimd[TN]; - simdgroup_matrix results[TM * TN] = { - simdgroup_matrix(0)}; - - short sm; - short sn; - - /* Constructor */ - METAL_FUNC BlockMMA( - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { - short qid = simd_lane_id / 4; - sm = (qid & 4) + (simd_lane_id / 2) % 4; - sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { -// Iterate over BK in blocks of 8 -#pragma clang loop unroll(full) - for (short kk = 0; kk < BK; kk += 8) { - short2 offset_a = - transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm); - short2 offset_b = - transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm); - - const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x; - const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x; - - simdgroup_barrier(mem_flags::mem_none); -// Load elements from threadgroup A as simdgroup matrices -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { - Asimd[i].thread_elements()[0] = static_cast(As__[0]); - Asimd[i].thread_elements()[1] = static_cast(As__[jump_a]); - As__ += simd_stride_a; - } - - simdgroup_barrier(mem_flags::mem_none); -// Load elements from threadgroup B as simdgroup matrices -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - Bsimd[j].thread_elements()[0] = static_cast(Bs__[0]); - Bsimd[j].thread_elements()[1] = static_cast(Bs__[jump_b]); - Bs__ += simd_stride_b; - } - - simdgroup_barrier(mem_flags::mem_none); -// Multiply and accumulate into result simdgroup matrices -#pragma clang loop unroll(full) - for (short i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (short j = 0; j < TN; j++) { - simdgroup_multiply_accumulate( - results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]); - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device T* C, const int ldc) const { -#pragma clang loop unroll(full) - for (int i = 0; i < TM; i++) { -#pragma clang loop unroll(full) - for (int j = 0; j < TN; j++) { - C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] = - Epilogue::apply(results[i * TN + j].thread_elements()[0]); - C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] = - Epilogue::apply(results[i * TN + j].thread_elements()[1]); - } - } - } - - METAL_FUNC void - store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const { -#pragma clang loop unroll(full) - for (int i = 0; i < TM; i++) { - if (tm + i * TM_stride + sm < dst_tile_dims.y) { -#pragma clang loop unroll(full) - for (int j = 0; j < TN; j++) { - if (tn + j * TN_stride + sn < dst_tile_dims.x) { - C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] = - Epilogue::apply(results[i * TN + j].thread_elements()[0]); - } - - if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) { - C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] = - Epilogue::apply(results[i * TN + j].thread_elements()[1]); - } - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct GEMMKernel { - MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T); - MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T); - MLX_MTL_CONST short tgp_mem_size_a = - transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); - MLX_MTL_CONST short tgp_mem_size_b = - transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); - MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; - - MLX_MTL_CONST short tgp_size = WM * WN * 32; - MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4; - - using loader_a_t = BlockLoader< - T, - BM, - BK, - BK, - vec_size, - tgp_size, - transpose_a, - true, - tgp_padding_a>; - using loader_b_t = BlockLoader< - T, - BK, - BN, - BK, - vec_size, - tgp_size, - transpose_b, - false, - tgp_padding_b>; - using mma_t = BlockMMA< - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - tgp_padding_a, - tgp_padding_b, - AccumType, - Epilogue>; - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant int& M [[buffer(3)]], - const constant int& N [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& batch_stride_a [[buffer(6)]], - const constant int& batch_stride_b [[buffer(7)]], - const constant int& batch_stride_c [[buffer(8)]], - threadgroup T* tgp_memory [[threadgroup(0)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - // Adjust for batch - A += batch_stride_a * tid.z; - B += batch_stride_b * tid.z; - C += batch_stride_c * tid.z; - - // Adjust for transpose - const int lda_dev = transpose_a ? M : K; - const int ldb_dev = transpose_b ? K : N; - - // Find block in A, B, C - const int c_row = tid.y * BM; - const int c_col = tid.x * BN; - - A += transpose_a ? c_row : c_row * K; - B += transpose_b ? c_col * K : c_col; - C += c_row * N + c_col; - - // Prepare threadgroup memory for loading - threadgroup T* As = tgp_memory; - threadgroup T* Bs = tgp_memory + tgp_mem_size_a; - - // Prepare threadgroup loading operations - loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id); - loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_group_id, simd_lane_id); - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned && K_aligned) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - mma_op.store_result(C, N); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN aligned, K unaligned loop - else if (MN_aligned && !K_aligned) { - // Main loop - int k = 0; - for (; k + BK <= K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - // Loop tail - threadgroup_barrier(mem_flags::mem_threadgroup); - - loader_a.load_safe(short2(K - k, BM)); - loader_b.load_safe(short2(BN, K - k)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - - // Store results to device memory - mma_op.store_result(C, N); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MNK unaligned loop - else { // Loop over K - unaligned case - - short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row)); - - if (src_tile_dims.y == BM && src_tile_dims.x == BN) { - int k = 0; - for (; k + BK <= K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - if (k < K) { - loader_a.load_safe(short2(K - k, BM)); - loader_b.load_safe(short2(BN, K - k)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - - mma_op.store_result(C, N); - return; - - } else { - int k = 0; - for (; k + BK <= K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_safe(short2(BK, src_tile_dims.y)); - loader_b.load_safe(short2(src_tile_dims.x, BK)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - if (k < K) { - loader_a.load_safe(short2(K - k, src_tile_dims.y)); - loader_b.load_safe(short2(src_tile_dims.x, K - k)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - - threadgroup_barrier(mem_flags::mem_none); - mma_op.store_result_safe(C, N, src_tile_dims); - - return; - } - } - } -}; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 9cb54e0f8..294dbab5c 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -5,9 +5,10 @@ #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/defines.h" -#include "mlx/backend/metal/kernels/gemm/gemm.h" #include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" + using namespace metal; #define MLX_MTL_CONST static constant constexpr const @@ -239,8 +240,9 @@ template ; - using loader_x_t = BlockLoader; + using mma_t = mlx::steel::BlockMMA; + using loader_x_t = mlx::steel::BlockLoader; + threadgroup T scales_block[BN * groups_per_block]; threadgroup T biases_block[BN * groups_per_block]; @@ -392,8 +394,8 @@ template ; - using loader_x_t = BlockLoader; + using mma_t = mlx::steel::BlockMMA; + using loader_x_t = mlx::steel::BlockLoader; threadgroup T scales_block[BK * groups_per_block]; threadgroup T biases_block[BK * groups_per_block]; diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm.h b/mlx/backend/metal/kernels/steel/gemm/gemm.h new file mode 100644 index 000000000..3a8f0280c --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -0,0 +1,312 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; + thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; + + if (!M_aligned) { + short2 tile_dims_A = + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + loader_a.set_mask(tile_dims_A, mask_A); + } + + if (!N_aligned) { + short2 tile_dims_B = + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + loader_b.set_mask(tile_dims_B, mask_B); + } + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(mask_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(mask_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.set_mask(tile_dims_A_last, mask_A); + loader_b.set_mask(tile_dims_B_last, mask_B); + + loader_a.load_safe(mask_A); + loader_b.load_safe(mask_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + C += c_row * params->ldc + c_col; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; + thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; + + loader_a.set_mask(tile_dims_A, mask_A); + loader_b.set_mask(tile_dims_B, mask_B); + + loader_a.load_safe(mask_A); + loader_b.load_safe(mask_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(C, params->ldc); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(C, params->ldc); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/gemm.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal similarity index 72% rename from mlx/backend/metal/kernels/gemm.metal rename to mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal index df150b50d..fb051131c 100644 --- a/mlx/backend/metal/kernels/gemm.metal +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm.metal @@ -1,9 +1,10 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2024 Apple Inc. #include "mlx/backend/metal/kernels/bf16.h" -#include "mlx/backend/metal/kernels/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" using namespace metal; +using namespace mlx::steel; /////////////////////////////////////////////////////////////////////////////// // GEMM kernels @@ -23,26 +24,26 @@ template ; + using gemm_kernel = GEMMKernel; - threadgroup T tgp_memory[gemm_kernel::tgp_mem_size]; + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Adjust for batch + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + C += params->batch_stride_c * tid.z; gemm_kernel::run( A, B, C, - M, N, K, - batch_stride_a, batch_stride_b, batch_stride_c, - tgp_memory, + params, + As, Bs, simd_lane_id, simd_group_id, tid, lid ); } @@ -52,17 +53,12 @@ template ( \ const device itype *A [[buffer(0)]], \ const device itype *B [[buffer(1)]], \ device itype *C [[buffer(2)]], \ - const constant int &M [[buffer(3)]], \ - const constant int &N [[buffer(4)]], \ - const constant int &K [[buffer(5)]], \ - const constant int &batch_stride_a [[buffer(6)]], \ - const constant int &batch_stride_b [[buffer(7)]], \ - const constant int &batch_stride_c [[buffer(8)]], \ + const constant GEMMParams* params [[buffer(3)]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint3 tid [[threadgroup_position_in_grid]], \ @@ -84,10 +80,10 @@ template > +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + const device T *C [[buffer(2)]], + device T *D [[buffer(3)]], + const constant GEMMAddMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + // Pacifying compiler + (void)lid; + + using gemm_kernel = + GEMMKernel; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Adjust for batch + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + C += params->batch_stride_c * tid.z; + D += params->batch_stride_d * tid.z; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + + A += transpose_a ? c_row : c_row * params->lda; + B += transpose_b ? c_col * params->ldb : c_col; + C += c_row * params->ldc + c_col * params->fdc; + D += c_row * params->ldd + c_col; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + const Epilogue epilogue_op(params->alpha, params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; + thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; + + loader_a.set_mask(tile_dims_A, mask_A); + loader_b.set_mask(tile_dims_B, mask_B); + + loader_a.load_safe(mask_A); + loader_b.load_safe(mask_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op); + return; + + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, params->ldd, + C, params->ldc, params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, params->ldd, + C, params->ldc, params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + return mma_op.store_result_safe( + D, params->ldd, + C, params->ldc, params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \ + template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \ + [[kernel]] void addmm>( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + const device itype *C [[buffer(2)]], \ + device itype *D [[buffer(3)]], \ + const constant GEMMAddMMParams* params [[buffer(4)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby) + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t); + +instantiate_gemm_shapes_helper(float32, float, float32, float); \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal new file mode 100644 index 000000000..873f5faf1 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal @@ -0,0 +1,280 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" + +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk( + const device T *A [[buffer(0)]], + const device T *B [[buffer(1)]], + device U *C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + (void)lid; + + using gemm_kernel = GEMMKernel; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + + A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda); + B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb); + C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK; + if(!K_aligned || gemm_k_iter_remaining > 0) + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iter_remaining, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel initializations +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \ + template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \ + [[kernel]] void gemm_splitk( \ + const device itype *A [[buffer(0)]], \ + const device itype *B [[buffer(1)]], \ + device otype *C [[buffer(2)]], \ + const constant GEMMSpiltKParams* params [[buffer(3)]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \ + instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) + +instantiate_gemm_shapes_helper(float16, half, float32, float); +instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float); + +instantiate_gemm_shapes_helper(float32, float, float32, float); + +/////////////////////////////////////////////////////////////////////////////// +// Split k accumulation kernel +/////////////////////////////////////////////////////////////////////////////// + +template > +[[kernel]] void gemm_splitk_accum( + const device AccT *C_split [[buffer(0)]], + device OutT *D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + + // Ajust D and C + D += gid.x + gid.y * ldd; + C_split += gid.x + gid.y * ldd; + + int offset = 0; + AccT out = 0; + + for(int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + D[0] = Epilogue::apply(out); + +} + +template > +[[kernel]] void gemm_splitk_accum_axpby( + const device AccT *C_split [[buffer(0)]], + device OutT *D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device OutT *C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) { + + // Ajust D and C + C += gid.x * fdc + gid.y * ldc; + D += gid.x + gid.y * ldd; + C_split += gid.x + gid.y * ldd; + + int offset = 0; + AccT out = 0; + + for(int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + Epilogue op(alpha, beta); + D[0] = op.apply(out, *C); + +} + +#define instantiate_accum(oname, otype, aname, atype) \ + template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \ + [[kernel]] void gemm_splitk_accum( \ + const device atype *C_split [[buffer(0)]], \ + device otype *D [[buffer(1)]], \ + const constant int& k_partitions [[buffer(2)]], \ + const constant int& partition_stride [[buffer(3)]], \ + const constant int& ldd [[buffer(4)]], \ + uint2 gid [[thread_position_in_grid]]); \ + template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \ + [[kernel]] void gemm_splitk_accum_axpby( \ + const device atype *C_split [[buffer(0)]], \ + device otype *D [[buffer(1)]], \ + const constant int& k_partitions [[buffer(2)]], \ + const constant int& partition_stride [[buffer(3)]], \ + const constant int& ldd [[buffer(4)]], \ + const device otype *C [[buffer(5)]], \ + const constant int& ldc [[buffer(6)]], \ + const constant int& fdc [[buffer(7)]], \ + const constant float& alpha [[buffer(8)]], \ + const constant float& beta [[buffer(9)]], \ + uint2 gid [[thread_position_in_grid]]); + +instantiate_accum(bfloat16, bfloat16_t, float32, float); +instantiate_accum(float16, half, float32, float); +instantiate_accum(float32, float, float32, float); \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/loader.h b/mlx/backend/metal/kernels/steel/gemm/loader.h new file mode 100644 index 000000000..5e52bbb33 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -0,0 +1,160 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void set_mask( + thread const short2& src_tile_dims, + thread bool mask[n_rows][vec_size]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + mask[i][j] = + ((bi + i) < src_tile_dims.y) && ((bj + j) < src_tile_dims.x); + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(const thread bool mask[n_rows][vec_size]) const { + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0, ii = 0; i < BROWS; i += TROWS, ii++) { + simdgroup_barrier(mem_flags::mem_none); + // Use fast thread memory for bound checks + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(mask[ii][j] ? i * src_ld + j : 0)]; + } + + simdgroup_barrier(mem_flags::mem_none); + + // Zero out uneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = mask[ii][j] ? tmp_val[j] : T(0); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/mma.h b/mlx/backend/metal/kernels/steel/gemm/mma.h new file mode 100644 index 000000000..6f58bfcaf --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -0,0 +1,264 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = 8 * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Strides of A, B along reduction axis + STEEL_CONST short simd_stride_a = { + transpose_a ? TM_stride : TM_stride * lda_tgp}; + STEEL_CONST short simd_stride_b = { + transpose_b ? TN_stride * ldb_tgp : TN_stride}; + + // Jump between elements + STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1}; + STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1}; + + STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8}; + STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp}; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + // Offsets within threadgroup + const short tm; + const short tn; + + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + // Determine thread position in simdgroup matrix + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + + // Determine thread and simdgroup offset + As_offset = + transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp); + Bs_offset = + transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn)); + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of 8 + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += 8) { + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup A as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = + static_cast(As[i * simd_stride_a + 0]); + Asimd[i].thread_elements()[1] = + static_cast(As[i * simd_stride_a + jump_a]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Load elements from threadgroup B as simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = + static_cast(Bs[j * simd_stride_b + 0]); + Bsimd[j].thread_elements()[1] = + static_cast(Bs[j * simd_stride_b + jump_b]); + } + + simdgroup_barrier(mem_flags::mem_none); + + // Multiply and accumulate into result simdgroup matrices + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + short j_serp = (i % 2) ? (TN - 1 - j) : j; + + simdgroup_multiply_accumulate( + results[i * TN + j_serp], + Asimd[i], + Bsimd[j_serp], + results[i * TN + j_serp]); + } + } + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* C, const int ldc) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue + U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])}; + + // Write out C + C[offset] = outs[0]; + C[offset + 1] = outs[1]; + } + } + } + + METAL_FUNC void + store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn); + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset = (i * TM_stride) * ldc + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + C[offset] = Epilogue::apply(accum[0]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + C[offset + 1] = Epilogue::apply(accum[1]); + } + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + U outs[2] = { + epilogue_op.apply(accum[0], C[offset_c]), + epilogue_op.apply(accum[1], C[offset_c + fdc])}; + + // Write out D + D[offset_d] = outs[0]; + D[offset_d + 1] = outs[1]; + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm + tm) * ldc + (tn + sn) * fdc; + D += (sm + tm) * ldd + tn + sn; + dst_tile_dims -= short2(tn + sn, sm + tm); + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = results[i * TN + j].thread_elements(); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue and output C + if (j * TN_stride < dst_tile_dims.x) { + D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]); + } + + if (j * TN_stride + 1 < dst_tile_dims.x) { + D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]); + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/gemm/params.h b/mlx/backend/metal/kernels/steel/gemm/params.h new file mode 100644 index 000000000..d7e4db043 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/params.h @@ -0,0 +1,79 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_a; + const int batch_stride_b; + const int batch_stride_c; + + const int swizzle_log; + const int gemm_k_iterations_aligned; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const int batch_stride_a; + const int batch_stride_b; + const int batch_stride_c; + const int batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const float alpha; + const float beta; + + const int fdc; +}; + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/gemm/transforms.h b/mlx/backend/metal/kernels/steel/gemm/transforms.h new file mode 100644 index 000000000..100f34925 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/transforms.h @@ -0,0 +1,63 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/host.h b/mlx/backend/metal/kernels/steel/host.h new file mode 100644 index 000000000..6fb4e54c9 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/host.h @@ -0,0 +1,5 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/params.h" \ No newline at end of file diff --git a/mlx/backend/metal/kernels/steel/utils.h b/mlx/backend/metal/kernels/steel/utils.h new file mode 100644 index 000000000..a4b6aa261 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/utils.h @@ -0,0 +1,9 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include "mlx/backend/metal/kernels/steel/host.h" + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") \ No newline at end of file diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index 0bce599d3..6d48f07a7 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -8,6 +8,7 @@ #include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/steel/host.h" #include "mlx/backend/metal/matmul.h" #include "mlx/backend/metal/mps/gemm.h" #include "mlx/backend/metal/utils.h" @@ -16,6 +17,10 @@ namespace mlx::core { +/////////////////////////////////////////////////////////////////////////////// +// MPS Matmul fallback +/////////////////////////////////////////////////////////////////////////////// + namespace { bool use_mps() { @@ -46,7 +51,9 @@ inline void mps_matmul( int ldb, bool transpose_a, bool transpose_b, - std::vector& copies) { + std::vector& copies, + float alpha = 1.0f, + float beta = 0.0f) { MPS::DataType mps_dtype = MPS::DataTypeFloat32; if (out.dtype() == float16) { @@ -121,7 +128,7 @@ inline void mps_matmul( auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc); auto kernel = MPS::MatrixMultiplication::alloc()->init( - d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0); + d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta); auto command_buffer = d.get_command_buffer(s.index); kernel->setBatchSize(batch_size_out); @@ -162,7 +169,7 @@ inline void mps_matmul( auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc); auto kernel = MPS::MatrixMultiplication::alloc()->init( - d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0); + d.mtl_device(), transpose_a, transpose_b, M, N, K, alpha, beta); auto command_buffer = d.get_command_buffer(s.index); for (int i = 0; i < batch_size_out; ++i) { @@ -186,7 +193,11 @@ inline void mps_matmul( } // namespace -void mlx_matmul( +/////////////////////////////////////////////////////////////////////////////// +// Steel matmul fallback +/////////////////////////////////////////////////////////////////////////////// + +void steel_matmul( const Stream& s, metal::Device& d, const array& a, @@ -201,6 +212,15 @@ void mlx_matmul( bool transpose_a, bool transpose_b, std::vector& copies) { + using namespace mlx::steel; + + // Coalesce (B, M, K) X (K, N) to (B*M, K) X (K, N) + if (batch_size_out > 1 && !transpose_a && + a.data_size() == batch_size_out * M * K && b.size() == K * N) { + M = M * batch_size_out; + batch_size_out = 1; + } + // Account for batch sizes and basic broadcasting int batch_size_a = a.data_size() / (M * K); int batch_size_b = b.data_size() / (K * N); @@ -209,11 +229,108 @@ void mlx_matmul( int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N; int matrix_stride_out = M * N; + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = + _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); + copies.push_back(C_split); + + std::ostringstream kname; + kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" + << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; + + // Encode and dispatch gemm kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + M, + N, + K, + lda, + ldb, + N, + tn, + tm, + split_k_partitions, + split_k_partition_stride, + split_k_partition_size, + gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, C_split, 2); + + compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + // Do accum kernel + { + auto c_split_buf = + static_cast(C_split.buffer().ptr()); + const class MTL::Resource* const resources[1] = {c_split_buf}; + compute_encoder->memoryBarrier(resources, 1); + + auto kernel = d.get_kernel( + "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split)); + compute_encoder->setComputePipelineState(kernel); + + // Set the arguments for the kernel + set_array_buffer(compute_encoder, C_split, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); + compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); + compute_encoder->setBytes(&N, sizeof(int), 4); + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular kernel dispatch + // Determine dispatch kernel int bm = 32, bn = 32, bk = 16; int wm = 2, wn = 2; - if ((size_t)batch_size_out * M * N >= 2ul << 20) { + if ((size_t)batch_size_out * M * N >= 1ul << 20) { if (!transpose_a && transpose_b) { bm = 64; bn = (out.dtype() == float32) ? 64 : 32; @@ -224,10 +341,12 @@ void mlx_matmul( } } + // Prepare kernel name std::ostringstream kname; - kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') - << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm - << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" + kname << "steel_gemm_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; @@ -236,34 +355,55 @@ void mlx_matmul( auto kernel = d.get_kernel(kname.str()); compute_encoder->setComputePipelineState(kernel); + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + + // Prepare steel matmul params + GEMMParams params{ + M, + N, + K, + lda, + ldb, + N, + tn, + tm, + matrix_stride_a, + matrix_stride_b, + matrix_stride_out, + swizzle_log, + (K / bk)}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + // Launch only 1 kernel in the case of simple batching / broadcasting if (batch_size_out == std::max(batch_size_a, batch_size_b) && (batch_size_a == batch_size_b || std::min(batch_size_a, batch_size_b) == 1)) { - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = - MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out); - set_array_buffer(compute_encoder, a, 0); set_array_buffer(compute_encoder, b, 1); set_array_buffer(compute_encoder, out, 2); - compute_encoder->setBytes(&M, sizeof(int), 3); - compute_encoder->setBytes(&N, sizeof(int), 4); - compute_encoder->setBytes(&K, sizeof(int), 5); - compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6); - compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7); - compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8); + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3); compute_encoder->dispatchThreadgroups(grid_dims, group_dims); - } else { // Other launch kernels with set offsets + } else { // Otherwise launch kernels with set offsets + + MTL::Size grid_dims_single = MTL::Size(tn, tm, 1); for (int i = 0; i < batch_size_out; ++i) { auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides()); auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides()); - MTL::Size group_dims = MTL::Size(32, wn, wm); - MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1); - auto a_buf = static_cast(a.buffer().ptr()); auto b_buf = static_cast(b.buffer().ptr()); auto out_buf = static_cast(out.buffer().ptr()); @@ -272,13 +412,8 @@ void mlx_matmul( compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1); compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2); - compute_encoder->setBytes(&M, sizeof(int), 3); - compute_encoder->setBytes(&N, sizeof(int), 4); - compute_encoder->setBytes(&K, sizeof(int), 5); - compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6); - compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7); - compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8); - compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + compute_encoder->setBytes(¶ms, sizeof(GEMMParams), 3); + compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims); } } @@ -300,6 +435,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + // Keep a vector with copies to be cleared in the completed buffer to release // the arrays std::vector copies; @@ -328,6 +466,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { auto batch_size_out = out.size() / (M * N); + ///////////////////////////////////////////////////////////////////////////// + // Gemv specialization + // Route to gemv if needed if (std::min(M, N) == 1) { // Collect problem info @@ -433,10 +574,13 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { return; } - d.end_encoding(s.index); + ///////////////////////////////////////////////////////////////////////////// + // Gemm specialization if (use_mps()) { - mps_matmul( + d.end_encoding(s.index); + + return mps_matmul( s, d, a, @@ -451,10 +595,9 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { a_transposed, b_transposed, copies); - return; } - mlx_matmul( + return steel_matmul( s, d, a, @@ -471,4 +614,266 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { copies); } +void AddMM::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (!is_floating_point(out.dtype())) { + throw std::runtime_error( + "[matmul] Does not yet support non-floating point types."); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto& c_pre = inputs[2]; + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto check_transpose = [&copies, &s](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [transpose_a, a_cols, a] = check_transpose(a_pre); + auto [transpose_b, b_cols, b] = check_transpose(b_pre); + + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + auto batch_size_out = out.size() / (M * N); + + array c = c_pre; + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + int matrix_stride_c = c.ndim() <= 2 ? 0 : c.strides()[c.ndim() - 3]; + + int lda = a_cols; + int ldb = b_cols; + + using namespace mlx::steel; + + // Account for batch sizes and basic broadcasting + int batch_size_a = a.data_size() / (M * K); + int batch_size_b = b.data_size() / (K * N); + + int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K; + int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N; + int matrix_stride_out = M * N; + + int _tm = M / 16; + int _tn = N / 16; + int _tk = K / 16; + + ///////////////////////////////////////////////////////////////////////////// + // Split K specialization + + if (batch_size_out == 1 && (_tm * _tn) <= 32 && _tk >= 8) { + int bm = M < 40 ? 16 : 32; + int bn = N < 40 ? 16 : 32; + int bk = 16; + int wm = 2, wn = 2; + + int split_k_partitions = + _tk < 16 ? 2 : (_tk < 32 ? 4 : (_tk < 64 ? 8 : 16)); + int split_k_partition_stride = M * N; + int gemm_k_iterations = (K / bk) / split_k_partitions; + int split_k_partition_size = gemm_k_iterations * bk; + + array C_split({split_k_partitions, M, N}, float32, nullptr, {}); + C_split.set_data(allocator::malloc_or_wait(C_split.nbytes())); + copies.push_back(C_split); + + std::ostringstream kname; + kname << "steel_gemm_splitk_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" + << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; + + // Encode and dispatch gemm kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + GEMMSpiltKParams params{ + M, + N, + K, + lda, + ldb, + N, + tn, + tm, + split_k_partitions, + split_k_partition_stride, + split_k_partition_size, + gemm_k_iterations}; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, split_k_partitions); + + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, C_split, 2); + + compute_encoder->setBytes(¶ms, sizeof(GEMMSpiltKParams), 3); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + // Do accum kernel + { + auto kernel = d.get_kernel( + "steel_gemm_splitk_accum_" + type_to_name(out) + "_" + + type_to_name(C_split) + "_axpby"); + compute_encoder->setComputePipelineState(kernel); + + // Set the arguments for the kernel + set_array_buffer(compute_encoder, C_split, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&split_k_partitions, sizeof(int), 2); + compute_encoder->setBytes(&split_k_partition_stride, sizeof(int), 3); + compute_encoder->setBytes(&N, sizeof(int), 4); + set_array_buffer(compute_encoder, c, 5); + compute_encoder->setBytes(&ldc, sizeof(int), 6); + compute_encoder->setBytes(&fdc, sizeof(int), 7); + compute_encoder->setBytes(&alpha_, sizeof(float), 8); + compute_encoder->setBytes(&beta_, sizeof(float), 9); + + // Launch enough thread groups for each output + MTL::Size grid_dims = MTL::Size(N, M, 1); + MTL::Size group_dims = MTL::Size(std::min(1024, N * M), 1, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } + + ///////////////////////////////////////////////////////////////////////////// + // Regular addmm dispatch + + // Determine dispatch kernel + int bm = 32, bn = 32, bk = 16; + int wm = 2, wn = 2; + + if ((size_t)batch_size_out * M * N >= 1ul << 20) { + if (!transpose_a && transpose_b) { + bm = 64; + bn = (out.dtype() == float32) ? 64 : 32; + bk = (out.dtype() == float32) ? 16 : 32; + } else { + bm = 64; + bn = 64; + } + } + + std::ostringstream kname; + kname << "steel_addmm_" << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_" + << type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" + << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned" + << ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby"); + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = 0; // tm >= 6 ? 3 : (tm <= 3 ? 0 : 2); + + GEMMAddMMParams params{ + M, + N, + K, + lda, + ldb, + ldc, + N, + tn, + tm, + matrix_stride_a, + matrix_stride_b, + matrix_stride_c, + matrix_stride_out, + swizzle_log, + (K / bk), + alpha_, + beta_, + fdc}; + + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + + // Launch only 1 kernel in the case of simple batching / broadcasting + if (batch_size_out == std::max(batch_size_a, batch_size_b) && + (batch_size_a == batch_size_b || + std::min(batch_size_a, batch_size_b) == 1)) { + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, c, 2); + set_array_buffer(compute_encoder, out, 3); + + compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } else { // Otherwise launch kernels with set offsets + + MTL::Size grid_dims_single = MTL::Size(tn, tm, 1); + + for (int i = 0; i < batch_size_out; ++i) { + auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides()); + auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides()); + auto c_off = elem_to_loc(M * N * i, c.shape(), c.strides()); + + auto a_buf = static_cast(a.buffer().ptr()); + auto b_buf = static_cast(b.buffer().ptr()); + auto c_buf = static_cast(c.buffer().ptr()); + auto out_buf = static_cast(out.buffer().ptr()); + + compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0); + compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1); + compute_encoder->setBuffer(c_buf, c_off * c.itemsize(), 2); + compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 3); + + compute_encoder->setBytes(¶ms, sizeof(GEMMAddMMParams), 4); + compute_encoder->dispatchThreadgroups(grid_dims_single, group_dims); + } + } + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; +} + } // namespace mlx::core diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h index 78a0e7f26..1ebccf0e1 100644 --- a/mlx/backend/metal/matmul.h +++ b/mlx/backend/metal/matmul.h @@ -12,7 +12,7 @@ namespace mlx::core { -void mlx_matmul( +void steel_matmul( const Stream& s, metal::Device& d, const array& a, diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index a6902ba3a..899f7caff 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -17,6 +17,7 @@ namespace mlx::core { NO_GPU(Abs) NO_GPU(Add) +NO_GPU(AddMM) NO_GPU(Arange) NO_GPU(ArcCos) NO_GPU(ArcCosh) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7fcf17403..a67e2f220 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3057,4 +3057,98 @@ array inner(const array& a, const array& b, StreamOrDevice s /* = {} */) { return tensordot(a, b, {{-1}, {-1}}, s); } +/** Compute D = beta * C + alpha * (A @ B) */ +array addmm( + array c, + array a, + array b, + const float& alpha /* = 1.f */, + const float& beta /* = 1.f */, + StreamOrDevice s /* = {} */) { + // Divert in the case of vector-matrix multiplication + // TODO: Add the needed specializtion + if (a.ndim() == 1 || b.ndim() == 1) { + array X = matmul(a, b, s); + array alpha_arr = array(alpha, X.dtype()); + array aX = multiply(alpha_arr, X, s); + + array beta_arr = array(beta, c.dtype()); + array bY = multiply(beta_arr, c, s); + return add(aX, bY, s); + } + + if (a.ndim() == 0 || b.ndim() == 0) { + throw std::invalid_argument( + "[addmm] Got 0 dimension input. Inputs must " + "have at least one dimension."); + } + + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[addmm] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + // Type promotion + auto out_type = result_type({a, b, c}); + if (!is_floating_point(out_type) || is_complex(out_type)) { + std::ostringstream msg; + msg << "[addmm] Only real floating point types are supported but " + << c.dtype() << ", " << a.dtype() << " and " << b.dtype() + << " were provided which results in " << out_type + << ", which is not a real floating point type."; + throw std::invalid_argument(msg.str()); + } + + a = astype(a, out_type, s); + b = astype(b, out_type, s); + c = astype(c, out_type, s); + + // We can batch the multiplication by reshaping a + if (a.ndim() > 2 && b.ndim() == 2 && c.ndim() <= 1) { + std::vector out_shape = a.shape(); + a = reshape(a, {-1, out_shape.back()}, s); + out_shape.back() = b.shape(-1); + c = broadcast_to(c, {a.shape(0), b.shape(1)}, s); + auto out = array( + {a.shape(0), b.shape(1)}, + out_type, + std::make_unique(to_stream(s), alpha, beta), + {a, b, c}); + return reshape(out, out_shape, s); + } + + if (a.ndim() > 2 || b.ndim() > 2) { + std::vector bsx_a(a.shape().begin(), a.shape().end() - 2); + std::vector bsx_b(b.shape().begin(), b.shape().end() - 2); + auto inner_shape = broadcast_shapes(bsx_a, bsx_b); + + // Broadcast a + inner_shape.push_back(a.shape(-2)); + inner_shape.push_back(a.shape(-1)); + a = broadcast_to(a, inner_shape, s); + + // Broadcast b + *(inner_shape.end() - 2) = b.shape(-2); + *(inner_shape.end() - 1) = b.shape(-1); + b = broadcast_to(b, inner_shape, s); + } + + auto out_shape = a.shape(); + out_shape.back() = b.shape(-1); + + auto c_broadcast_shape = broadcast_shapes(c.shape(), out_shape); + c = broadcast_to(c, c_broadcast_shape, s); + + auto out = array( + out_shape, + out_type, + std::make_unique(to_stream(s), alpha, beta), + {a, b, c}); + + return out; +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index 16f85e147..865617cd9 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1122,4 +1122,12 @@ std::unordered_map load_gguf( void save_gguf(std::string file, std::unordered_map a); +/** Compute D = beta * C + alpha * (A @ B) */ +array addmm( + array c, + array a, + array b, + const float& alpha = 1.f, + const float& beta = 1.f, + StreamOrDevice s = {}); } // namespace mlx::core diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9cb85c67b..cc91d7147 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -124,6 +124,52 @@ std::pair, std::vector> Add::vmap( return {{add(a, b, stream())}, {to_ax}}; } +std::vector AddMM::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector&) { + std::vector vjps; + auto& cotan = cotangents[0]; + std::vector reorder(cotan.ndim()); + std::iota(reorder.begin(), reorder.end(), 0); + std::iter_swap(reorder.end() - 1, reorder.end() - 2); + for (auto arg : argnums) { + if (arg == 0) { + // M X N * (K X N).T -> M X K + auto cotan_scaled = cotan; + if (alpha_ != 1.) { + auto alpha_arr = array(alpha_, cotan.dtype()); + cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream())); + } + vjps.push_back(matmul( + cotan_scaled, transpose(primals[1], reorder, stream()), stream())); + } else if (arg == 1) { + // (M X K).T * M X N -> K X N + auto cotan_scaled = cotan; + if (alpha_ != 1.) { + auto alpha_arr = array(alpha_, cotan.dtype()); + cotan_scaled = (multiply(alpha_arr, cotan_scaled, stream())); + } + vjps.push_back(matmul( + transpose(primals[0], reorder, stream()), cotan_scaled, stream())); + } else { + auto cotan_scaled = cotan; + if (beta_ != 1.) { + auto beta_arr = array(beta_, cotan.dtype()); + cotan_scaled = (multiply(beta_arr, cotan_scaled, stream())); + } + vjps.push_back(cotan_scaled); + } + } + return vjps; +} + +bool AddMM::is_equivalent(const Primitive& other) const { + const AddMM& a_other = static_cast(other); + return (alpha_ == a_other.alpha_ && beta_ == a_other.beta_); +} + bool Arange::is_equivalent(const Primitive& other) const { const Arange& a_other = static_cast(other); return ( diff --git a/mlx/primitives.h b/mlx/primitives.h index 2ef30e4ad..c0a176417 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -171,6 +171,29 @@ class Add : public UnaryPrimitive { void eval(const std::vector& inputs, array& out); }; +class AddMM : public UnaryPrimitive { + public: + explicit AddMM(Stream stream, float alpha, float beta) + : UnaryPrimitive(stream), alpha_(alpha), beta_(beta){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_PRINT(AddMM) + + bool is_equivalent(const Primitive& other) const override; + + private: + const float alpha_; + const float beta_; +}; + class Arange : public UnaryPrimitive { public: explicit Arange(Stream stream, double start, double stop, double step) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 77b340721..42d2fce79 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -63,9 +63,10 @@ def _extra_repr(self) -> str: return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" def __call__(self, x: mx.array) -> mx.array: - x = x @ self.weight.T if "bias" in self: - x = x + self.bias + x = mx.addmm(self.bias, x, self.weight.T) + else: + x = x @ self.weight.T return x diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9db273b40..a85c94d86 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3476,4 +3476,34 @@ void init_ops(py::module_& m) { Returns: result (array): The tiled array. )pbdoc"); + m.def( + "addmm", + &addmm, + "c"_a, + "a"_a, + "b"_a, + py::pos_only(), + "alpha"_a = 1.0f, + "beta"_a = 1.0f, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: Union[None, Stream, Device] = None) -> array + + Matrix multiplication with addition and optional scaling. + + Perform the (possibly batched) matrix multiplication of two arrays and add to the result + with optional scaling factors. + + Args: + c (array): Input array or scalar. + a (array): Input array or scalar. + b (array): Input array or scalar. + alpha (float, optional): Scaling factor for the + matrix product of ``a`` and ``b`` (default: ``1``) + beta (float, optional): Scaling factor for ``c`` (default: ``1``) + + Returns: + array: ``alpha * (a @ b) + beta * c`` + )pbdoc"); } diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index bc8b27f51..fe2346fea 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -74,6 +74,7 @@ def test_matmul_shapes(self): if mx.default_device() == mx.gpu: shapes += [ (16, 768, 768, 128), + (1, 64, 64, 4096), ] for dtype in self.dtypes: @@ -444,3 +445,139 @@ def test_matrix_vector_edgecases(self): list(c_npy.shape), list(c_mlx.shape) ) self.assertTrue(np.array_equal(c_mlx, c_npy)) + + def test_addmm(self): + np.random.seed(0) + # Batched matmul + alpha = 0.5 + beta = 2.0 + + # Regular batched case + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 16)).astype(np.float32) + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) + + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + + # Batched and transposed matmul + b_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_mlx = mx.array(b_npy) + + for c_shape in ((1,), (32, 1, 128), (1, 128)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) + + b_np_t = np.transpose(b_npy, (0, 2, 1)) + b_mx_t = mx.transpose(b_mlx, (0, 2, 1)) + + d_npy = alpha * (a_npy @ b_np_t) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mx_t, alpha, beta) + + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + + # # Batched matmul with simple broadcast + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + for c_shape in ((1,), (1, 16), (32, 1, 16), (1, 128, 16)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) + + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + + # Matmul with vector + a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + for c_shape in ( + (1,), + (32, 128), + ): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) + + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + + # Split K specializtion + a_npy = np.random.normal(0.0, 1.0 / 128, (64, 4096)).astype(np.float32) + b_npy = np.random.normal(0.0, 1.0 / 128, (4096, 32)).astype(np.float32) + + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + for c_shape in ((1,), (1, 32), (64, 1), (64, 32)): + c_npy = np.ones(c_shape).astype(np.float32) + c_mlx = mx.array(c_npy) + + d_npy = alpha * (a_npy @ b_npy) + beta * c_npy + d_mlx = mx.addmm(c_mlx, a_mlx, b_mlx, alpha, beta) + + self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) + self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + + def test_addmm_grad(self): + def make_ref_addmm(alpha, beta): + return lambda c, a, b: alpha * (a @ b) + beta * c + + def make_addmm(alpha, beta): + return lambda c, a, b: mx.addmm(c, a, b, alpha, beta) + + # B, M, N, K + shapes = ((1, 64, 32, 128), (4, 28, 24, 47), (1, 1, 24, 47)) + + alpha = 2.0 + beta = 0.5 + + f_test = make_addmm(alpha, beta) + f_ref = make_ref_addmm(alpha, beta) + + for B, M, N, K in shapes: + cotan = mx.ones((B, M, N)) + c = mx.random.normal((B, M, N)) + a = mx.random.normal((B, M, K)) + b = mx.random.normal((B, K, N)) + + out_ref, dout_ref = mx.vjp( + f_ref, + [c, a, b], + [ + cotan, + ], + ) + out_test, dout_test = mx.vjp( + f_test, + [c, a, b], + [ + cotan, + ], + ) + + self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-5).item()) + + for r, t in zip(dout_ref, dout_test): + self.assertListEqual(r.shape, t.shape) + self.assertTrue(mx.allclose(r, t, atol=1e-5).item())