Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
62062ff
Optimize `all_reduce` by porting the shared memory kernel of deepspee…
chunyuan-w Feb 20, 2025
d51b479
add norm kernels for CPU
mingfeima Mar 6, 2025
481b1ac
add silu_and_mul kernels for CPU
mingfeima Mar 6, 2025
a252d1a
add grouped topk kernels for CPU
mingfeima Mar 6, 2025
29668f7
add decode attention kernels for CPU
mingfeima Mar 6, 2025
6934dda
add fused moe kernels for CPU
mingfeima Mar 7, 2025
be7a5c5
decode attention: fix non-contiguous k_buffer and v_buffer
mingfeima Mar 7, 2025
ea61ccd
add fused moe kernels for CPU: part 2
mingfeima Mar 7, 2025
8c335dc
decode attention: fix seq_len req_pool_indices dtypes using int64_t
mingfeima Mar 7, 2025
e51ccd3
add extend attention kernel for CPU
mingfeima Mar 10, 2025
8fe2983
extend attention: fix bug in MLA when k_extend and k_buffer have diff…
mingfeima Mar 11, 2025
da4fe82
fused moe: fix when w13 has OC not multiples of 64
mingfeima Mar 11, 2025
b4db2d8
add weight packed linear for bfloat16/float16 on CPU
mingfeima Mar 11, 2025
d0b5fc2
weight_packed_linear: remove out as input parameter
mingfeima Mar 12, 2025
235c6cd
convert_weight_packed: use int64_t for stride to avoid overflow
mingfeima Mar 15, 2025
33a7009
add int8_scaled_mm for int8 W8A8 on CPU
mingfeima Mar 17, 2025
0df5370
add biased_grouped_topk for CPU
mingfeima Mar 18, 2025
ba23156
Add record_function for profiling (#14)
yanbing-j Mar 18, 2025
ebc341a
moe: apply avx512-bf16 tinygemm when M is small
mingfeima Mar 19, 2025
b5de4d0
grouped_topk: add support for num_experts = 160, config from DeepSeekV2
mingfeima Mar 19, 2025
1e7ef35
moe: change indexing from int32 to int64 to avoid overflow
mingfeima Mar 20, 2025
33b8be8
int8_scaled_mm: move dequant to per_token_quant_int8
mingfeima Mar 21, 2025
5edc328
Add fused_moe int8 w8a8 support for CPU
mingfeima Mar 22, 2025
4096183
fused_add_rmsnorm: replace at::zeros with at::empty
mingfeima Mar 24, 2025
8556c74
mv cpu source files from src/sgl-kernel/csrc/cpu to csrc/cpu
mingfeima Mar 25, 2025
a40b671
biased_grouped_topk: fix correction_bias dtype, should be bfloat16 in…
mingfeima Mar 25, 2025
4497115
Add bmm AMX and avx512-bf16 kernels on CPU
mingfeima Mar 25, 2025
47db4b7
Add RECORD_FUNCTION in bmm_cpu, int8 mm, per token quant (#22)
chunyuan-w Mar 25, 2025
6884f38
Add rope.cpp and torch_extension_cpu.cpp from 47bc8df
mingfeima Apr 8, 2025
c3e4c89
Add shared_expert for intel AMX
mingfeima Mar 27, 2025
33b2d6d
Move empty from python to C++ (#25)
yanbing-j Mar 27, 2025
460382b
int8_scaled_mm: fuse quant A to reduce python overhead.
mingfeima Mar 28, 2025
2dcc0de
decode_attention: use int64_t for indexing to avoid overflow in stride
mingfeima Mar 28, 2025
cc65cbe
use setup_cpu.py for now
mingfeima Apr 8, 2025
d9ad48b
remove debug print
mingfeima Apr 8, 2025
42ec13e
Add qkv_proj_with_rope which fused qkv projection with segment gemm a…
mingfeima Apr 1, 2025
add5dc4
optimize all_gather (#33)
blzheng Apr 1, 2025
2ad92c6
apply pre-commit format changes
mingfeima Apr 8, 2025
80e6fe2
Merge branch 'main' into pr_native_kernels_for_cpu
zhyncs Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions sgl-kernel/csrc/cpu/activation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#include "common.h"
#include "vec.h"

namespace {

template <typename scalar_t, typename func_t, typename vec_func_t>
void act_and_mul_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t num_tokens,
int64_t dim,
const func_t& f,
const vec_func_t& vf) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;

constexpr int64_t kVecSize = bVec::size();
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
scalar_t* __restrict__ output_ptr = output + i * dim;

int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);

bVec y_bvec = bVec::loadu(input_other_ptr + d);
fVec y_fvec0, y_fvec1;
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);

x_fvec0 = vf(x_fvec0);
x_fvec1 = vf(x_fvec1);

x_fvec0 = x_fvec0 * y_fvec0;
x_fvec1 = x_fvec1 * y_fvec1;

x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(output_ptr + d);
}
#pragma GCC unroll 4
for (; d < dim; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float y_val = static_cast<float>(input_other_ptr[d]);
output_ptr[d] = f(x_val) * y_val;
}
}
});
}

} // anonymous namespace

