SplitterForClassification¶
SplitterForClassification is a specialized version of BaseSplitter designed for
*ForSequenceClassification HuggingFace models. It simplifies the setup by automatically identifying the
classification head as the split point and the granularity as the [CLS] token.
When to Use¶
Use SplitterForClassification instead of ModelWithSplitPoints when:
- Your model is a Hugging Face
*ForSequenceClassificationmodel. - You want to extract CLS-token activations without manually specifying a split point.
- You want a cleaner, faster concept pipeline for classification tasks.
Additional Gain¶
It unlocks the inputs-to-concepts attributions workflow, which is not possible with ModelWithSplitPoints.
Quick Example¶
from interpreto import SplitterForClassification
splitter = SplitterForClassification(
"nateraw/bert-base-uncased-emotion",
batch_size=32,
device_map="cuda",
)
# Compute activations and predictions on a dataset
activations, predictions = splitter.get_activations(texts, tqdm_bar=True)
API Reference¶
interpreto.SplitterForClassification
¶
SplitterForClassification(model_or_repo_id, split_point=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)
Bases: BaseSplitter
A BaseSplitter specialization for sequence classification models.
Provides optimized implementations of activation extraction and concept gradient computation by exploiting the known structure of classification models: a backbone followed by a single classification head.
The split point is always the classification head, and activations are the CLS-token representations fed into that head.
The wrapper loads a sequence classification model and automatically identifies its classification head as the split point. This simplifies the concept pipeline for classification models by removing the need to manually specify split points and forcing the granularity to be the [CLS] token.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str | PreTrainedModel
|
A Hugging Face model ID or a pre-loaded
|
required |
|
str | None
|
Name of the classification head module.
If None, auto-detected by searching for common names ( |
None
|
|
PreTrainedTokenizer | PreTrainedTokenizerFast | None
|
The tokenizer associated with the model. If None, it is loaded from the model repo. |
None
|
|
PretrainedConfig | None
|
Model configuration. If None, loaded automatically. |
None
|
|
int
|
Batch size for activation extraction and gradient computation. |
1
|
|
device | str | None
|
Device on which to load the model
(e.g., |
None
|
|
Additional keyword arguments forwarded to |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Example
Source code in interpreto/concepts/splitters/splitter_for_classification.py
inputs_to_activations
¶
Compute latent activations (CLS-token representations) from raw inputs.
Runs the model backbone up to the classification head and extracts the input representation that would be fed to the classifier.
This method does does not include batching, it is meant to be called by other methods/classes.
In particular, it is used by the ModelForInputsToConcepts forward,
which is batched in the InputsToConceptsInferenceWrapper.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
list[str] | Tensor | BatchEncoding | dict[str, Tensor] | None
|
Raw model inputs. Can be a list of strings, a tensor of input IDs, a BatchEncoding, or a dictionary of tensors. |
None
|
|
Additional keyword arguments forwarded to the trace context
(e.g., |
{}
|
Returns:
| Type | Description |
|---|---|
Float[Tensor, 'n d']
|
Float[torch.Tensor, "n d"]: The CLS-token activations of shape
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If both |
Source code in interpreto/concepts/splitters/splitter_for_classification.py
activations_to_outputs
¶
activations_to_outputs(activations)
Compute classification logits from latent activations.
As activations correspond to the inputs of the classification head. This method just passes the activations through the classification head to obtain output logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Float[Tensor, 'n d']
|
Latent activations of shape
|
required |
Returns:
| Type | Description |
|---|---|
Float[Tensor, 'n cls']
|
Float[torch.Tensor, "n cls"]: Classification logits of shape
|
Source code in interpreto/concepts/splitters/splitter_for_classification.py
get_activations
¶
get_activations(inputs, tqdm_bar=False, forward_kwargs={}, **kwargs)
Extract CLS-token activations and predictions for a dataset of inputs.
Iterates over the inputs in batches, extracting the activations at the classification head input and the model predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
list[str] | Int[Tensor, 'n l']
|
Raw text inputs or tokenized input IDs. |
required |
|
bool
|
Whether to display a progress bar. |
False
|
|
dict[str, Any]
|
Additional keyword arguments for
the model forward pass (e.g., |
{}
|
|
Unused, kept for API compatibility with |
{}
|
Returns:
| Type | Description |
|---|---|
tuple[LatentActivations, Tensor]
|
tuple[LatentActivations, torch.Tensor]: The activations tensor of shape
|
Source code in interpreto/concepts/splitters/splitter_for_classification.py
get_latent_shape
¶
Get the shape of the latent activations.
Uses a quick trace with a dummy input to determine the classifier input shape.
Returns:
| Type | Description |
|---|---|
Size
|
torch.Size: Shape of the activations at the classification head input. |