Retep's

vLLM EP02 - PagedAttention and Implementation


Let’s learn PagedAttention today!

PagedAttention

The bigger picture here is about efficient model serving. After we have trained a model, to make the model useful, we spent much more time performing inference(or in other words, prediction) on that model. This is the majority of costs for companies like OpenAI and Cursor, so how to make more cost-effective inference architecture became a very relevant topic.

Attention is the key component in today’s LLM world: a LLM is basically a cascading of attention layers (with other compoents, for sure, but attention is the most important). Attention mechanism is summarized into the equations below. Given input sequence x i x_i , compute query q i q_i , key k i k_i and value v i v_i with the weight matrices (which is the parameters we learned during training). Then we

  1. Compute softmax and get element-wise attention score
  2. Get weighted sum of the attention score over the value v i v_i .

What’s wrong with LLM is that it’s “autoregressive”, meaning it generates token one after another (as appose to diffusion model, where the entire outputs are generated at the same time). Such dependency creates a barrier for performance optimization.

For example, given prompt tokens ( x 1 , x 2 , . . . x n ) (x_1, x_2, ...x_n) , we generate the first token x n + 1 x_{n+1} . And then based on ( x 1 , x 2 , . . . x n + 1 ) (x_1, x_2, ...x_{n+1}) , we generate the second output token x n + 2 x_{n+2} . Notice that during both generation step, the k k and v v we computed for x 1 , x 2 , . . . x n x_1, x_2, ...x_n are always the same (since input is the same, and weight is the deterministic). Therefore, it’s natural that we cache the results somewhere so we can reuse.

To put it in more formal terms, there are two phases for each inference step. The prompt phase takes the whole user prompt and generates the first output token, as well as all the k k and v v for the prompt tokens. The autoregressive generation phase generates the remaining tokens sequencially. We are focusing on optimizing the prompt phase here.

The bottleneck for effective KV-cache is memory. As the model gets bigger and more attention layers are introduced, it becomes harder to keep all KVs in cache. Even if you managed to do so, the throughput is also limited because a large chunk of the memory on GPU is occupied by those KV cache, so few requests can be served concurrently.

Looking closer to the memory management before PagedAttention, the researchers see some prominent problems: reserved slots, internal and external fragmentation. Reserved slots and internal fragmentation is basically the speculative memory reserved for the entire output length, even if only a part of it has been generated. External fragmentation is the memory fragment between two requests in the contiguous memory space.

PagedAttention tries to solve the memory bottleneck by resolving those fragmentations. First of all, fetching the KVs by batch and does aggregation before softmax shouldn’t hurt (notice that in the equation below, denominator is still the complete sum).

So, we can divide KVs into blocks and store them, instead of storing them as a whole. Similar to how we manage pagination in OS kernel, we can have a page manager that memorizes and manages those block addresses. We call it “KV Cache Manager”.

That’s it. That’s the essence of PagedAttention. Other topics like copy-on-write, reference counting and swapping come in naturally.

vLLM

Let’s talk about vLLM implementation of PagedAttention.

vLLM’s architecture has been laid out pretty clearly in this blog post. I will extensively borrow the images there. There are mainly three layers: a FastAPI server that handles incoming http requests, an AsyncLLM that runs detokenization, and an EngineCore that runs inference and scheduling. The PagedAttention is solely relevant to EngineCore.

On incoming request, the EngineCore delegate the request to it’s scheduler. The EngineCore runs step in a busy loop, which calls scheduler.schedule to prefill all the KV cache for the model_executor to perform inference.

class EngineCore:
    def __init__():
        # ...
        self.scheduler: SchedulerInterface = Scheduler(
            vllm_config=vllm_config,
            kv_cache_config=kv_cache_config,
            structured_output_manager=self.structured_output_manager,
            include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
            log_stats=self.log_stats,
            block_size=scheduler_block_size,
        )

    # ...
    def add_request(self, request: Request, request_wave: int = 0):
        # ...
        self.scheduler.add_request(request)       

    # ...
    def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
        """Schedule, execute, and make output.

        Returns tuple of outputs and a flag indicating whether the model
        was executed.
        """

        # Check for any requests remaining in the scheduler - unfinished,
        # or finished and not yet removed from the batch.
        if not self.scheduler.has_requests():
            return {}, False
        scheduler_output = self.scheduler.schedule()
        future = self.model_executor.execute_model(scheduler_output, non_block=True)
        # ...
        # Before processing the model output, process any aborts that happened
        # during the model execution.
        self._process_aborts_queue()
        engine_core_outputs = self.scheduler.update_from_output(
            scheduler_output, model_output
        )

        return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0

