Source code for kvboost.cacheblend

"""
CacheBlend Recompute
====================
An alternative to SelectiveRecompute that fixes stale KV tensors
*statistically* rather than *spatially*.

Problem (same as SelectiveRecompute)
------------------------------------
When independently-cached chunks are stitched together, their KV tensors
are "stale" — each chunk was encoded without cross-chunk attention context.

Why SelectiveRecompute is suboptimal
------------------------------------
SelectiveRecompute blindly recomputes the last R tokens at every chunk seam.
This is both too broad (recomputes tokens that may not actually deviate) and
too narrow (misses tokens inside a chunk that depend on cross-chunk context).
In practice it costs ~79% of full prefill (Experiment 2: 553ms vs 703ms).

CacheBlend approach
-------------------
1. Concatenate all cached KV chunks into one assembled tensor
2. Run a single forward pass over the cached token IDs with the assembled
   KV as context — this produces "what KV *would* be with full attention"
3. Compare the updated KV against the cached KV per token per layer using
   cosine distance — tokens that deviate significantly are the ones that
   actually need fixing
4. Recompute only the top-k% most deviated tokens (default 15%)

The result: ~15% of tokens recomputed instead of O(R * num_seams), and
the tokens chosen are the ones that actually matter for output quality.

Reference: CacheBlend (USENIX ATC '25)

Trade-off knobs
---------------
recompute_ratio : fraction of tokens to recompute (default 0.15 = 15%)
                  Higher → better quality, more compute.
                  Lower  → faster, slight quality risk.
min_deviation   : minimum cosine distance to consider a token "deviated".
                  Tokens below this are never recomputed regardless of ratio.
"""

from __future__ import annotations

import logging
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F

from .models import PastKVType, AssembledPrompt
from .cache_manager import KVCacheManager

log = logging.getLogger(__name__)


[docs] class CacheBlendRecompute: def __init__( self, recompute_ratio: float = 0.15, min_deviation: float = 0.01, device: str = "cpu", ): self.recompute_ratio = recompute_ratio self.min_deviation = min_deviation self.device = device
[docs] def apply( self, assembled: AssembledPrompt, model, ) -> AssembledPrompt: """ Fix stale KV tensors by deviation-guided selective recomputation. Same interface as SelectiveRecompute.apply(). """ if assembled.cached_past_kv is None: return assembled if assembled.cached_length == 0: return assembled new_kv = self._deviation_recompute( assembled.full_token_ids, assembled.cached_past_kv, assembled.cached_length, model, ) return AssembledPrompt( full_token_ids=assembled.full_token_ids, cached_past_kv=new_kv, cached_length=assembled.cached_length, live_token_ids=assembled.live_token_ids, live_position_ids=assembled.live_position_ids, chunk_boundaries=assembled.chunk_boundaries, cache_hit_ratio=assembled.cache_hit_ratio, )
def _deviation_recompute( self, full_token_ids: List[int], cached_kv: PastKVType, cached_length: int, model, ) -> PastKVType: """ Core CacheBlend algorithm: 1. Forward pass with assembled KV to get updated KV 2. Measure per-token deviation (cosine distance on K tensors) 3. Select top-k% deviated positions 4. Recompute only those positions """ model_device = next(model.parameters()).device cached_tokens = full_token_ids[:cached_length] # ── Step 1: cheap forward pass to get full-context KV ────── input_ids = torch.tensor( [cached_tokens], dtype=torch.long, device=model_device ) pos_ids = torch.arange( 0, cached_length, dtype=torch.long, device=model_device ).unsqueeze(0) with torch.no_grad(): out = model( input_ids=input_ids, position_ids=pos_ids, use_cache=True, ) updated_kv = self._extract_kv(out.past_key_values) # ── Step 2: compute per-token deviation across layers ────── num_layers = len(cached_kv) # Deviation per layer per token: cosine distance on K vectors # K shape: [1, num_heads, seq_len, head_dim] layer_deviations = [] for layer_idx in range(num_layers): cached_k = cached_kv[layer_idx][0].to(model_device) updated_k = updated_kv[layer_idx][0].to(model_device) # Flatten heads into feature dim: [1, seq_len, heads * head_dim] seq_len = cached_k.shape[2] ck = cached_k.permute(0, 2, 1, 3).reshape(1, seq_len, -1) uk = updated_k.permute(0, 2, 1, 3).reshape(1, seq_len, -1) # Cosine distance per token position cos_sim = F.cosine_similarity(ck, uk, dim=-1) # [1, seq_len] dev = 1.0 - cos_sim # 0 = identical, 2 = opposite layer_deviations.append(dev) # Average deviation across all layers → [1, seq_len] mean_dev = torch.stack(layer_deviations, dim=0).mean(dim=0) mean_dev = mean_dev.squeeze(0) # [seq_len] # ── Step 3: select positions that need recomputation ─────── num_tokens = mean_dev.shape[0] num_to_recompute = max(1, int(num_tokens * self.recompute_ratio)) # Apply minimum deviation threshold above_min = mean_dev > self.min_deviation if not above_min.any(): log.debug("CacheBlend: no tokens above min_deviation=%.3f, skipping recompute", self.min_deviation) return cached_kv # Top-k most deviated among those above threshold masked_dev = mean_dev.clone() masked_dev[~above_min] = -1.0 _, top_indices = masked_dev.topk(min(num_to_recompute, above_min.sum().item())) top_indices = top_indices.sort().values # sort by position for locality recompute_pct = len(top_indices) / num_tokens * 100 log.debug( "CacheBlend: recomputing %d/%d tokens (%.1f%%), max_dev=%.4f, mean_dev=%.4f", len(top_indices), num_tokens, recompute_pct, mean_dev.max().item(), mean_dev[above_min].mean().item(), ) # ── Step 4: patch cached KV with updated values at selected positions patched_kv = [] for layer_idx in range(num_layers): cached_k = cached_kv[layer_idx][0].clone() cached_v = cached_kv[layer_idx][1].clone() updated_k_layer = updated_kv[layer_idx][0].to(self.device) updated_v_layer = updated_kv[layer_idx][1].to(self.device) # Patch only the deviated positions idx = top_indices.to(self.device) cached_k[:, :, idx, :] = updated_k_layer[:, :, idx, :] cached_v[:, :, idx, :] = updated_v_layer[:, :, idx, :] patched_kv.append((cached_k, cached_v)) return tuple(patched_kv) @staticmethod def _extract_kv(past_key_values) -> PastKVType: """Normalize HF past_key_values to tuple of (K, V) on CPU.""" if hasattr(past_key_values, "key_cache") and hasattr(past_key_values, "value_cache"): return tuple( (k.cpu(), v.cpu()) for k, v in zip(past_key_values.key_cache, past_key_values.value_cache) ) if hasattr(past_key_values, "to_legacy_cache"): legacy = past_key_values.to_legacy_cache() return tuple((layer[0].cpu(), layer[1].cpu()) for layer in legacy) return tuple((layer[0].cpu(), layer[1].cpu()) for layer in past_key_values)