Skip to content

Sparse Autoencoders (SAEs)

List of available SAEs

interpreto.concepts.methods.VanillaSAEConcepts

VanillaSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)

Bases: SAEExplainer[SAE]

Code: concepts/methods/overcomplete.py

ConceptAutoEncoderExplainer with the Vanilla SAE from Cunningham et al. (2023)1 and Bricken et al. (2023)2 as concept model.

Vanilla SAE implementation from overcomplete.sae.SAE class.


  1. Huben, R., Cunningham, H., Smith, L. R., Ewart, A., Sharkey, L. Sparse Autoencoders Find Highly Interpretable Features in Language Models. The Twelfth International Conference on Learning Representations, 2024. 

  2. Bricken, T. et al., Towards Monosemanticity: Decomposing Language Models With Dictionary Learning, Transformer Circuits Thread, 2023. 

interpreto.concepts.methods.TopKSAEConcepts

TopKSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)

Bases: SAEExplainer[TopKSAE]

Code: concepts/methods/overcomplete.py

ConceptAutoEncoderExplainer with the TopK SAE from Gao et al. (2024)1 as concept model.

TopK SAE implementation from overcomplete.sae.TopKSAE class.


  1. Gao, L. et al., Scaling and evaluating sparse autoencoders. The Thirteenth International Conference on Learning Representations, 2025. 

interpreto.concepts.methods.BatchTopKSAEConcepts

BatchTopKSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)

Bases: SAEExplainer[BatchTopKSAE]

Code: concepts/methods/overcomplete.py

ConceptAutoEncoderExplainer with the BatchTopK SAE from Bussmann et al. (2024)1 as concept model.

BatchTopK SAE implementation from overcomplete.sae.BatchTopKSAE class.


  1. Bussmann, B., Leask, P., Nanda, N. BatchTopK Sparse Autoencoders. Arxiv Preprint, 2024. 

interpreto.concepts.methods.JumpReLUSAEConcepts

JumpReLUSAEConcepts(model_with_split_points, *, nb_concepts, split_point=None, encoder_module=None, dictionary_params=None, device='cpu', **kwargs)

Bases: SAEExplainer[JumpSAE]

Code: concepts/methods/overcomplete.py

ConceptAutoEncoderExplainer with the JumpReLU SAE from Rajamanoharan et al. (2024)1 as concept model.

JumpReLU SAE implementation from overcomplete.sae.JumpReLUSAE class.

Abstract base class

interpreto.concepts.methods.SAEExplainer

Bases: ConceptAutoEncoderExplainer[SAE], Generic[_SAE_co]

Code: concepts/methods/overcomplete.py

Implementation of a concept explainer using a overcomplete.sae.SAE variant as concept_model.

Attributes:

Name Type Description
model_with_split_points ModelWithSplitPoints

The model to apply the explanation on. It should have at least one split point on which concept_model can be fitted.

split_point str | None

The split point used to train the concept_model. Default: None, set only when the concept explainer is fitted.

concept_model SAE

An Overcomplete SAE variant for concept extraction.

is_fitted bool

Whether the concept_model was fit on model activations.

has_differentiable_concept_encoder bool

Whether the encode_activations operation is differentiable.

has_differentiable_concept_decoder bool

Whether the decode_concepts operation is differentiable.

Parameters:

Name Type Description Default

model_with_split_points

ModelWithSplitPoints

The model to apply the explanation on. It should have at least one split point on which a concept explainer can be trained.

required

nb_concepts

int

Size of the SAE concept space.

required

split_point

str | None

The split point used to train the concept_model. If None, tries to use the split point of model_with_split_points if a single one is defined.

None

encoder_module

Module | str | None

Encoder module to use to construct the SAE, see Overcomplete SAE documentation.

None

dictionary_params

dict | None

Dictionary parameters to use to construct the SAE, see Overcomplete SAE documentation.

None

device

device | str

Device to use for the concept_module.

'cpu'