The scheduler does the following things:

  1. Loop over the running requests
  2. Try to allocate new blocks for new tokens (to put in KV Cache) by kv_cache_manager.allocate_slots
  3. If can not allocate blocks, try to make space by preempting other running requests. If can’t make space, exit.
  4. If allocated blocks successfully, the newly allocated block ids will be added to corresponding requests.
class SchedulerInterface(ABC):
    def schedule(self) -> "SchedulerOutput":
        """Schedule the requests to process in this scheduling step.

        The scheduling decision is made at the iteration level. Each scheduling
        step corresponds to a single forward pass of the model. Therefore, this
        method is called repeatedly by a busy loop in the engine.

        Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
        that specifies how many tokens to process for each request in this
        scheduling step. For example, num_tokens can be as large as the number
        of prompt tokens for new requests, or it can be 1 for the requests that
        are auto-regressively generating new tokens one by one. Otherwise, it
        can be somewhere in between in case of chunked prefills, prefix caching,
        speculative decoding, etc.

        Additionally, the scheduler also returns useful data about each request
        or the batch as a whole. The model runner will use this information in
        preparing inputs to the model.

        Returns:
            A SchedulerOutput object containing information about the scheduled
            requests.
        """

class Scheduler(SchedulerInterface)
    def schedule(self) -> SchedulerOutput:

        scheduled_new_reqs: list[Request] = []
        scheduled_resumed_reqs: list[Request] = []
        scheduled_running_reqs: list[Request] = []
        preempted_reqs: list[Request] = []

        req_to_new_blocks: dict[str, KVCacheBlocks] = {}
        num_scheduled_tokens: dict[str, int] = {}
        token_budget = self.max_num_scheduled_tokens

        # First, schedule the RUNNING requests.
        req_index = 0
        while req_index < len(self.running) and token_budget > 0:
            request = self.running[req_index]
            num_new_tokens = (
                request.num_tokens_with_spec
                + request.num_output_placeholders
                - request.num_computed_tokens
            )
            num_new_tokens = min(num_new_tokens, token_budget)
            num_new_tokens = min(
                num_new_tokens, self.max_model_len - 1 - request.num_computed_tokens
            )
            if num_new_tokens == 0:
                req_index += 1
                continue

            # Schedule newly needed KV blocks for the request.
            with record_function_or_nullcontext("schedule: allocate_slots"):
                while True:
                    new_blocks = self.kv_cache_manager.allocate_slots(
                        request,
                        num_new_tokens,
                        num_lookahead_tokens=self.num_lookahead_tokens,
                    )

                    if new_blocks is not None:
                        # The request can be scheduled.
                        break

                    # The request cannot be scheduled.
                    # Preempt the lowest-priority request.
                    if self.policy == SchedulingPolicy.PRIORITY:
                        preempted_req = max(
                            self.running,
                            key=lambda r: (r.priority, r.arrival_time),
                        )
                        self.running.remove(preempted_req)
                        # ...
                    else:
                        preempted_req = self.running.pop()
            if new_blocks is None:
                # Cannot schedule this request.
                break

            # Schedule the request.
            scheduled_running_reqs.append(request)
            req_to_new_blocks[request.request_id] = new_blocks
            num_scheduled_tokens[request.request_id] = num_new_tokens
            token_budget -= num_new_tokens
            req_index += 1

        scheduler_output = SchedulerOutput(
            scheduled_new_reqs=new_reqs_data,
            scheduled_cached_reqs=cached_reqs_data,
            num_scheduled_tokens=num_scheduled_tokens,
            total_num_scheduled_tokens=total_num_scheduled_tokens,
            scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
            scheduled_encoder_inputs=scheduled_encoder_inputs,
            num_common_prefix_blocks=num_common_prefix_blocks,
            preempted_req_ids={req.request_id for req in preempted_reqs},
            finished_req_ids=self.finished_req_ids,
            free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
        )
        # ...
        return scheduler_output

There are several methods notable for KVCacheManager

  • get_computed_blocks that fetches prefix cache hit.
  • allocate_slots: the main method called by scheduler to allocate new blocks for a request.
  • free: free all blocks used by a finished request.
