Skip to content

ModelWithSplitPoints

ModelWithSplitPoints now uses the singular split_point argument/property because only one split point is supported. The previous split_points argument/property remains temporarily available as a deprecated compatibility alias and emits a DeprecationWarning guiding users to split_point. It will be removed in version 1.0.0.

Specificities

In comparison to SplitterForClassification and SplitterForGeneration, the ModelWithSplitPoints class is more versatile. It is more complex to use, but it covers the two other splitter cases.

In both get_activations and _get_concept_output_gradients, one needs to specify:

  • activation_granularity: specifies which of the (n, l, d) activations to return. It can be one of CLS_TOKEN, ALL_TOKENS, TOKEN, WORD, SENTENCE, or SAMPLE. Use activation_granularity=ModelWithSplitPoints.activation_granularities.TOKEN to specify it.
  • aggregation_strategy: how activations should be aggregated (only for WORD, SENTENCE, or SAMPLE). It can be one of SUM, MEAN, MAX, or SIGNED_MAX. Use aggregation_strategy=ModelWithSplitPoints.aggregation_strategies.MEAN to specify it.

interpreto.ModelWithSplitPoints

ModelWithSplitPoints(model_or_repo_id, split_point=None, *args, split_points=None, automodel=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)

Bases: BaseSplitter

