Source code for kvboost.compat
"""
Model Compatibility
===================
KVBoost's position-ID stitching assumes RoPE positional encoding where
positions are passed explicitly via position_ids and the model correctly
handles non-contiguous sequences.
This module tracks which HuggingFace architectures are known-safe, which
are known-broken, and provides runtime validation for untested models.
Known-broken architectures:
- ALiBi models (MPT, Falcon): positional bias added directly to attention
scores based on distance — no position_ids injection possible
- Learned absolute embeddings (GPT-2): position info baked into token
representations at the input layer, irreversible by KV stitching
- Sliding window attention (Mistral with sliding_window != None): KV from
outside the window includes tokens that should have been masked
"""
from __future__ import annotations
import logging
import warnings
from typing import Optional
log = logging.getLogger(__name__)
# Architectures verified to work correctly with KV cache stitching
SUPPORTED_ARCHITECTURES = {
"LlamaForCausalLM", # RoPE
"Qwen2ForCausalLM", # RoPE
"Qwen2_5ForCausalLM", # RoPE
"GemmaForCausalLM", # RoPE
"Gemma2ForCausalLM", # RoPE
"MistralForCausalLM", # RoPE (safe only with full attention)
"PhiForCausalLM", # RoPE
"Phi3ForCausalLM", # RoPE
"StableLmForCausalLM", # RoPE
"InternLMForCausalLM", # RoPE
"InternLM2ForCausalLM", # RoPE
}
# Architectures known to be incompatible, with reasons
UNSUPPORTED_ARCHITECTURES = {
"GPT2LMHeadModel": (
"GPT-2 uses learned absolute positional embeddings. Position info is "
"baked into token representations at the embedding layer — KV cache "
"stitching cannot correct for position mismatches."
),
"GPTNeoForCausalLM": (
"GPT-Neo uses learned absolute positional embeddings."
),
"GPTNeoXForCausalLM": (
"GPT-NeoX uses rotary embeddings but the HF implementation does not "
"accept position_ids — KV stitching may produce incorrect positions."
),
"MPTForCausalLM": (
"MPT uses ALiBi positional encoding. Positional bias is added directly "
"to attention scores based on token distance — there is no position_ids "
"tensor to inject, so KV cache stitching cannot produce correct positions."
),
"FalconForCausalLM": (
"Falcon uses ALiBi positional encoding."
),
"BloomForCausalLM": (
"BLOOM uses ALiBi positional encoding."
),
}
[docs]
def check_model_compatibility(model, strict: bool = True) -> None:
"""
Validate that a model's architecture is compatible with KV cache stitching.
Args:
model: A HuggingFace CausalLM model instance.
strict: If True (default), raise ValueError for unsupported models
and warn for untested ones. If False, only warn.
Raises:
ValueError: If the model architecture is known to be incompatible.
"""
arch = type(model).__name__
# Check for known-broken architectures
if arch in UNSUPPORTED_ARCHITECTURES:
reason = UNSUPPORTED_ARCHITECTURES[arch]
msg = f"KVBoost does not support {arch}: {reason}"
if strict:
raise ValueError(msg)
warnings.warn(msg, stacklevel=3)
return
# Check for Mistral sliding window
if arch == "MistralForCausalLM":
sliding_window = getattr(model.config, "sliding_window", None)
if sliding_window is not None:
msg = (
f"MistralForCausalLM with sliding_window={sliding_window} is not "
f"supported. KV cache stitching breaks the sliding window mask "
f"assumption — tokens outside the window that should be invisible "
f"will be included in the stitched KV."
)
if strict:
raise ValueError(msg)
warnings.warn(msg, stacklevel=3)
return
# Warn for untested architectures
if arch not in SUPPORTED_ARCHITECTURES:
warnings.warn(
f"KVBoost has not been tested with {arch}. Output correctness is "
f"not guaranteed. Run engine.verify_correctness() to validate "
f"before trusting cached outputs. Pass strict=False to "
f"from_pretrained() to suppress this warning.",
stacklevel=3,
)