Kernel EP-02
SIMD
What is SIMD
SIMD = one instruction operates on multiple data elements.
// instead of
for i in range(8):
c[i] = a[i] + b[i]
// cpu does
c_vec = a_vec + b_vec // 8 floats at once (AVX)
Typical vector widths are:
| ISA | Float32 per register |
|---|---|
| SSE | 4 |
| AVX2 | 8 |
| AVX-512 | 16 |
Why SIMD Matters for ML Kernels
SIMD works perfectly for operations that are elementwise, reductions and dot products, which makes it significant for ML kernels. However, the performance is compromised if the kernel has branches, strided memory access or loads misaligned data.
SIMD Example
Consider this C++ code:
// simd.cpp
void add(float *a, float *b, float *c, int n)
{
for (int i = 0; i < n; i++)
{
c[i] = a[i] + b[i];
}
}
If we compile it to assembly with g++ -O2 -g -c simd.cpp -o simd.o, and inspect the assembly code with objdump -d -S simd.o, we get the following code. This might be confusing because you can see there are several different implmentations on same code. The reason is that the compiler emitted multiple versions of the loop (scalar + vectorized) and a dispatcher that selects which one to run.
0000000000000000 <ltmp0>:
; for (int i = 0; i < n; i++)
0: 7100047f cmp w3, #0x1
4: 5400020b b.lt 0x44 <ltmp0+0x44>
8: 2a0303e8 mov w8, w3
c: 71003c7f cmp w3, #0xf
10: 540001c8 b.hi 0x48 <ltmp0+0x48>
14: d2800009 mov x9, #0x0 ; =0
; for (int i = 0; i < n; i++)
18: d37ef52c lsl x12, x9, #2
1c: 8b0c004a add x10, x2, x12
20: 8b0c002b add x11, x1, x12
24: 8b0c000c add x12, x0, x12
28: cb090108 sub x8, x8, x9
; c[i] = a[i] + b[i];
2c: bc404580 ldr s0, [x12], #0x4
30: bc404561 ldr s1, [x11], #0x4
34: 1e212800 fadd s0, s0, s1
38: bc004540 str s0, [x10], #0x4
; for (int i = 0; i < n; i++)
3c: f1000508 subs x8, x8, #0x1
40: 54ffff61 b.ne 0x2c <ltmp0+0x2c>
; }
44: d65f03c0 ret
48: d2800009 mov x9, #0x0 ; =0
; for (int i = 0; i < n; i++)
4c: cb00004a sub x10, x2, x0
50: f101015f cmp x10, #0x40
54: 54fffe23 b.lo 0x18 <ltmp0+0x18>
58: cb01004a sub x10, x2, x1
5c: f101015f cmp x10, #0x40
60: 54fffdc3 b.lo 0x18 <ltmp0+0x18>
64: 927c6909 and x9, x8, #0x7ffffff0
68: 9100804a add x10, x2, #0x20
6c: 9100802b add x11, x1, #0x20
70: 9100800c add x12, x0, #0x20
74: aa0903ed mov x13, x9
; c[i] = a[i] + b[i];
78: ad7f0580 ldp q0, q1, [x12, #-0x20]
7c: acc20d82 ldp q2, q3, [x12], #0x40
80: ad7f1564 ldp q4, q5, [x11, #-0x20]
84: acc21d66 ldp q6, q7, [x11], #0x40
88: 4e24d400 fadd.4s v0, v0, v4
8c: 4e25d421 fadd.4s v1, v1, v5
90: 4e26d442 fadd.4s v2, v2, v6
94: 4e27d463 fadd.4s v3, v3, v7
98: ad3f0540 stp q0, q1, [x10, #-0x20]
9c: ac820d42 stp q2, q3, [x10], #0x40
; for (int i = 0; i < n; i++)
a0: f10041ad subs x13, x13, #0x10
a4: 54fffea1 b.ne 0x78 <ltmp0+0x78>
a8: eb08013f cmp x9, x8
ac: 54fffb61 b.ne 0x18 <ltmp0+0x18>
b0: 17ffffe5 b 0x44 <ltmp0+0x44>
First, let’s understand the convention of assembly code:
| Register | Meaning (typical) |
|---|---|
x0 |
pointer a |
x1 |
pointer b |
x2 |
pointer c |
w3 |
n (32-bit int) |
x8–x15 |
temporaries |
v0–v31 |
SIMD / FP registers |
Entry & Early exit
0: 7100047f cmp w3, #0x1 // if n < 1
4: 5400020b b.lt 0x44 <ltmp0+0x44> // directly return
8: 2a0303e8 mov w8, w3 // store n into x8
c: 71003c7f cmp w3, #0xf // if n > 15
10: 540001c8 b.hi 0x48 <ltmp0+0x48> // branch to vector path
14: d2800009 mov x9, #0x0 // i = 0
Scalar Loop (for small n)
18: d37ef52c lsl x12, x9, #2 // i * 4 (by shifting left by 2)
1c: 8b0c004a add x10, x2, x12 // &c[i], store to x10
20: 8b0c002b add x11, x1, x12 // &b[i], store to x11
24: 8b0c000c add x12, x0, x12 // &a[i], store to x12
28: cb090108 sub x8, x8, x9 // x8 = n - i
// loop start
2c: bc404580 ldr s0, [x12], #0x4 // load a[i]
30: bc404561 ldr s1, [x11], #0x4 // load b[i]
34: 1e212800 fadd s0, s0, s1 // a[i] + b[i]
38: bc004540 str s0, [x10], #0x4 // store to c[i]
3c: f1000508 subs x8, x8, #0x1 // x8 = x8 - 1
40: 54ffff61 b.ne 0x2c <ltmp0+0x2c> // branch to loop start
Vectorized loop (for large n)
When n >= 16, we are able to vectorize it by processing 16 elements at the same time for a, b and c.
Why 16 floats? This is because for the instruction
fadd.4s v0, v0, v4
we process 4 floats. This is because each q register has 128 bits = 16 bytes = 4 floats. Using 4 registers (e.g. q0, q1, q2, q3 for a) allows us to process 16 floats at a time.
Then why 4 registers? This is a performance-pressure tradeoff made by compiler. More resgisters meaning more register pressure and more scheduling needed.
What are q registers?
ARM has two register files: General-purpose registers (GPRs) and SIMD / FP registers.
GPRs
x0–x30 (64-bit)
w0–w30 (lower 32 bits)
SIMD
v0–v31
GPRs are used for pointers, integers, addresses, etc. SIMD registers are used specifically to store SIMD data. Each SIMD registers can be viewed as:
| Name | Width | Used for |
|---|---|---|
qN |
128-bit | SIMD vectors |
dN |
64-bit | double / 2 floats |
sN |
32-bit | float |
hN |
16-bit | half |
bN |
8-bit | byte |
And they are used in different instructions
fadd s0, s0, s1 // scalar float
fadd.4s v0, v0, v4 // 4 floats in parallel
Now let’s look at the assembly code breakdown:
First, as a guard checks for SIMD, we need to make sure no aliasing is used. Aliasing means two variables points to the same address. We check by comparing if the addresses of a, b and c overlaps with a within 64 bytes.
4c: cb00004a sub x10, x2, x0 // x10 = c - a
50: f101015f cmp x10, #0x40 // compare x10 with 01000000
54: 54fffe23 b.lo 0x18 <ltmp0+0x18> // if lower, branch to scalar path
58: cb01004a sub x10, x2, x1 // x10 = c - b
5c: f101015f cmp x10, #0x40 // compare x10 with 01000000
60: 54fffdc3 b.lo 0x18 <ltmp0+0x18> // if lower, branch to scalar path
Then, we do some trick called pre-fetching. You can see we actually advance the pointer of a, b and c by 8 floats (a float is 4 bytes). This is to read the “middle” of the 16 bytes, so we can actually use negative offset to read the first half and 0 offset to read the second half (for example, ldp q0, q1, [x12, #-0x20])
64: 927c6909 and x9, x8, #0x7ffffff0 // x9 = n - n % 16 (n round down to multiple of 16)
68: 9100804a add x10, x2, #0x20 // x10 = c + 32 bytes
6c: 9100802b add x11, x1, #0x20 // x11 = b + 32 bytes
70: 9100800c add x12, x0, #0x20 // x12 = a + 32 bytes
74: aa0903ed mov x13, x9 // update iterator
Why are we doing this instead of just read from the “start” of the 16 bytes? Like:
ldp q0, q1, [x12] // read first 8 floats
ldp q2, q3, [x12, #32] // read second 8 floats
add x12, x12, #64 // advance 64 bytes
The problem with the naive approach is:
- Address generation depends on x12
- add creates a dependency chain
- Harder to overlap load latency with FP ops
So if we read from the “middle”:
ldp q0, q1, [x12, #-0x20]
ldp q2, q3, [x12], #0x40
We are able to have one instruction that both loads and advances, so address loads can happen earlier, since we don’t have dependency on the addition step.
Hey! It’s similar to “fusion” in kernel programming
ldp stands for “load pair of registers”, which is the simd verson of ldr(load register). Same for stp
78: ad7f0580 ldp q0, q1, [x12, #-0x20] // load a[0..7] to q0 and q1
7c: acc20d82 ldp q2, q3, [x12], #0x40 // load a[8..15] to q2 and q3, advance x12 by 64 bytes
80: ad7f1564 ldp q4, q5, [x11, #-0x20] // load b[0..7] to q4 and q5
84: acc21d66 ldp q6, q7, [x11], #0x40 // load b[8..15] to q6 and q7, advance x11 by 64 bytes
88: 4e24d400 fadd.4s v0, v0, v4 // v0 = v0 + v4 (a[0..3] + b[0..3])
8c: 4e25d421 fadd.4s v1, v1, v5 // v1 = v1 + v5 (a[4..7], b[4..7])
90: 4e26d442 fadd.4s v2, v2, v6 // v2 = v2 + v6 (a[8..11], b[8..11])
94: 4e27d463 fadd.4s v3, v3, v7 // v3 = v3 + v7 (a[12..15], b[12..15])
98: ad3f0540 stp q0, q1, [x10, #-0x20] // store to x10
9c: ac820d42 stp q2, q3, [x10], #0x40
a0: f10041ad subs x13, x13, #0x10 // i -= 16 (because we processed 16 elements in this batch)
a4: 54fffea1 b.ne 0x78 <ltmp0+0x78> // branch to loop start
a8: eb08013f cmp x9, x8 // if i != n (there are remainders)
ac: 54fffb61 b.ne 0x18 <ltmp0+0x18> // go to scalar path to cleanup