class KVCacheManager:

    def allocate_slots(
        self,
        request: Request,
        num_new_tokens: int,
        num_new_computed_tokens: int = 0,
        new_computed_blocks: KVCacheBlocks | None = None,
        num_lookahead_tokens: int = 0,
        delay_cache_blocks: bool = False,
        num_encoder_tokens: int = 0,
    ) -> KVCacheBlocks | None:
        # ...
        self.coordinator.remove_skipped_blocks(
            request.request_id, request.num_computed_tokens
        )

        # ...
        num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
            request_id=request.request_id,
            num_tokens=num_tokens_need_slot,
            new_computed_blocks=new_computed_block_list,
            num_encoder_tokens=num_encoder_tokens,
        )
        
        # ...
        new_blocks = self.coordinator.allocate_new_blocks(
            request.request_id, num_tokens_need_slot, num_encoder_tokens
        )

        # ...
        self.coordinator.cache_blocks(request, num_tokens_to_cache)

        return self.create_kv_cache_blocks(new_blocks)

KVCacheCorrdinator basically just calls SingleTypeKVCacheManager, which delegates the block allocation request to BlockPool. SingleTypeKVCacheManager keeps per-request block lists for one attention spec. It decides how many new blocks a request needs (get_num_blocks_to_allocate), records cached prefix hits (save_new_computed_blocks), allocates growth (allocate_new_blocks), commits full blocks to the prefix cache (cache_blocks), frees finished requests (free), and trims sliding-window/cross-attn history by swapping blocks with the null_block (remove_skipped_blocks). It looks up the prefix hits by walking though request’s blocks and asks the BlockPool for cached blocks (see SingleTypeKVCacheManager.find_longest_cache_hit, which is not include here).

class KVCacheCoordinator():
    def allocate_new_blocks(
        self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
    ) -> tuple[list[KVCacheBlock], ...]:
        return tuple(
            manager.allocate_new_blocks(
                request_id,
                num_encoder_tokens
                if isinstance(manager, CrossAttentionManager)
                else num_tokens,
            )
            for manager in self.single_type_managers
        )

class SingleTypeKVCacheManager():
    def allocate_new_blocks(
        self, request_id: str, num_tokens: int
    ) -> list[KVCacheBlock]:
        req_blocks = self.req_to_blocks[request_id]
        num_required_blocks = cdiv(num_tokens, self.block_size)
        num_new_blocks = num_required_blocks - len(req_blocks)
        if num_new_blocks <= 0:
            return []
        else:
            new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
            req_blocks.extend(new_blocks)
            return new_blocks
    

KVCacheBlock is a data structure that stores block information (block_id, ref_cnt and hash) To maintain efficient (O(1)) push/pop/lookup, BlockPool maintains a double linked list of free KVCacheBlock and also a mapping between block hash to block. The block hash to block mapping is used for caching. When a new request matches an existing prefix, the cached blocks will be reused. After the request is finished, the blocks can be freed if it’s ref_cnt is down to 0. get_new_blocks emit the actual free blocks that are not used in any prefix caching. The eviction algorithmn is LRU.


@dataclass
class KVCacheBlock:
    """KV-cache block metadata."""

    # Block ID, ranging from 0 to num_gpu_blocks - 1.
    block_id: int
    # Reference count.
    ref_cnt: int = 0
    # The hash key (block hash + group id) of the block, only available
    # when the block is full and cached.
    _block_hash: BlockHashWithGroupId | None = None

    # Used to construct a doubly linked list for free blocks.
    # These two attributes should only be manipulated by FreeKVCacheBlockQueue.
    prev_free_block: "KVCacheBlock | None" = None
    next_free_block: "KVCacheBlock | None" = None

    # Whether the block is a null block that should never be cached.
    is_null: bool = False

class BlockPool:
    self.blocks: list[KVCacheBlock] = [
        KVCacheBlock(idx) for idx in range(num_gpu_blocks)
    ]
    # Free block queue that constructs and manipulates a doubly linked
    # list of free blocks (including eviction candidates when caching is
    # enabled).
    self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)

    def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
        if num_blocks > self.get_num_free_blocks():
            raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")

        ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
        if self.enable_caching:
            for block in ret:
                self._maybe_evict_cached_block(block)
                assert block.ref_cnt == 0
                block.ref_cnt += 1
        else:
            for block in ret:
                assert block.ref_cnt == 0
                block.ref_cnt += 1
        return ret

    def cache_full_blocks(
        self,
        request: Request,
        blocks: list[KVCacheBlock],
        num_cached_blocks: int,
        num_full_blocks: int,
        block_size: int,
        kv_cache_group_id: int,
    ) -> None:
        """Cache a list of full blocks for prefix caching.
        This function takes a list of blocks that will have their block hash
        metadata to be updated and cached. Given a request, it updates the
        metadata for each block and caching it in the
        `cached_block_hash_to_block`.
        The block hashes values are computed by the Request object immediately
        when it is created and when new tokens are appended.
        """
        # ...
        block_hashes: BlockHashList = request.block_hashes
        # ...

        new_block_hashes = block_hashes[num_cached_blocks:]
        new_hashes: list[ExternalBlockHash] | None = (
            [] if self.enable_kv_cache_events else None
        )
        for i, blk in enumerate(new_full_blocks):
            assert blk.block_hash is None
            block_hash = new_block_hashes[i]

            # Update and added the full block to the cache.
            block_hash_with_group_id = make_block_hash_with_group_id(
                block_hash, kv_cache_group_id
            )
            blk.block_hash = block_hash_with_group_id
            self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
            if new_hashes is not None:
                new_hashes.append(maybe_convert_block_hash(block_hash))

