Skip to content

Commit

Permalink
Update GEMM (ml-explore#424)
Browse files Browse the repository at this point in the history
* Organize and collect metal subroutine templates and elements in `metal/kernels/steel/`
* Update gemm elements for better performance 
* Add split-K specialization for gemm
* Add `addmm` primitive, op and bindings for fused matmul and bias addition 
* Update tests and benchmarks as needed
  • Loading branch information
jagrit06 authored Jan 17, 2024
1 parent 556cdf0 commit 78102a4
Show file tree
Hide file tree
Showing 30 changed files with 2,361 additions and 646 deletions.
4 changes: 2 additions & 2 deletions benchmarks/python/blas/bench_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def get_gflop_count(B, M, N, K):
dtypes = ("float32", "float16")
transposes = ("nn", "nt", "tn")
shapes = (
(16, 234, 768, 3072),
(1, 64, 64, 25344),
(16, 1024, 1024, 1024),
(1, 1024, 1024, 2048),
(4, 1024, 1024, 4096),
(4, 1024, 4096, 1024),
(1, 4096, 4096, 4096),
(15, 1023, 1023, 1023),
(17, 1025, 1025, 1025),
)

for dtype in dtypes:
Expand Down
12 changes: 11 additions & 1 deletion benchmarks/python/comparative/bench_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ def linear(w, b, x):
mx.eval(ys)


def linear_fused(w, b, x):
ys = []
for i in range(10):
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
mx.eval(ys)


def rope(x):
*_, N, D = x.shape
ys = []
Expand Down Expand Up @@ -397,7 +404,10 @@ def selu(x):
print(bench(quant_matmul[args.benchmark], *xs))

elif args.benchmark == "linear":
print(bench(linear, *xs))
if args.fused:
print(bench(linear_fused, *xs))
else:
print(bench(linear, *xs))

elif args.benchmark == "sum_axis":
print(bench(reduction, "sum", axis, x))
Expand Down
51 changes: 43 additions & 8 deletions mlx/backend/accelerate/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ std::tuple<bool, size_t, array> check_transpose(const array& arr) {
}
}

inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
inline void matmul_cblas_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));

auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
Expand All @@ -50,21 +54,34 @@ inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
M,
N,
K,
1.0f, // alpha
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
0.0f, // beta
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}

inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[matmul_cblas] on CPU currently only supports float32");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_cblas_general(a_pre, b_pre, out);
}

inline void matmul_bnns_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
// TODO: Update to utilize BNNS broadcasting

auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
Expand All @@ -75,8 +92,8 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype());

const BNNSLayerParametersBroadcastMatMul gemm_params{
/* float alpha = */ 1.0,
/* float beta = */ 0.0,
/* float alpha = */ alpha,
/* float beta = */ beta,
/* bool transA = */ a_transposed,
/* bool transB = */ b_transposed,
/* bool quadratic = */ false,
Expand Down Expand Up @@ -157,6 +174,12 @@ inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
BNNSFilterDestroy(bnns_filter);
}

inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) {
// TODO: Update to utilize BNNS broadcasting
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_bnns_general(a_pre, b_pre, out);
}

} // namespace

void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
Expand All @@ -166,4 +189,16 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
return matmul_bnns(inputs[0], inputs[1], out);
}

void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);

if (out.dtype() == float32) {
return matmul_cblas_general(inputs[0], inputs[1], out, alpha_, beta_);
}
return matmul_bnns_general(inputs[0], inputs[1], out, alpha_, beta_);
}

} // namespace mlx::core
52 changes: 38 additions & 14 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,14 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT_MULTI(DivMod)

void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));