**kwargs

dict

Additional keyword arguments to pass to the concept_module. See the Overcomplete documentation of the provided concept_model_class for more details.

{}
Source code in interpreto/concepts/methods/overcomplete.py
def __init__(
    self,
    model_with_split_points: ModelWithSplitPoints,
    *,
    nb_concepts: int,
    split_point: str | None = None,
    encoder_module: nn.Module | str | None = None,
    dictionary_params: dict | None = None,
    device: str = "cpu",
    **kwargs,
):
    """
    Initialize the concept bottleneck explainer based on the Overcomplete SAE framework.

    Args:
        model_with_split_points (ModelWithSplitPoints): The model to apply the explanation on.
            It should have at least one split point on which a concept explainer can be trained.
        nb_concepts (int): Size of the SAE concept space.
        split_point (str | None): The split point used to train the `concept_model`. If None, tries to use the
            split point of `model_with_split_points` if a single one is defined.
        encoder_module (nn.Module | str | None): Encoder module to use to construct the SAE, see [Overcomplete SAE documentation](https://kempnerinstitute.github.io/overcomplete/saes/vanilla/).
        dictionary_params (dict | None): Dictionary parameters to use to construct the SAE, see [Overcomplete SAE documentation](https://kempnerinstitute.github.io/overcomplete/saes/vanilla/).
        device (torch.device | str): Device to use for the `concept_module`.
        **kwargs (dict): Additional keyword arguments to pass to the `concept_module`.
            See the Overcomplete documentation of the provided `concept_model_class` for more details.
    """
    if not issubclass(self.concept_model_class, oc_sae.SAE):
        raise ValueError(
            "ConceptEncoderDecoder must be a subclass of `overcomplete.sae.SAE`.\n"
            "Use `interpreto.concepts.methods.SAEExplainerClasses` to get the list of available SAE methods."
        )
    self.model_with_split_points = model_with_split_points
    self.split_point: str = split_point  # type: ignore

    # TODO: this will be replaced with a scan and a better way to select how to pick activations based on model class
    shapes = self.model_with_split_points.get_latent_shape()
    concept_model = self.concept_model_class(
        input_shape=shapes[self.split_point][-1],
        nb_concepts=nb_concepts,
        encoder_module=encoder_module,
        dictionary_params=dictionary_params,
        device=device,
        **kwargs,
    )
    super().__init__(model_with_split_points, concept_model, self.split_point)
    self.has_differentiable_concept_encoder = True
    self.has_differentiable_concept_decoder = True

fit

fit(activations, *, use_amp=False, batch_size=1024, criterion=MSELoss, optimizer_class=Adam, scheduler_class=None, lr=0.001, nb_epochs=20, clip_grad=None, monitoring=None, device='cpu', max_nan_fallbacks=5, overwrite=False)

Fit an Overcomplete SAE model on the given activations.

Parameters:

Name Type Description Default

activations

Tensor | dict[str, Tensor]

The activations used for fitting the concept_model. If a dictionary is provided, the activation corresponding to split_point will be used.

required

use_amp

bool

Whether to use automatic mixed precision for fitting.

False

criterion

SAELoss

Loss criterion for the training of the concept_model.

MSELoss

optimizer_class

type[Optimizer]

Optimizer for the training of the concept_model.

Adam

scheduler_class

type[LRScheduler] | None

Learning rate scheduler for the training of the concept_model.

None

lr

float

Learning rate for the training of the concept_model.

0.001

nb_epochs

int

Number of epochs for the training of the concept_model.

20

clip_grad

float | None

Gradient clipping value for the training of the concept_model.

None

monitoring

int | None

Monitoring frequency for the training of the concept_model.

None

device

device | str

Device to use for the training of the concept_model.

'cpu'

max_nan_fallbacks

int | None

Maximum number of fallbacks to use when NaNs are encountered during training. Ignored if use_amp is False.

5

overwrite

bool

