Retep's

Kernel EP-02


SIMD

What is SIMD

SIMD = one instruction operates on multiple data elements.

// instead of
for i in range(8):
    c[i] = a[i] + b[i]
// cpu does
c_vec = a_vec + b_vec   // 8 floats at once (AVX)

Typical vector widths are:

ISA Float32 per register
SSE 4
AVX2 8
AVX-512 16

Why SIMD Matters for ML Kernels

SIMD works perfectly for operations that are elementwise, reductions and dot products, which makes it significant for ML kernels. However, the performance is compromised if the kernel has branches, strided memory access or loads misaligned data.

SIMD Example

Consider this C++ code:

// simd.cpp
void add(float *a, float *b, float *c, int n)
{
    for (int i = 0; i < n; i++)
    {
        c[i] = a[i] + b[i];
    }
}

If we compile it to assembly with g++ -O2 -g -c simd.cpp -o simd.o, and inspect the assembly code with objdump -d -S simd.o, we get the following code. This might be confusing because you can see there are several different implmentations on same code. The reason is that the compiler emitted multiple versions of the loop (scalar + vectorized) and a dispatcher that selects which one to run.

0000000000000000 <ltmp0>:
;     for (int i = 0; i < n; i++)
       0: 7100047f      cmp     w3, #0x1
       4: 5400020b      b.lt    0x44 <ltmp0+0x44>
       8: 2a0303e8      mov     w8, w3
       c: 71003c7f      cmp     w3, #0xf
      10: 540001c8      b.hi    0x48 <ltmp0+0x48>
      14: d2800009      mov     x9, #0x0                ; =0
;     for (int i = 0; i < n; i++)
      18: d37ef52c      lsl     x12, x9, #2
      1c: 8b0c004a      add     x10, x2, x12
      20: 8b0c002b      add     x11, x1, x12
      24: 8b0c000c      add     x12, x0, x12
      28: cb090108      sub     x8, x8, x9
;         c[i] = a[i] + b[i];
      2c: bc404580      ldr     s0, [x12], #0x4
      30: bc404561      ldr     s1, [x11], #0x4
      34: 1e212800      fadd    s0, s0, s1
      38: bc004540      str     s0, [x10], #0x4
;     for (int i = 0; i < n; i++)
      3c: f1000508      subs    x8, x8, #0x1
      40: 54ffff61      b.ne    0x2c <ltmp0+0x2c>
; }
      44: d65f03c0      ret
      48: d2800009      mov     x9, #0x0                ; =0
;     for (int i = 0; i < n; i++)
      4c: cb00004a      sub     x10, x2, x0
      50: f101015f      cmp     x10, #0x40
      54: 54fffe23      b.lo    0x18 <ltmp0+0x18>
      58: cb01004a      sub     x10, x2, x1
      5c: f101015f      cmp     x10, #0x40
      60: 54fffdc3      b.lo    0x18 <ltmp0+0x18>
      64: 927c6909      and     x9, x8, #0x7ffffff0
      68: 9100804a      add     x10, x2, #0x20
      6c: 9100802b      add     x11, x1, #0x20
      70: 9100800c      add     x12, x0, #0x20
      74: aa0903ed      mov     x13, x9