auto& a_pre = inputs[0];
auto& b_pre = inputs[1];
namespace {

inline void matmul_common_general(
const array& a_pre,
const array& b_pre,
array& out,
float alpha = 1.0f,
float beta = 0.0f) {
auto check_transpose = [](const array& arr) {
auto stx = arr.strides()[arr.ndim() - 2];
auto sty = arr.strides()[arr.ndim() - 1];
Expand All @@ -125,9 +123,10 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {

auto [a_transposed, lda, a] = check_transpose(a_pre);
auto [b_transposed, ldb, b] = check_transpose(b_pre);
int M = a.shape(-2);
int N = b.shape(-1);
int K = a.shape(-1);
size_t M = a.shape(-2);
size_t N = b.shape(-1);
size_t K = a.shape(-1);

for (int i = 0; i < (a.size() / (M * K)); ++i) {
cblas_sgemm(
CblasRowMajor,
Expand All @@ -136,16 +135,41 @@ void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
M,
N,
K,
1.0f, // alpha
alpha, // alpha
a.data<float>() + elem_to_loc(M * K * i, a.shape(), a.strides()),
lda,
b.data<float>() + elem_to_loc(K * N * i, b.shape(), b.strides()),
ldb,
0.0f, // beta
beta, // beta
out.data<float>() + M * N * i,
out.shape(-1) // ldc
);
}
}

} // namespace

void Matmul::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[Matmul::eval_cpu] Currently only supports float32.");
}
out.set_data(allocator::malloc_or_wait(out.nbytes()));
return matmul_common_general(inputs[0], inputs[1], out);
}

void AddMM::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() != float32) {
throw std::runtime_error(
"[AddMM::eval_cpu] Currently only supports float32.");
}

// Fill output with C
auto& c = inputs[2];
CopyType ctype = c.data_size() == 1 ? CopyType::Scalar : CopyType::General;
copy(c, out, ctype);

return matmul_common_general(inputs[0], inputs[1], out, alpha_, beta_);
}

} // namespace mlx::core
6 changes: 3 additions & 3 deletions mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void explicit_gemm_conv_1D_gpu(

// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
return steel_matmul(
s,
d,
/*a = */ in_strided,
Expand Down Expand Up @@ -262,7 +262,7 @@ void explicit_gemm_conv_2D_gpu(

// Perform gemm
std::vector<array> copies = {in_padded, in_strided};
mlx_matmul(
return steel_matmul(
s,
d,
/*a = */ in_strided,
Expand Down Expand Up @@ -411,7 +411,7 @@ void winograd_conv_2D_gpu(
copies_w.push_back(out_wg);
{
std::vector<array> empty_copies;
mlx_matmul(
steel_matmul(
s,
d,
/*a = */ inp_wg,
Expand Down
37 changes: 23 additions & 14 deletions mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ set(
"binary_two"
"conv"
"copy"
"gemm"
"gemv"
"quantized"
"random"
Expand All @@ -30,33 +29,43 @@ set(
"indexing"
)

function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "gemm")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/gemm.h)
endif()
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/gemm/conv.h)
endif()
function(build_kernel_base TARGET SRCFILE DEPS)
add_custom_command(
COMMAND xcrun -sdk macosx metal -Wall -Wextra
-fno-fast-math
-c ${SRCFILE}
-I${PROJECT_SOURCE_DIR}
-o ${KERNEL}.air
DEPENDS ${SRCFILE} ${HEADERS_PADDED}
OUTPUT ${KERNEL}.air
COMMENT "Building ${KERNEL}.air"
-o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS}
OUTPUT ${TARGET}.air
COMMENT "Building ${TARGET}.air"
VERBATIM
)
endfunction(build_kernel_base)

function(build_kernel KERNEL)
set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
set(HEADERS_PADDED ${HEADERS})
if(${KERNEL} STREQUAL "conv")
set(HEADERS_PADDED ${HEADERS_PADDED} ${CMAKE_CURRENT_SOURCE_DIR}/conv.h)
endif()
build_kernel_base(${KERNEL} ${SRCFILE} "${HEADERS_PADDED}")
endfunction(build_kernel)

foreach(KERNEL ${KERNELS})
build_kernel(${KERNEL})
set(KERNEL_AIR ${KERNEL}.air ${KERNEL_AIR})
endforeach()

file(GLOB_RECURSE STEEL_KERNELS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.metal)
file(GLOB_RECURSE STEEL_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/steel/*.h)

foreach(KERNEL ${STEEL_KERNELS})
cmake_path(GET KERNEL STEM TARGET)
build_kernel_base(${TARGET} ${KERNEL} "${STEEL_HEADERS}")
set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR})
endforeach()

add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion mlx/backend/metal/kernels/conv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "mlx/backend/metal/kernels/conv_params.h"
#include "mlx/backend/metal/kernels/bf16.h"

#include "mlx/backend/metal/kernels/gemm/conv.h"
#include "mlx/backend/metal/kernels/conv.h"

using namespace metal;

Expand Down
Loading

0 comments on commit 78102a4

Please sign in to comment.