Skip to content

TopKInputs or MaxAct

Generalization of maximally activating inputs used by Towards Monosemanticity: Decomposing Language Models With Dictionary Learning by Bricken et al. (2023)

interpreto.concepts.interpretations.TopKInputs

TopKInputs(*, model_with_split_points, concept_model, activation_granularity=WORD, source, split_point=None, k=5)

Bases: BaseConceptInterpretationMethod

Code concepts/interpretations/topk_inputs.py

Implementation of the Top-K Inputs concept interpretation method also called MaxAct. It associate to each concept the inputs that activates it the most. It is the most natural way to interpret a concept, as it is the most natural way to explain a concept. Hence several papers used it without describing it. Nonetheless, we can reference Bricken et al. (2023) 1 from Anthropic for their post on transformer-circuits.


  1. Trenton Bricken, Adly Templeton, Joshua Batson, Brian Chen, Adam Jermyn*, Tom Conerly, Nicholas L Turner, Cem Anil, Carson Denison, Amanda Askell, Robert Lasenby, Yifan Wu, Shauna Kravec, Nicholas Schiefer, Tim Maxwell, Nicholas Joseph, Alex Tamkin, Karina Nguyen, Brayden McLean, Josiah E Burke, Tristan Hume, Shan Carter, Tom Henighan, Chris Olah Towards Monosemanticity: Decomposing Language Models With Dictionary Learning Transformer Circuits, 2023. 

Attributes:

Name Type Description
model_with_split_points ModelWithSplitPoints

The model with split points to use for the interpretation.

split_point str

The split point to use for the interpretation.

concept_model ConceptModelProtocol

The concept model to use for the interpretation.

activation_granularity ActivationGranularity

The granularity at which the interpretation is computed. Allowed values are TOKEN, WORD, SENTENCE, and SAMPLE. Ignored for source VOCABULARY.

source InterpretationSources

In any case, TopKInputs requires concept-activations and inputs. But depending on the available variable, you will or will not have to recompute all of this activations. The source correspond to starting from which activations should be computed. Supported sources are

- `CONCEPTS_ACTIVATIONS`: if you already have the concept activations corresponding to the inputs, you can use this.

- `LATENT_ACTIVATIONS`: in most case you have computed latent activation to fit the concept explainer, if the granularity is the same, you can use them and not recompute the whole thing.

- `INPUTS`: activations are computed from the text inputs, you can specify the granularity freely.

- `VOCABULARY`: consider the tokenizer vocabulary tokens as inputs. It forces a `TOKEN` granularity.
k int

The number of inputs to use for the interpretation.

Examples:

>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import NeuronsAsConcepts
>>> from interpreto.concepts.interpretations import TopKInputs
>>> # load and split the model
>>> split = "bert.encoder.layer.1.output"
>>> model_with_split_points = ModelWithSplitPoints(
...     "hf-internal-testing/tiny-random-bert",
...     split_points=[split],
...     model_autoclass=AutoModelForMaskedLM,
...     batch_size=4,
... )
>>> # NeuronsAsConcepts do not need to be fitted
>>> concept_model = NeuronsAsConcepts(model_with_split_points=model_with_split_points, split_point=split)
>>> # extracting concept interpretations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> all_top_k_words = concept_model.interpret(
...     interpretation_method=TopKInputs,
...     activation_granularity=TopKInputs.activation_granularities.WORD,
...     source=TopKInputs.sources.INPUTS,
...     k=2,
...     concepts_indices="all",
...     inputs=dataset,
...     latent_activations=activations,
... )
Source code in interpreto/concepts/interpretations/topk_inputs.py
def __init__(
    self,
    *,
    model_with_split_points: ModelWithSplitPoints,
    concept_model: ConceptModelProtocol,
    activation_granularity: ActivationGranularity = ActivationGranularity.WORD,
    source: InterpretationSources,
    split_point: str | None = None,
    k: int = 5,
):
    super().__init__(
        model_with_split_points=model_with_split_points, concept_model=concept_model, split_point=split_point
    )

    if source not in InterpretationSources:
        raise ValueError(f"The source {source} is not supported. Supported sources: {InterpretationSources}")

    if activation_granularity not in (
        ActivationGranularity.TOKEN,
        ActivationGranularity.WORD,
        ActivationGranularity.SENTENCE,
        ActivationGranularity.SAMPLE,
    ):
        raise ValueError(
            f"The granularity {activation_granularity} is not supported. Supported `activation_granularities`: TOKEN, WORD, SENTENCE, and SAMPLE"
        )

    self.activation_granularity = activation_granularity
    self.source = source
    self.k = k

