Source code for kvboost.models

"""
Core data structures for chunk-level KV cache reuse.
KV tensors live in HuggingFace past_key_values format:
  tuple[num_layers] of (key, value)
  key/value shape: [batch=1, num_heads, seq_len, head_dim]

Cache Key Design
----------------
Two-tier keying (inspired by vLLM v1 block hashes):

  prefix_hash  : hash(parent_hash, token_ids) — unique to the full prefix
                 chain. Same tokens at different positions or with different
                 preceding context get different keys. Used for exact match.

  content_hash : hash(token_ids) — content-only, position-independent.
                 Used for approximate reuse: same tokens in a different
                 context can still be a cache hit, but the chunk is flagged
                 for mandatory full recompute (not just boundary repair).

This resolves the RoPE position collision bug (same tokens at different
positions reusing KV encoded at wrong positions) and the cross-chunk
attention contamination bug (KV vectors encoding the wrong preceding context).
"""

from __future__ import annotations
import time
import hashlib
from dataclasses import dataclass, field
from typing import Optional, Tuple, List

# HF past_key_values type alias
PastKVType = Tuple[Tuple["torch.Tensor", "torch.Tensor"], ...]


[docs] @dataclass class CachedChunk: """ A single cached chunk: a slice of tokenized text + its KV tensors. position_start / position_end are the absolute token positions at which this chunk was originally encoded. They drive position_ids on reuse so RoPE offsets stay consistent. """ chunk_id: str # primary key (prefix_hash for new, content_hash for legacy) text: str # original text (debugging / display) token_ids: List[int] # tokenized form past_key_values: PastKVType # extracted KV tensors (on CPU by default) position_start: int # absolute position of first token position_end: int # absolute position of last token + 1 prefix_hash: str = "" # hash(parent_hash, tokens) — positional + contextual content_hash: str = "" # hash(tokens) — content-only, position-independent created_at: float = field(default_factory=time.time) access_count: int = 0 recomputed: bool = False # True if boundary recompute was applied @property def length(self) -> int: return self.position_end - self.position_start
[docs] def touch(self) -> None: self.access_count += 1
[docs] def memory_bytes(self) -> int: total = 0 for layer in self.past_key_values: total += layer[0].nelement() * layer[0].element_size() total += layer[1].nelement() * layer[1].element_size() return total
def __repr__(self) -> str: mb = self.memory_bytes() / 1e6 return ( f"CachedChunk(id={self.chunk_id[:8]}, " f"pos=[{self.position_start},{self.position_end}), " f"len={self.length}, mem={mb:.2f}MB, hits={self.access_count})" )
[docs] @dataclass class AssembledPrompt: """ Result of stitching cached chunks + live (uncached) tail tokens. cached_past_kv : merged KV tensors for all cached tokens cached_length : number of tokens covered by cached_past_kv live_token_ids : tokens that still need a fresh forward pass live_position_ids : absolute positions for each live token chunk_boundaries : list of (start, end) for each reused chunk (used by SelectiveRecompute to find seam positions) cache_hit_ratio : fraction of total tokens served from cache has_approximate : True if any chunk was matched by content_hash (not prefix_hash) — signals that full recompute is needed, not just boundary repair """ full_token_ids: List[int] cached_past_kv: Optional[PastKVType] cached_length: int live_token_ids: List[int] live_position_ids: List[int] chunk_boundaries: List[Tuple[int, int]] cache_hit_ratio: float has_approximate: bool = False @property def total_length(self) -> int: return len(self.full_token_ids)
[docs] @dataclass class WarmResult: """Diagnostic returned by engine.warm() to help catch alignment issues.""" chunks_stored: int token_count: int chunk_size: int chunk_boundary_aligned: bool partial_tail_tokens: int alignment_warning: Optional[str] = None def __repr__(self) -> str: aligned = "aligned" if self.chunk_boundary_aligned else f"partial tail={self.partial_tail_tokens}" return f"WarmResult(stored={self.chunks_stored}, tokens={self.token_count}, {aligned})"
# ── Hashing helpers ────────────────────────────────────────────────
[docs] def content_hash_from_tokens(token_ids: List[int]) -> str: """Content-only hash. Same tokens always produce the same key.""" raw = b"".join(t.to_bytes(4, "little") for t in token_ids) return hashlib.sha256(raw).hexdigest()
[docs] def chained_hash(token_ids: List[int], parent_hash: Optional[str] = None) -> str: """ Prefix-chained hash (vLLM-style). key = SHA256(parent_hash || token_bytes) Same tokens with different parent hashes produce different keys, so the same text at different positions in different conversations correctly gets separate cache entries. """ parent = (parent_hash or "root").encode("utf-8") tokens = b"".join(t.to_bytes(4, "little") for t in token_ids) return hashlib.sha256(parent + tokens).hexdigest()
# Backwards compatibility alias chunk_id_from_tokens = content_hash_from_tokens