Skip to content

SplitterForClassification

SplitterForClassification is a specialized version of BaseSplitter designed for *ForSequenceClassification HuggingFace models. It simplifies the setup by automatically identifying the classification head as the split point and the granularity as the [CLS] token.

When to Use

Use SplitterForClassification instead of ModelWithSplitPoints when:

  • Your model is a Hugging Face *ForSequenceClassification model.
  • You want to extract CLS-token activations without manually specifying a split point.
  • You want a cleaner, faster concept pipeline for classification tasks.

Additional Gain

It unlocks the inputs-to-concepts attributions workflow, which is not possible with ModelWithSplitPoints.

Quick Example

from interpreto import SplitterForClassification

splitter = SplitterForClassification(
    "nateraw/bert-base-uncased-emotion",
    batch_size=32,
    device_map="cuda",
)

# Compute activations and predictions on a dataset
activations, predictions = splitter.get_activations(texts, tqdm_bar=True)

API Reference

interpreto.SplitterForClassification

SplitterForClassification(model_or_repo_id, split_point=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)

Bases: BaseSplitter

A BaseSplitter specialization for sequence classification models.

Provides optimized implementations of activation extraction and concept gradient computation by exploiting the known structure of classification models: a backbone followed by a single classification head.

The split point is always the classification head, and activations are the CLS-token representations fed into that head.

The wrapper loads a sequence classification model and automatically identifies its classification head as the split point. This simplifies the concept pipeline for classification models by removing the need to manually specify split points and forcing the granularity to be the [CLS] token.

Parameters:

Name Type Description Default

model_or_repo_id

str | PreTrainedModel

A Hugging Face model ID or a pre-loaded PreTrainedModel instance. Must be a sequence classification model.

required

split_point

str | None

Name of the classification head module. If None, auto-detected by searching for common names ("classifier", "classification_head", "score"). For most models, one can trust the auto-detection. Nonetheless, there is a difference with other splitter, Here we use the input of the split_point not its output.

None

tokenizer

PreTrainedTokenizer | PreTrainedTokenizerFast | None

The tokenizer associated with the model. If None, it is loaded from the model repo.

None

config

PretrainedConfig | None

Model configuration. If None, loaded automatically.

None

batch_size

int

Batch size for activation extraction and gradient computation.

1

device_map

device | str | None

Device on which to load the model (e.g., "cuda" or "cpu").

None

**kwargs

Additional keyword arguments forwarded to BaseSplitter.

{}

Raises:

Type Description
ValueError

If model_or_repo_id is a PreTrainedModel that is not a sequence classification model.

Example
from interpreto import SplitterForClassification

splitter = SplitterForClassification(
    "nateraw/bert-base-uncased-emotion",
    batch_size=32,
    device_map="cuda",
)
Source code in interpreto/concepts/splitters/splitter_for_classification.py
def __init__(
    self,
    model_or_repo_id: str | PreTrainedModel,
    split_point: str | None = None,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
    config: PretrainedConfig | None = None,
    batch_size: int = 1,
    device_map: torch.device | str | None = None,
    **kwargs,
):
    """Initialize a SplitterForClassification model wrapper.

    The wrapper loads a sequence classification model and automatically identifies
    its classification head as the split point. This simplifies the concept pipeline
    for classification models by removing the need to manually specify split points and
    forcing the granularity to be the [CLS] token.

    Args:
        model_or_repo_id (str | PreTrainedModel): A Hugging Face model ID or a pre-loaded
            ``PreTrainedModel`` instance. Must be a sequence classification model.
        split_point (str | None): Name of the classification head module.
            If None, auto-detected by searching for common names (``"classifier"``,
            ``"classification_head"``, ``"score"``).
            For most models, one can trust the auto-detection.
            Nonetheless, there is a difference with other splitter,
            Here we use the input of the split_point not its output.
        tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast | None): The tokenizer
            associated with the model. If None, it is loaded from the model repo.
        config (PretrainedConfig | None): Model configuration. If None, loaded automatically.
        batch_size (int): Batch size for activation extraction and gradient computation.
        device_map (torch.device | str | None): Device on which to load the model
            (e.g., ``"cuda"`` or ``"cpu"``).
        **kwargs: Additional keyword arguments forwarded to ``BaseSplitter``.

    Raises:
        ValueError: If ``model_or_repo_id`` is a PreTrainedModel that is not a
            sequence classification model.

    Example:
        ```python
        from interpreto import SplitterForClassification

        splitter = SplitterForClassification(
            "nateraw/bert-base-uncased-emotion",
            batch_size=32,
            device_map="cuda",
        )
        ```
    """
    if isinstance(model_or_repo_id, PreTrainedModel):
        if "ForSequenceClassification" not in model_or_repo_id.__class__.__name__:
            raise ValueError(
                "The provided model is not a sequence classification model. "
                "Please provide a model that inherits from `transformers.ForSequenceClassification`."
            )

    # Pass a placeholder split_point; our overridden setter skips walk_modules validation.
    # The real split point is resolved after super().__init__() loads the model,
    # because the split_point setter needs access to self._model.
    super().__init__(
        model_or_repo_id,
        split_point=split_point,
        config=config,
        tokenizer=tokenizer,
        automodel=AutoModelForSequenceClassification,  # type: ignore
        batch_size=batch_size,
        device_map=device_map,
        **kwargs,
    )

