Source code for kvboost.cache_manager

"""
KVCacheManager
==============
Two-tier storage with two-tier keying:

  Storage tiers:
    Tier 1 (hot)  : in-process dict, tensors on CPU RAM
    Tier 2 (cold) : mmap'd files on disk via torch.save / torch.load

  Key tiers:
    prefix_hash   : hash(parent_hash, tokens) — exact match, positionally correct
    content_hash  : hash(tokens) — approximate match, needs full recompute

  Lookup order:
    1. Try prefix_hash (exact) → use directly
    2. Try content_hash (approximate) → use but flag for full recompute

Eviction policy: LRU by access_count + age.
"""

from __future__ import annotations

import logging
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import torch

from .models import CachedChunk, PastKVType, content_hash_from_tokens, chained_hash
from .kv_quantize import QuantizedKV, quantize_kv, dequantize_kv
from .disk_tier import DiskTier

log = logging.getLogger(__name__)


[docs] class ChunkMatch: """Result of a cache lookup — tracks whether the match is exact or approximate.""" __slots__ = ("chunk", "approximate") def __init__(self, chunk: CachedChunk, approximate: bool = False): self.chunk = chunk self.approximate = approximate
[docs] class KVCacheManager: def __init__( self, max_chunks: int = 64, disk_dir: Optional[str] = None, device: str = "cpu", kv_cache_bits: int = 16, ): self.max_chunks = max_chunks self.device = device self.kv_cache_bits = kv_cache_bits # 16 = no quantization, 8 = int8, 4 = int4 # Primary store keyed by prefix_hash (exact match) self._hot: OrderedDict[str, CachedChunk] = OrderedDict() # Quantized storage: chunk_id → QuantizedKV (when bits < 16) self._quantized: Dict[str, QuantizedKV] = {} # Secondary index: content_hash → prefix_hash for approximate lookup self._content_index: Dict[str, str] = {} # Frequency counter: chunk_id → number of generate() calls it appeared in. # Chunks that appear across many requests (system prompts) get high counts # and are protected from eviction. One-off document chunks stay at 1. self._frequency: Dict[str, int] = {} # Optional disk tier (flat mmap block pool) self._disk: Optional[DiskTier] = None if disk_dir: self._disk = DiskTier( cache_dir=disk_dir, max_chunks=max_chunks * 2, # cold tier can hold more than hot ) # Stats self.hits = 0 self.misses = 0 self.approximate_hits = 0 # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def store(self, chunk: CachedChunk) -> None: """Store a chunk. Evicts lowest-frequency entry if over capacity.""" key = chunk.prefix_hash or chunk.chunk_id if key in self._hot: self._hot.move_to_end(key) self._frequency[key] = self._frequency.get(key, 0) + 1 return # Move KV tensors to storage device chunk.past_key_values = self._move_kv(chunk.past_key_values, self.device) # Quantize for compressed storage (if enabled) if self.kv_cache_bits < 16: qkv = quantize_kv(chunk.past_key_values, bits=self.kv_cache_bits) self._quantized[key] = qkv # Release the full-precision tensors — they'll be dequantized on get() chunk.past_key_values = () # empty sentinel log.debug( "Stored chunk %s quantized int%d (%.2fMB → %.2fMB)", key[:8], self.kv_cache_bits, qkv.memory_bytes() / 1e6 * (16 / self.kv_cache_bits), qkv.memory_bytes() / 1e6, ) if len(self._hot) >= self.max_chunks: self._evict_lfu() self._hot[key] = chunk self._frequency[key] = 1 # Index by content_hash for approximate lookup if chunk.content_hash: self._content_index[chunk.content_hash] = key if self.kv_cache_bits >= 16: log.debug("Stored chunk %s (%.2fMB)", key[:8], chunk.memory_bytes() / 1e6)
[docs] def get(self, chunk_id: str) -> Optional[CachedChunk]: """Retrieve a chunk by prefix_hash (exact match only). Dequantizes if needed.""" if chunk_id in self._hot: self.hits += 1 chunk = self._hot[chunk_id] chunk.touch() self._hot.move_to_end(chunk_id) # Dequantize on load if stored quantized if chunk_id in self._quantized: chunk = self._dequantize_chunk(chunk, chunk_id) return chunk # Try disk tier if self._disk and self._disk.contains(chunk_id): self.hits += 1 chunk = self._disk.read(chunk_id, device=self.device) if chunk is not None: self.store(chunk) # promote to hot return self.get(chunk_id) # re-enter to handle quantization self.misses += 1 return None
def _dequantize_chunk(self, chunk: CachedChunk, key: str) -> CachedChunk: """Reconstruct full-precision KV tensors from quantized storage.""" qkv = self._quantized[key] # Return a copy with dequantized tensors (don't modify stored chunk) import copy out = copy.copy(chunk) out.past_key_values = dequantize_kv(qkv) return out
[docs] def get_by_content(self, content_hash: str) -> Optional[ChunkMatch]: """ Look up by content_hash (approximate match). Returns ChunkMatch with approximate=True if found via content index. """ # Check if any stored chunk has this content hash if content_hash in self._content_index: prefix_key = self._content_index[content_hash] if prefix_key in self._hot: self.approximate_hits += 1 chunk = self._hot[prefix_key] chunk.touch() self._hot.move_to_end(prefix_key) # Dequantize if needed if prefix_key in self._quantized: chunk = self._dequantize_chunk(chunk, prefix_key) return ChunkMatch(chunk=chunk, approximate=True) return None
[docs] def lookup(self, token_ids: List[int], parent_hash: Optional[str] = None) -> Optional[ChunkMatch]: """ Two-tier lookup: 1. Try prefix-chained hash (exact) — correct position + context 2. Fall back to content hash (approximate) — flagged for full recompute Returns ChunkMatch or None. """ # Tier 1: exact match by prefix chain p_hash = chained_hash(token_ids, parent_hash) chunk = self.get(p_hash) if chunk is not None: return ChunkMatch(chunk=chunk, approximate=False) # Tier 2: approximate match by content c_hash = content_hash_from_tokens(token_ids) return self.get_by_content(c_hash)
def get_or_none(self, token_ids: List[int]) -> Optional[CachedChunk]: """Legacy API: look up by content hash only (backwards compat).""" c_hash = content_hash_from_tokens(token_ids) if c_hash in self._content_index: prefix_key = self._content_index[c_hash] if prefix_key in self._hot: self.hits += 1 chunk = self._hot[prefix_key] chunk.touch() self._hot.move_to_end(prefix_key) return chunk # Direct lookup by content hash (for chunks stored with old keying) if c_hash in self._hot: self.hits += 1 chunk = self._hot[c_hash] chunk.touch() self._hot.move_to_end(c_hash) return chunk self.misses += 1 return None
[docs] def build_prefix_kv( self, token_ids: List[int], chunk_size: int ) -> Tuple[Optional[PastKVType], int]: """ Greedily assemble the longest cached prefix using chained hashes. Only exact matches are used (no approximate fallback for prefix mode). """ chunks: List[CachedChunk] = [] pos = 0 parent_hash = None while pos + chunk_size <= len(token_ids): slice_ids = token_ids[pos : pos + chunk_size] p_hash = chained_hash(slice_ids, parent_hash) chunk = self.get(p_hash) if chunk is None: break chunks.append(chunk) parent_hash = p_hash pos += chunk_size if not chunks: return None, 0 merged = self.merge_kv_list([c.past_key_values for c in chunks]) return merged, pos
[docs] def find_matching_chunks( self, token_ids: List[int], chunk_size: int ) -> List[Tuple[int, ChunkMatch]]: """ Scan for all matching chunks using two-tier lookup. Returns list of (start_pos, ChunkMatch) pairs in order. Each ChunkMatch carries approximate=True/False. """ results = [] parent_hash = None for start in range(0, len(token_ids) - chunk_size + 1, chunk_size): slice_ids = token_ids[start : start + chunk_size] match = self.lookup(slice_ids, parent_hash) if match is not None: results.append((start, match)) # Chain the hash for the next chunk (use the exact prefix hash) parent_hash = chained_hash(slice_ids, parent_hash) else: # Chain is broken — subsequent chunks can't be exact matches parent_hash = None return results
[docs] def invalidate(self, chunk_id: str) -> None: chunk = self._hot.pop(chunk_id, None) if chunk and chunk.content_hash in self._content_index: if self._content_index[chunk.content_hash] == chunk_id: del self._content_index[chunk.content_hash] self._quantized.pop(chunk_id, None) self._frequency.pop(chunk_id, None) if self._disk: self._disk.remove(chunk_id)
[docs] def stats(self) -> Dict: total = self.hits + self.misses + self.approximate_hits hot_mb = sum(c.memory_bytes() for c in self._hot.values()) / 1e6 result = { "hot_chunks": len(self._hot), "hot_memory_mb": round(hot_mb, 2), "cache_hits": self.hits, "approximate_hits": self.approximate_hits, "cache_misses": self.misses, "hit_rate": round((self.hits + self.approximate_hits) / max(total, 1), 3), "exact_hit_rate": round(self.hits / max(total, 1), 3), } if self._disk: result.update(self._disk.stats()) return result
# ------------------------------------------------------------------ # Static helpers # ------------------------------------------------------------------
[docs] @staticmethod def merge_kv_list(kv_list: List[PastKVType]) -> PastKVType: if not kv_list: raise ValueError("kv_list is empty") if len(kv_list) == 1: return kv_list[0] num_layers = len(kv_list[0]) merged = [] for layer_idx in range(num_layers): keys = torch.cat([kv[layer_idx][0] for kv in kv_list], dim=2) vals = torch.cat([kv[layer_idx][1] for kv in kv_list], dim=2) merged.append((keys, vals)) return tuple(merged)
[docs] @staticmethod def slice_kv(kv: PastKVType, start: int, end: int) -> PastKVType: return tuple( (layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in kv )
[docs] @staticmethod def kv_seq_len(kv: PastKVType) -> int: return kv[0][0].shape[2]
# ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _evict_lfu(self) -> None: """ Frequency-aware eviction: evict the chunk with the lowest frequency count. Chunks that appear across many generate() calls (system prompts) are protected; one-off document chunks are evicted first. Tie-breaking: among chunks with equal frequency, evict the LRU one (first in OrderedDict insertion order). """ if not self._hot: return # Find the entry with the lowest frequency min_freq = float("inf") victim_id = None for cid in self._hot: freq = self._frequency.get(cid, 0) if freq < min_freq: min_freq = freq victim_id = cid if victim_id is None: return victim = self._hot[victim_id] log.debug("Evicting chunk %s (freq=%d)", victim_id[:8], min_freq) # Clean up content index if victim.content_hash in self._content_index: if self._content_index[victim.content_hash] == victim_id: del self._content_index[victim.content_hash] # Demote to disk tier if available if self._disk: # Dequantize first if stored quantized if victim_id in self._quantized: victim = self._dequantize_chunk(victim, victim_id) self._disk.write(victim) del self._hot[victim_id] self._frequency.pop(victim_id, None) self._quantized.pop(victim_id, None) @staticmethod def _move_kv(kv: PastKVType, device: str) -> PastKVType: return tuple((layer[0].to(device), layer[1].to(device)) for layer in kv)