;         c[i] = a[i] + b[i];
      78: ad7f0580      ldp     q0, q1, [x12, #-0x20]
      7c: acc20d82      ldp     q2, q3, [x12], #0x40
      80: ad7f1564      ldp     q4, q5, [x11, #-0x20]
      84: acc21d66      ldp     q6, q7, [x11], #0x40
      88: 4e24d400      fadd.4s v0, v0, v4
      8c: 4e25d421      fadd.4s v1, v1, v5
      90: 4e26d442      fadd.4s v2, v2, v6
      94: 4e27d463      fadd.4s v3, v3, v7
      98: ad3f0540      stp     q0, q1, [x10, #-0x20]
      9c: ac820d42      stp     q2, q3, [x10], #0x40
;     for (int i = 0; i < n; i++)
      a0: f10041ad      subs    x13, x13, #0x10
      a4: 54fffea1      b.ne    0x78 <ltmp0+0x78>
      a8: eb08013f      cmp     x9, x8
      ac: 54fffb61      b.ne    0x18 <ltmp0+0x18>
      b0: 17ffffe5      b       0x44 <ltmp0+0x44>

First, let’s understand the convention of assembly code:

Register Meaning (typical)
x0 pointer a
x1 pointer b
x2 pointer c
w3 n (32-bit int)
x8x15 temporaries
v0v31 SIMD / FP registers

Entry & Early exit

       0: 7100047f      cmp     w3, #0x1            // if n < 1
       4: 5400020b      b.lt    0x44 <ltmp0+0x44>   // directly return
       8: 2a0303e8      mov     w8, w3              // store n into x8
       c: 71003c7f      cmp     w3, #0xf            // if n > 15
      10: 540001c8      b.hi    0x48 <ltmp0+0x48>   // branch to vector path
      14: d2800009      mov     x9, #0x0            // i = 0

Scalar Loop (for small n)

      18: d37ef52c      lsl     x12, x9, #2     // i * 4 (by shifting left by 2)
      1c: 8b0c004a      add     x10, x2, x12    // &c[i], store to x10
      20: 8b0c002b      add     x11, x1, x12    // &b[i], store to x11
      24: 8b0c000c      add     x12, x0, x12    // &a[i], store to x12
      28: cb090108      sub     x8, x8, x9      // x8 = n - i
      // loop start
      2c: bc404580      ldr     s0, [x12], #0x4 // load a[i]
      30: bc404561      ldr     s1, [x11], #0x4 // load b[i]
      34: 1e212800      fadd    s0, s0, s1      // a[i] + b[i]
      38: bc004540      str     s0, [x10], #0x4 // store to c[i]
      3c: f1000508      subs    x8, x8, #0x1    // x8 = x8 - 1
      40: 54ffff61      b.ne    0x2c <ltmp0+0x2c> // branch to loop start

Vectorized loop (for large n)

When n >= 16, we are able to vectorize it by processing 16 elements at the same time for a, b and c.

Why 16 floats? This is because for the instruction

fadd.4s v0, v0, v4

we process 4 floats. This is because each q register has 128 bits = 16 bytes = 4 floats. Using 4 registers (e.g. q0, q1, q2, q3 for a) allows us to process 16 floats at a time.

Then why 4 registers? This is a performance-pressure tradeoff made by compiler. More resgisters meaning more register pressure and more scheduling needed.

What are q registers? ARM has two register files: General-purpose registers (GPRs) and SIMD / FP registers.

GPRs
x0–x30  (64-bit)
w0–w30  (lower 32 bits)

SIMD
v0–v31

GPRs are used for pointers, integers, addresses, etc. SIMD registers are used specifically to store SIMD data. Each SIMD registers can be viewed as:

Name Width Used for
qN 128-bit SIMD vectors
dN 64-bit double / 2 floats
sN 32-bit float
hN 16-bit half
bN 8-bit byte

And they are used in different instructions

fadd s0, s0, s1      // scalar float
fadd.4s v0, v0, v4  // 4 floats in parallel

Now let’s look at the assembly code breakdown:

First, as a guard checks for SIMD, we need to make sure no aliasing is used. Aliasing means two variables points to the same address. We check by comparing if the addresses of a, b and c overlaps with a within 64 bytes.

      4c: cb00004a      sub     x10, x2, x0             // x10 = c - a
      50: f101015f      cmp     x10, #0x40              // compare x10 with 01000000
      54: 54fffe23      b.lo    0x18 <ltmp0+0x18>       // if lower, branch to scalar path
      58: cb01004a      sub     x10, x2, x1             // x10 = c - b
      5c: f101015f      cmp     x10, #0x40              // compare x10 with 01000000
      60: 54fffdc3      b.lo    0x18 <ltmp0+0x18>       // if lower, branch to scalar path

Then, we do some trick called pre-fetching. You can see we actually advance the pointer of a, b and c by 8 floats (a float is 4 bytes). This is to read the “middle” of the 16 bytes, so we can actually use negative offset to read the first half and 0 offset to read the second half (for example, ldp q0, q1, [x12, #-0x20])

      64: 927c6909      and     x9, x8, #0x7ffffff0     // x9 = n - n % 16 (n round down to multiple of 16)
      68: 9100804a      add     x10, x2, #0x20          // x10 = c + 32 bytes
      6c: 9100802b      add     x11, x1, #0x20          // x11 = b + 32 bytes
      70: 9100800c      add     x12, x0, #0x20          // x12 = a + 32 bytes
      74: aa0903ed      mov     x13, x9                 // update iterator

Why are we doing this instead of just read from the “start” of the 16 bytes? Like:

ldp q0, q1, [x12]           // read first 8 floats
ldp q2, q3, [x12, #32]      // read second 8 floats
add x12, x12, #64           // advance 64 bytes

The problem with the naive approach is:

  • Address generation depends on x12
  • add creates a dependency chain
  • Harder to overlap load latency with FP ops

So if we read from the “middle”:

ldp q0, q1, [x12, #-0x20]
ldp q2, q3, [x12],  #0x40

We are able to have one instruction that both loads and advances, so address loads can happen earlier, since we don’t have dependency on the addition step.

Hey! It’s similar to “fusion” in kernel programming

ldp stands for “load pair of registers”, which is the simd verson of ldr(load register). Same for stp

      78: ad7f0580      ldp     q0, q1, [x12, #-0x20]   // load a[0..7] to q0 and q1  
      7c: acc20d82      ldp     q2, q3, [x12], #0x40    // load a[8..15] to q2 and q3, advance x12 by 64 bytes
      80: ad7f1564      ldp     q4, q5, [x11, #-0x20]   // load b[0..7] to q4 and q5
      84: acc21d66      ldp     q6, q7, [x11], #0x40    // load b[8..15] to q6 and q7, advance x11 by 64 bytes
      88: 4e24d400      fadd.4s v0, v0, v4              // v0 = v0 + v4 (a[0..3] + b[0..3])
      8c: 4e25d421      fadd.4s v1, v1, v5              // v1 = v1 + v5 (a[4..7], b[4..7])
      90: 4e26d442      fadd.4s v2, v2, v6              // v2 = v2 + v6 (a[8..11], b[8..11])
      94: 4e27d463      fadd.4s v3, v3, v7              // v3 = v3 + v7 (a[12..15], b[12..15])
      98: ad3f0540      stp     q0, q1, [x10, #-0x20]   // store to x10
      9c: ac820d42      stp     q2, q3, [x10], #0x40
      a0: f10041ad      subs    x13, x13, #0x10         // i -= 16 (because we processed 16 elements in this batch)
      a4: 54fffea1      b.ne    0x78 <ltmp0+0x78>       // branch to loop start
      a8: eb08013f      cmp     x9, x8                  // if i != n (there are remainders)
      ac: 54fffb61      b.ne    0x18 <ltmp0+0x18>       // go to scalar path to cleanup