Source code for kvboost.compat

"""
Model Compatibility
===================
KVBoost's position-ID stitching assumes RoPE positional encoding where
positions are passed explicitly via position_ids and the model correctly
handles non-contiguous sequences.

This module tracks which HuggingFace architectures are known-safe, which
are known-broken, and provides runtime validation for untested models.

Known-broken architectures:
  - ALiBi models (MPT, Falcon): positional bias added directly to attention
    scores based on distance — no position_ids injection possible
  - Learned absolute embeddings (GPT-2): position info baked into token
    representations at the input layer, irreversible by KV stitching
  - Sliding window attention (Mistral with sliding_window != None): KV from
    outside the window includes tokens that should have been masked
"""

from __future__ import annotations

import logging
import warnings
from typing import Optional

log = logging.getLogger(__name__)

# Architectures verified to work correctly with KV cache stitching
SUPPORTED_ARCHITECTURES = {
    "LlamaForCausalLM",        # RoPE
    "Qwen2ForCausalLM",        # RoPE
    "Qwen2_5ForCausalLM",      # RoPE
    "GemmaForCausalLM",        # RoPE
    "Gemma2ForCausalLM",       # RoPE
    "MistralForCausalLM",      # RoPE (safe only with full attention)
    "PhiForCausalLM",          # RoPE
    "Phi3ForCausalLM",         # RoPE
    "StableLmForCausalLM",     # RoPE
    "InternLMForCausalLM",     # RoPE
    "InternLM2ForCausalLM",    # RoPE
}

# Architectures known to be incompatible, with reasons
UNSUPPORTED_ARCHITECTURES = {
    "GPT2LMHeadModel": (
        "GPT-2 uses learned absolute positional embeddings. Position info is "
        "baked into token representations at the embedding layer — KV cache "
        "stitching cannot correct for position mismatches."
    ),
    "GPTNeoForCausalLM": (
        "GPT-Neo uses learned absolute positional embeddings."
    ),
    "GPTNeoXForCausalLM": (
        "GPT-NeoX uses rotary embeddings but the HF implementation does not "
        "accept position_ids — KV stitching may produce incorrect positions."
    ),
    "MPTForCausalLM": (
        "MPT uses ALiBi positional encoding. Positional bias is added directly "
        "to attention scores based on token distance — there is no position_ids "
        "tensor to inject, so KV cache stitching cannot produce correct positions."
    ),
    "FalconForCausalLM": (
        "Falcon uses ALiBi positional encoding."
    ),
    "BloomForCausalLM": (
        "BLOOM uses ALiBi positional encoding."
    ),
}


[docs] def check_model_compatibility(model, strict: bool = True) -> None: """ Validate that a model's architecture is compatible with KV cache stitching. Args: model: A HuggingFace CausalLM model instance. strict: If True (default), raise ValueError for unsupported models and warn for untested ones. If False, only warn. Raises: ValueError: If the model architecture is known to be incompatible. """ arch = type(model).__name__ # Check for known-broken architectures if arch in UNSUPPORTED_ARCHITECTURES: reason = UNSUPPORTED_ARCHITECTURES[arch] msg = f"KVBoost does not support {arch}: {reason}" if strict: raise ValueError(msg) warnings.warn(msg, stacklevel=3) return # Check for Mistral sliding window if arch == "MistralForCausalLM": sliding_window = getattr(model.config, "sliding_window", None) if sliding_window is not None: msg = ( f"MistralForCausalLM with sliding_window={sliding_window} is not " f"supported. KV cache stitching breaks the sliding window mask " f"assumption — tokens outside the window that should be invisible " f"will be included in the stitched KV." ) if strict: raise ValueError(msg) warnings.warn(msg, stacklevel=3) return # Warn for untested architectures if arch not in SUPPORTED_ARCHITECTURES: warnings.warn( f"KVBoost has not been tested with {arch}. Output correctness is " f"not guaranteed. Run engine.verify_correctness() to validate " f"before trusting cached outputs. Pass strict=False to " f"from_pretrained() to suppress this warning.", stacklevel=3, )