Sparse Autoencoders (SAEs)¶
Abstract base class¶
interpreto.concepts.methods.SAEExplainer
¶
SAEExplainer(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: ConceptAutoEncoderExplainer[SAE], Generic[_SAE_co]
Code: concepts/methods/overcomplete.py
Implementation of a concept explainer using a
overcomplete.sae.SAE variant as concept_model.
Attributes:
| Name | Type | Description |
|---|---|---|
model_with_split_points |
ModelWithSplitPoints
|
The model to apply the explanation on.
It should have at least one split point on which |
split_point |
str | None
|
The split point used to train the |
concept_model |
SAE
|
An Overcomplete SAE variant for concept extraction. |
is_fitted |
bool
|
Whether the |
has_differentiable_concept_encoder |
bool
|
Whether the |
has_differentiable_concept_decoder |
bool
|
Whether the |
Examples:
>>> import datasets
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import VanillaSAE
>>> from interpreto.concepts.interpretations import TopKInputs
>>> CLS_TOKEN = ModelWithSplitPoints.activation_granularities.CLS_TOKEN
>>> WORD = ModelWithSplitPoints.activation_granularities.WORD
...
>>> dataset = datasets.load_dataset("stanfordnlp/imdb")["train"]["text"][:1000]
>>> repo_id = "Qwen/Qwen3-0.6B"
>>> model = AutoModelForCausalLM.from_pretrained(repo_id, device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained(repo_id)
...
>>> # 1. Split your model in two parts
>>> splitted_model = ModelWithSplitPoints(
>>> model, tokenizer=tokenizer, split_points=[5],
>>> )
...
>>> # 2. Compute a dataset of activations
>>> activations = splitted_model.get_activations(
>>> dataset, activation_granularity=WORD
>>> )
...
>>> # 3. Fit a concept model on the dataset
>>> explainer = VanillaSAE(splitted_model, nb_concepts=100, device="cuda")
>>> explainer.fit(activations, lr=1e-3, nb_epochs=20, batch_size=1024)
...
>>> # 4. Interpret the concepts
>>> interpreter = TopKInputs(
>>> concept_explainer=explainer,
>>> activation_granularity=WORD,
>>> )
>>> interpretations = interpreter.interpret(
>>> inputs=dataset, latent_activations=activations
>>> )
...
>>> # Print the interpretations
>>> for id, words in interpretations.items():
>>> print(f"Concept {id}: {list(words.keys()) if words else None}")
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
ModelWithSplitPoints
|
The model to apply the explanation on. It should have at least one split point on which a concept explainer can be trained. |
required |
|
int
|
Size of the SAE concept space. |
required |
|
str | None
|
The split point used to train the |
None
|
|
Module | str | None
|
Encoder module to use to construct the SAE, see Overcomplete SAE documentation. |
None
|
|
dict | None
|
Dictionary parameters to use to construct the SAE, see Overcomplete SAE documentation. |
None
|
|
device | str
|
Device to use for the |
'cpu'
|
|
dict
|
Additional keyword arguments to pass to the |
{}
|
Source code in interpreto/concepts/methods/overcomplete.py
fit
¶
fit(activations, *, use_amp=False, batch_size=1024, criterion=MSELoss, optimizer_class=Adam, optimizer_kwargs={}, scheduler_class=None, scheduler_kwargs={}, lr=0.001, nb_epochs=20, clip_grad=None, monitoring=None, device=None, max_nan_fallbacks=5, overwrite=False)
Fit an Overcomplete SAE model on the given activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor | dict[str, Tensor]
|
The activations used for fitting the |
required |
|
bool
|
Whether to use automatic mixed precision for fitting. |
False
|
|
SAELoss
|
Loss criterion for the training of the |
MSELoss
|
|
type[Optimizer]
|
Optimizer for the training of the |
Adam
|
|
dict
|
Keyword arguments to pass to the optimizer. |
{}
|
|
type[LRScheduler] | None
|
Learning rate scheduler for the
training of the |
None
|
|
dict
|
Keyword arguments to pass to the scheduler. |
{}
|
|
float
|
Learning rate for the training of the |
0.001
|
|
int
|
Number of epochs for the training of the |
20
|
|
float | None
|
Gradient clipping value for the training of the |
None
|
|
int | None
|
Monitoring frequency for the training of the |
None
|
|
device | str
|
Device to use for the training of the |
None
|
|
int | None
|
Maximum number of fallbacks to use when NaNs are encountered during training. Ignored if use_amp is False. |
5
|
|
bool
|
Whether to overwrite the current model if it has already been fitted. Default: False. |
False
|
Returns:
| Type | Description |
|---|---|
dict
|
A dictionary with training history logs. |
Source code in interpreto/concepts/methods/overcomplete.py
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 | |
encode_activations
¶
encode_activations(activations)
Encode the given activations using the concept_model encoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The activations to encode. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The encoded concept activations. |
Source code in interpreto/concepts/methods/overcomplete.py
decode_concepts
¶
decode_concepts(concepts)
Decode the given concepts using the concept_model decoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Tensor
|
The concepts to decode. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The decoded concept activations. |
Source code in interpreto/concepts/methods/overcomplete.py
get_dictionary
¶
Get the dictionary learned by the fitted concept_model.
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: A |
Source code in interpreto/concepts/base.py
interpret
¶
Deprecated API for concept interpretation.
Interpretation methods should now be instantiated directly with the fitted concept explainer. For example:
TopKInputs(concept_explainer).interpret(inputs, latent_activations)
This method is kept only for backwards compatibility and will always
raise a :class:NotImplementedError.
Source code in interpreto/concepts/base.py
concept_output_gradient
¶
concept_output_gradient(inputs, targets=None, split_point=None, activation_granularity=TOKEN, aggregation_strategy=MEAN, concepts_x_gradients=True, normalization=True, tqdm_bar=False, batch_size=None)
Compute the gradients of the predictions with respect to the concepts.
To clarify what this function does, lets detail some notations. Suppose the initial model was splitted such that \(f = g \circ h\). Hence the concept model was fitted on \(A = h(X)\) with \(X\) a dataset of samples. The resulting concept model encoders and decoders are noted \(t\) and \(t^{-1}\). \(t\) can be seen as projections from the latent space to the concept space. Hence, the function going from the inputs to the concepts is \(f_{ic} = t \circ h\) and the function going from the concepts to the outputs is \(f_{co} = g \circ t^-1\).
Given a set of samples \(X\), and the functions \((h, t, t^{-1}, g)\) This function first compute \(C = t(A) = t \circ h(X)\), then returns \(\nabla{f_{co}}(C)\).
In practice all computations are done by ModelWithSplitPoints._get_concept_output_gradients,
which relies on NNsight. The current method only forwards the \(t\) and \(t^{-1}\),
respectively self.encode_activations and self.decode_concepts methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
list[str] | Tensor | BatchEncoding
|
The input data, either a list of samples, the tokenized input or a batch of samples. |
required |
|
list[int] | None
|
Specify which outputs of the model should be used to compute the gradients.
Note that \(f_{co}\) often has several outputs, by default gradients are computed for each output.
The |
None
|
|
str | None
|
The split point used to train the |
None
|
|
ActivationGranularity
|
The granularity of the activations to use for the attribution.
It is highly recommended to to use the same granularity as the one used in the
|
TOKEN
|
|
GranularityAggregationStrategy
|
Strategy to aggregate token activations into larger inputs granularities.
Applied for
|
MEAN
|
|
bool
|
If the resulting gradients should be multiplied by the concepts activations. True by default (similarly to attributions), because of mathematical properties. Therefore the out put is \(C * \nabla{f_{co}}(C)\). |
True
|
|
bool
|
Whether to normalize the gradients.
Gradients will be normalized on the concept (c) and sequence length (g) dimensions.
Such that for a given sample-target-granular pair,
the sum of the absolute values of the gradients is equal to 1.
(The granular elements depend on the :arg: |
True
|
|
bool
|
Whether to display a progress bar. |
False
|
|
int | None
|
Batch size for the model.
It might be different from the one used in |
None
|
Returns:
| Type | Description |
|---|---|
list[Float[Tensor, 't g c']]
|
list[Float[torch.Tensor, "t g c"]]: The gradients of the model output with respect to the concept activations. List length: correspond to the number of inputs. Tensor shape: (t, g, c) with t the target dimension, g the number of granularity elements in one input, and c the number of concepts. |
Source code in interpreto/concepts/base.py
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 | |
List of available SAEs¶
interpreto.concepts.methods.BatchTopKSAEConcepts
¶
BatchTopKSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[BatchTopKSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer with the BatchTopK SAE from Bussmann et al. (2024)1 as concept model.
BatchTopK SAE implementation from overcomplete.sae.BatchTopKSAE class.
-
Bussmann, B., Leask, P., Nanda, N. BatchTopK Sparse Autoencoders. Arxiv Preprint, 2024. ↩
interpreto.concepts.methods.JumpReLUSAEConcepts
¶
JumpReLUSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[JumpSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer with the JumpReLU SAE from Rajamanoharan et al. (2024)1 as concept model.
JumpReLU SAE implementation from overcomplete.sae.JumpReLUSAE class.
-
Rajamanoharan, S. et al., Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders. Arxiv Preprint, 2024. ↩
interpreto.concepts.methods.MpSAEConcepts
¶
MpSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[MpSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer with the MpSAE from Costa et al. (2025)1 as concept model.
Matching Pursuit SAE implementation from overcomplete.sae.MpSAE class.
-
Valérie Costa, Thomas Fel, Ekdeep Singh Lubana, Bahareh Tolooshams, Demba Ba (2025). From Flat to Hierarchical: Extracting Sparse Representations with Matching Pursuit. arXiv preprint arXiv:2506.03093. ↩
interpreto.concepts.methods.TopKSAEConcepts
¶
TopKSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[TopKSAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer with the TopK SAE from Gao et al. (2024)1 as concept model.
TopK SAE implementation from overcomplete.sae.TopKSAE class.
-
Gao, L. et al., Scaling and evaluating sparse autoencoders. The Thirteenth International Conference on Learning Representations, 2025. ↩
interpreto.concepts.methods.VanillaSAEConcepts
¶
VanillaSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)
Bases: SAEExplainer[SAE]
Code: concepts/methods/overcomplete.py
ConceptAutoEncoderExplainer with the Vanilla SAE from Cunningham et al. (2023)1 and Bricken et al. (2023)2 as concept model.
Vanilla SAE implementation from overcomplete.sae.SAE class.
-
Huben, R., Cunningham, H., Smith, L. R., Ewart, A., Sharkey, L. Sparse Autoencoders Find Highly Interpretable Features in Language Models. The Twelfth International Conference on Learning Representations, 2024. ↩
-
Bricken, T. et al., Towards Monosemanticity: Decomposing Language Models With Dictionary Learning, Transformer Circuits Thread, 2023. ↩
Loss Functions¶
These functions can be passed as the criterion argument in the fit method of the SAEExplainer class. MSELoss is the default loss function.
interpreto.concepts.methods.SAELossClasses
¶
Bases: Enum
Enumeration of possible loss functions for SAEs.
To pass as the criterion parameter of SAEExplainer.fit().
Attributes:
| Name | Type | Description |
|---|---|---|
MSE |
type[SAELoss]
|
Mean Squared Error loss. |
DeadNeuronsReanimation |
type[SAELoss]
|
Loss function promoting reanimation of dead neurons. |