|
| 1 | +#include "common.h" |
| 2 | +#include "vec.h" |
| 3 | + |
| 4 | +namespace { |
| 5 | + |
| 6 | +template <typename scalar_t> |
| 7 | +void rope_kernel_impl( |
| 8 | + scalar_t* __restrict__ q_pe_out, |
| 9 | + scalar_t* __restrict__ k_pe_out, |
| 10 | + int64_t* __restrict__ t_pos, |
| 11 | + scalar_t* __restrict__ q_pe, |
| 12 | + scalar_t* __restrict__ k_pe, |
| 13 | + scalar_t* __restrict__ t_emb_pos, |
| 14 | + int64_t seq_len, |
| 15 | + int64_t num_head, |
| 16 | + int64_t rotary_dim, |
| 17 | + int64_t HR, |
| 18 | + int64_t q_pe_stride_s, |
| 19 | + int64_t out_stride_qs, |
| 20 | + int64_t out_stride_ks, |
| 21 | + int64_t HK, |
| 22 | + int64_t k_pe_stride_s, |
| 23 | + int64_t q_pe_stride_n, |
| 24 | + int64_t out_stride_qn) { |
| 25 | + int64_t COFF = HR / 2; |
| 26 | + at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { |
| 27 | + int64_t seq{0}, head_id{0}; |
| 28 | + data_index_init(begin, seq, seq_len, head_id, num_head); |
| 29 | + for (int64_t i = begin; i < end; ++i) { |
| 30 | + int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n; |
| 31 | + int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn; |
| 32 | + int64_t out_offset_k = seq * out_stride_ks; |
| 33 | + int64_t p = 0; |
| 34 | + scalar_t* sin_start = nullptr; |
| 35 | + scalar_t* cos_start = nullptr; |
| 36 | + // step 0) get the rotary position embedding for the current position |
| 37 | + p = t_pos[seq]; |
| 38 | + sin_start = t_emb_pos + p * HR + COFF; |
| 39 | + cos_start = t_emb_pos + p * HR; |
| 40 | + // step 1) apply_rotary_pos_emb for the rotary_dim elements in every |
| 41 | + // head of query/key |
| 42 | + for (int64_t h = 0; h < rotary_dim; h += 2) { |
| 43 | + scalar_t cos = cos_start[h >> 1]; |
| 44 | + scalar_t sin = sin_start[h >> 1]; |
| 45 | + scalar_t in1 = q_pe[in_offset_q + h]; |
| 46 | + scalar_t in2 = q_pe[in_offset_q + h + 1]; |
| 47 | + scalar_t out1 = in1 * cos - in2 * sin; |
| 48 | + scalar_t out2 = in2 * cos + in1 * sin; |
| 49 | + q_pe_out[out_offset_q + h] = out1; |
| 50 | + q_pe_out[out_offset_q + h + 1] = out2; |
| 51 | + } |
| 52 | + for (int64_t h = 0; h < HK; h += 2) { |
| 53 | + scalar_t cos = cos_start[h >> 1]; |
| 54 | + scalar_t sin = sin_start[h >> 1]; |
| 55 | + int64_t k_pe_offset = seq * k_pe_stride_s; |
| 56 | + scalar_t in1_k = k_pe[k_pe_offset + h]; |
| 57 | + scalar_t in2_k = k_pe[k_pe_offset + h + 1]; |
| 58 | + scalar_t out1_k = in1_k * cos - in2_k * sin; |
| 59 | + scalar_t out2_k = in2_k * cos + in1_k * sin; |
| 60 | + k_pe_out[out_offset_k + h] = out1_k; |
| 61 | + k_pe_out[out_offset_k + h + 1] = out2_k; |
| 62 | + } |
| 63 | + // move to the next index |
| 64 | + data_index_step(seq, seq_len, head_id, num_head); |
| 65 | + } |
| 66 | + }); |
| 67 | +} |
| 68 | +} // namespace |
| 69 | + |
| 70 | +std::tuple<at::Tensor, at::Tensor> |
| 71 | +rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) { |
| 72 | + RECORD_FUNCTION( |
| 73 | + "sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos})); |
| 74 | + CHECK_INPUT (t_pos); |
| 75 | + CHECK_LAST_DIM_CONTIGUOUS_INPUT (q_pe); |
| 76 | + CHECK_LAST_DIM_CONTIGUOUS_INPUT (k_pe); |
| 77 | + CHECK_INPUT (t_emb_pos); |
| 78 | + CHECK_DIM(1, t_pos); |
| 79 | + CHECK_DIM(3, q_pe); |
| 80 | + CHECK_DIM(3, k_pe); |
| 81 | + CHECK_DIM(2, t_emb_pos); |
| 82 | + |
| 83 | + int64_t seq_len = q_pe.size(0); |
| 84 | + int64_t num_head = q_pe.size(1); |
| 85 | + int64_t rotary_dim = q_pe.size(2); |
| 86 | + int64_t HK = k_pe.size(2); |
| 87 | + int64_t HR = t_emb_pos.size(1); |
| 88 | + CHECK_EQ(HR, rotary_dim); |
| 89 | + CHECK_EQ(k_pe.size(0), seq_len); |
| 90 | + CHECK_EQ(k_pe.size(1), 1); |
| 91 | + CHECK_EQ(t_pos.size(0), seq_len); |
| 92 | + CHECK_EQ(HK, rotary_dim); |
| 93 | + |
| 94 | + at::Tensor q_pe_out = at::empty_like(q_pe); |
| 95 | + at::Tensor k_pe_out = at::empty_like(k_pe); |
| 96 | + int64_t q_pe_stride_s = q_pe.stride(0); |
| 97 | + int64_t q_pe_stride_n = q_pe.stride(1); |
| 98 | + int64_t k_pe_stride_s = k_pe.stride(0); |
| 99 | + int64_t out_stride_qs = q_pe_out.stride(0); |
| 100 | + int64_t out_stride_qn = q_pe_out.stride(1); |
| 101 | + int64_t out_stride_ks = k_pe_out.stride(0); |
| 102 | + |
| 103 | + const auto input_dtype = q_pe.scalar_type(); |
| 104 | + TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type()); |
| 105 | + TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type"); |
| 106 | + TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type"); |
| 107 | + |
| 108 | + AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] { |
| 109 | + rope_kernel_impl<scalar_t>( |
| 110 | + q_pe_out.data_ptr<scalar_t>(), |
| 111 | + k_pe_out.data_ptr<scalar_t>(), |
| 112 | + t_pos.data_ptr<int64_t>(), |
| 113 | + q_pe.data_ptr<scalar_t>(), |
| 114 | + k_pe.data_ptr<scalar_t>(), |
| 115 | + t_emb_pos.data_ptr<scalar_t>(), |
| 116 | + seq_len, |
| 117 | + num_head, |
| 118 | + rotary_dim, |
| 119 | + HR, |
| 120 | + q_pe_stride_s, |
| 121 | + out_stride_qs, |
| 122 | + out_stride_ks, |
| 123 | + HK, |
| 124 | + k_pe_stride_s, |
| 125 | + q_pe_stride_n, |
| 126 | + out_stride_qn); |
| 127 | + }); |
| 128 | + return std::make_tuple(q_pe_out, k_pe_out); |
| 129 | +} |
0 commit comments