Whether to overwrite the current model if it has already been fitted. Default: False.

False

Returns:

Type Description
dict

A dictionary with training history logs.

Source code in interpreto/concepts/methods/overcomplete.py
def fit(
    self,
    activations: LatentActivations | dict[str, LatentActivations],
    *,
    use_amp: bool = False,
    batch_size: int = 1024,
    criterion: type[SAELoss] = MSELoss,
    optimizer_class: type[torch.optim.Optimizer] = torch.optim.Adam,
    scheduler_class: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
    lr: float = 1e-3,
    nb_epochs: int = 20,
    clip_grad: float | None = None,
    monitoring: int | None = None,
    device: torch.device | str = "cpu",
    max_nan_fallbacks: int | None = 5,
    overwrite: bool = False,
) -> dict:
    """Fit an Overcomplete SAE model on the given activations.

    Args:
        activations (torch.Tensor | dict[str, torch.Tensor]): The activations used for fitting the `concept_model`.
            If a dictionary is provided, the activation corresponding to `split_point` will be used.
        use_amp (bool): Whether to use automatic mixed precision for fitting.
        criterion (interpreto.concepts.SAELoss): Loss criterion for the training of the `concept_model`.
        optimizer_class (type[torch.optim.Optimizer]): Optimizer for the training of the `concept_model`.
        scheduler_class (type[torch.optim.lr_scheduler.LRScheduler] | None): Learning rate scheduler for the
            training of the `concept_model`.
        lr (float): Learning rate for the training of the `concept_model`.
        nb_epochs (int): Number of epochs for the training of the `concept_model`.
        clip_grad (float | None): Gradient clipping value for the training of the `concept_model`.
        monitoring (int | None): Monitoring frequency for the training of the `concept_model`.
        device (torch.device | str): Device to use for the training of the `concept_model`.
        max_nan_fallbacks (int | None): Maximum number of fallbacks to use when NaNs are encountered during
            training. Ignored if use_amp is False.
        overwrite (bool): Whether to overwrite the current model if it has already been fitted.
            Default: False.

    Returns:
        A dictionary with training history logs.
    """
    split_activations = self._prepare_fit(activations, overwrite=overwrite)
    dataloader = DataLoader(TensorDataset(split_activations.detach()), batch_size=batch_size, shuffle=True)
    optimizer_kwargs = {"lr": lr}
    optimizer = optimizer_class(self.concept_model.parameters(), **optimizer_kwargs)  # type: ignore
    train_params = {
        "model": self.concept_model,
        "dataloader": dataloader,
        "criterion": criterion(),
        "optimizer": optimizer,
        "nb_epochs": nb_epochs,
        "clip_grad": clip_grad,
        "monitoring": monitoring,
        "device": device,
    }
    if scheduler_class is not None:
        scheduler = scheduler_class(optimizer)
        train_params["scheduler"] = scheduler

    if use_amp:
        train_method = oc_sae.train.train_sae_amp
        train_params["max_nan_fallbacks"] = max_nan_fallbacks
    else:
        train_method = oc_sae.train_sae
    log = train_method(**train_params)
    self.concept_model.fitted = True
    return log

encode_activations

encode_activations(activations)

Encode the given activations using the concept_model encoder.

Parameters:

Name Type Description Default

activations

Tensor

The activations to encode.

required

Returns:

Type Description
Tensor

The encoded concept activations.

Source code in interpreto/concepts/methods/overcomplete.py
@check_fitted
def encode_activations(self, activations: LatentActivations) -> torch.Tensor:  # ConceptsActivations
    """Encode the given activations using the `concept_model` encoder.

    Args:
        activations (torch.Tensor): The activations to encode.

    Returns:
        The encoded concept activations.
    """
    # SAEs.encode returns both codes (concepts activations) and pre_codes (before relu)
    _, codes = super().encode_activations(activations.to(self.device))
    return codes

decode_concepts

decode_concepts(concepts)

Decode the given concepts using the concept_model decoder.

