Skip to content

ModelWithSplitPoints

interpreto.ModelWithSplitPoints

ModelWithSplitPoints(model_or_repo_id, split_points, *args, automodel=None, tokenizer=None, config=None, batch_size=1, device_map=None, output_tuple_index=None, **kwargs)

Bases: LanguageModel

Code: model_wrapping/model_with_split_points.py`

The ModelWithSplitPoints is a wrapper around your HuggingFace model. Its goal is to allow you to split your model at specified locations and extract activations.

It is one of the key component of the Concept-Based Explainers framework in Interpreto. Indeed, any Interpreto concept explainer is built around a ModelWithSplitPoints object. Because, splitting the model is the first step of the concept-based explanation process.

It is based on the LanguageModel class from NNsight and inherits its functionalities. In a sense, the LanguageModel class is a wrapper around the HuggingFace model. The ModelWithSplitPoints class is a wrapper around the LanguageModel class.

We often shorten the ModelWithSplitPoints class as MWSP and instances as mwsp.

Parameters:

Name Type Description Default

model_or_repo_id

str | PreTrainedModel

One of:

  • A str corresponding to the ID of the model that should be loaded from the HF Hub.
  • A str corresponding to the local path of a folder containing a compatible checkpoint.
  • A preloaded transformers.PreTrainedModel object. If a string is provided, a automodel should also be provided.
required

split_points

str | Sequence[str] | int | Sequence[int]

One or more to split locations inside the model. Either one of the following:

  • A str corresponding to the path of a split point inside the model.
  • An int corresponding to the n-th layer.
  • A Sequence[str] or Sequence[int] corresponding to multiple split points.

Example: split_points='cls.predictions.transform.LayerNorm' correspond to a split after the LayerNorm layer in the MLM head (assuming a BertForMaskedLM model in input).

required

automodel

type[AutoModel]

Huggingface AutoClass corresponding to the desired type of model (e.g. AutoModelForSequenceClassification).

⚠ automodel must be defined if model_or_repo_id is str, since the the model class cannot be known otherwise.

None

config

PretrainedConfig

Custom configuration for the loaded model. If not specified, it will be instantiated with the default configuration for the model.

None

tokenizer

PreTrainedTokenizer | PreTrainedTokenizerFast | None

Custom tokenizer for the loaded model. If not specified, it will be instantiated with the default tokenizer for the model.

⚠ If model_or_repo_id is a transformers.PreTrainedModel object, then tokenizer must be defined.

None

batch_size

int

Batch size for the model.

1

device_map

device | str | None

Device map for the model. Directly passed to the model.

None

output_tuple_index

int | None

If the output at the split point is a tuple, this is the index of the hidden state. If None, an element with 3 dimensions is searched for. If not found, an error is raised. If several elements are found, an error is raised.

None

Attributes:

Name Type Description
activation_granularities ActivationGranularity

Enumeration of the available granularities for the get_activations method.

aggregation_strategies GranularityAggregationStrategy

Enumeration of the available aggregation strategies for the get_activations method.

automodel type[AutoModel]

The AutoClass corresponding to the loaded model type.

batch_size int

Batch size for the model.

output_tuple_index int | None

If the output at the split point is a tuple, this is the index of the hidden state. If None, an element with 3 dimensions is searched for. If not found, an error is raised. If several elements are found, an error is raised.

repo_id str

Either the model id in the HF Hub, or the path from which the model was loaded.

split_points list[str]

Getter/setters for model paths corresponding to split points inside the loaded model. Automatically handle validation, sorting and resolving int paths to strings.

tokenizer PreTrainedTokenizer

Tokenizer for the loaded model, either given by the user or loaded from the repo_id.

_model PreTrainedModel

Huggingface transformers model wrapped by NNSight.

Examples:

Minimal example with gpt2:

>>> from transformers import AutoModelForCausalLM
>>> from interpreto import ModelWithSplitPoints
>>> model_with_split_points = ModelWithSplitPoints(
...     "gpt2",
...     split_points=10,  # split at the 10th layer
...     automodel=AutoModelForCausalLM,
...     device_map="auto",
... )
>>> activations_dict = model_with_split_points.get_activations(
...     inputs="interpreto is magic",
...     activation_granularity=ModelWithSplitPoints.activation_granularities.TOKEN,  # highly recommended for generation
... )

Load the model from its repository id, split it at the first layer, and get the raw activations for the first layer.

>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> # load and split the model
>>> model_with_split_points = ModelWithSplitPoints(
...     "bert-base-uncased",
...     split_points="bert.encoder.layer.1.output",
...     automodel=AutoModelForSequenceClassification,
...     batch_size=64,
...     device_map="cuda" if torch.cuda.is_available() else "cpu",
... )
>>> # get activations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations_dict = model_with_split_points.get_activations(
...     dataset,
...     activation_granularity=ModelWithSplitPoints.activation_granularities.CLS_TOKEN,  # highly recommended for classification
... )

Load the model then pass it the ModelWithSplitPoint, split it at the first layer, get the word activations for the tenth layer, skip special tokens, and aggregate tokens activations by mean into words.

>>> from transformers import AutoModelCausalLM, AutoTokenizer
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints as MWSP
>>> # load the model
>>> model = AutoModelCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> # wrap and split the model at the 10th layer
>>> model_with_split_points = MWSP(
...     model,
...     tokenizer=tokenizer,
...     split_points=10,  # split at the 10th layer
...     batch_size=16,
...     device_map="auto",
... )
>>> # get activations at the word granularity
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations = model_with_split_points.get_activations(
...     dataset,
...     activation_granularity=MWSP.activation_granularities.WORD,
...     aggregation_strategy=MWSP.aggregation_strategies.MEAN,  # average tokens activations by words
... )

Most of the work is forwarded to the LanguageModel class initialization from NNsight.

Raises:

Type Description
InitializationError(ValueError)

If the model cannot be loaded, because of a missing tokenizer or automodel.

ValueError

If the device_map is set to 'auto' and the model is not a generation model.

TypeError

If the model_or_repo_id is not a str or a transformers.PreTrainedModel.

Source code in interpreto/model_wrapping/model_with_split_points.py
def __init__(
    self,
    model_or_repo_id: str | PreTrainedModel,
    split_points: str | int | list[str] | list[int] | tuple[str] | tuple[int],
    *args: tuple[Any],
    automodel: type[AutoModel] | None = None,
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | None = None,
    config: PretrainedConfig | None = None,
    batch_size: int = 1,
    device_map: torch.device | str | None = None,
    output_tuple_index: int | None = None,
    **kwargs,
) -> None:
    # For parameters list, see class docstring. It was moved to change the order in the documentation.
    """Initialize a ModelWithSplitPoints object.

    Most of the work is forwarded to the `LanguageModel` class initialization from NNsight.

    Raises:
        InitializationError (ValueError): If the model cannot be loaded, because of a missing `tokenizer` or `automodel`.
        ValueError: If the `device_map` is set to 'auto' and the model is not a generation model.
        TypeError: If the `model_or_repo_id` is not a `str` or a `transformers.PreTrainedModel`.
    """
    if isinstance(model_or_repo_id, PreTrainedModel):
        if tokenizer is None:
            raise InitializationError(
                "Tokenizer is not set. When providing a model instance, the tokenizer must be set."
            )
    elif isinstance(model_or_repo_id, str):  # Repository ID
        if automodel is None:
            raise InitializationError(
                "Model autoclass not found.\n"
                "The model class can be omitted if a pre-loaded model is passed to `model_or_repo_id` "
                "param.\nIf an HF Hub ID is used, the corresponding autoclass must be specified in `automodel`.\n"
                "Example: ModelWithSplitPoints('bert-base-uncased', automodel=AutoModelForMaskedLM, ...)"
            )
    else:
        raise TypeError(
            f"Invalid model_or_repo_id type: {type(model_or_repo_id)}. "
            "Expected `str` or `transformers.PreTrainedModel`."
        )

    # Handles model loading through nnsight.LanguageModel._load
    super().__init__(
        model_or_repo_id,
        *args,
        config=config,
        tokenizer=tokenizer,  # type: ignore (under specification from NNsight)
        automodel=automodel,  # type: ignore (under specification from NNsight)
        device_map=device_map,
        **kwargs,
    )

    # set split points
    self._model_paths = list(walk_modules(self._model))
    self.split_points = split_points  # this uses the setter which handles validation
    self._model: PreTrainedModel  # specify type of `_model` attribute from NNsight
    if self.repo_id is None:
        self.repo_id = self._model.config.name_or_path  # type: ignore  (under specification from NNsight)
    self.batch_size = batch_size

    if not isinstance(model_or_repo_id, str):
        # `device_map` is ignored by `nnsight` in this case, hence we manage it manually
        if device_map is not None:
            if device_map == "auto":
                raise ValueError(
                    "'auto' device_map is only supported when loading a generation model from a repository id. "
                    "Please specify a device_map, e.g. 'cuda' or 'cpu'."
                )
                # pass the provided model to the specified device
            self.to(device_map)  # type: ignore  (under specification from NNsight)
        else:
            # we leave the model on its device
            pass

    if self.tokenizer is None:
        raise ValueError("Tokenizer is not set. When providing a model instance, the tokenizer must be set.")
    self.output_tuple_index = output_tuple_index

activation_granularities class-attribute instance-attribute

activation_granularities = ActivationGranularity

aggregation_strategies class-attribute instance-attribute

aggregation_strategies = GranularityAggregationStrategy

get_activations

Get intermediate activations for all model split points on the given inputs.

Also include the model predictions in the returned activations dictionary.

Parameters:

Name Type Description Default

inputs list[str] | torch.Tensor | BatchEncoding

Inputs to the model forward pass before or after tokenization. In the case of a torch.Tensor, we assume a batch dimension and token ids.

required

activation_granularity

ActivationGranularity

Selection strategy for activations. In the model, activations have the shape (n, l, d), where d is the model dimension. This parameters specifies which elements of these tensors are selected. If the granularity is larger then tokens, i.e. words and sentences, the activations are aggregated. The parameter aggregation_strategy specifies how the activations are aggregated.

It is highly recommended to use CLS_TOKEN for classification tasks and TOKEN for other tasks.

Available options are:

  • ModelWithSplitPoints.activation_granularities.ALL: the raw activations are returned as is (n, l, d). They are padded manually so that each batch of activations can be concatenated.

  • ModelWithSplitPoints.activation_granularities.ALL_TOKENS: the raw activations are flattened (n x l, d). Hence, each token activation is now considered as a separate element. This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc.

  • ModelWithSplitPoints.activation_granularities.CLS_TOKEN: for each sample, only the first token (e.g. [CLS]) activation is returned (n, d). This will raise an error if the model is not ForSequenceClassification.

  • ModelWithSplitPoints.activation_granularities.SAMPLE: special tokens are removed and the remaining ones are aggregated on the whole sample (n, d).

  • ModelWithSplitPoints.activation_granularities.SENTENCE: special tokens are removed and the remaining ones are aggregate by sentences. Then the activations are flattened. (n x g, d) where g is the number of sentences in the input. The split is defined by interpreto.commons.granularity.Granularity.SENTENCE. Requires spacy to be installed.

  • ModelWithSplitPoints.activation_granularities.TOKEN: the raw activations are flattened, but the special tokens are removed. (n x g, d) where g is the number of non-special tokens in the input. This is the default granularity.

  • ModelWithSplitPoints.activation_granularities.WORD: the special tokens are removed and the remaining ones are aggregate by words. Then the activations are flattened. (n x g, d) where g is the number of words in the input. The split is defined by interpreto.commons.granularity.Granularity.WORD.

required

aggregation_strategy

GranularityAggregationStrategy

Strategy to aggregate token activations into larger inputs granularities. Applied for WORD, SENTENCE and SAMPLE activation strategies. Token activations of shape n * (l, d) are aggregated on the sequence length dimension. The concatenated into (ng, d) tensors.

Existing strategies are:

  • ModelWithSplitPoints.aggregation_strategies.SUM: Tokens activations are summed along the sequence length dimension.

  • ModelWithSplitPoints.aggregation_strategies.MEAN: Tokens activations are averaged along the sequence length dimension.

  • ModelWithSplitPoints.aggregation_strategies.MAX: The maximum of the token activations along the sequence length dimension is selected.

  • ModelWithSplitPoints.aggregation_strategies.SIGNED_MAX: The maximum of the absolute value of the activations multiplied by its initial sign. signed_max([[-1, 0, 1, 2], [-3, 1, -2, 0]]) = [-3, 1, -2, 2]

MEAN

pad_side

str | None

'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy. Forced right for classification models and left for causal LMs.

None

tqdm_bar

bool

Whether to display a progress bar.

False

include_predicted_classes

bool

Whether to include the predicted classes in the output dictionary. Only applicable for classification models.

False

model_forward_kwargs

dict

Additional keyword arguments passed to the model forward pass.

{}

Returns:

Type Description
dict[str, LatentActivations]

(dict[str, LatentActivations]) Dictionary having one key, value pair for each split point defined for the model. Keys correspond to split names in self.split_points, while values correspond to the extracted activations for the split point for the given inputs.

Source code in interpreto/model_wrapping/model_with_split_points.py
def get_activations(  # noqa: PLR0912  # ignore too many branches  # too many special cases
    self,
    inputs: list[str] | torch.Tensor | BatchEncoding,
    activation_granularity: ActivationGranularity,
    aggregation_strategy: GranularityAggregationStrategy = GranularityAggregationStrategy.MEAN,
    pad_side: str | None = None,
    tqdm_bar: bool = False,
    include_predicted_classes: bool = False,
    model_forward_kwargs: dict[str, Any] = {},
) -> dict[str, LatentActivations]:
    """

    Get intermediate activations for all model split points on the given `inputs`.

    Also include the model predictions in the returned activations dictionary.

    Args:
        inputs list[str] | torch.Tensor | BatchEncoding:
            Inputs to the model forward pass before or after tokenization.
            In the case of a `torch.Tensor`, we assume a batch dimension and token ids.

        activation_granularity (ActivationGranularity):
            Selection strategy for activations.
            In the model, activations have the shape `(n, l, d)`, where `d` is the model dimension.
            This parameters specifies which elements of these tensors are selected.
            If the granularity is larger then tokens, i.e. words and sentences, the activations are aggregated.
            The parameter `aggregation_strategy` specifies how the activations are aggregated.

            **It is highly recommended to use `CLS_TOKEN` for classification tasks and `TOKEN` for other tasks.**

            Available options are:

            - ``ModelWithSplitPoints.activation_granularities.ALL``:
                the raw activations are returned as is ``(n, l, d)``.
                They are padded manually so that each batch of activations can be concatenated.

            - ``ModelWithSplitPoints.activation_granularities.ALL_TOKENS``:
                the raw activations are flattened ``(n x l, d)``.
                Hence, each token activation is now considered as a separate element.
                This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc.

            - ``ModelWithSplitPoints.activation_granularities.CLS_TOKEN``:
                for each sample, only the first token (e.g. ``[CLS]``) activation is returned ``(n, d)``.
                This will raise an error if the model is not `ForSequenceClassification`.

            - ``ModelWithSplitPoints.activation_granularities.SAMPLE``:
                special tokens are removed and the remaining ones are aggregated on the whole sample ``(n, d)``.

            - ``ModelWithSplitPoints.activation_granularities.SENTENCE``:
                special tokens are removed and the remaining ones are aggregate by sentences.
                Then the activations are flattened.
                ``(n x g, d)`` where `g` is the number of sentences in the input.
                The split is defined by `interpreto.commons.granularity.Granularity.SENTENCE`.
                Requires `spacy` to be installed.

            - ``ModelWithSplitPoints.activation_granularities.TOKEN``:
                the raw activations are flattened, but the special tokens are removed.
                ``(n x g, d)`` where `g` is the number of non-special tokens in the input.
                This is the default granularity.

            - ``ModelWithSplitPoints.activation_granularities.WORD``:
                the special tokens are removed and the remaining ones are aggregate by words.
                Then the activations are flattened.
                ``(n x g, d)`` where `g` is the number of words in the input.
                The split is defined by `interpreto.commons.granularity.Granularity.WORD`.

        aggregation_strategy (GranularityAggregationStrategy):
            Strategy to aggregate token activations into larger inputs granularities.
            Applied for `WORD`, `SENTENCE` and `SAMPLE` activation strategies.
            Token activations of shape  n * (l, d) are aggregated on the sequence length dimension.
            The concatenated into (ng, d) tensors.

            Existing strategies are:

            - ``ModelWithSplitPoints.aggregation_strategies.SUM``:
                Tokens activations are summed along the sequence length dimension.

            - ``ModelWithSplitPoints.aggregation_strategies.MEAN``:
                Tokens activations are averaged along the sequence length dimension.

            - ``ModelWithSplitPoints.aggregation_strategies.MAX``:
                The maximum of the token activations along the sequence length dimension is selected.

            - ``ModelWithSplitPoints.aggregation_strategies.SIGNED_MAX``:
                The maximum of the absolute value of the activations multiplied by its initial sign.
                signed_max([[-1, 0, 1, 2], [-3, 1, -2, 0]]) = [-3, 1, -2, 2]

        pad_side (str | None):
            'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy.
            Forced right for classification models and left for causal LMs.

        tqdm_bar (bool):
            Whether to display a progress bar.

        include_predicted_classes (bool):
            Whether to include the predicted classes in the output dictionary.
            Only applicable for classification models.

        model_forward_kwargs (dict):
            Additional keyword arguments passed to the model forward pass.

    Returns:
        (dict[str, LatentActivations]) Dictionary having one key, value pair for each split point defined for the model. Keys correspond to split
            names in `self.split_points`, while values correspond to the extracted activations for the split point
            for the given `inputs`.
    """
    # set default pad side value and catch unsupported cases
    if self._model.__class__.__name__.endswith("ForSequenceClassification"):
        pad_side = "right"
    else:
        if self._model.__class__.__name__.endswith("ForCausalLM"):
            pad_side = "left"
        else:
            pad_side = pad_side or "left"
        if include_predicted_classes:
            raise ValueError(
                "`include_predicted_classes` is only supported for classification models. "
                f"Provided model is a {self._model.__class__.__name__}."
            )
    self.tokenizer.padding_side = pad_side

    # add padding token to vocabulary if not present (model and tokenizer)
    if not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None:
        self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        self.model.resize_token_embeddings(len(self.tokenizer))  # type: ignore  # weird huggingface typing

    # batch inputs
    if isinstance(inputs, BatchEncoding):
        batch_generator = []
        # manage key by key batching for BatchEncoding
        for i in range(0, len(inputs), self.batch_size):
            end_idx = min(i + self.batch_size, len(inputs))
            batch_generator.append({key: value[i:end_idx] for key, value in inputs.items()})
    else:  # sequence of inputs or tensors
        # create a generator for iterable of inputs and tensors
        batch_generator = (
            inputs[i : min(i + self.batch_size, len(inputs))] for i in range(0, len(inputs), self.batch_size)
        )

    # wrap generator in tqdm for progress bar
    tqdm_wrapped_batch_generator = tqdm(
        batch_generator,
        desc="Computing activations",
        unit="batch",
        total=ceil(len(inputs) / self.batch_size),
        disable=not tqdm_bar,
    )

    # initialize activations dictionary
    activations: dict = {}
    for split_point in self.split_points + ["predictions"]:
        activations[split_point] = []

    # iterate over batch of inputs
    with torch.no_grad():
        # several call of the same model should be grouped in an nnsight session
        with self.session():
            for batch_inputs in tqdm_wrapped_batch_generator:
                # ------------------------------------------------------------------------------
                # prepare inputs and compute granular indices
                if isinstance(batch_inputs, list):
                    # tokenize text inputs for granularity selection
                    # include "offsets_mapping" for sentence selection strategy
                    tokenized_inputs = self.tokenizer(
                        batch_inputs,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        return_offsets_mapping=True,
                    )

                    # special case for T5 in a generation setting
                    if isinstance(self.args[0], T5ForConditionalGeneration):
                        # TODO: find a way for this not to be necessary
                        tokenized_inputs["decoder_input_ids"] = tokenized_inputs["input_ids"]
                else:
                    # the input was already tokenized
                    tokenized_inputs = batch_inputs

                # get granularity indices
                granularity_indices = self._get_granularity_indices(tokenized_inputs, activation_granularity)

                # extract offset mapping not supported by forward but was necessary for sentence selection strategy
                if isinstance(tokenized_inputs, (BatchEncoding, dict)):  # noqa: UP038
                    tokenized_inputs.pop("offset_mapping", None)

                # ------------------------------------------------------------------------------
                # model forward pass with nnsight to extract activations and predictions

                # all model calls use trace with nnsight
                # call model forward pass and save split point outputs
                with self.trace(tokenized_inputs, **model_forward_kwargs) as tracer:
                    # nnsight quick way to obtain the activations for all split points
                    batch_activations = tracer.cache(modules=[self.get(sp) for sp in self.split_points])  # type: ignore  (under specification from NNsight)

                    # for classification optionally compute and save the predictions
                    if include_predicted_classes:
                        batch_predictions: Float[torch.Tensor, "n"] = (
                            self.output.logits.argmax(dim=-1).cpu().save()  # type: ignore  (under specification from NNsight)
                        )

                # free memory after each batch, necessary with nnsight, overwise, memory piles up
                torch.cuda.empty_cache()

                # ------------------------------------------------------------------------------
                # apply granularity selection and aggregation of activations and predictions
                for sp in self.split_points:
                    # extracting the activations for the current split point
                    sp_module = batch_activations["model." + sp]
                    output_name = "nns_output" if hasattr(sp_module, "nns_output") else "output"
                    batch_outputs = getattr(sp_module, output_name)

                    # manage the output tuple and extract the (n, l, d) activations from it
                    batch_sp_activations: Float[torch.Tensor, "n l d"] = self._manage_output_tuple(
                        batch_outputs, sp
                    )

                    # select relevant activations with respect to the granularity strategy
                    # potentially aggregate activations over the granularity elements
                    # this merges the `n` and `g` dimensions with `g` a subset of `n`
                    # shape (n, l, d) only for `ALL` granularity, thus raw activations
                    granular_activations: Float[torch.Tensor, "ng d"] | Float[torch.Tensor, "n l d"] = (
                        self._apply_selection_strategy(
                            activations=batch_sp_activations,
                            granularity_indices=granularity_indices,
                            activation_granularity=activation_granularity,
                            aggregation_strategy=aggregation_strategy,
                        )
                    )

                    activations[sp].append(granular_activations)

                if include_predicted_classes:
                    # for granularities outside of `ALL`
                    if granularity_indices is not None:
                        # adapt predictions to match the granularity indices
                        repeats: Float[torch.Tensor, "ng"] = torch.tensor(
                            [len(indices) for indices in granularity_indices]
                        )

                        # predictions have a shape (n,), which we convert to (ng,)
                        # by repeating each predicted class as many times as the number of granularity elements in a sample
                        repeated_predictions = torch.repeat_interleave(batch_predictions, repeats, dim=0)  # type: ignore  (ignore possibly unbound)
                        activations["predictions"].append(repeated_predictions)

    # ------------------------------------------------------------------------------------------
    # concat activation batches and validate that activations have the expected type
    for split_point in self.split_points:
        if activation_granularity == AG.ALL:
            # three dimensional tensor (n, l, d)
            activations[split_point] = ModelWithSplitPoints._pad_and_concat(
                activations[split_point], pad_side, 0.0
            )
        else:
            # two dimensional tensor (n*g, d)
            activations[split_point] = torch.cat(activations[split_point], dim=0)

    if include_predicted_classes:
        activations["predictions"] = torch.cat(activations["predictions"], dim=0)
    else:
        activations.pop("predictions", None)

    # validate that activations have the expected type
    for layer, act in activations.items():
        if not isinstance(act, torch.Tensor):
            raise RuntimeError(
                f"Invalid output for layer '{layer}'. Expected torch.Tensor activation, got {type(act)}: {act}"
            )
    return activations  # type: ignore

get_split_activations

get_split_activations(activations, split_point=None)

Extract activations for the specified split point. If no split point is specified, it works if and only if the model_with_split_points has only one split point. Verify that the given activations are valid for the model_with_split_points and split_point. Cases in which the activations are not valid include:

  • Activations are not a valid dictionary.
  • Specified split point does not exist in the activations.

Parameters:

Name Type Description Default

activations

dict[str, LatentActivations]

A dictionary with model paths as keys and the corresponding tensors as values.

required

split_point

str | None

The split point to extract activations from. If None, the split_point of the explainer is used.

None

Returns:

Type Description
LatentActivations

The activations for the explainer split point.

Examples:

>>> from interpreto import ModelWithSplitPoints as MWSP
>>> model = ModelWithSplitPoints("bert-base-uncased", split_points=4,
>>>                              automodel=AutoModelForSequenceClassification)
>>> activations_dict: dict[str, LatentActivations] = model.get_activations(
...     "interpreto is magic",
... )
>>> activations: LatentActivations = model.get_split_activations(activations_dict)
>>> activations.shape
torch.Size([1, 12, 768])

Raises:

Type Description
ValueError

If not split point is specified and the model_with_split_points has more than one split point.

TypeError

If the activations are not a valid dictionary.

ValueError

If the specified split point is not found in the activations.

Source code in interpreto/model_wrapping/model_with_split_points.py
def get_split_activations(
    self, activations: dict[str, LatentActivations], split_point: str | None = None
) -> LatentActivations:
    """
    Extract activations for the specified split point.
    If no split point is specified, it works if and only if the `model_with_split_points` has only one split point.
    Verify that the given activations are valid for the `model_with_split_points` and `split_point`.
    Cases in which the activations are not valid include:

    * Activations are not a valid dictionary.
    * Specified split point does not exist in the activations.

    Args:
        activations (dict[str, LatentActivations]): A dictionary with model paths as keys and the corresponding
            tensors as values.
        split_point (str | None): The split point to extract activations from.
            If None, the `split_point` of the explainer is used.

    Returns:
        (LatentActivations): The activations for the explainer split point.

    Examples:
        >>> from interpreto import ModelWithSplitPoints as MWSP
        >>> model = ModelWithSplitPoints("bert-base-uncased", split_points=4,
        >>>                              automodel=AutoModelForSequenceClassification)
        >>> activations_dict: dict[str, LatentActivations] = model.get_activations(
        ...     "interpreto is magic",
        ... )
        >>> activations: LatentActivations = model.get_split_activations(activations_dict)
        >>> activations.shape
        torch.Size([1, 12, 768])

    Raises:
        ValueError: If not split point is specified and the `model_with_split_points` has more than one split point.
        TypeError: If the activations are not a valid dictionary.
        ValueError: If the specified split point is not found in the activations.
    """
    if split_point is not None:
        local_split_point: str = split_point
    elif not self.split_points:
        raise ValueError(
            "The activations cannot correspond to `model_with_split_points` model. "
            "The `model_with_split_points` model do not have `split_point` defined. "
        )
    elif len(self.split_points) > 1:
        raise ValueError("Cannot determine the split point with multiple `model_with_split_points` split points. ")
    else:
        local_split_point: str = self.split_points[0]

    if not isinstance(activations, dict) or not all(isinstance(act, torch.Tensor) for act in activations.values()):
        raise TypeError(
            "Invalid activations for the concept explainer. "
            "Activations should be a dictionary of model paths and torch.Tensor activations. "
            f"Got: '{type(activations)}'"
        )
    activations_split_points: list[str] = list(activations.keys())  # type: ignore
    if local_split_point not in activations_split_points:
        raise ValueError(
            f"Fitted split point '{local_split_point}' not found in activations.\n"
            f"Available split_points: {', '.join(activations_split_points)}."
        )

    return activations[local_split_point]  # type: ignore

interpreto.model_wrapping.model_with_split_points.ActivationGranularity

Bases: Enum

Activation selection strategies for ModelWithSplitPoints.get_activations().

  • ALL: the raw activations are returned as is (n, l, d). They are padded manually so that each batch of activations can be concatenated.

  • ALL_TOKENS: the raw activations are flattened (n x l, d). Hence, each token activation is now considered as a separate element. This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc.

  • CLS_TOKEN: for each sample, only the first token (e.g. [CLS]) activation is returned (n, d). This will raise an error if the model is not ForSequenceClassification.

  • SAMPLE: special tokens are removed and the remaining ones are aggregated on the whole sample (n, d).

  • SENTENCE: special tokens are removed and the remaining ones are aggregate by sentences. Then the activations are flattened. (n x g, d) where g is the number of sentences in the input. The split is defined by interpreto.commons.granularity.Granularity.SENTENCE. Requires spacy to be installed.

  • TOKEN: the raw activations are flattened, but the special tokens are removed. (n x g, d) where g is the number of non-special tokens in the input. This is the default granularity.

  • WORD: the special tokens are removed and the remaining ones are aggregate by words. Then the activations are flattened. (n x g, d) where g is the number of words in the input. The split is defined by interpreto.commons.granularity.Granularity.WORD.

ALL class-attribute instance-attribute

ALL = 'all'

ALL_TOKENS class-attribute instance-attribute

ALL_TOKENS = ALL_TOKENS

CLS_TOKEN class-attribute instance-attribute

CLS_TOKEN = 'cls_token'

SAMPLE class-attribute instance-attribute

SAMPLE = 'sample'

SENTENCE class-attribute instance-attribute

SENTENCE = SENTENCE

TOKEN class-attribute instance-attribute

TOKEN = TOKEN

WORD class-attribute instance-attribute

WORD = WORD

interpreto.model_wrapping.model_with_split_points.GranularityAggregationStrategy

Bases: Enum

Enumeration of the available aggregation strategies for combining token-level scores into a single score for each unit of a higher-level granularity (e.g., word, sentence).

This is used in explainability methods to reduce token-based attributions according to a defined granularity.

Attributes:

Name Type Description
MEAN

Average of the token scores within each group.

MAX

Maximum token score within each group.

MIN

Minimum token score within each group.

SUM

Sum of all token scores within each group.

SIGNED_MAX

Selects the token with the highest absolute score and returns its signed value. For example, given scores [3, -1, 7], returns 7; for [3, -1, -7], returns -7.

MAX class-attribute instance-attribute

MAX = 'max'

MEAN class-attribute instance-attribute

MEAN = 'mean'

MIN class-attribute instance-attribute

MIN = 'min'

SIGNED_MAX class-attribute instance-attribute

SIGNED_MAX = 'signed_max'

SUM class-attribute instance-attribute

SUM = 'sum'

aggregate

aggregate(x, dim)

Aggregate activations. Args: x (torch.Tensor): The tensor to aggregate, shape: (l, d). Returns: torch.Tensor: The aggregated tensor, shape (1, d).

unfold

unfold(x, new_dim_length)

Unfold activations. Args: x (torch.Tensor): The tensor to unfold, shape: (1, d). new_dim_length (int): The new dimension length. Returns: torch.Tensor: The unfolded tensor, shape: (l, d).