Source code for kvboost.engine

"""
InferenceEngine (exported as KVBoost)
=====================================
Ties together:
  model / tokenizer
  KVCacheManager
  ChunkRegistry
  PromptAssembler
  SelectiveRecompute

Exposes three generation modes for benchmarking:
  BASELINE        — standard HF generate, no caching
  PREFIX_CACHE    — exact prefix caching only (control)
  CHUNK_KV_REUSE  — full chunk-level KV reuse + selective recompute

Usage
-----
    from kvboost import KVBoost

    engine = KVBoost.from_pretrained("Qwen/Qwen2.5-3B")
    engine.warm("You are a helpful assistant.")
    result = engine.generate("You are a helpful assistant.\n\nHello!")
    print(result.output_text)
"""

from __future__ import annotations

import enum
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import DynamicCache

from .models import AssembledPrompt, CachedChunk, PastKVType, WarmResult, content_hash_from_tokens, chained_hash
from .cache_manager import KVCacheManager
from .chunk_registry import ChunkRegistry, ChunkStrategy
from .prompt_assembler import AssemblyMode, PromptAssembler
from .selective_recompute import SelectiveRecompute
from .cacheblend import CacheBlendRecompute
from .compat import check_model_compatibility, SUPPORTED_ARCHITECTURES

log = logging.getLogger(__name__)