Parameters:

Name Type Description Default

concepts

Tensor

The concepts to decode.

required

Returns:

Type Description
Tensor

The decoded concept activations.

Source code in interpreto/concepts/methods/overcomplete.py
@check_fitted
def decode_concepts(self, concepts: torch.Tensor) -> torch.Tensor:
    """Decode the given concepts using the `concept_model` decoder.

    Args:
        concepts (torch.Tensor): The concepts to decode.

    Returns:
        The decoded concept activations.
    """
    return self.concept_model.decode(concepts.to(self.device))  # type: ignore

get_dictionary

get_dictionary()

Get the dictionary learned by the fitted concept_model.

Returns:

Type Description
Tensor

torch.Tensor: A torch.Tensor containing the learned dictionary.

Source code in interpreto/concepts/base.py
@check_fitted
def get_dictionary(self) -> torch.Tensor:  # TODO: add this to tests
    """Get the dictionary learned by the fitted `concept_model`.

    Returns:
        torch.Tensor: A `torch.Tensor` containing the learned dictionary.
    """
    return self.concept_model.get_dictionary()  # type: ignore

interpret

Interpret the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and an object allowing to interpret them. It can be a label, a description, examples, etc.

Parameters:

Name Type Description Default

interpretation_method

type[BaseConceptInterpretationMethod]

The interpretation method to use to interpret the concepts.

required

concepts_indices

int | list[int] | Literal['all']

The indices of the concepts to interpret. If "all", all concepts are interpreted.

required

inputs

list[str] | None

The inputs to use for the interpretation. Necessary if the source is not VOCABULARY, as examples are extracted from the inputs.

None

latent_activations

LatentActivations | dict[str, LatentActivations] | None

The latent activations to use for the interpretation. Necessary if the source is LATENT_ACTIVATIONS. Otherwise, it is computed from the inputs or ignored if the source is CONCEPT_ACTIVATIONS.

None

concepts_activations

ConceptsActivations | None

The concepts activations to use for the interpretation. Necessary if the source is not CONCEPT_ACTIVATIONS. Otherwise, it is computed from the latent activations.

None

**kwargs

Additional keyword arguments to pass to the interpretation method.

{}

Returns:

Type Description
Mapping[int, Any]

Mapping[int, Any]: A mapping between the concepts indices and the interpretation of the concepts.

Source code in interpreto/concepts/base.py
@check_fitted
def interpret(
    self,
    interpretation_method: type[BaseConceptInterpretationMethod],
    concepts_indices: int | list[int] | Literal["all"],
    inputs: list[str] | None = None,
    latent_activations: dict[str, LatentActivations] | LatentActivations | None = None,
    concepts_activations: ConceptsActivations | None = None,
    **kwargs,
) -> Mapping[int, Any]:
    """
    Interpret the concepts dimensions in the latent space into a human-readable format.
    The interpretation is a mapping between the concepts indices and an object allowing to interpret them.
    It can be a label, a description, examples, etc.

    Args:
        interpretation_method: The interpretation method to use to interpret the concepts.
        concepts_indices (int | list[int] | Literal["all"]): The indices of the concepts to interpret.
            If "all", all concepts are interpreted.
        inputs (list[str] | None): The inputs to use for the interpretation.
            Necessary if the source is not `VOCABULARY`, as examples are extracted from the inputs.
        latent_activations (LatentActivations | dict[str, LatentActivations] | None): The latent activations to use for the interpretation.
            Necessary if the source is `LATENT_ACTIVATIONS`.
            Otherwise, it is computed from the inputs or ignored if the source is `CONCEPT_ACTIVATIONS`.
        concepts_activations (ConceptsActivations | None): The concepts activations to use for the interpretation.
            Necessary if the source is not `CONCEPT_ACTIVATIONS`. Otherwise, it is computed from the latent activations.
        **kwargs: Additional keyword arguments to pass to the interpretation method.

    Returns:
        Mapping[int, Any]: A mapping between the concepts indices and the interpretation of the concepts.
    """
    if concepts_indices == "all":
        concepts_indices = list(range(self.concept_model.nb_concepts))

    # verify
    if latent_activations is not None:
        split_latent_activations = self._sanitize_activations(latent_activations)
    else:
        split_latent_activations = None

    # initialize the interpretation method
    method = interpretation_method(
        model_with_split_points=self.model_with_split_points,
        split_point=self.split_point,
        concept_model=self.concept_model,
        **kwargs,
    )

    # compute the interpretation from inputs and activations
    return method.interpret(
        concepts_indices=concepts_indices,
        inputs=inputs,
        latent_activations=split_latent_activations,
        concepts_activations=concepts_activations,
    )

