Source code for kvboost.selective_recompute

"""
SelectiveRecompute
==================
Problem
-------
When we concatenate KV caches from independently-cached chunks,
each chunk's stored KV was computed with attention only over its own
tokens (and whatever prefix preceded it at cache-time). After stitching
chunks together the seam positions have "stale" KV — they didn't attend
to cross-chunk context.

This matters primarily for the *keys and values* of the boundary tokens
(not the queries, which are recomputed fresh for every generation step).

Fix
---
For the last R tokens of each chunk boundary (except the final chunk),
re-run a forward pass using the full merged KV as prefix context and
replace those R positions in the stitched KV with freshly computed ones.

This is "selective" because we only recompute O(R * num_boundaries)
tokens rather than the full prompt.

Trade-off knobs
---------------
recompute_overlap   : R — how many tokens back from each seam to fix.
                      Larger → better quality, more compute.
                      Default 16 is a good balance for chunk_size >= 64.
skip_if_no_seams    : skip the module entirely for pure prefix hits
                      (no stitching was done).
"""

from __future__ import annotations

import logging
from typing import List, Optional, Tuple

import torch

from .models import PastKVType, AssembledPrompt
from .cache_manager import KVCacheManager

log = logging.getLogger(__name__)


[docs] class SelectiveRecompute: def __init__( self, recompute_overlap: int = 16, skip_if_no_seams: bool = True, device: str = "cpu", ): self.overlap = recompute_overlap self.skip_if_no_seams = skip_if_no_seams self.device = device
[docs] def apply( self, assembled: AssembledPrompt, model, ) -> AssembledPrompt: """ Optionally fix the KV seams in assembled.cached_past_kv. Returns a (possibly modified) AssembledPrompt. model must be a HuggingFace CausalLM. """ boundaries = assembled.chunk_boundaries # Only seams between chunks need fixing (not the final boundary) seams = boundaries[:-1] # each seam is the END of a non-final chunk if self.skip_if_no_seams and not seams: return assembled # pure prefix — no stitching seams if assembled.cached_past_kv is None: return assembled new_kv = self._recompute_seams( assembled.full_token_ids, assembled.cached_past_kv, seams, model, ) return AssembledPrompt( full_token_ids=assembled.full_token_ids, cached_past_kv=new_kv, cached_length=assembled.cached_length, live_token_ids=assembled.live_token_ids, live_position_ids=assembled.live_position_ids, chunk_boundaries=assembled.chunk_boundaries, cache_hit_ratio=assembled.cache_hit_ratio, )
# ------------------------------------------------------------------ # Core recomputation # ------------------------------------------------------------------ def _recompute_seams( self, full_token_ids: List[int], merged_kv: PastKVType, seams: List[Tuple[int, int]], model, ) -> PastKVType: """ For each (start, end) seam in `seams`, re-encode the last `self.overlap` tokens of that seam position using the preceding merged_kv as context. Returns updated merged_kv with those positions patched. """ current_kv = merged_kv for _start, end_pos in seams: recompute_start = max(0, end_pos - self.overlap) recompute_tokens = full_token_ids[recompute_start:end_pos] if not recompute_tokens: continue # Prefix = everything before recompute_start prefix_kv = KVCacheManager.slice_kv(current_kv, 0, recompute_start) prefix_len = recompute_start # Build inputs input_ids = torch.tensor( [recompute_tokens], dtype=torch.long, device=self.device ) pos_ids = torch.arange( prefix_len, prefix_len + len(recompute_tokens), dtype=torch.long, device=self.device ).unsqueeze(0) # Move prefix_kv to model device model_device = next(model.parameters()).device from transformers import DynamicCache as _DynCache prefix_kv_dev = _DynCache() for layer_k, layer_v in prefix_kv: prefix_kv_dev.update( layer_k.to(model_device), layer_v.to(model_device), len(prefix_kv_dev), ) input_ids = input_ids.to(model_device) pos_ids = pos_ids.to(model_device) with torch.no_grad(): out = model( input_ids=input_ids, past_key_values=prefix_kv_dev, position_ids=pos_ids, use_cache=True, ) # Normalize DynamicCache → plain tuple[(key, val), ...] raw_kv = out.past_key_values if hasattr(raw_kv, "layers"): raw_kv = tuple((l.keys, l.values) for l in raw_kv.layers) elif hasattr(raw_kv, "key_cache") and hasattr(raw_kv, "value_cache"): raw_kv = tuple(zip(raw_kv.key_cache, raw_kv.value_cache)) else: raw_kv = tuple((layer[0], layer[1]) for layer in raw_kv) fresh_kv = tuple( (layer[0].to(self.device), layer[1].to(self.device)) for layer in raw_kv ) # fresh_kv now covers [0, recompute_start + overlap) # Splice it back: take fresh_kv up to end_pos, then append # whatever was in current_kv after end_pos tail_kv = KVCacheManager.slice_kv(current_kv, end_pos, KVCacheManager.kv_seq_len(current_kv)) if KVCacheManager.kv_seq_len(tail_kv) > 0: current_kv = KVCacheManager.merge_kv_list([fresh_kv, tail_kv]) else: current_kv = fresh_kv log.debug( "Recomputed seam at pos %d%d (%d tokens)", recompute_start, end_pos, len(recompute_tokens), ) return current_kv