Code: concepts.splitters/model_with_split_points.py`

The ModelWithSplitPoints is a wrapper around your HuggingFace model. Its goal is to allow you to split your model at specified locations and extract activations.

It is one of the key component of the Concept-Based Explainers framework in Interpreto. Indeed, any Interpreto concept explainer is built around a ModelWithSplitPoints object. Because, splitting the model is the first step of the concept-based explanation process.

It is based on the LanguageModel class from NNsight and inherits its functionalities. In a sense, the LanguageModel class is a wrapper around the HuggingFace model. The ModelWithSplitPoints class is a wrapper around the LanguageModel class.

We often shorten the ModelWithSplitPoints class as MWSP and instances as mwsp.

Parameters:

Name Type Description Default

model_or_repo_id

str | PreTrainedModel

One of:

  • A str corresponding to the ID of the model that should be loaded from the HF Hub.
  • A str corresponding to the local path of a folder containing a compatible checkpoint.
  • A preloaded transformers.PreTrainedModel object. If a string is provided, a automodel should also be provided.
required

split_point

str | int

The split location inside the model. Either one of the following:

  • A str corresponding to the path of a split point inside the model.
  • An int corresponding to the n-th layer.

Example: split_point='cls.predictions.transform.LayerNorm' correspond to a split after the LayerNorm layer in the MLM head (assuming a BertForMaskedLM model in input).

None

split_points

(str | int | list[str] | list[int], deprecated)

Backward-compatible alias for split_point. If a list/tuple is provided, only the first element is used.

None

automodel

type[AutoModel]

Huggingface AutoClass corresponding to the desired type of model (e.g. AutoModelForSequenceClassification).

⚠ automodel must be defined if model_or_repo_id is str, since the the model class cannot be known otherwise.

None

config

PretrainedConfig

Custom configuration for the loaded model. If not specified, it will be instantiated with the default configuration for the model.

None

tokenizer

PreTrainedTokenizer | PreTrainedTokenizerFast | None

Custom tokenizer for the loaded model. If not specified, it will be instantiated with the default tokenizer for the model.

⚠ If model_or_repo_id is a transformers.PreTrainedModel object, then tokenizer must be defined.

None

batch_size

int

Batch size for the model.

1

device_map

device | str | None

Device map for the model. Directly passed to the model.

None

Attributes:

Name Type Description
activation_granularities ActivationGranularity

Enumeration of the available granularities for the get_activations method.

aggregation_strategies GranularityAggregationStrategy

Enumeration of the available aggregation strategies for the get_activations method.

automodel type[AutoModel]

The AutoClass corresponding to the loaded model type.

batch_size int

Batch size for the model.

output_tuple_index int | None

If the output at the split point is a tuple, this is the index of the hidden state. If None, an element with 3 dimensions is searched for. If not found, an error is raised. If several elements are found, an error is raised.

repo_id str

Either the model id in the HF Hub, or the path from which the model was loaded.

tokenizer PreTrainedTokenizer

Tokenizer for the loaded model, either given by the user or loaded from the repo_id.

_model PreTrainedModel

Huggingface transformers model wrapped by NNSight.

Examples:

Minimal example with gpt2:

>>> from transformers import AutoModelForCausalLM
>>> from interpreto import ModelWithSplitPoints
>>> model_with_split_points = ModelWithSplitPoints(
...     "gpt2",
...     split_point=10,  # split at the 10th layer
...     automodel=AutoModelForCausalLM,
...     device_map="auto",
... )
>>> activations, _ = model_with_split_points.get_activations(
...     inputs="interpreto is magic",
...     activation_granularity=ModelWithSplitPoints.activation_granularities.TOKEN,  # highly recommended for generation
... )

Load the model from its repository id, split it at the first layer, and get the raw activations for the first layer.

>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> # load and split the model
>>> model_with_split_points = ModelWithSplitPoints(
...     "bert-base-uncased",
...     split_point="bert.encoder.layer.1.output",
...     automodel=AutoModelForSequenceClassification,
...     batch_size=64,
...     device_map="cuda" if torch.cuda.is_available() else "cpu",
... )
>>> # get activations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations, _ = model_with_split_points.get_activations(
...     dataset,
...     activation_granularity=ModelWithSplitPoints.activation_granularities.CLS_TOKEN,  # highly recommended for classification
... )

Load the model then pass it the ModelWithSplitPoint, split it at the first layer, get the word activations for the tenth layer, skip special tokens, and aggregate tokens activations by mean into words.

>>> from transformers import AutoModelCausalLM, AutoTokenizer
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints as MWSP
>>> # load the model
>>> model = AutoModelCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> # wrap and split the model at the 10th layer
>>> model_with_split_points = MWSP(
...     model,
...     tokenizer=tokenizer,
...     split_point=10,  # split at the 10th layer
...     batch_size=16,
...     device_map="auto",
... )
>>> # get activations at the word granularity
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations, _ = model_with_split_points.get_activations(
...     dataset,
...     activation_granularity=MWSP.activation_granularities.WORD,
...     aggregation_strategy=MWSP.aggregation_strategies.MEAN,  # average tokens activations by words
... )

Most of the work is forwarded to the BaseSplitter class initialization. Which is in turn a wrapper around the nnsight.LanguageModel class.

Raises:

Type Description
InitializationError(ValueError)

If the model cannot be loaded, because of a missing tokenizer or automodel.

ValueError

If the device_map is set to 'auto' and the model is not a generation model.

TypeError

If the model_or_repo_id is not a str or a transformers.PreTrainedModel.

Source code in interpreto/concepts/splitters/model_with_split_points.py
def __init__(
    self,
    model_or_repo_id: str | PreTrainedModel,
    split_point: str | int | list[str] | list[int] | tuple[str, ...] | tuple[int, ...] | None = None,
    *args: tuple[Any],
    split_points: str | int | list[str] | list[int] | tuple[str, ...] | tuple[int, ...] | None = None,
    automodel: type[AutoModel] | None = None,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
    config: PretrainedConfig | None = None,
    batch_size: int = 1,
    device_map: torch.device | str | None = None,
    **kwargs,
) -> None:
    # For parameters list, see class docstring. It was moved to change the order in the documentation.
    """Initialize a ModelWithSplitPoints object.

    Most of the work is forwarded to the `BaseSplitter` class initialization.
    Which is in turn a wrapper around the `nnsight.LanguageModel` class.

    Raises:
        InitializationError (ValueError): If the model cannot be loaded, because of a missing `tokenizer` or `automodel`.
        ValueError: If the `device_map` is set to 'auto' and the model is not a generation model.
        TypeError: If the `model_or_repo_id` is not a `str` or a `transformers.PreTrainedModel`.
    """
    # Handle deprecated `split_points` parameter
    if split_point is not None and split_points is not None:
        raise ValueError("Specify only one of `split_point` or deprecated `split_points`.")
    if split_points is not None:
        split_point = self._deprecated_split_points_to_split_point(split_points)
    elif isinstance(split_point, list | tuple):
        split_point = self._deprecated_split_points_to_split_point(split_point)
    if split_point is None:
        raise TypeError("Missing required argument `split_point`.")

    # Delegate to BaseSplitter (handles validation, loading, split point, device, tokenizer)
    super().__init__(
        model_or_repo_id,
        split_point,
        *args,
        automodel=automodel,
        tokenizer=tokenizer,
        config=config,
        batch_size=batch_size,
        device_map=device_map,
        **kwargs,
    )

activation_granularities class-attribute instance-attribute

activation_granularities = ActivationGranularity

aggregation_strategies class-attribute instance-attribute

aggregation_strategies = GranularityAggregationStrategy

get_activations

Get intermediate activations for the model split point on the given inputs.

Optionally include the model predictions in the returned tuple.

Parameters:

Name Type Description Default

inputs list[str] | torch.Tensor | BatchEncoding

Inputs to the model forward pass before or after tokenization. In the case of a torch.Tensor, we assume a batch dimension and token ids.

required

activation_granularity

ActivationGranularity

Selection strategy for activations. In the model, activations have the shape (n, l, d), where d is the model dimension. This parameters specifies which elements of these tensors are selected. If the granularity is larger then tokens, i.e. words and sentences, the activations are aggregated. The parameter aggregation_strategy specifies how the activations are aggregated.

It is highly recommended to use CLS_TOKEN for classification tasks and TOKEN for other tasks.

Available options are:

  • ModelWithSplitPoints.activation_granularities.ALL_TOKENS: the raw activations are flattened (n x l, d). Hence, each token activation is now considered as a separate element. This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc.

  • ModelWithSplitPoints.activation_granularities.CLS_TOKEN: for each sample, only the first token (e.g. [CLS]) activation is returned (n, d). This will raise an error if the model is not ForSequenceClassification.

  • ModelWithSplitPoints.activation_granularities.SAMPLE: special tokens are removed and the remaining ones are aggregated on the whole sample (n, d).

  • ModelWithSplitPoints.activation_granularities.SENTENCE: special tokens are removed and the remaining ones are aggregate by sentences. Then the activations are flattened. (n x g, d) where g is the number of sentences in the input. The split is defined by interpreto.commons.granularity.Granularity.SENTENCE.

  • ModelWithSplitPoints.activation_granularities.TOKEN: the raw activations are flattened, but the special tokens are removed. (n x g, d) where g is the number of non-special tokens in the input. This is the default granularity.

  • ModelWithSplitPoints.activation_granularities.WORD: the special tokens are removed and the remaining ones are aggregate by words. Then the activations are flattened. (n x g, d) where g is the number of words in the input. The split is defined by interpreto.commons.granularity.Granularity.WORD.

required

aggregation_strategy

GranularityAggregationStrategy

Strategy to aggregate token activations into larger inputs granularities. Applied for WORD, SENTENCE and SAMPLE activation strategies. Token activations of shape n * (l, d) are aggregated on the sequence length dimension. The concatenated into (ng, d) tensors.

Existing strategies are:

  • ModelWithSplitPoints.aggregation_strategies.SUM: Tokens activations are summed along the sequence length dimension.

  • ModelWithSplitPoints.aggregation_strategies.MEAN: Tokens activations are averaged along the sequence length dimension.

  • ModelWithSplitPoints.aggregation_strategies.MAX: The maximum of the token activations along the sequence length dimension is selected.

  • ModelWithSplitPoints.aggregation_strategies.SIGNED_MAX: The maximum of the absolute value of the activations multiplied by its initial sign. signed_max([[-1, 0, 1, 2], [-3, 1, -2, 0]]) = [-3, 1, -2, 2]

MEAN

pad_side

str | None

'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy. Forced right for classification models and left for causal LMs.

None

tqdm_bar

bool

Whether to display a progress bar.

False

include_predicted_classes

bool

Whether to include the predicted classes in the output tuple. Only applicable for classification models.

False

flatten_activations

bool

Whether to flatten the activations tensors.

  • If True, the activations will be flattened from (n, l, d) to (n x l, d). It allows storing the activations for the split point in a single tensor.

  • If False, a list of sample-wise activations will be returned.

True

forward_kwargs

dict

Additional keyword arguments passed to the model forward pass.

{}

Returns:

Name Type Description
tuple[LatentActivations, Tensor | None] | tuple[list[LatentActivations], list[Tensor] | None]

activations (LatentActivations | [list[LatentActivations]: The extracted activations either in a sample-wise list are flattened.

predictions Tensor | list[Tensor] | None

The predicted classes, if requested.

Source code in interpreto/concepts/splitters/model_with_split_points.py
def get_activations(  # noqa: PLR0912  # ignore too many branches  # too many special cases
    self,
    inputs: list[str] | torch.Tensor | BatchEncoding,
    *,
    activation_granularity: ActivationGranularity,
    aggregation_strategy: GranularityAggregationStrategy = GranularityAggregationStrategy.MEAN,
    pad_side: str | None = None,
    tqdm_bar: bool = False,
    include_predicted_classes: bool = False,
    flatten_activations: bool = True,
    forward_kwargs: dict[str, Any] = {},
) -> tuple[LatentActivations, torch.Tensor | None] | tuple[list[LatentActivations], list[torch.Tensor] | None]:
    """

    Get intermediate activations for the model split point on the given `inputs`.

    Optionally include the model predictions in the returned tuple.

    Args:
        inputs list[str] | torch.Tensor | BatchEncoding:
            Inputs to the model forward pass before or after tokenization.
            In the case of a `torch.Tensor`, we assume a batch dimension and token ids.

        activation_granularity (ActivationGranularity):
            Selection strategy for activations.
            In the model, activations have the shape `(n, l, d)`, where `d` is the model dimension.
            This parameters specifies which elements of these tensors are selected.
            If the granularity is larger then tokens, i.e. words and sentences, the activations are aggregated.
            The parameter `aggregation_strategy` specifies how the activations are aggregated.

            **It is highly recommended to use `CLS_TOKEN` for classification tasks and `TOKEN` for other tasks.**

            Available options are:

            - ``ModelWithSplitPoints.activation_granularities.ALL_TOKENS``:
                the raw activations are flattened ``(n x l, d)``.
                Hence, each token activation is now considered as a separate element.
                This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc.

            - ``ModelWithSplitPoints.activation_granularities.CLS_TOKEN``:
                for each sample, only the first token (e.g. ``[CLS]``) activation is returned ``(n, d)``.
                This will raise an error if the model is not `ForSequenceClassification`.

            - ``ModelWithSplitPoints.activation_granularities.SAMPLE``:
                special tokens are removed and the remaining ones are aggregated on the whole sample ``(n, d)``.

            - ``ModelWithSplitPoints.activation_granularities.SENTENCE``:
                special tokens are removed and the remaining ones are aggregate by sentences.
                Then the activations are flattened.
                ``(n x g, d)`` where `g` is the number of sentences in the input.
                The split is defined by `interpreto.commons.granularity.Granularity.SENTENCE`.

            - ``ModelWithSplitPoints.activation_granularities.TOKEN``:
                the raw activations are flattened, but the special tokens are removed.
                ``(n x g, d)`` where `g` is the number of non-special tokens in the input.
                This is the default granularity.

            - ``ModelWithSplitPoints.activation_granularities.WORD``:
                the special tokens are removed and the remaining ones are aggregate by words.
                Then the activations are flattened.
                ``(n x g, d)`` where `g` is the number of words in the input.
                The split is defined by `interpreto.commons.granularity.Granularity.WORD`.

        aggregation_strategy (GranularityAggregationStrategy):
            Strategy to aggregate token activations into larger inputs granularities.
            Applied for `WORD`, `SENTENCE` and `SAMPLE` activation strategies.
            Token activations of shape  n * (l, d) are aggregated on the sequence length dimension.
            The concatenated into (ng, d) tensors.

            Existing strategies are:

            - ``ModelWithSplitPoints.aggregation_strategies.SUM``:
                Tokens activations are summed along the sequence length dimension.

            - ``ModelWithSplitPoints.aggregation_strategies.MEAN``:
                Tokens activations are averaged along the sequence length dimension.

            - ``ModelWithSplitPoints.aggregation_strategies.MAX``:
                The maximum of the token activations along the sequence length dimension is selected.

            - ``ModelWithSplitPoints.aggregation_strategies.SIGNED_MAX``:
                The maximum of the absolute value of the activations multiplied by its initial sign.
                signed_max([[-1, 0, 1, 2], [-3, 1, -2, 0]]) = [-3, 1, -2, 2]

        pad_side (str | None):
            'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy.
            Forced right for classification models and left for causal LMs.

        tqdm_bar (bool):
            Whether to display a progress bar.

        include_predicted_classes (bool):
            Whether to include the predicted classes in the output tuple.
            Only applicable for classification models.

        flatten_activations (bool):
            Whether to flatten the activations tensors.

            - If True, the activations will be flattened from (n, l, d) to (n x l, d).
                It allows storing the activations for the split point in a single tensor.

            - If False, a list of sample-wise activations will be returned.

        forward_kwargs (dict):
            Additional keyword arguments passed to the model forward pass.

    Returns:
        activations (LatentActivations | [list[LatentActivations]:
            The extracted activations either in a sample-wise list are flattened.
        predictions (torch.Tensor | list[torch.Tensor] | None):
            The predicted classes, if requested.
    """
    # set default pad side value and catch unsupported cases
    if self._model.__class__.__name__.endswith("ForSequenceClassification"):
        pad_side = "right"
    else:
        if self._model.__class__.__name__.endswith("ForCausalLM"):
            pad_side = "left"
        else:
            pad_side = pad_side or "left"
        if include_predicted_classes:
            raise ValueError(
                "`include_predicted_classes` is only supported for classification models. "
                f"Provided model is a {self._model.__class__.__name__}."
            )
    self.tokenizer.padding_side = pad_side

    # add padding token to vocabulary if not present (model and tokenizer)
    if not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None:
        self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        self.model.resize_token_embeddings(len(self.tokenizer))  # type: ignore  # weird huggingface typing

    # batch inputs
    if isinstance(inputs, BatchEncoding):
        batch_generator = []
        # manage key by key batching for BatchEncoding
        for i in range(0, len(inputs), self.batch_size):
            end_idx = min(i + self.batch_size, len(inputs))
            batch_generator.append({key: value[i:end_idx] for key, value in inputs.items()})
    elif isinstance(inputs, list | torch.Tensor):
        # create a generator for iterable of inputs and tensors
        batch_generator = (
            inputs[i : min(i + self.batch_size, len(inputs))] for i in range(0, len(inputs), self.batch_size)
        )
    else:
        raise TypeError(
            f"Invalid inputs type: {type(inputs)}. Expected: list[str] | torch.Tensor | BatchEncoding."
        )

    # wrap generator in tqdm for progress bar
    tqdm_wrapped_batch_generator = tqdm(
        batch_generator,
        desc="Computing activations",
        unit="batch",
        total=ceil(len(inputs) / self.batch_size),
        disable=not tqdm_bar,
    )

    # initialize activation and prediction storage
    activations: list[LatentActivations] = []
    predictions: list[torch.Tensor] = []

    sp_module = self.get(self._split_point)

    # iterate over batch of inputs
    with torch.no_grad():
        # several call of the same model should be grouped in an nnsight session
        for batch_inputs in tqdm_wrapped_batch_generator:
            # ------------------------------------------------------------------------------
            # prepare inputs and compute granular indices
            if isinstance(batch_inputs, list):
                # tokenize text inputs for granularity selection
                # include "offsets_mapping" for sentence selection strategy
                tokenized_inputs = self.tokenizer(
                    batch_inputs,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    return_offsets_mapping=True,
                )

                # special case for T5 in a generation setting
                if isinstance(self.args[0], T5ForConditionalGeneration):
                    # TODO: find a way for this not to be necessary
                    tokenized_inputs["decoder_input_ids"] = tokenized_inputs["input_ids"]
            else:
                # the input was already tokenized
                tokenized_inputs = batch_inputs

            # get granularity indices
            granularity_indices: list[list[list[int]]] = self._get_granularity_indices(
                tokenized_inputs, activation_granularity
            )

            # extract offset mapping not supported by forward but was necessary for sentence selection strategy
            if isinstance(tokenized_inputs, (BatchEncoding, dict)):  # noqa: UP038
                tokenized_inputs.pop("offset_mapping", None)

            # ------------------------------------------------------------------------------
            # model forward pass with nnsight to extract activations and predictions

            # all model calls use trace with nnsight
            # call model forward pass and save split point outputs
            with self.trace(tokenized_inputs, **forward_kwargs) as tracer:
                output_name = "nns_output" if hasattr(sp_module, "nns_output") else "output"
                batch_outputs = getattr(sp_module, output_name).save()

                # for classification optionally compute and save the predictions
                if include_predicted_classes:
                    batch_predictions: Float[torch.Tensor, "n"] = (
                        self.output.logits.argmax(dim=-1).cpu().save()  # type: ignore  (under specification from NNsight)
                    )
                else:
                    tracer.stop()

            # free memory after each batch, necessary with nnsight, overwise, memory piles up
            torch.cuda.empty_cache()

            # ------------------------------------------------------------------------------
            # apply granularity selection and aggregation of activations and predictions
            # manage the output tuple and extract the (n, l, d) activations from it
            batch_sp_activations: Float[torch.Tensor, "n l d"] = self._manage_output_tuple(
                batch_outputs, self._split_point
            )

            # select relevant activations with respect to the granularity strategy
            # potentially aggregate activations over the granularity elements
            # this merges the `n` and `g` dimensions with `g` a subset of `n`
            # shape (n, l, d) only for `ALL` granularity, thus raw activations
            granular_activations: list[Float[torch.Tensor, "g d"]] = self._apply_selection_strategy(
                activations=batch_sp_activations,
                granularity_indices=granularity_indices,
                activation_granularity=activation_granularity,
                aggregation_strategy=aggregation_strategy,
            )

            activations.extend(
                act.detach().to(device="cpu", dtype=torch.float32, copy=True) for act in granular_activations
            )

            if include_predicted_classes:
                if not flatten_activations:
                    predictions.extend(
                        list(batch_predictions)  # type: ignore  (ignore possibly unbound)
                    )
                else:
                    # adapt predictions to match the granularity indices
                    repeats: Float[torch.Tensor, "ng"] = torch.tensor(
                        [len(indices) for indices in granularity_indices]
                    )

                    # predictions have a shape (n,), which we convert to (ng,)
                    # by repeating each predicted class as many times as the number of granularity elements in a sample
                    repeated_predictions = torch.repeat_interleave(
                        batch_predictions,  # type: ignore  (ignore possibly unbound)
                        repeats,
                        dim=0,
                    )
                    predictions.append(repeated_predictions)

    # ------------------------------------------------------------------------------------------
    # concat activation batches and validate that activations have the expected type
    if flatten_activations:
        # two dimensional tensor (n*g, d)
        flattened_activations = torch.cat(activations, dim=0)

        if include_predicted_classes:
            return flattened_activations, torch.cat(predictions, dim=0)
        return flattened_activations, None

    # validate that activations have the expected type
    if not all(isinstance(act, torch.Tensor) for act in activations):
        raise RuntimeError("Invalid output. Expected a list of torch.Tensor activations.")

    if include_predicted_classes:
        return activations, predictions
    return activations, None

get_latent_shape

get_latent_shape()

Get the shape of the latent activations at the split point.

Use the scan operation from NNsight to get the shape of the activations. It basically builds the computation graph, but it is much quicker than a forward.

Returns:

Type Description
Size

torch.Size: Shape of the activations for the split point.

Source code in interpreto/concepts/splitters/model_with_split_points.py
def get_latent_shape(self) -> torch.Size:
    """Get the shape of the latent activations at the split point.

    Use the `scan` operation from NNsight to get the shape of the activations.
    It basically builds the computation graph, but it is much quicker than a forward.

    Returns:
        torch.Size: Shape of the activations for the split point.
    """
    shape = None
    with self.scan("scan"):
        curr_module = self.get(self._split_point)
        module_out_name = "nns_output" if hasattr(curr_module, "nns_output") else "output"
        module = getattr(curr_module, module_out_name)
        if isinstance(module, tuple):
            for candidate in module:
                if candidate.dim() == 3:
                    module = candidate
                    break
        shape = nnsight.save(module.shape)  # type: ignore  (under specification from NNsight)
    return shape

interpreto.concepts.splitters.model_with_split_points.ActivationGranularity

Bases: Enum

Activation selection strategies for ModelWithSplitPoints.get_activations().

  • ALL_TOKENS: the raw activations are flattened (n x l, d). Hence, each token activation is now considered as a separate element. This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc.

  • CLS_TOKEN: for each sample, only the first token (e.g. [CLS]) activation is returned (n, d). This will raise an error if the model is not ForSequenceClassification.

  • SAMPLE: special tokens are removed and the remaining ones are aggregated on the whole sample (n, d).

  • SENTENCE: special tokens are removed and the remaining ones are aggregate by sentences. Then the activations are flattened. (n x g, d) where g is the number of sentences in the input. The split is defined by interpreto.commons.granularity.Granularity.SENTENCE.

  • TOKEN: the raw activations are flattened, but the special tokens are removed. (n x g, d) where g is the number of non-special tokens in the input. This is the default granularity.

  • WORD: the special tokens are removed and the remaining ones are aggregate by words. Then the activations are flattened. (n x g, d) where g is the number of words in the input. The split is defined by interpreto.commons.granularity.Granularity.WORD.

ALL_TOKENS class-attribute instance-attribute

ALL_TOKENS = Granularity.ALL_TOKENS

CLS_TOKEN class-attribute instance-attribute

CLS_TOKEN = 'cls_token'

SAMPLE class-attribute instance-attribute

SAMPLE = 'sample'

SENTENCE class-attribute instance-attribute

SENTENCE = Granularity.SENTENCE

TOKEN class-attribute instance-attribute

TOKEN = Granularity.TOKEN

WORD class-attribute instance-attribute

WORD = Granularity.WORD

interpreto.concepts.splitters.model_with_split_points.GranularityAggregationStrategy

Bases: Enum

Enumeration of the available aggregation strategies for combining token-level scores into a single score for each unit of a higher-level granularity (e.g., word, sentence).

This is used in explainability methods to reduce token-based attributions according to a defined granularity.

Attributes:

Name Type Description
MEAN

Average of the token scores within each group.

MAX

Maximum token score within each group.

MIN

Minimum token score within each group.

SUM

Sum of all token scores within each group.

SIGNED_MAX

Selects the token with the highest absolute score and returns its signed value. For example, given scores [3, -1, 7], returns 7; for [3, -1, -7], returns -7.

FIRST class-attribute instance-attribute

FIRST = 'first'

LAST class-attribute instance-attribute

LAST = 'last'

MAX class-attribute instance-attribute

MAX = 'max'

MEAN class-attribute instance-attribute

MEAN = 'mean'

MIN class-attribute instance-attribute

MIN = 'min'

SIGNED_MAX class-attribute instance-attribute

SIGNED_MAX = 'signed_max'

SUM class-attribute instance-attribute

SUM = 'sum'

aggregate

aggregate(x, dim)

Aggregate activations. Args: x (torch.Tensor): The tensor to aggregate, shape: (l, d). Returns: torch.Tensor: The aggregated tensor, shape (1, d).

unfold

unfold(x, new_dim_length)

Unfold activations. Args: x (torch.Tensor): The tensor to unfold, shape: (1, d). new_dim_length (int): The new dimension length. Returns: torch.Tensor: The unfolded tensor, shape: (l, d).