interpret

Give the interpretation of the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them. The granularity of input examples is determined by the activation_granularity class attribute.

The returned inputs are the most activating inputs for the concepts.

The required arguments depend on the source class attribute.

If all activations are zero, the corresponding concept interpretation is set to None.

Parameters:

Name Type Description Default

concepts_indices

int | list[int]

The indices of the concepts to interpret.

required

inputs

list[str] | None

The inputs to use for the interpretation. Necessary if the source is not VOCABULARY, as examples are extracted from the inputs.

None

latent_activations

Float[Tensor, 'nl d'] | None

The latent activations to use for the interpretation. Necessary if the source is LATENT_ACTIVATIONS. Otherwise, it is computed from the inputs or ignored if the source is CONCEPT_ACTIVATIONS.

None

concepts_activations

Float[Tensor, 'nl cpt'] | None

The concepts activations to use for the interpretation. Necessary if the source is not CONCEPT_ACTIVATIONS. Otherwise, it is computed from the latent activations.

None

Returns:

Type Description
Mapping[int, Any]

Mapping[int, Any]: The interpretation of the concepts indices.

Raises:

Type Description
ValueError

If the arguments do not correspond to the specified source.

Source code in interpreto/concepts/interpretations/topk_inputs.py
def interpret(
    self,
    concepts_indices: int | list[int],
    inputs: list[str] | None = None,
    latent_activations: LatentActivations | None = None,
    concepts_activations: ConceptsActivations | None = None,
) -> Mapping[int, Any]:
    """
    Give the interpretation of the concepts dimensions in the latent space into a human-readable format.
    The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them.
    The granularity of input examples is determined by the `activation_granularity` class attribute.

    The returned inputs are the most activating inputs for the concepts.

    The required arguments depend on the `source` class attribute.

    If all activations are zero, the corresponding concept interpretation is set to `None`.

    Args:
        concepts_indices (int | list[int]): The indices of the concepts to interpret.
        inputs (list[str] | None): The inputs to use for the interpretation.
            Necessary if the source is not `VOCABULARY`, as examples are extracted from the inputs.
        latent_activations (Float[torch.Tensor, "nl d"] | None): The latent activations to use for the interpretation.
            Necessary if the source is `LATENT_ACTIVATIONS`.
            Otherwise, it is computed from the inputs or ignored if the source is `CONCEPT_ACTIVATIONS`.
        concepts_activations (Float[torch.Tensor, "nl cpt"] | None): The concepts activations to use for the interpretation.
            Necessary if the source is not `CONCEPT_ACTIVATIONS`. Otherwise, it is computed from the latent activations.

    Returns:
        Mapping[int, Any]: The interpretation of the concepts indices.

    Raises:
        ValueError: If the arguments do not correspond to the specified source.
    """
    # compute the concepts activations from the provided source, can also create inputs from the vocabulary
    sure_inputs: list[str]  # Verified by concepts_activations_from_source
    sure_concepts_activations: Float[torch.Tensor, "ng cpt"]  # Verified by concepts_activations_from_source
    sure_inputs, sure_concepts_activations = self._concepts_activations_from_source(
        inputs, latent_activations, concepts_activations
    )

    concepts_indices = self._verify_concepts_indices(
        concepts_activations=sure_concepts_activations, concepts_indices=concepts_indices
    )

    granular_inputs: list[str]  # len: ng, inputs becomes a list of elements extracted from the examples
    granular_inputs = self._get_granular_inputs(sure_inputs)
    if len(granular_inputs) != len(sure_concepts_activations):
        if latent_activations is not None and len(granular_inputs) != len(latent_activations):
            raise ValueError(
                f"The lengths of the granulated inputs do not match le number of provided latent activations {len(granular_inputs)} != {len(latent_activations)}"
                "If you provide latent activations, make sure they have the same granularity as the inputs."
            )
        if concepts_activations is not None and len(granular_inputs) != len(concepts_activations):
            raise ValueError(
                f"The lengths of the granulated inputs do not match le number of provided concepts activations {len(granular_inputs)} != {len(concepts_activations)}"
                "If you provide concepts activations, make sure they have the same granularity as the inputs."
            )
        raise ValueError(
            f"The lengths of the granulated inputs do not match le number of concepts activations {len(granular_inputs)} != {len(sure_concepts_activations)}"
        )

    return self._topk_inputs_from_concepts_activations(
        inputs=granular_inputs,
        concepts_activations=sure_concepts_activations,
        concepts_indices=concepts_indices,
    )