[docs] class GenerationMode(str, enum.Enum): BASELINE = "baseline" PREFIX_CACHE = "prefix_cache" CHUNK_KV_REUSE = "chunk_kv_reuse"
[docs] class RecomputeStrategy(str, enum.Enum): SELECTIVE = "selective" # fix last R tokens at each seam (original) CACHEBLEND = "cacheblend" # fix top-k% most deviated tokens (smarter) NONE = "none" # no recompute — fastest, slight quality risk
[docs] @dataclass class GenerationResult: mode: str prompt: str output_text: str generated_tokens: int ttft_ms: float # time-to-first-token total_ms: float # end-to-end tokens_per_sec: float kv_reuse_ratio: float # fraction of prompt tokens served from cache prompt_tokens: int cached_tokens: int
[docs] class InferenceEngine: def __init__( self, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, chunk_size: int = 128, max_chunks: int = 128, recompute_overlap: int = 16, recompute_strategy: RecomputeStrategy = RecomputeStrategy.SELECTIVE, recompute_ratio: float = 0.15, kv_cache_bits: int = 16, disk_cache_dir: Optional[str] = None, device: Optional[str] = None, ): if device is None: if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" self.model = model.to(device) self.tokenizer = tokenizer self.device = device self.recompute_strategy = RecomputeStrategy(recompute_strategy) # Sub-systems (CPU storage for cache tensors, move to device on use) self.cache_manager = KVCacheManager( max_chunks=max_chunks, disk_dir=disk_cache_dir, device="cpu", kv_cache_bits=kv_cache_bits, ) self.chunk_registry = ChunkRegistry( chunk_size=chunk_size, strategy=ChunkStrategy.FIXED, ) self.assembler = PromptAssembler( cache_manager=self.cache_manager, chunk_registry=self.chunk_registry, mode=AssemblyMode.CHUNK_REUSE, ) self.selective_recompute = SelectiveRecompute( recompute_overlap=recompute_overlap, skip_if_no_seams=True, device="cpu", ) self.cacheblend_recompute = CacheBlendRecompute( recompute_ratio=recompute_ratio, device="cpu", ) # ------------------------------------------------------------------ # Factory # ------------------------------------------------------------------
[docs] @classmethod def from_pretrained( cls, model_name: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0", strict: bool = True, **kwargs, ) -> "InferenceEngine": """ Load a HuggingFace model and create a KVBoost engine. Args: model_name: Any HF decoder-only causal LM (must use RoPE). strict: If True (default), raise on unsupported architectures and warn on untested ones. Set False to skip checks. **kwargs: Passed to InferenceEngine.__init__(). """ log.info("Loading model %s ...", model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, ) model.eval() check_model_compatibility(model, strict=strict) return cls(model=model, tokenizer=tokenizer, **kwargs)
# ------------------------------------------------------------------ # Public generate API # ------------------------------------------------------------------
[docs] def generate( self, prompt: str, max_new_tokens: int = 64, mode: GenerationMode = GenerationMode.CHUNK_KV_REUSE, temperature: float = 1.0, do_sample: bool = False, ) -> GenerationResult: token_ids = self._encode(prompt) if mode == GenerationMode.BASELINE: return self._generate_baseline(prompt, token_ids, max_new_tokens, temperature, do_sample) elif mode == GenerationMode.PREFIX_CACHE: return self._generate_prefix_cache(prompt, token_ids, max_new_tokens, temperature, do_sample) elif mode == GenerationMode.CHUNK_KV_REUSE: return self._generate_chunk_reuse(prompt, token_ids, max_new_tokens, temperature, do_sample) raise ValueError(f"Unknown mode {mode}")
[docs] def generate_batch( self, prompts: List[str], max_new_tokens: int = 64, temperature: float = 1.0, do_sample: bool = False, ) -> List[GenerationResult]: """ Generate responses for multiple prompts sharing a common prefix. Loads shared prefix KV once, runs batched prefill and decode. Args: prompts: List of prompts (should share a common prefix for best results). max_new_tokens: Max tokens to generate per prompt. temperature: Sampling temperature. do_sample: Greedy (False) or sampling (True). Returns: List of GenerationResult, one per prompt. """ from .batch import ( find_common_chunk_prefix, broadcast_kv, pad_and_mask, batched_decode, ) if len(prompts) == 1: return [self.generate(prompts[0], max_new_tokens, temperature=temperature, do_sample=do_sample)] t0 = time.perf_counter() batch_size = len(prompts) # Tokenize all prompts all_token_ids = [self._encode(p) for p in prompts] # Find shared chunk-aligned prefix common_len = find_common_chunk_prefix(all_token_ids, self.chunk_registry.chunk_size) # Load shared prefix KV from cache shared_kv = None if common_len > 0: assembled = self.assembler.assemble(all_token_ids[0][:common_len + 1]) shared_kv = assembled.cached_past_kv common_len = assembled.cached_length # Collect suffix token IDs (non-shared tail of each prompt) suffix_ids_list = [ids[common_len:] for ids in all_token_ids] # Pad suffixes and build attention masks pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id padded_suffixes, attn_masks = pad_and_mask(suffix_ids_list, pad_id) max_suffix_len = max(len(s) for s in suffix_ids_list) # Build batched input tensors suffix_input = torch.tensor(padded_suffixes, dtype=torch.long, device=self.device) pos_ids = torch.arange( common_len, common_len + max_suffix_len, dtype=torch.long, device=self.device, ).unsqueeze(0).expand(batch_size, -1) # Broadcast shared KV across batch (zero-copy expand) batched_past = None if shared_kv is not None: shared_kv_device = tuple( (k.to(self.device), v.to(self.device)) for k, v in shared_kv ) batched_past = broadcast_kv(shared_kv_device, batch_size) # Batched prefill with torch.no_grad(): out = self.model( input_ids=suffix_input, past_key_values=self._as_cache(batched_past), position_ids=pos_ids, use_cache=True, ) first_token_time = time.perf_counter() past_kv = self._normalize_past_kv(out.past_key_values) # Sample first token per sequence (using each sequence's last real token logits) first_tokens = [] for b in range(batch_size): real_len = len(suffix_ids_list[b]) logits_b = out.logits[b, real_len - 1, :].unsqueeze(0) tok = self._sample(logits_b, temperature, do_sample) first_tokens.append(tok) first_tokens_t = torch.tensor(first_tokens, dtype=torch.long, device=self.device) # Batched decode generated_ids, _ = batched_decode( model=self.model, past_kv=past_kv, first_tokens=first_tokens_t, start_pos=common_len + max_suffix_len, max_new_tokens=max_new_tokens, eos_token_id=self.tokenizer.eos_token_id, temperature=temperature, do_sample=do_sample, device=self.device, ) t1 = time.perf_counter() # Store prompt chunks for future reuse for ids in all_token_ids: self._store_prompt_chunks(ids) # Build results results = [] ttft = (first_token_time - t0) * 1000 total_ms = (t1 - t0) * 1000 hit_ratio = common_len / max(max(len(ids) for ids in all_token_ids), 1) for b in range(batch_size): output_text = self.tokenizer.decode(generated_ids[b], skip_special_tokens=True) results.append(GenerationResult( mode="chunk_kv_reuse_batch", prompt=prompts[b], output_text=output_text, generated_tokens=len(generated_ids[b]), ttft_ms=ttft, total_ms=total_ms, tokens_per_sec=len(generated_ids[b]) / max(t1 - t0, 1e-6), kv_reuse_ratio=hit_ratio, prompt_tokens=len(all_token_ids[b]), cached_tokens=common_len, )) return results
[docs] def generate_many( self, prompts: List[str], max_new_tokens: int = 64, temperature: float = 1.0, do_sample: bool = False, ) -> List[GenerationResult]: """ Like generate_batch(), but auto-groups prompts by shared prefix. Prompts without shared prefixes are processed individually. Args: prompts: List of prompts (may or may not share prefixes). max_new_tokens: Max tokens to generate per prompt. Returns: List of GenerationResult in the same order as input prompts. """ from .batch import group_by_prefix all_token_ids = [self._encode(p) for p in prompts] groups = group_by_prefix( prompts, all_token_ids, self.chunk_registry.chunk_size, ) results: List[Optional[GenerationResult]] = [None] * len(prompts) for group_indices in groups.values(): group_prompts = [prompts[i] for i in group_indices] if len(group_prompts) == 1: group_results = [self.generate( group_prompts[0], max_new_tokens, temperature=temperature, do_sample=do_sample, )] else: group_results = self.generate_batch( group_prompts, max_new_tokens, temperature=temperature, do_sample=do_sample, ) for idx, result in zip(group_indices, group_results): results[idx] = result return results
# ------------------------------------------------------------------ # Cache population helper # ------------------------------------------------------------------
[docs] def warm(self, text: str, position_offset: int = 0) -> WarmResult: """ Encode `text` and cache all its fixed-size chunks. Returns a WarmResult with diagnostics including alignment warnings. The result is truthy (usable as int) via chunks_stored. Call this for your system prompt / few-shot examples / documents BEFORE calling generate() so the cache is already populated. """ token_ids = self._encode(text) chunks_added = 0 pos = position_offset parent_hash = None for start, end, slice_ids in self.chunk_registry.split(token_ids, text): p_hash = chained_hash(slice_ids, parent_hash) c_hash = content_hash_from_tokens(slice_ids) if self.cache_manager.get(p_hash) is not None: parent_hash = p_hash pos += len(slice_ids) continue kv = self._encode_to_kv(slice_ids, position_offset=pos) chunk = CachedChunk( chunk_id=p_hash, text=self.tokenizer.decode(slice_ids), token_ids=slice_ids, past_key_values=kv, position_start=pos, position_end=pos + len(slice_ids), prefix_hash=p_hash, content_hash=c_hash, ) self.cache_manager.store(chunk) parent_hash = p_hash pos += len(slice_ids) chunks_added += 1 # Build diagnostic chunk_size = self.chunk_registry.chunk_size n_tokens = len(token_ids) partial_tail = n_tokens % chunk_size aligned = partial_tail == 0 or partial_tail < self.chunk_registry.min_chunk_tokens warning = None if not aligned: warning = ( f"Prompt length {n_tokens} tokens is not a multiple of " f"chunk_size {chunk_size}. The last {partial_tail} tokens " f"will not be cached and must be recomputed on every " f"generate() call." ) log.warning("warm(): %s", warning) return WarmResult( chunks_stored=chunks_added, token_count=n_tokens, chunk_size=chunk_size, chunk_boundary_aligned=aligned, partial_tail_tokens=partial_tail, alignment_warning=warning, )
# Keep old name as alias warm_chunks = warm # ------------------------------------------------------------------ # Generation implementations # ------------------------------------------------------------------ def _generate_baseline( self, prompt: str, token_ids: List[int], max_new_tokens: int, temperature: float, do_sample: bool, ) -> GenerationResult: input_ids = torch.tensor([token_ids], dtype=torch.long, device=self.device) t0 = time.perf_counter() first_token_time = None generated = [] with torch.no_grad(): past = None cur_ids = input_ids for step in range(max_new_tokens): out = self.model(input_ids=cur_ids, past_key_values=self._as_cache(past), use_cache=True) if first_token_time is None: first_token_time = time.perf_counter() # Normalize: newer transformers returns DynamicCache, not plain tuple past = self._normalize_past_kv(out.past_key_values) next_token = self._sample(out.logits[:, -1, :], temperature, do_sample) generated.append(next_token) if next_token == self.tokenizer.eos_token_id: break cur_ids = torch.tensor([[next_token]], dtype=torch.long, device=self.device) t1 = time.perf_counter() output_text = self.tokenizer.decode(generated, skip_special_tokens=True) ttft = (first_token_time - t0) * 1000 if first_token_time else 0 total_ms = (t1 - t0) * 1000 tps = len(generated) / max((t1 - t0), 1e-6) return GenerationResult( mode="baseline", prompt=prompt, output_text=output_text, generated_tokens=len(generated), ttft_ms=ttft, total_ms=total_ms, tokens_per_sec=tps, kv_reuse_ratio=0.0, prompt_tokens=len(token_ids), cached_tokens=0, ) def _generate_prefix_cache( self, prompt: str, token_ids: List[int], max_new_tokens: int, temperature: float, do_sample: bool, ) -> GenerationResult: """Standard prefix caching: reuse contiguous leading chunks.""" merged_kv, covered = self.cache_manager.build_prefix_kv( token_ids, self.chunk_registry.chunk_size ) live_ids = token_ids[covered:] return self._decode_with_kv( prompt, token_ids, merged_kv, covered, live_ids, max_new_tokens, temperature, do_sample, mode_name="prefix_cache" ) def _generate_chunk_reuse( self, prompt: str, token_ids: List[int], max_new_tokens: int, temperature: float, do_sample: bool, ) -> GenerationResult: """Full chunk-level KV reuse + recompute (strategy-dependent).""" assembled = self.assembler.assemble(token_ids) # Apply recompute strategy when multiple chunks are stitched if len(assembled.chunk_boundaries) > 1: if assembled.has_approximate: # Approximate matches (content-only key) have wrong position # encodings and/or wrong preceding context — always use # CacheBlend to fix the full KV, not just boundaries log.debug("Approximate match detected — forcing CacheBlend recompute") assembled = self.cacheblend_recompute.apply(assembled, self.model) elif self.recompute_strategy == RecomputeStrategy.SELECTIVE: assembled = self.selective_recompute.apply(assembled, self.model) elif self.recompute_strategy == RecomputeStrategy.CACHEBLEND: assembled = self.cacheblend_recompute.apply(assembled, self.model) # NONE: skip recompute entirely return self._decode_with_kv( prompt, token_ids, assembled.cached_past_kv, assembled.cached_length, assembled.live_token_ids, max_new_tokens, temperature, do_sample, mode_name="chunk_kv_reuse", hit_ratio=assembled.cache_hit_ratio, ) # ------------------------------------------------------------------ # Shared decode loop # ------------------------------------------------------------------ def _decode_with_kv( self, prompt: str, full_token_ids: List[int], past_kv: Optional[PastKVType], cached_len: int, live_ids: List[int], max_new_tokens: int, temperature: float, do_sample: bool, mode_name: str, hit_ratio: Optional[float] = None, ) -> GenerationResult: t0 = time.perf_counter() first_token_time = None generated = [] # Move past_kv to model device if past_kv is not None: past_kv = tuple( (layer[0].to(self.device), layer[1].to(self.device)) for layer in past_kv ) # ----- encode live tokens (prompt tail) ------------------------- if live_ids: input_ids = torch.tensor([live_ids], dtype=torch.long, device=self.device) pos_ids = torch.arange( cached_len, cached_len + len(live_ids), dtype=torch.long, device=self.device, ).unsqueeze(0) with torch.no_grad(): out = self.model( input_ids=input_ids, past_key_values=self._as_cache(past_kv), position_ids=pos_ids, use_cache=True, ) first_token_time = time.perf_counter() past_kv = self._normalize_past_kv(out.past_key_values) next_token = self._sample(out.logits[:, -1, :], temperature, do_sample) generated.append(next_token) if next_token == self.tokenizer.eos_token_id: pass # let loop handle else: # All tokens were cached — run a dummy forward to get first logits # by feeding the last cached token again at its position last_id = full_token_ids[-1] if full_token_ids else 0 input_ids = torch.tensor([[last_id]], dtype=torch.long, device=self.device) pos_ids = torch.tensor([[cached_len - 1]], dtype=torch.long, device=self.device) # Trim past_kv to exclude last position so re-encoding is valid trimmed_kv: Optional[PastKVType] = None if past_kv is not None and KVCacheManager.kv_seq_len(past_kv) > 1: trimmed_kv = KVCacheManager.slice_kv(past_kv, 0, cached_len - 1) trimmed_kv = tuple( (layer[0].to(self.device), layer[1].to(self.device)) for layer in trimmed_kv ) with torch.no_grad(): out = self.model( input_ids=input_ids, past_key_values=self._as_cache(trimmed_kv), position_ids=pos_ids, use_cache=True, ) first_token_time = time.perf_counter() past_kv = self._normalize_past_kv(out.past_key_values) next_token = self._sample(out.logits[:, -1, :], temperature, do_sample) generated.append(next_token) # ----- autoregressive decode ------------------------------------ cur_pos = cached_len + len(live_ids) while len(generated) < max_new_tokens: if generated[-1] == self.tokenizer.eos_token_id: break cur_ids = torch.tensor([[generated[-1]]], dtype=torch.long, device=self.device) pos_ids = torch.tensor([[cur_pos]], dtype=torch.long, device=self.device) with torch.no_grad(): out = self.model( input_ids=cur_ids, past_key_values=self._as_cache(past_kv), position_ids=pos_ids, use_cache=True, ) past_kv = self._normalize_past_kv(out.past_key_values) next_token = self._sample(out.logits[:, -1, :], temperature, do_sample) generated.append(next_token) cur_pos += 1 t1 = time.perf_counter() # ----- store newly computed chunks into cache ------------------- self._store_prompt_chunks(full_token_ids) output_text = self.tokenizer.decode(generated, skip_special_tokens=True) ttft = (first_token_time - t0) * 1000 if first_token_time else 0 total_ms = (t1 - t0) * 1000 tps = len(generated) / max(t1 - t0, 1e-6) actual_hit = hit_ratio if hit_ratio is not None else (cached_len / max(len(full_token_ids), 1)) return GenerationResult( mode=mode_name, prompt=prompt, output_text=output_text, generated_tokens=len(generated), ttft_ms=ttft, total_ms=total_ms, tokens_per_sec=tps, kv_reuse_ratio=actual_hit, prompt_tokens=len(full_token_ids), cached_tokens=cached_len, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _encode(self, text: str) -> List[int]: return self.tokenizer.encode(text, add_special_tokens=True) @staticmethod def _as_cache(past_kv): """Convert tuple-of-tuples KV to DynamicCache for newer transformers.""" if past_kv is None or hasattr(past_kv, "get_seq_length"): return past_kv cache = DynamicCache() for layer_k, layer_v in past_kv: cache.update(layer_k, layer_v, len(cache)) return cache @staticmethod def _normalize_past_kv(past_key_values) -> PastKVType: """ Normalize past_key_values → tuple[ (key_Tensor, val_Tensor), ... ] one entry per layer, each tensor shape [batch, heads, seq, head_dim]. Handles: • transformers < 4.38 : plain tuple of (k, v) tuples • transformers 4.38–4.44: DynamicCache with .to_legacy_cache() • transformers ≥ 4.45 : DynamicCache with .key_cache / .value_cache """ if past_key_values is None: return None if hasattr(past_key_values, "get_seq_length"): return past_key_values if hasattr(past_key_values, "to_legacy_cache"): legacy = past_key_values.to_legacy_cache() return tuple((layer[0], layer[1]) for layer in legacy) return tuple((layer[0], layer[1]) for layer in past_key_values) def _encode_to_kv( self, token_ids: List[int], position_offset: int = 0 ) -> PastKVType: """Run a forward pass and return only the KV cache (on CPU).""" input_ids = torch.tensor([token_ids], dtype=torch.long, device=self.device) pos_ids = torch.arange( position_offset, position_offset + len(token_ids), dtype=torch.long, device=self.device, ).unsqueeze(0) with torch.no_grad(): out = self.model( input_ids=input_ids, position_ids=pos_ids, use_cache=True, ) kv = out.past_key_values # Extract (k, v) tuples for CPU storage if hasattr(kv, "layers"): return tuple((l.keys.cpu(), l.values.cpu()) for l in kv.layers) if hasattr(kv, "key_cache") and hasattr(kv, "value_cache"): return tuple((k.cpu(), v.cpu()) for k, v in zip(kv.key_cache, kv.value_cache)) return tuple((layer[0].cpu(), layer[1].cpu()) for layer in kv) def _store_prompt_chunks(self, token_ids: List[int]) -> None: """Cache all un-cached fixed-size chunks from this prompt.""" pos = 0 parent_hash = None for start, end, slice_ids in self.chunk_registry.split(token_ids): p_hash = chained_hash(slice_ids, parent_hash) c_hash = content_hash_from_tokens(slice_ids) if self.cache_manager.get(p_hash) is None: kv = self._encode_to_kv(slice_ids, position_offset=pos) chunk = CachedChunk( chunk_id=p_hash, text=self.tokenizer.decode(slice_ids), token_ids=slice_ids, past_key_values=kv, position_start=pos, position_end=pos + len(slice_ids), prefix_hash=p_hash, content_hash=c_hash, ) self.cache_manager.store(chunk) parent_hash = p_hash pos += len(slice_ids) @staticmethod def _sample(logits: "torch.Tensor", temperature: float, do_sample: bool) -> int: if temperature != 1.0: logits = logits / temperature if do_sample: probs = torch.softmax(logits, dim=-1) return torch.multinomial(probs, 1).item() return logits.argmax(dim=-1).item()
[docs] def cache_stats(self) -> Dict: return self.cache_manager.stats()
[docs] def verify_correctness(self, max_new_tokens: int = 32) -> bool: """ Quick self-test: runs greedy decode on a synthetic prompt with both BASELINE and CHUNK_KV_REUSE, verifies identical output. Returns True if outputs match, False otherwise. Use this to validate untested model architectures before trusting cached outputs in production. """ test_prefix = ( "The following is a factual statement about mathematics. " "Two plus two equals four. Three times three equals nine. " "The square root of sixteen is four. Pi is approximately " "three point one four one five nine. Euler's number e is " "approximately two point seven one eight." ) test_query = "\n\nQuestion: What is two plus two?\nAnswer:" prompt = test_prefix + test_query # Warm the prefix self.warm(test_prefix) # Run both modes with greedy decoding r_base = self.generate( prompt, max_new_tokens=max_new_tokens, mode=GenerationMode.BASELINE, do_sample=False, ) r_cached = self.generate( prompt, max_new_tokens=max_new_tokens, mode=GenerationMode.CHUNK_KV_REUSE, do_sample=False, ) match = r_base.output_text == r_cached.output_text arch = type(self.model).__name__ if match: log.info( "verify_correctness PASSED for %s — " "baseline and cached outputs are identical", arch, ) else: log.warning( "verify_correctness FAILED for %s — " "outputs differ!\n baseline: %r\n cached: %r", arch, r_base.output_text[:100], r_cached.output_text[:100], ) return match