input_concept_attribution

input_concept_attribution(inputs, concept, attribution_method, **attribution_kwargs)

Attributes model inputs for a selected concept.

Parameters:

Name Type Description Default

inputs

ModelInputs

The input data, which can be a string, a list of tokens/words/clauses/sentences or a dataset.

required

concept

int

Index identifying the position of the concept of interest (score in the ConceptsActivations tensor) for which relevant input elements should be retrieved.

required

attribution_method

type[AttributionExplainer]

The attribution method to obtain importance scores for input elements.

required

Returns:

Type Description
list[float]

A list of attribution scores for each input.

Source code in interpreto/concepts/base.py
@check_fitted
def input_concept_attribution(
    self,
    inputs: ModelInputs,
    concept: int,
    attribution_method: type[AttributionExplainer],
    **attribution_kwargs,
) -> list[float]:
    """Attributes model inputs for a selected concept.

    Args:
        inputs (ModelInputs): The input data, which can be a string, a list of tokens/words/clauses/sentences
            or a dataset.
        concept (int): Index identifying the position of the concept of interest (score in the
            `ConceptsActivations` tensor) for which relevant input elements should be retrieved.
        attribution_method: The attribution method to obtain importance scores for input elements.

    Returns:
        A list of attribution scores for each input.
    """
    raise NotImplementedError("Input-to-concept attribution method is not implemented yet.")

concept_output_attribution

concept_output_attribution(inputs, concepts, target, attribution_method, **attribution_kwargs)

Computes the attribution of each concept for the logit of a target output element.

Parameters:

Name Type Description Default

inputs

ModelInputs

An input data-point for the model.

required

concepts

Tensor

Concept activation tensor.

required

target

int

The target class for which the concept output attribution should be computed.

required

attribution_method

type[AttributionExplainer]

The attribution method to obtain importance scores for input elements.

required

Returns:

Type Description
list[float]

A list of attribution scores for each concept.

Source code in interpreto/concepts/base.py
@check_fitted
def concept_output_attribution(
    self,
    inputs: ModelInputs,
    concepts: ConceptsActivations,
    target: int,
    attribution_method: type[AttributionExplainer],
    **attribution_kwargs,
) -> list[float]:
    """Computes the attribution of each concept for the logit of a target output element.

    Args:
        inputs (ModelInputs): An input data-point for the model.
        concepts (torch.Tensor): Concept activation tensor.
        target (int): The target class for which the concept output attribution should be computed.
        attribution_method: The attribution method to obtain importance scores for input elements.

    Returns:
        A list of attribution scores for each concept.
    """
    raise NotImplementedError("Concept-to-output attribution method is not implemented yet.")

Loss Functions

These functions can be passed as the criterion argument in the fit method of the SAEExplainer class. MSELoss is the default loss function.

interpreto.concepts.methods.SAELossClasses

Bases: Enum

Enumeration of possible loss functions for SAEs.

To pass as the criterion parameter of SAEExplainer.fit().

Attributes:

Name Type Description
MSE type[SAELoss]

Mean Squared Error loss.

DeadNeuronsReanimation type[SAELoss]

Loss function promoting reanimation of dead neurons.