SplitterForGeneration¶
SplitterForGeneration is a BaseSplitter specialization for causal language models such as
*ForCausalLM and *LMHeadModel Hugging Face models. It keeps the split-point API explicit while
providing a simpler concept workflow for generation models than ModelWithSplitPoints.
When to Use¶
Use SplitterForGeneration when:
- Your model is a Hugging Face causal language model.
- You want token-level activations at a single split point.
- You want flattened token activations for concept fitting or sample-wise token activations for interpretation.
- You do not need word, sentence, sample, or custom granularity aggregation.
Use ModelWithSplitPoints instead when you need the full ActivationGranularity API, including word,
sentence, or sample aggregation.
Token Selection¶
By default, get_activations filters out padding and special tokens before returning activations. Pass
include_special_tokens=True to keep special tokens while still removing padding.
The return shape depends on flatten_activations:
flatten_activations=Truereturns one tensor of shape(n_tokens, hidden_dim).flatten_activations=Falsereturns one tensor per input, each with shape(n_tokens_i, hidden_dim).
Quick Example¶
from interpreto import SplitterForGeneration
splitter = SplitterForGeneration(
"gpt2",
split_point=10,
batch_size=8,
device_map="auto",
)
# Flattened token activations, suitable for concept fitting.
activations, _ = splitter.get_activations(texts, tqdm_bar=True)
# Sample-wise activations, useful when token alignment matters downstream.
sample_activations, _ = splitter.get_activations(
texts,
flatten_activations=False,
)
Concept Gradients¶
For concept-to-output gradients, SplitterForGeneration reintegrates decoded concept activations at the
selected split point, then differentiates generation logits with respect to the concept activations. The
returned gradients are sample-wise tensors of shape (n_targets, n_tokens_i, n_concepts).
API Reference¶
interpreto.SplitterForGeneration
¶
SplitterForGeneration(model_or_repo_id, split_point, *, automodel=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)
Bases: BaseSplitter
A BaseSplitter specialization for causal language models (generation).
Wraps a ForCausalLM model, splits it at a user-specified layer, and
provides activation extraction with simple token-level granularity.
Compared to ModelWithSplitPoints this class:
- Only supports two activation modes: include_special_tokens=True/False.
- Does not depend on interpreto.commons.granularity.Granularity.
- Uses tokenizer.all_special_ids directly for special-token filtering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str | PreTrainedModel
|
A HuggingFace model ID or a pre-loaded CausalLM instance. |
required |
|
str | int
|
The split location inside the model. |
required |
|
PreTrainedTokenizer | PreTrainedTokenizerFast | None
|
Tokenizer. Required when providing a model instance. |
None
|
|
PretrainedConfig | None
|
Model configuration. |
None
|
|
int
|
Batch size for batched operations. |
1
|
|
device | str | None
|
Device on which to load the model. |
None
|
|
Additional keyword arguments forwarded to NNsight. |
{}
|
Example
Raises:
| Type | Description |
|---|---|
TypeError
|
If |
Source code in interpreto/concepts/splitters/splitter_for_generation.py
get_activations
¶
get_activations(inputs, include_special_tokens=False, flatten_activations=True, tqdm_bar=False, forward_kwargs={}, **kwargs)
Extract per-token activations at the split point for a list of text inputs.
Iterates over inputs in batches and delegates each batch to
inputs_to_activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
list[str]
|
Raw text inputs. |
required |
|
bool
|
If True, return all token activations
(including special tokens but not padding). If False (default),
filter out special tokens using |
False
|
|
bool
|
If True (default), flatten the activations. Into a single tensor (n*g, d). Where g varies if all tokens are included or not. If False, returns a list of sample-wise activations. |
True
|
|
bool
|
Whether to display a progress bar. |
False
|
|
dict[str, Any]
|
Additional kwargs for the model forward pass. |
{}
|
|
Unused, kept for API compatibility. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
activations |
list[LatentActivations] | LatentActivations
|
list[LatentActivations]: A list of tensors (one per sample, shape |
predictions |
None
|
|
Source code in interpreto/concepts/splitters/splitter_for_generation.py
get_latent_shape
¶
Get the shape of the latent activations at the split point.
Uses a short real trace instead of NNsight's scan. Some causal LMs, such as Qwen3, run RoPE autocast checks that reject the fake/meta device used by scan.
Returns:
| Type | Description |
|---|---|
Size
|
torch.Size: Shape of the activations at the split point (typically |