Skip to content

SplitterForGeneration

SplitterForGeneration is a BaseSplitter specialization for causal language models such as *ForCausalLM and *LMHeadModel Hugging Face models. It keeps the split-point API explicit while providing a simpler concept workflow for generation models than ModelWithSplitPoints.

When to Use

Use SplitterForGeneration when:

  • Your model is a Hugging Face causal language model.
  • You want token-level activations at a single split point.
  • You want flattened token activations for concept fitting or sample-wise token activations for interpretation.
  • You do not need word, sentence, sample, or custom granularity aggregation.

Use ModelWithSplitPoints instead when you need the full ActivationGranularity API, including word, sentence, or sample aggregation.

Token Selection

By default, get_activations filters out padding and special tokens before returning activations. Pass include_special_tokens=True to keep special tokens while still removing padding.

The return shape depends on flatten_activations:

  • flatten_activations=True returns one tensor of shape (n_tokens, hidden_dim).
  • flatten_activations=False returns one tensor per input, each with shape (n_tokens_i, hidden_dim).

Quick Example

from interpreto import SplitterForGeneration

splitter = SplitterForGeneration(
    "gpt2",
    split_point=10,
    batch_size=8,
    device_map="auto",
)

# Flattened token activations, suitable for concept fitting.
activations, _ = splitter.get_activations(texts, tqdm_bar=True)

# Sample-wise activations, useful when token alignment matters downstream.
sample_activations, _ = splitter.get_activations(
    texts,
    flatten_activations=False,
)

Concept Gradients

For concept-to-output gradients, SplitterForGeneration reintegrates decoded concept activations at the selected split point, then differentiates generation logits with respect to the concept activations. The returned gradients are sample-wise tensors of shape (n_targets, n_tokens_i, n_concepts).

API Reference

interpreto.SplitterForGeneration

SplitterForGeneration(model_or_repo_id, split_point, *, automodel=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)

Bases: BaseSplitter

A BaseSplitter specialization for causal language models (generation).

Wraps a ForCausalLM model, splits it at a user-specified layer, and provides activation extraction with simple token-level granularity.

Compared to ModelWithSplitPoints this class: - Only supports two activation modes: include_special_tokens=True/False. - Does not depend on interpreto.commons.granularity.Granularity. - Uses tokenizer.all_special_ids directly for special-token filtering.

Parameters:

Name Type Description Default

model_or_repo_id

str | PreTrainedModel

A HuggingFace model ID or a pre-loaded CausalLM instance.

required

split_point

str | int

The split location inside the model.

required

tokenizer

PreTrainedTokenizer | PreTrainedTokenizerFast | None

Tokenizer. Required when providing a model instance.

None

config

PretrainedConfig | None

Model configuration.

None

batch_size

int

Batch size for batched operations.

1

device_map

device | str | None

Device on which to load the model.

None

**kwargs

Additional keyword arguments forwarded to NNsight.

{}
Example
from interpreto import SplitterForGeneration

splitter = SplitterForGeneration(
    "gpt2",
    split_point=10,
    batch_size=8,
    device_map="auto",
)
activations, _ = splitter.get_activations(
    ["Hello world!", "Interpreto is magic"],
)

Raises:

Type Description
TypeError

If model_or_repo_id is a PreTrainedModel that is not a CausalLM.

Source code in interpreto/concepts/splitters/splitter_for_generation.py
def __init__(
    self,
    model_or_repo_id: str | PreTrainedModel,
    split_point: str | int,
    *,
    automodel: type[AutoModel] | None = None,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
    config: PretrainedConfig | None = None,
    batch_size: int = 1,
    device_map: torch.device | str | None = None,
    **kwargs,
):
    """Initialize a SplitterForGeneration model wrapper.

    Raises:
        TypeError: If ``model_or_repo_id`` is a PreTrainedModel that is not a CausalLM.
    """
    if isinstance(model_or_repo_id, PreTrainedModel):
        class_name = model_or_repo_id.__class__.__name__
        if "ForCausalLM" not in class_name and "LMHeadModel" not in class_name:
            raise TypeError(
                "The provided model is not a causal language model. "
                "Please provide a model that inherits from `transformers.*ForCausalLM` "
                "or `*LMHeadModel`."
            )

    super().__init__(
        model_or_repo_id,
        split_point,
        config=config,
        tokenizer=tokenizer,
        automodel=automodel if automodel is not None else AutoModelForCausalLM,  # type: ignore
        batch_size=batch_size,
        device_map=device_map,
        **kwargs,
    )

    # Ensure a pad token is available
    self.tokenizer.pad_token = self.tokenizer.eos_token

get_activations

get_activations(inputs, include_special_tokens=False, flatten_activations=True, tqdm_bar=False, forward_kwargs={}, **kwargs)

Extract per-token activations at the split point for a list of text inputs.

Iterates over inputs in batches and delegates each batch to inputs_to_activations.

Parameters:

Name Type Description Default

