ModelWithSplitPoints
interpreto.ModelWithSplitPoints
¶
ModelWithSplitPoints(model_or_repo_id, split_points, *args, automodel=None, tokenizer=None, config=None, batch_size=1, device_map=None, output_tuple_index=None, **kwargs)
Bases: LanguageModel
Code: model_wrapping/model_with_split_points.py`
The ModelWithSplitPoints is a wrapper around your HuggingFace model.
Its goal is to allow you to split your model at specified locations and extract activations.
It is one of the key component of the Concept-Based Explainers framework in Interpreto.
Indeed, any Interpreto concept explainer is built around a ModelWithSplitPoints object.
Because, splitting the model is the first step of the concept-based explanation process.
It is based on the LanguageModel class from NNsight and inherits its functionalities.
In a sense, the LanguageModel class is a wrapper around the HuggingFace model.
The ModelWithSplitPoints class is a wrapper around the LanguageModel class.
We often shorten the ModelWithSplitPoints class as MWSP and instances as mwsp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str | PreTrainedModel
|
One of:
|
required |
|
str | Sequence[str] | int | Sequence[int]
|
One or more to split locations inside the model. Either one of the following:
Example: |
required |
|
type[AutoModel]
|
Huggingface AutoClass
corresponding to the desired type of model (e.g.
|
None
|
|
PretrainedConfig
|
Custom configuration for the loaded model. If not specified, it will be instantiated with the default configuration for the model. |
None
|
|
PreTrainedTokenizer | PreTrainedTokenizerFast | None
|
Custom tokenizer for the loaded model. If not specified, it will be instantiated with the default tokenizer for the model.
|
None
|
|
int
|
Batch size for the model. |
1
|
|
device | str | None
|
Device map for the model. Directly passed to the model. |
None
|
|
int | None
|
If the output at the split point is a tuple, this is the index of the hidden state.
If |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
activation_granularities |
ActivationGranularity
|
Enumeration of the available granularities for the |
aggregation_strategies |
GranularityAggregationStrategy
|
Enumeration of the available aggregation strategies for the |
automodel |
type[AutoModel]
|
The AutoClass corresponding to the loaded model type. |
batch_size |
int
|
Batch size for the model. |
output_tuple_index |
int | None
|
If the output at the split point is a tuple, this is the index of the hidden state.
If |
repo_id |
str
|
Either the model id in the HF Hub, or the path from which the model was loaded. |
split_points |
list[str]
|
Getter/setters for model paths corresponding to split points inside the loaded model. Automatically handle validation, sorting and resolving int paths to strings. |
tokenizer |
PreTrainedTokenizer
|
Tokenizer for the loaded model, either given by the user or loaded from the repo_id. |
_model |
PreTrainedModel
|
Huggingface transformers model wrapped by NNSight. |
Examples:
Minimal example with gpt2:
>>> from transformers import AutoModelForCausalLM
>>> from interpreto import ModelWithSplitPoints
>>> model_with_split_points = ModelWithSplitPoints(
... "gpt2",
... split_points=10, # split at the 10th layer
... automodel=AutoModelForCausalLM,
... device_map="auto",
... )
>>> activations_dict = model_with_split_points.get_activations(
... inputs="interpreto is magic",
... activation_granularity=ModelWithSplitPoints.activation_granularities.TOKEN, # highly recommended for generation
... )
Load the model from its repository id, split it at the first layer, and get the raw activations for the first layer.
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> # load and split the model
>>> model_with_split_points = ModelWithSplitPoints(
... "bert-base-uncased",
... split_points="bert.encoder.layer.1.output",
... automodel=AutoModelForSequenceClassification,
... batch_size=64,
... device_map="cuda" if torch.cuda.is_available() else "cpu",
... )
>>> # get activations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations_dict = model_with_split_points.get_activations(
... dataset,
... activation_granularity=ModelWithSplitPoints.activation_granularities.CLS_TOKEN, # highly recommended for classification
... )
Load the model then pass it the ModelWithSplitPoint, split it at the first layer,
get the word activations for the tenth layer, skip special tokens, and aggregate tokens activations by mean into words.
>>> from transformers import AutoModelCausalLM, AutoTokenizer
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints as MWSP
>>> # load the model
>>> model = AutoModelCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> # wrap and split the model at the 10th layer
>>> model_with_split_points = MWSP(
... model,
... tokenizer=tokenizer,
... split_points=10, # split at the 10th layer
... batch_size=16,
... device_map="auto",
... )
>>> # get activations at the word granularity
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations = model_with_split_points.get_activations(
... dataset,
... activation_granularity=MWSP.activation_granularities.WORD,
... aggregation_strategy=MWSP.aggregation_strategies.MEAN, # average tokens activations by words
... )
Most of the work is forwarded to the LanguageModel class initialization from NNsight.
Raises:
| Type | Description |
|---|---|
InitializationError(ValueError)
|
If the model cannot be loaded, because of a missing |
ValueError
|
If the |
TypeError
|
If the |
Source code in interpreto/model_wrapping/model_with_split_points.py
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 | |
activation_granularities
class-attribute
instance-attribute
¶
activation_granularities = ActivationGranularity
aggregation_strategies
class-attribute
instance-attribute
¶
aggregation_strategies = GranularityAggregationStrategy
get_activations
¶
get_activations(inputs, activation_granularity, aggregation_strategy=MEAN, pad_side=None, tqdm_bar=False, include_predicted_classes=False, model_forward_kwargs={})
Get intermediate activations for all model split points on the given inputs.
Also include the model predictions in the returned activations dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Inputs to the model forward pass before or after tokenization.
In the case of a |
required | |
|
ActivationGranularity
|
Selection strategy for activations.
In the model, activations have the shape It is highly recommended to use Available options are:
|
required |
|
GranularityAggregationStrategy
|
Strategy to aggregate token activations into larger inputs granularities.
Applied for Existing strategies are:
|
MEAN
|
|
str | None
|
'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy. Forced right for classification models and left for causal LMs. |
None
|
|
bool
|
Whether to display a progress bar. |
False
|
|
bool
|
Whether to include the predicted classes in the output dictionary. Only applicable for classification models. |
False
|
|
dict
|
Additional keyword arguments passed to the model forward pass. |
{}
|
Returns:
| Type | Description |
|---|---|
dict[str, LatentActivations]
|
(dict[str, LatentActivations]) Dictionary having one key, value pair for each split point defined for the model. Keys correspond to split
names in |
Source code in interpreto/model_wrapping/model_with_split_points.py
767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 | |
get_split_activations
¶
get_split_activations(activations, split_point=None)
Extract activations for the specified split point.
If no split point is specified, it works if and only if the model_with_split_points has only one split point.
Verify that the given activations are valid for the model_with_split_points and split_point.
Cases in which the activations are not valid include:
- Activations are not a valid dictionary.
- Specified split point does not exist in the activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dict[str, LatentActivations]
|
A dictionary with model paths as keys and the corresponding tensors as values. |
required |
|
str | None
|
The split point to extract activations from.
If None, the |
None
|
Returns:
| Type | Description |
|---|---|
LatentActivations
|
The activations for the explainer split point. |
Examples:
>>> from interpreto import ModelWithSplitPoints as MWSP
>>> model = ModelWithSplitPoints("bert-base-uncased", split_points=4,
>>> automodel=AutoModelForSequenceClassification)
>>> activations_dict: dict[str, LatentActivations] = model.get_activations(
... "interpreto is magic",
... )
>>> activations: LatentActivations = model.get_split_activations(activations_dict)
>>> activations.shape
torch.Size([1, 12, 768])
Raises:
| Type | Description |
|---|---|
ValueError
|
If not split point is specified and the |
TypeError
|
If the activations are not a valid dictionary. |
ValueError
|
If the specified split point is not found in the activations. |
Source code in interpreto/model_wrapping/model_with_split_points.py
1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 | |
interpreto.model_wrapping.model_with_split_points.ActivationGranularity
¶
Bases: Enum
Activation selection strategies for ModelWithSplitPoints.get_activations().
-
ALL: the raw activations are returned as is(n, l, d). They are padded manually so that each batch of activations can be concatenated. -
ALL_TOKENS: the raw activations are flattened(n x l, d). Hence, each token activation is now considered as a separate element. This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc. -
CLS_TOKEN: for each sample, only the first token (e.g.[CLS]) activation is returned(n, d). This will raise an error if the model is notForSequenceClassification. -
SAMPLE: special tokens are removed and the remaining ones are aggregated on the whole sample(n, d). -
SENTENCE: special tokens are removed and the remaining ones are aggregate by sentences. Then the activations are flattened.(n x g, d)wheregis the number of sentences in the input. The split is defined byinterpreto.commons.granularity.Granularity.SENTENCE. Requiresspacyto be installed. -
TOKEN: the raw activations are flattened, but the special tokens are removed.(n x g, d)wheregis the number of non-special tokens in the input. This is the default granularity. -
WORD: the special tokens are removed and the remaining ones are aggregate by words. Then the activations are flattened.(n x g, d)wheregis the number of words in the input. The split is defined byinterpreto.commons.granularity.Granularity.WORD.
interpreto.model_wrapping.model_with_split_points.GranularityAggregationStrategy
¶
Bases: Enum
Enumeration of the available aggregation strategies for combining token-level scores into a single score for each unit of a higher-level granularity (e.g., word, sentence).
This is used in explainability methods to reduce token-based attributions according to a defined granularity.
Attributes:
| Name | Type | Description |
|---|---|---|
MEAN |
Average of the token scores within each group. |
|
MAX |
Maximum token score within each group. |
|
MIN |
Minimum token score within each group. |
|
SUM |
Sum of all token scores within each group. |
|
SIGNED_MAX |
Selects the token with the highest absolute score and returns its signed value. For example, given scores [3, -1, 7], returns 7; for [3, -1, -7], returns -7. |
aggregate
¶
Aggregate activations. Args: x (torch.Tensor): The tensor to aggregate, shape: (l, d). Returns: torch.Tensor: The aggregated tensor, shape (1, d).
unfold
¶
Unfold activations. Args: x (torch.Tensor): The tensor to unfold, shape: (1, d). new_dim_length (int): The new dimension length. Returns: torch.Tensor: The unfolded tensor, shape: (l, d).