Sparse Autoencoders (SAEs)¶
List of available SAEs¶
interpreto.concepts.methods.VanillaSAEConcepts
¶
VanillaSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[SAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer
with the Vanilla SAE from Cunningham et al. (2023)1 and Bricken et al. (2023)2 as concept model.
Vanilla SAE implementation from overcomplete.sae.SAE class.
-
Huben, R., Cunningham, H., Smith, L. R., Ewart, A., Sharkey, L. Sparse Autoencoders Find Highly Interpretable Features in Language Models. The Twelfth International Conference on Learning Representations, 2024. ↩
-
Bricken, T. et al., Towards Monosemanticity: Decomposing Language Models With Dictionary Learning, Transformer Circuits Thread, 2023. ↩
interpreto.concepts.methods.TopKSAEConcepts
¶
TopKSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[TopKSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer
with the TopK SAE from Gao et al. (2024)1 as concept model.
TopK SAE implementation from overcomplete.sae.TopKSAE class.
-
Gao, L. et al., Scaling and evaluating sparse autoencoders. The Thirteenth International Conference on Learning Representations, 2025. ↩
interpreto.concepts.methods.BatchTopKSAEConcepts
¶
BatchTopKSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[BatchTopKSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer
with the BatchTopK SAE from Bussmann et al. (2024)1 as concept model.
BatchTopK SAE implementation from overcomplete.sae.BatchTopKSAE class.
-
Bussmann, B., Leask, P., Nanda, N. BatchTopK Sparse Autoencoders. Arxiv Preprint, 2024. ↩
interpreto.concepts.methods.JumpReLUSAEConcepts
¶
JumpReLUSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[JumpSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer
with the JumpReLU SAE from Rajamanoharan et al. (2024)1 as concept model.
JumpReLU SAE implementation from overcomplete.sae.JumpReLUSAE class.
-
Rajamanoharan, S. et al., Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders. Arxiv Preprint, 2024. ↩
Abstract base class¶
interpreto.concepts.methods.SAEExplainer
¶
SAEExplainer(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: ConceptAutoEncoderExplainer[SAE]
, Generic[_SAE_co]
Code: concepts/methods/overcomplete.py
Implementation of a concept explainer using a
overcomplete.sae.SAE variant as concept_model
.
Attributes:
Name | Type | Description |
---|---|---|
model_with_split_points |
ModelWithSplitPoints
|
The model to apply the explanation on.
It should have at least one split point on which |
split_point |
str | None
|
The split point used to train the |
concept_model |
SAE
|
An Overcomplete SAE variant for concept extraction. |
is_fitted |
bool
|
Whether the |
has_differentiable_concept_encoder |
bool
|
Whether the |
has_differentiable_concept_decoder |
bool
|
Whether the |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ModelWithSplitPoints
|
The model to apply the explanation on. It should have at least one split point on which a concept explainer can be trained. |
required |
|
int
|
Size of the SAE concept space. |
required |
|
str | None
|
The split point used to train the |
None
|
|
Module | str | None
|
Encoder module to use to construct the SAE, see Overcomplete SAE documentation. |
None
|
|
dict | None
|
Dictionary parameters to use to construct the SAE, see Overcomplete SAE documentation. |
None
|
|
device | str
|
Device to use for the |
'cpu'
|
|
dict
|
Additional keyword arguments to pass to the |
{}
|
Source code in interpreto/concepts/methods/overcomplete.py
fit
¶
fit(activations, *, use_amp=False, batch_size=1024, criterion=MSELoss, optimizer_class=Adam, scheduler_class=None, lr=0.001, nb_epochs=20, clip_grad=None, monitoring=None, device='cpu', max_nan_fallbacks=5, overwrite=False)
Fit an Overcomplete SAE model on the given activations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor | dict[str, Tensor]
|
The activations used for fitting the |
required |
|
bool
|
Whether to use automatic mixed precision for fitting. |
False
|
|
SAELoss
|
Loss criterion for the training of the |
MSELoss
|
|
type[Optimizer]
|
Optimizer for the training of the |
Adam
|
|
type[LRScheduler] | None
|
Learning rate scheduler for the
training of the |
None
|
|
float
|
Learning rate for the training of the |
0.001
|
|
int
|
Number of epochs for the training of the |
20
|
|
float | None
|
Gradient clipping value for the training of the |
None
|
|
int | None
|
Monitoring frequency for the training of the |
None
|
|
device | str
|
Device to use for the training of the |
'cpu'
|
|
int | None
|
Maximum number of fallbacks to use when NaNs are encountered during training. Ignored if use_amp is False. |
5
|
|
bool
|
Whether to overwrite the current model if it has already been fitted. Default: False. |
False
|
Returns:
Type | Description |
---|---|
dict
|
A dictionary with training history logs. |
Source code in interpreto/concepts/methods/overcomplete.py
encode_activations
¶
encode_activations(activations)
Encode the given activations using the concept_model
encoder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The activations to encode. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The encoded concept activations. |
Source code in interpreto/concepts/methods/overcomplete.py
decode_concepts
¶
decode_concepts(concepts)
Decode the given concepts using the concept_model
decoder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
Tensor
|
The concepts to decode. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The decoded concept activations. |
Source code in interpreto/concepts/methods/overcomplete.py
get_dictionary
¶
Get the dictionary learned by the fitted concept_model
.
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: A |
Source code in interpreto/concepts/base.py
interpret
¶
interpret(interpretation_method, concepts_indices, inputs=None, latent_activations=None, concepts_activations=None, **kwargs)
Interpret the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and an object allowing to interpret them. It can be a label, a description, examples, etc.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
type[BaseConceptInterpretationMethod]
|
The interpretation method to use to interpret the concepts. |
required |
|
int | list[int] | Literal['all']
|
The indices of the concepts to interpret. If "all", all concepts are interpreted. |
required |
|
list[str] | None
|
The inputs to use for the interpretation.
Necessary if the source is not |
None
|
|
LatentActivations | dict[str, LatentActivations] | None
|
The latent activations to use for the interpretation.
Necessary if the source is |
None
|
|
ConceptsActivations | None
|
The concepts activations to use for the interpretation.
Necessary if the source is not |
None
|
|
Additional keyword arguments to pass to the interpretation method. |
{}
|
Returns:
Type | Description |
---|---|
Mapping[int, Any]
|
Mapping[int, Any]: A mapping between the concepts indices and the interpretation of the concepts. |
Source code in interpreto/concepts/base.py
input_concept_attribution
¶
input_concept_attribution(inputs, concept, attribution_method, **attribution_kwargs)
Attributes model inputs for a selected concept.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ModelInputs
|
The input data, which can be a string, a list of tokens/words/clauses/sentences or a dataset. |
required |
|
int
|
Index identifying the position of the concept of interest (score in the
|
required |
|
type[AttributionExplainer]
|
The attribution method to obtain importance scores for input elements. |
required |
Returns:
Type | Description |
---|---|
list[float]
|
A list of attribution scores for each input. |
Source code in interpreto/concepts/base.py
concept_output_attribution
¶
concept_output_attribution(inputs, concepts, target, attribution_method, **attribution_kwargs)
Computes the attribution of each concept for the logit of a target output element.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
ModelInputs
|
An input data-point for the model. |
required |
|
Tensor
|
Concept activation tensor. |
required |
|
int
|
The target class for which the concept output attribution should be computed. |
required |
|
type[AttributionExplainer]
|
The attribution method to obtain importance scores for input elements. |
required |
Returns:
Type | Description |
---|---|
list[float]
|
A list of attribution scores for each concept. |
Source code in interpreto/concepts/base.py
Loss Functions¶
These functions can be passed as the criterion
argument in the fit
method of the SAEExplainer
class. MSELoss
is the default loss function.
interpreto.concepts.methods.SAELossClasses
¶
Bases: Enum
Enumeration of possible loss functions for SAEs.
To pass as the criterion
parameter of SAEExplainer.fit()
.
Attributes:
Name | Type | Description |
---|---|---|
MSE |
type[SAELoss]
|
Mean Squared Error loss. |
DeadNeuronsReanimation |
type[SAELoss]
|
Loss function promoting reanimation of dead neurons. |