split_point property writable

split_point

inputs_to_activations

inputs_to_activations(inputs=None, **kwargs)

Compute latent activations (CLS-token representations) from raw inputs.

Runs the model backbone up to the classification head and extracts the input representation that would be fed to the classifier.

This method does does not include batching, it is meant to be called by other methods/classes. In particular, it is used by the ModelForInputsToConcepts forward, which is batched in the InputsToConceptsInferenceWrapper.

Parameters:

Name Type Description Default

inputs

list[str] | Tensor | BatchEncoding | dict[str, Tensor] | None

Raw model inputs. Can be a list of strings, a tensor of input IDs, a BatchEncoding, or a dictionary of tensors.

None

**kwargs

Additional keyword arguments forwarded to the trace context (e.g., truncation=True).

{}

Returns:

Type Description
Float[Tensor, 'n d']

Float[torch.Tensor, "n d"]: The CLS-token activations of shape (n_samples, hidden_dim).

Raises:

Type Description
ValueError

If both inputs and kwargs are empty.

Source code in interpreto/concepts/splitters/splitter_for_classification.py
def inputs_to_activations(
    self, inputs: list[str] | torch.Tensor | BatchEncoding | dict[str, torch.Tensor] | None = None, **kwargs
) -> Float[torch.Tensor, "n d"]:
    """Compute latent activations (CLS-token representations) from raw inputs.

    Runs the model backbone up to the classification head and extracts the
    input representation that would be fed to the classifier.

    This method does does not include batching, it is meant to be called by other methods/classes.
    In particular, it is used by the ``ModelForInputsToConcepts`` forward,
    which is batched in the ``InputsToConceptsInferenceWrapper``.

    Args:
        inputs (list[str] | torch.Tensor | BatchEncoding | dict[str, torch.Tensor] | None):
            Raw model inputs. Can be a list of strings, a tensor of input IDs,
            a BatchEncoding, or a dictionary of tensors.
        **kwargs: Additional keyword arguments forwarded to the trace context
            (e.g., ``truncation=True``).

    Returns:
        Float[torch.Tensor, "n d"]: The CLS-token activations of shape
            ``(n_samples, hidden_dim)``.

    Raises:
        ValueError: If both ``inputs`` and ``kwargs`` are empty.
    """
    if inputs is None and len(kwargs) == 0:
        raise ValueError("Either inputs or kwargs must be provided.")

    with self.trace(inputs, **kwargs) as tracer:
        activations = getattr(self, self.split_point).input.save()
        tracer.stop()  # we only needed the CLS token, no need to complete the forward pass

    # force two dimensions
    if activations.ndim == 3:
        activations = self.__extract_cls_token(activations)
    return activations

activations_to_outputs

activations_to_outputs(activations)

Compute classification logits from latent activations.

As activations correspond to the inputs of the classification head. This method just passes the activations through the classification head to obtain output logits.

Parameters:

Name Type Description Default

activations

Float[Tensor, 'n d']

Latent activations of shape (n_samples, hidden_dim).

required

Returns:

Type Description
Float[Tensor, 'n cls']

Float[torch.Tensor, "n cls"]: Classification logits of shape (n_samples, n_classes).

