Kernel EP-01
What is kernel?
A highly optimized function that performs a core numerical operation on large tensors, close to the hardware, with performance as the primary goal. Examples:
- Matrix multiplication
- Softmax
- LayerNorm
- Convolution
- Attention
These things run billions of times, and dominate the runtime across different layers. Therefore, it’s important that it’s efficient.
Why kernel exists?
Consider doing a simple dot product of two tensors:
for i in range(N):
for j in range(N):
for k in range(N):
C[i][j] += A[i][k] * B[k][j]
- Ignores cache
- Ignores vector units
- Ignores GPU
- Re-reads memory excessively
Each element of A and B is reused many times and reloaded many times. How to make it better?
for i in rows:
for j in cols:
C[i, j] = dot(A[i, :], B[:, j])
In the numpy implementation, we are doing vector-product instead of element-product. All elements in the same vector can be loaded and computed at once.
Computation & Memory
Are the two fundamental limits. Any kernel is limited by one of the two things:
- Compute throughput, like how many FLOPS the hardware can do per second.
- Memroy bandwidth, like how fast data can be moved from memory to compute units.
The key idea for kernel is “hardware can do math much faster than it can move data.”
Let’s take matrix multiplication (GEMM) as an example. For an N*N matrix multiple,
- FLOPs (
rows in A,columns in B, one mult and one add per element in each row/col) - Data size
Arithmetic intensity (math to memory ratio)
which is . Therefore, it is compute-bound because computation cost grows faster than memory cost.
However, naive GEMM is memory bound, because the data size is now as well because we need to re-load the element everytime we do computation.
What happens if matrices don’t fit in cache?
We always need to pay attention to whether the cache size is big enough to hold both a row of A and a column of B. The worst case is they repeatedly evict each other so no reuse actually happens.
Kernels
MatMul Kernel
Softmax Kernel
Naive implementation:
y = np.exp(x) / np.sum(np.exp(x))
which is wrong, because it doesn’t prevent numerical overflow and we are also recomputing exp(x) twice.
Correct implementation should be:
- Find
max(x) - Substract max
- Compute
exp - Sum
- Normalize
m = max(x) # find max
for i: # notice we all divide by m so that it doesn't overflow (y[i] <= 1)
y[i] = exp(x[i] - m)
s = sum(y)
for i:
y[i] /= s
This kernel requires multiple passes over the data: finding max requires one pass, computing exponents requires one pass, sum requires on pass and normalization requires one pass. Also, notice the steps are all dependent on each other.
LayerNorm Kernel
Which is essentialy normalizing each row by substracting the mean() and offsetting the variance(, to prevent zero-division), and then shift and scale by and .
mean = sum(x) / N
var = sum((x - mean)^2) / N
y = (x = mean)/sqrt(var+eps)
y = y * gamma + beta
Notice that it also requires multiple passes through the data, so it is also memory bound. To reduce the performance downgrade, we can “fuse” multiple operations into a single kernel to reduce memory traffic. You can think of an example of we used to have 1-in-1-out components that are strictly sequenced after each other. Now we can have a 2-in-1-out component that combines both component, so that the memory within the component can be more efficiently used.
Typical fused kernels
- LayerNorm + residual
- Softmax + scale
- Bias + activation (e.g., GELU)
As oppose to LayerNorm, we have RMSNorm that is more efficient memory-wise:
It only needs root mean square, but layerNorm needs both mean and variance, so one pass less:
OneDNN MatMul Kernel
Let’s take a look at an example of kernel code from OneDNN CPU matmul:
auto ker = [=](const dims_t dst_dims_idx, dim_t m, dim_t n) {
dims_t src_dims_idx, weights_dims_idx;
utils::copy_dims_with_mask(src_dims_idx, dst_dims_idx, ndims, src_mask);
utils::copy_dims_with_mask(
weights_dims_idx, dst_dims_idx, ndims, wei_mask);
src_dims_idx[ndims - 2] = m;
weights_dims_idx[ndims - 1] = n;
auto &src_k_dim = src_dims_idx[ndims - 1];
auto &wei_k_dim = weights_dims_idx[ndims - 2];
float res = 0.0f;
for (dim_t i_group = 0; i_group < ngroups_k; i_group++) {
float acc = 0.0f;
for (dim_t k = 0; k < group_k; ++k) {
src_k_dim = k + i_group * group_k;
wei_k_dim = k + i_group * group_k;
const auto src_off = src_d.off_v(src_dims_idx);
const auto weights_off = weights_d.off_v(weights_dims_idx);
const float s
= io::load_float_value(src_d.data_type(), src, src_off);
float w = io::load_float_value(
weights_d.data_type(), weights, weights_off);
// weights decompression should happen before the operation
if (with_wei_decompression) {
if (with_wei_zero_points) {
const dim_t wei_zp_offset
= matmul_helper_t::get_quant_off(
weights_dims_idx, ndims, wei_zp_mask,
wei_zp_group_k, wei_zp_group_n,
wei_zp_md);
const auto wei_zp = io::load_float_value(
wei_zp_dt, wei_zero_points, wei_zp_offset);
w -= wei_zp;
}
if (with_wei_scales) {
const dim_t wei_scale_offset
= matmul_helper_t::get_quant_off(
weights_dims_idx, ndims, wei_scale_mask,
wei_scale_group_k, wei_scale_group_n,
wei_scale_md);
const float wei_scale = io::load_float_value(
wei_scale_dt, wei_scales, wei_scale_offset);
w *= wei_scale;
}
}
acc += s * w;
}
// apply scales after computing a group along K
if (with_src_scales) {
const dim_t src_scale_offset = matmul_helper_t::get_quant_off(
src_dims_idx, ndims, src_scale_mask, src_scale_group_m,
src_scale_group_k, src_scale_md);
float src_scale = io::load_float_value(
src_scale_dt, src_scales, src_scale_offset);
acc *= src_scale;
}
if (with_wei_scales && !with_wei_decompression) {
const dim_t wei_scale_offset = matmul_helper_t::get_quant_off(
weights_dims_idx, ndims, wei_scale_mask,
wei_scale_group_k, wei_scale_group_n, wei_scale_md);
const float wei_scale = io::load_float_value(
wei_scale_dt, wei_scales, wei_scale_offset);
acc *= wei_scale;
}
res += acc;
}
return res;
};
The basic rule of matmul is to “iterate through the rows of first matrix, then iterate through the columns of second matrix, then do vector dot product (or generally, iterate over a shared dimension for both matrices). So the main job of the kernel is to create a “iteration plan” to retrieve elements from both matrices and do computation.
The input of the kernel is (dst_dims_idx, m, n). dst_dims_idx is the index of the element in the target matrix, and m and n are the non-reduction dimensions we want to iterate through for source matrix and weight matrix.
The kernel assumes source matrix to have a dimension of (..., m, k), the weight matrix to have a dimension of (..., k, n) and the target matrix to have a dimension of (..., m, n).
src_dims_idx is the iterator of the source matrix and weights_dims_idx is the iterator of the weight matrix. You can view them as “coordinates” that points to an element in those matrices.
group_k is the size of the block(or tile) for each compuation. If dimension k is very large and we can not fit both vectors in the cache (which is 2k in size), we need to divide it into smaller blocks so that each block pairs can fit in cache.
for (dim_t i_group = 0; i_group < ngroups_k; i_group++) {
// iterate through each block
for (dim_t k = 0; k < group_k; ++k) {
// iterate through the k dimension of that block by computing the element index
src_k_dim = k + i_group * group_k;
wei_k_dim = k + i_group * group_k;
//...
// Then, if we ignore the decompression, it boils down to element product
acc += s * w;
}
// and we accumulate the result from each block to get the final result
res += acc;
}