However, notice that the KVCacheBlock itself does not store any physical address as to the data actually sits. The overall workflow looks like this:

When a request arrives, it’s prompt tokens are chuncked and hashed (with opional MM/LoRA/salt/prompt-embed keys) via generate_block_hash_extra_keys and hash_block_tokens, so the blocks are stored inside each future KVCacheBlock. Then KVCacheManager.get_computed_blocks ask the coordinator to get the longest cache hit by SingleTypeKVCacheManager.find_longest_cache_hit, so only full cached blocks are reused. “Full cached blocks” is defined as those blocks that have all tokens filled. If it hits cache, the BlockPool runs touch to increment reference count for data sharing. If someone needs to mutate later blocks, it does COW by

  1. Use SingleTypeKVCacheManager.save_new_computed_blocks to create a new entry for the request that gets all the previously cached blocks’s shallow copy
  2. Use SingleTypeKVCacheManager.allocate_new_blocks and BlockPool.get_new_blocks to get new blocks
  3. BlockPool computes the block hash and does the hashmap-double-linked-list insertion.

Lifecycle of a Request

Terminology:

  • computed_blocks and num_computed_tokens: the blocks and number of tokens that are already computed for the request. This include both cached blocks (from prefix cache) and newly computed blocks (from previous scheduling steps).
  • cached_blocks: blocks that are cached for the request from prefix cache.
  • new_computed_blocks and num_new_computed_tokens: think of them as a “initialization” for computed_blocks. The blocks are computed from SingleTypeKVCacheManager.find_longest_cache_hit only once before the request starts running. At that moment, all computed_blocks are cached_blocks. These tokens are already computed and stored in KV cache, so they don’t need new blocks to be allocated (they are shared!).
  • num_new_tokens: the number of tokens that are generated in the current scheduling step, which does not include previously cached or computed tokens. These tokens need new blocks to be allocated.

When a request is added to the EngineCore, it’s passed to the Scheduler. First it will be stored in the scheduler’s waiting queue. When the request got picked up for scheduling, the scheduler will first call KVCacheManager.get_computed_blocks to get the prefix cache hit. Then it calls KVCacheManager.allocate_slots to allocate new blocks for the request.

  • KVCacheManager.get_computed_blocks calls KVCacheCoordinator.find_longest_cache_hit, uses SingleTypeKVCacheManager.find_longest_cache_hit to walk through the request’s blocks and ask BlockPool for cached blocks. If it fails to find a block from block hash, it terminates because the longest prefix cache hit is found. The blocks found are returned as new_computed_blocks and num_new_computed_tokens, which will be recorded as the requests’s initial computed_blocks and cached_blocks.
  • KVCacheManager.allocate_slots first calls KVCacheCoordinator.get_num_blocks_to_allocate to compute how many new blocks are needed for the request. It excludes those cached blocks and skipped blocks (for sliding window attention) from the calculation. Then it calls KVCacheCoordinator.allocate_new_blocks, which uses SingleTypeKVCacheManager.allocate_new_blocks to call BlockPool.get_new_blocks to get new blocks from the free list. The newly allocated blocks are cached by SingleTypeKVCacheManager.cache_blocks if they are full blocks.

From the scheduler’s perspective, it always tries to finish all the running requests in each scheduling step. If token budges are not used up for all running requests, it also tries to schedule new requests from the waiting queue. After each scheduling step get executed, the scheduler calls Scheduler.update_from_output to update each request’s computed_blocks and num_computed_tokens. If the request is finished (judging by check_stop), it calls KVCacheManager.free to free all blocks used by the request.