ModelWithSplitPoints
ModelWithSplitPoints now uses the singular split_point argument/property because only one split point is supported.
The previous split_points argument/property remains temporarily available as a deprecated compatibility alias and emits a DeprecationWarning guiding users to split_point. It will be removed in version 1.0.0.
Specificities¶
In comparison to SplitterForClassification and SplitterForGeneration, the ModelWithSplitPoints class is more versatile.
It is more complex to use, but it covers the two other splitter cases.
In both get_activations and _get_concept_output_gradients, one needs to specify:
activation_granularity: specifies which of the(n, l, d)activations to return. It can be one ofCLS_TOKEN,ALL_TOKENS,TOKEN,WORD,SENTENCE, orSAMPLE. Useactivation_granularity=ModelWithSplitPoints.activation_granularities.TOKENto specify it.aggregation_strategy: how activations should be aggregated (only forWORD,SENTENCE, orSAMPLE). It can be one ofSUM,MEAN,MAX, orSIGNED_MAX. Useaggregation_strategy=ModelWithSplitPoints.aggregation_strategies.MEANto specify it.
interpreto.ModelWithSplitPoints
¶
ModelWithSplitPoints(model_or_repo_id, split_point=None, *args, split_points=None, automodel=None, tokenizer=None, config=None, batch_size=1, device_map=None, **kwargs)
Bases: BaseSplitter
Code: concepts.splitters/model_with_split_points.py`
The ModelWithSplitPoints is a wrapper around your HuggingFace model.
Its goal is to allow you to split your model at specified locations and extract activations.
It is one of the key component of the Concept-Based Explainers framework in Interpreto.
Indeed, any Interpreto concept explainer is built around a ModelWithSplitPoints object.
Because, splitting the model is the first step of the concept-based explanation process.
It is based on the LanguageModel class from NNsight and inherits its functionalities.
In a sense, the LanguageModel class is a wrapper around the HuggingFace model.
The ModelWithSplitPoints class is a wrapper around the LanguageModel class.
We often shorten the ModelWithSplitPoints class as MWSP and instances as mwsp.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str | PreTrainedModel
|
One of:
|
required |
|
str | int
|
The split location inside the model. Either one of the following:
Example: |
None
|
|
(str | int | list[str] | list[int], deprecated)
|
Backward-compatible alias for
|
None
|
|
type[AutoModel]
|
Huggingface AutoClass
corresponding to the desired type of model (e.g.
|
None
|
|
PretrainedConfig
|
Custom configuration for the loaded model. If not specified, it will be instantiated with the default configuration for the model. |
None
|
|
PreTrainedTokenizer | PreTrainedTokenizerFast | None
|
Custom tokenizer for the loaded model. If not specified, it will be instantiated with the default tokenizer for the model.
|
None
|
|
int
|
Batch size for the model. |
1
|
|
device | str | None
|
Device map for the model. Directly passed to the model. |
None
|
Attributes:
| Name | Type | Description |
|---|---|---|
activation_granularities |
ActivationGranularity
|
Enumeration of the available granularities for the |
aggregation_strategies |
GranularityAggregationStrategy
|
Enumeration of the available aggregation strategies for the |
automodel |
type[AutoModel]
|
The AutoClass corresponding to the loaded model type. |
batch_size |
int
|
Batch size for the model. |
output_tuple_index |
int | None
|
If the output at the split point is a tuple, this is the index of the hidden state.
If |
repo_id |
str
|
Either the model id in the HF Hub, or the path from which the model was loaded. |
tokenizer |
PreTrainedTokenizer
|
Tokenizer for the loaded model, either given by the user or loaded from the repo_id. |
_model |
PreTrainedModel
|
Huggingface transformers model wrapped by NNSight. |
Examples:
Minimal example with gpt2:
>>> from transformers import AutoModelForCausalLM
>>> from interpreto import ModelWithSplitPoints
>>> model_with_split_points = ModelWithSplitPoints(
... "gpt2",
... split_point=10, # split at the 10th layer
... automodel=AutoModelForCausalLM,
... device_map="auto",
... )
>>> activations, _ = model_with_split_points.get_activations(
... inputs="interpreto is magic",
... activation_granularity=ModelWithSplitPoints.activation_granularities.TOKEN, # highly recommended for generation
... )
Load the model from its repository id, split it at the first layer, and get the raw activations for the first layer.
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> # load and split the model
>>> model_with_split_points = ModelWithSplitPoints(
... "bert-base-uncased",
... split_point="bert.encoder.layer.1.output",
... automodel=AutoModelForSequenceClassification,
... batch_size=64,
... device_map="cuda" if torch.cuda.is_available() else "cpu",
... )
>>> # get activations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations, _ = model_with_split_points.get_activations(
... dataset,
... activation_granularity=ModelWithSplitPoints.activation_granularities.CLS_TOKEN, # highly recommended for classification
... )
Load the model then pass it the ModelWithSplitPoint, split it at the first layer,
get the word activations for the tenth layer, skip special tokens, and aggregate tokens activations by mean into words.
>>> from transformers import AutoModelCausalLM, AutoTokenizer
>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints as MWSP
>>> # load the model
>>> model = AutoModelCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> # wrap and split the model at the 10th layer
>>> model_with_split_points = MWSP(
... model,
... tokenizer=tokenizer,
... split_point=10, # split at the 10th layer
... batch_size=16,
... device_map="auto",
... )
>>> # get activations at the word granularity
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> activations, _ = model_with_split_points.get_activations(
... dataset,
... activation_granularity=MWSP.activation_granularities.WORD,
... aggregation_strategy=MWSP.aggregation_strategies.MEAN, # average tokens activations by words
... )
Most of the work is forwarded to the BaseSplitter class initialization.
Which is in turn a wrapper around the nnsight.LanguageModel class.
Raises:
| Type | Description |
|---|---|
InitializationError(ValueError)
|
If the model cannot be loaded, because of a missing |
ValueError
|
If the |
TypeError
|
If the |
Source code in interpreto/concepts/splitters/model_with_split_points.py
activation_granularities
class-attribute
instance-attribute
¶
activation_granularities = ActivationGranularity
aggregation_strategies
class-attribute
instance-attribute
¶
aggregation_strategies = GranularityAggregationStrategy
get_activations
¶
get_activations(inputs, *, activation_granularity, aggregation_strategy=MEAN, pad_side=None, tqdm_bar=False, include_predicted_classes=False, flatten_activations=True, forward_kwargs={})
Get intermediate activations for the model split point on the given inputs.
Optionally include the model predictions in the returned tuple.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Inputs to the model forward pass before or after tokenization.
In the case of a |
required | |
|
ActivationGranularity
|
Selection strategy for activations.
In the model, activations have the shape It is highly recommended to use Available options are:
|
required |
|
GranularityAggregationStrategy
|
Strategy to aggregate token activations into larger inputs granularities.
Applied for Existing strategies are:
|
MEAN
|
|
str | None
|
'left' or 'right' — side on which to apply padding along dim=1 only for ALL strategy. Forced right for classification models and left for causal LMs. |
None
|
|
bool
|
Whether to display a progress bar. |
False
|
|
bool
|
Whether to include the predicted classes in the output tuple. Only applicable for classification models. |
False
|
|
bool
|
Whether to flatten the activations tensors.
|
True
|
|
dict
|
Additional keyword arguments passed to the model forward pass. |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
tuple[LatentActivations, Tensor | None] | tuple[list[LatentActivations], list[Tensor] | None]
|
activations (LatentActivations | [list[LatentActivations]: The extracted activations either in a sample-wise list are flattened. |
|
predictions |
Tensor | list[Tensor] | None
|
The predicted classes, if requested. |
Source code in interpreto/concepts/splitters/model_with_split_points.py
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 | |
get_latent_shape
¶
Get the shape of the latent activations at the split point.
Use the scan operation from NNsight to get the shape of the activations.
It basically builds the computation graph, but it is much quicker than a forward.
Returns:
| Type | Description |
|---|---|
Size
|
torch.Size: Shape of the activations for the split point. |
Source code in interpreto/concepts/splitters/model_with_split_points.py
interpreto.concepts.splitters.model_with_split_points.ActivationGranularity
¶
Bases: Enum
Activation selection strategies for ModelWithSplitPoints.get_activations().
-
ALL_TOKENS: the raw activations are flattened(n x l, d). Hence, each token activation is now considered as a separate element. This includes special tokens such as [CLS], [SEP], [EOS], [PAD], etc. -
CLS_TOKEN: for each sample, only the first token (e.g.[CLS]) activation is returned(n, d). This will raise an error if the model is notForSequenceClassification. -
SAMPLE: special tokens are removed and the remaining ones are aggregated on the whole sample(n, d). -
SENTENCE: special tokens are removed and the remaining ones are aggregate by sentences. Then the activations are flattened.(n x g, d)wheregis the number of sentences in the input. The split is defined byinterpreto.commons.granularity.Granularity.SENTENCE. -
TOKEN: the raw activations are flattened, but the special tokens are removed.(n x g, d)wheregis the number of non-special tokens in the input. This is the default granularity. -
WORD: the special tokens are removed and the remaining ones are aggregate by words. Then the activations are flattened.(n x g, d)wheregis the number of words in the input. The split is defined byinterpreto.commons.granularity.Granularity.WORD.
interpreto.concepts.splitters.model_with_split_points.GranularityAggregationStrategy
¶
Bases: Enum
Enumeration of the available aggregation strategies for combining token-level scores into a single score for each unit of a higher-level granularity (e.g., word, sentence).
This is used in explainability methods to reduce token-based attributions according to a defined granularity.
Attributes:
| Name | Type | Description |
|---|---|---|
MEAN |
Average of the token scores within each group. |
|
MAX |
Maximum token score within each group. |
|
MIN |
Minimum token score within each group. |
|
SUM |
Sum of all token scores within each group. |
|
SIGNED_MAX |
Selects the token with the highest absolute score and returns its signed value. For example, given scores [3, -1, 7], returns 7; for [3, -1, -7], returns -7. |
aggregate
¶
Aggregate activations. Args: x (torch.Tensor): The tensor to aggregate, shape: (l, d). Returns: torch.Tensor: The aggregated tensor, shape (1, d).
unfold
¶
Unfold activations. Args: x (torch.Tensor): The tensor to unfold, shape: (1, d). new_dim_length (int): The new dimension length. Returns: torch.Tensor: The unfolded tensor, shape: (l, d).