Skip to content

Model Wrapping

interpreto.model_wrapping.ModelWithSplitPoints

ModelWithSplitPoints(model_or_repo_id, split_points, *args, model_autoclass=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)

Bases: LanguageModel

Code: model_wrapping/model_with_split_points.py`

Generalized NNsight.LanguageModel wrapper around encoder-only, decoder-only and encoder-decoder language models. Handles splitting model at specified locations and activation extraction.

Inputs can be in the form of:

* One (`str`) or more (`list[str]`) prompts, including batched prompts (`list[list[str]]`).

* One (`list[int] or torch.Tensor`) or more (`list[list[int]] or torch.Tensor`) tokenized prompts.

* Direct model inputs: (`dic[str,Any]`)

Attributes:

Name Type Description
model_autoclass type

The AutoClass corresponding to the loaded model type.

split_points list[str]

Getter/setters for model paths corresponding to split points inside the loaded model. Automatically handle validation, sorting and resolving int paths to strings.

repo_id str

Either the model id in the HF Hub, or the path from which the model was loaded.

generator Envoy | None

If the model is generative, a generator is provided to handle multi-step inference. None for encoder-only models.

_model PreTrainedModel

Huggingface transformers model wrapped by NNSight.

_model_paths list[str]

List of cached valid paths inside _model, used to validate split_points.

_split_points list[str]

List of split points, should be accessed with getter/setter.

Examples:

Load the model from its repository id, split it at the first layer, and get the raw activations for the first layer.

>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> # load and split the model
>>> model_with_split_points = ModelWithSplitPoints(
...     "bert-base-uncased",
...     split_points="bert.encoder.layer.1.output",
...     model_autoclass=AutoModelForMaskedLM,
...     batch_size=64,
...     device_map="auto",
... )
>>> # get activations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations = model_with_split_points.get_activations(
...     dataset,
...     activation_granularity=ModelWithSplitPoints.activation_granularities.TOKEN,
... )

Load the model then pass it the ModelWithSplitPoint, split it at the first layer, get the word activations for the first layer, skip special tokens, and aggregate tokens activations by mean into words.

>>> from transformers import AutoModelCausalLM, AutoTokenizer
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> # load the model
>>> model = AutoModelCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> # wrap and split the model
>>> model_with_split_points = ModelWithSplitPoints(
...     model,
...     tokenizer=tokenizer,
...     split_points="transformer.h.1.mlp"],,
...     model_autoclass=AutoModelForMaskedLM,
...     batch_size=16,
...     device_map="auto",
... )
>>> # get activations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations = model_with_split_points.get_activations(
...     dataset,
...     activation_granularity=ModelWithSplitPoints.activation_granularities.WORD,
...     aggregation_strategy=ModelWithSplitPoints.aggregation_strategies.MEAN,
... )

Parameters:

Name Type Description Default

model_or_repo_id

str | PreTrainedModel

One of:

  • A str corresponding to the ID of the model that should be loaded from the HF Hub.
  • A str corresponding to the local path of a folder containing a compatible checkpoint.
  • A preloaded transformers.PreTrainedModel object. If a string is provided, a model_autoclass should also be provided.
required

split_points

str | Sequence[str] | int | Sequence[int]

One or more to split locations inside the model. Either the path is provided explicitly (str), or an int is used as shorthand for splitting at the n-th layer. Example: split_points='cls.predictions.transform.LayerNorm' correspond to a split after the LayerNorm layer in the MLM head (assuming a BertForMaskedLM model in input).

required

model_autoclass

Type

Huggingface AutoClass corresponding to the desired type of model (e.g. AutoModelForSequenceClassification).

⚠ model_autoclass must be defined if model_or_repo_id is str, since the the model class cannot be known otherwise.

None

config

PretrainedConfig

Custom configuration for the loaded model. If not specified, it will be instantiated with the default configuration for the model.

None

tokenizer

PreTrainedTokenizer

Custom tokenizer for the loaded model. If not specified, it will be instantiated with the default tokenizer for the model.

None

batch_size

int

Batch size for the model.

1

device_map

str | None

Device map for the model. Directly passed to the model.

None

get_activations

get_activations(inputs, activation_granularity=ALL_TOKENS, aggregation_strategy=MEAN, pad_side='left', **kwargs)

Get intermediate activations for all model split points

Parameters:

Name Type Description Default

inputs

list[str] | Tensor | BatchEncoding

Inputs to the model forward pass before or after tokenization. In the case of a torch.Tensor, we assume a batch dimension and token ids.

required

activation_granularity

ActivationGranularity

Selection strategy for activations.

Options are: - ModelWithSplitPoints.activation_granularities.ALL: the activations are returned as is (batch, seq_len, d_model). They are padded manually so that each batch of activations can be concatenated.

  • ModelWithSplitPoints.activation_granularities.CLS_TOKEN: only the first token (e.g. [CLS]) activation is returned (batch, d_model).

  • ModelWithSplitPoints.activation_granularities.ALL_TOKENS: every token activation is treated as a separate element (batch x seq_len, d_model).

  • ModelWithSplitPoints.activation_granularities.TOKEN: remove special tokens.

  • ModelWithSplitPoints.activation_granularities.WORD: aggregate by words following the split defined by :class:~interpreto.commons.granularity.Granularity.WORD.

  • ModelWithSplitPoints.activation_granularities.SENTENCE: aggregate by sentences following the split defined by :class:~interpreto.commons.granularity.Granularity.SENTENCE. Requires spacy to be installed.

  • ModelWithSplitPoints.activation_granularities.SAMPLE: activations are aggregated on the whole sample.

ALL_TOKENS

aggregation_strategy

AggregationProtocol

Strategy to aggregate token activations into larger inputs granularities. Applied for WORD, SENTENCE and SAMPLE activation strategies. Token activations of shape n * (l, d) are aggregated on the sequence length dimension. The concatenated into (ng, d) tensors. Existing strategies are:

  • ModelWithSplitPoints.aggregation_strategies.SUM: Tokens activations are summed along the sequence length dimension.

  • ModelWithSplitPoints.aggregation_strategies.MEAN: Tokens activations are averaged along the sequence length dimension.

  • ModelWithSplitPoints.aggregation_strategies.MAX: The maximum of the token activations along the sequence length dimension is selected.

  • ModelWithSplitPoints.aggregation_strategies.SIGNED_MAX: The maximum of the absolute value of the activations multiplied by its initial sign. signed_max([[-1, 0, 1, 2], [-3, 1, -2, 0]]) = [-3, 1, -2, 2]

MEAN

pad_side

str

'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy.

'left'

**kwargs

Additional keyword arguments passed to the model forward pass.

{}

Returns:

Type Description
dict[str, LatentActivations]

(dict[str, LatentActivations]) Dictionary having one key, value pair for each split point defined for the model. Keys correspond to split names in self.split_points, while values correspond to the extracted activations for the split point for the given inputs.

get_latent_shape

get_latent_shape(inputs=None)

Get the shape of the latent activations at the specified split point.