Source code in interpreto/concepts/splitters/splitter_for_classification.py
def activations_to_outputs(
    self,
    activations: Float[torch.Tensor, "n d"],
) -> Float[torch.Tensor, "n cls"]:
    """Compute classification logits from latent activations.

    As activations correspond to the inputs of the classification head.
    This method just passes the activations through the classification head to obtain
    output logits.

    Args:
        activations (Float[torch.Tensor, "n d"]): Latent activations of shape
            ``(n_samples, hidden_dim)``.

    Returns:
        Float[torch.Tensor, "n cls"]: Classification logits of shape
            ``(n_samples, n_classes)``.
    """
    return getattr(self, self.split_point)(activations).logits

get_activations

get_activations(inputs, tqdm_bar=False, forward_kwargs={}, **kwargs)

Extract CLS-token activations and predictions for a dataset of inputs.

Iterates over the inputs in batches, extracting the activations at the classification head input and the model predictions.

Parameters:

Name Type Description Default

inputs

list[str] | Int[Tensor, 'n l']

Raw text inputs or tokenized input IDs.

required

tqdm_bar

bool

Whether to display a progress bar.

False

forward_kwargs

dict[str, Any]

Additional keyword arguments for the model forward pass (e.g., {"truncation": True}).

{}

**kwargs

Unused, kept for API compatibility with ModelWithSplitPoints.

{}

Returns:

Type Description
tuple[LatentActivations, Tensor]

tuple[LatentActivations, torch.Tensor]: The activations tensor of shape (n_samples, hidden_dim) and predicted class indices of shape (n_samples,).

Source code in interpreto/concepts/splitters/splitter_for_classification.py
def get_activations(
    self,
    inputs: list[str] | Int[torch.Tensor, "n l"],
    tqdm_bar: bool = False,
    forward_kwargs: dict[str, Any] = {},
    **kwargs,
) -> tuple[LatentActivations, torch.Tensor]:
    """Extract CLS-token activations and predictions for a dataset of inputs.

    Iterates over the inputs in batches, extracting the activations at the
    classification head input and the model predictions.

    Args:
        inputs (list[str] | Int[torch.Tensor, "n l"]): Raw text inputs or
            tokenized input IDs.
        tqdm_bar (bool): Whether to display a progress bar.
        forward_kwargs (dict[str, Any]): Additional keyword arguments for
            the model forward pass (e.g., ``{"truncation": True}``).
        **kwargs: Unused, kept for API compatibility with ``ModelWithSplitPoints``.

    Returns:
        tuple[LatentActivations, torch.Tensor]: The activations tensor of shape
            ``(n_samples, hidden_dim)`` and predicted class indices of shape ``(n_samples,)``.
    """
    activations = []
    predictions = []
    classification_head = getattr(self, self.split_point)

    self._model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(inputs), self.batch_size), disable=not tqdm_bar):
            # extract and prepare a batch of inputs
            end_idx = min(i + self.batch_size, len(inputs))
            batch = inputs[i:end_idx]
            if isinstance(batch, torch.Tensor):
                batch = {"input_ids": batch}

            # get activations and predictions for the batch
            with self.trace(batch, **forward_kwargs):
                batch_activations = classification_head.input.save()
                batch_predictions = self.output.logits.argmax(dim=-1).save()  # type: ignore

            # force two dimensions
            if batch_activations.ndim == 3:
                batch_activations = self.__extract_cls_token(batch_activations)

            # Materialize outside the trace. This is necessary to avoid memory leaks.
            activations.append(batch_activations.detach().cpu().clone())
            predictions.append(batch_predictions.detach().cpu().clone())

            del batch, batch_activations, batch_predictions

    activations = torch.cat(activations, dim=0)
    predictions = torch.cat(predictions, dim=0)

    # free memory
    torch.cuda.empty_cache()
    gc.collect()
    return activations, predictions

get_latent_shape

get_latent_shape()

Get the shape of the latent activations.

Uses a quick trace with a dummy input to determine the classifier input shape.

Returns:

Type Description
Size

torch.Size: Shape of the activations at the classification head input.

Source code in interpreto/concepts/splitters/splitter_for_classification.py
def get_latent_shape(self) -> torch.Size:
    """Get the shape of the latent activations.

    Uses a quick trace with a dummy input to determine the classifier input shape.

    Returns:
        torch.Size: Shape of the activations at the classification head input.
    """
    with self.trace("scan") as tracer:
        shape = getattr(self, self.split_point).input.shape.save()
        tracer.stop()
    return shape