// input : {num_tokens, 2 * d}
// output : {num_tokens, d}
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());

AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[](float x) { return x / (1.f + std::exp(-x)); },
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
});
return out;
}
122 changes: 122 additions & 0 deletions sgl-kernel/csrc/cpu/bmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"

namespace {

template <typename scalar_t>
void bmm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
int64_t B,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideB,
int64_t mat1_strideM,
int64_t out_strideB,
int64_t out_strideM,
float scale = 0.f) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);

// mat2 contiguous in [B, N, K]
int64_t mat2_strideB = N * K;
int64_t mat2_strideN = K;

const bool use_brgemm = can_use_brgemm<scalar_t>(M);

// parallel on [B, MB, NB]
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, mb{0}, nb{0};
data_index_init(begin, bs, B, mb, MB, nb, NB);

// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];

for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);

tinygemm_kernel<scalar_t>(
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);

// move to the next index
data_index_step(bs, B, mb, MB, nb, NB);
}

if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}

} // anonymous namespace

// mat1 : [B, M, K]
// mat2 : [B, N, K] or [B, OC, IC]
// out : [B, M, N]
// scale: [] 0-dim tensor for per tensor quant
//
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));

auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);

// input and out could be non-contiguous
// weight needs to be contiguous in [OC, IC] order
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
CHECK_INPUT(mat2);
CHECK_DIM(3, out);
CHECK_DIM(3, mat1);
CHECK_DIM(3, mat2);

int64_t B = mat1.size(0);
int64_t M = mat1.size(1);
int64_t N = mat2.size(1);
int64_t K = mat1.size(2);

TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");

int64_t mat1_strideB = mat1.stride(0);
int64_t mat1_strideM = mat1.stride(1);
int64_t out_strideB = out.stride(0);
int64_t out_strideM = out.stride(1);

// check shapes
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");

AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
bmm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
B,
M,
N,
K,
mat1_strideB,
mat1_strideM,
out_strideB,
out_strideM);
});
}
164 changes: 164 additions & 0 deletions sgl-kernel/csrc/cpu/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/record_function.h>

#if defined(_OPENMP)
#include <omp.h>
#endif

namespace {

// dispatch bool
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
[&] { \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME = false; \
return __VA_ARGS__(); \
} \
}()

// dispatch: bfloat16, float16, int8_t
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::BFloat16: { \
using packed_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using packed_t = at::Half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Char: { \
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
}()

#define UNUSED(x) (void)(x)

#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")

#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")

#define CHECK_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CPU(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)

#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")

#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)

// parallel routines
constexpr int GRAIN_SIZE = 1024;

template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) {
return (x + y - 1) / y;
}

template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}

template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
#endif
}

// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
return offset;
}

template <typename T, typename... Args>
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
}

inline bool data_index_step() {
return true;
}

template <typename T, typename... Args>
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
}
return false;
}

// forced unroll for perf critical path

#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif

template <int n>
struct Unroll {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};

template <>
struct Unroll<1> {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};

} // anonymous namespace
Loading
Loading