inputs

list[str]

Raw text inputs.

required

include_special_tokens

bool

If True, return all token activations (including special tokens but not padding). If False (default), filter out special tokens using tokenizer.all_special_ids.

False

flatten_activations

bool

If True (default), flatten the activations. Into a single tensor (n*g, d). Where g varies if all tokens are included or not. If False, returns a list of sample-wise activations.

True

tqdm_bar

bool

Whether to display a progress bar.

False

forward_kwargs

dict[str, Any]

Additional kwargs for the model forward pass.

{}

**kwargs

Unused, kept for API compatibility.

{}

Returns:

Name Type Description
activations list[LatentActivations] | LatentActivations

list[LatentActivations]: A list of tensors (one per sample, shape (l_i, d)) and LatentActivations: A single tensor (n*g, d) if flatten_activations=True.

predictions None

None (placeholder, no predicted classes for generation models).

Source code in interpreto/concepts/splitters/splitter_for_generation.py
def get_activations(
    self,
    inputs: list[str],
    include_special_tokens: bool = False,
    flatten_activations: bool = True,
    tqdm_bar: bool = False,
    forward_kwargs: dict[str, Any] = {},
    **kwargs,
) -> tuple[list[LatentActivations] | LatentActivations, None]:
    """Extract per-token activations at the split point for a list of text inputs.

    Iterates over inputs in batches and delegates each batch to
    ``inputs_to_activations``.

    Args:
        inputs (list[str]): Raw text inputs.
        include_special_tokens (bool): If True, return all token activations
            (including special tokens but not padding).  If False (default),
            filter out special tokens using ``tokenizer.all_special_ids``.
        flatten_activations (bool): If True (default), flatten the activations.
            Into a single tensor (n*g, d). Where g varies if all tokens are included or not.
            If False, returns a list of sample-wise activations.
        tqdm_bar (bool): Whether to display a progress bar.
        forward_kwargs (dict[str, Any]): Additional kwargs for the model forward pass.
        **kwargs: Unused, kept for API compatibility.

    Returns:
        activations (list[LatentActivations] | LatentActivations):
            list[LatentActivations]: A list of tensors (one per sample, shape ``(l_i, d)``) and
            LatentActivations: A single tensor (n*g, d) if ``flatten_activations=True``.
        predictions (None): ``None`` (placeholder, no predicted classes for generation models).
    """
    n_batches = ceil(len(inputs) / self.batch_size)
    batch_iter = tqdm(
        range(0, len(inputs), self.batch_size),
        desc="Computing activations",
        unit="batch",
        total=n_batches,
        disable=not tqdm_bar,
    )

    sp_module = self.get(self._split_point)
    output_name = "nns_output" if hasattr(sp_module, "nns_output") else "output"

    all_activations: list[LatentActivations] = []

    with torch.no_grad():
        for start in batch_iter:
            batch_texts = inputs[start : min(start + self.batch_size, len(inputs))]

            # extract non-special tokens mask
            tokenized, tokens_mask = self._tokenize_and_get_mask(batch_texts, include_special_tokens)

            # forward till the split point
            with self.trace(tokenized, **forward_kwargs) as tracer:
                batch_outputs = getattr(sp_module, output_name).save()
                tracer.stop()

            batch_acts: Float[torch.Tensor, "n l d"] = self._manage_output_tuple(batch_outputs, self._split_point)

            # filter out special tokens and expose public activations as float32.
            for acts, mask in zip(batch_acts, tokens_mask, strict=True):
                all_activations.append(acts[mask].detach().to(device="cpu", dtype=torch.float32, copy=True))

    torch.cuda.empty_cache()
    gc.collect()

    if flatten_activations:
        return torch.cat(all_activations, dim=0), None

    return all_activations, None

get_latent_shape

get_latent_shape()

Get the shape of the latent activations at the split point.

Uses a short real trace instead of NNsight's scan. Some causal LMs, such as Qwen3, run RoPE autocast checks that reject the fake/meta device used by scan.

Returns:

Type Description
Size

torch.Size: Shape of the activations at the split point (typically (1, l, d)).

Source code in interpreto/concepts/splitters/splitter_for_generation.py
def get_latent_shape(self) -> torch.Size:
    """Get the shape of the latent activations at the split point.

    Uses a short real trace instead of NNsight's scan. Some causal LMs, such as
    Qwen3, run RoPE autocast checks that reject the fake/meta device used by
    scan.

    Returns:
        torch.Size: Shape of the activations at the split point (typically ``(1, l, d)``).
    """
    with self.trace("scan") as tracer:
        curr_module = self.get(self._split_point)
        module_out_name = "nns_output" if hasattr(curr_module, "nns_output") else "output"
        module = getattr(curr_module, module_out_name)
        if isinstance(module, tuple):
            for candidate in module:
                if candidate.dim() == 3:
                    module = candidate
                    break
        shape = module.shape.save()  # type: ignore
        tracer.stop()
    return shape