Skip to content

TopKInputs or MaxAct

Generalization of maximally activating inputs used by Towards Monosemanticity: Decomposing Language Models With Dictionary Learning by Bricken et al. (2023)

interpreto.concepts.interpretations.TopKInputs

TopKInputs(*, model_with_split_points, concept_model, activation_granularity=WORD, split_point=None, k=5, use_vocab=False)

Bases: BaseConceptInterpretationMethod

Code concepts/interpretations/topk_inputs.py

Implementation of the Top-K Inputs concept interpretation method also called MaxAct. It associate to each concept the inputs that activates it the most. It is the most natural way to interpret a concept, as it is the most natural way to explain a concept. Hence several papers used it without describing it. Nonetheless, we can reference Bricken et al. (2023) 1 from Anthropic for their post on transformer-circuits.


  1. Trenton Bricken, Adly Templeton, Joshua Batson, Brian Chen, Adam Jermyn*, Tom Conerly, Nicholas L Turner, Cem Anil, Carson Denison, Amanda Askell, Robert Lasenby, Yifan Wu, Shauna Kravec, Nicholas Schiefer, Tim Maxwell, Nicholas Joseph, Alex Tamkin, Karina Nguyen, Brayden McLean, Josiah E Burke, Tristan Hume, Shan Carter, Tom Henighan, Chris Olah Towards Monosemanticity: Decomposing Language Models With Dictionary Learning Transformer Circuits, 2023. 

Attributes:

Name Type Description
model_with_split_points ModelWithSplitPoints

The model with split points to use for the interpretation.

split_point str

The split point to use for the interpretation.

concept_model ConceptModelProtocol

The concept model to use for the interpretation.

activation_granularity ActivationGranularity

The granularity at which the interpretation is computed. Allowed values are TOKEN, WORD, SENTENCE, and SAMPLE. Ignored when use_vocab=True.

k int

The number of inputs to use for the interpretation.

use_vocab bool

If True, the interpretation will be computed from the vocabulary of the model.

Examples:

>>> from datasets import load_dataset
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import NeuronsAsConcepts
>>> from interpreto.concepts.interpretations import TopKInputs
>>> # load and split the model
>>> split = "bert.encoder.layer.1.output"
>>> model_with_split_points = ModelWithSplitPoints(
...     "hf-internal-testing/tiny-random-bert",
...     split_points=[split],
...     model_autoclass=AutoModelForMaskedLM,
...     batch_size=4,
... )
>>> # NeuronsAsConcepts do not need to be fitted
>>> concept_model = NeuronsAsConcepts(model_with_split_points=model_with_split_points, split_point=split)
>>> # extracting concept interpretations
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> all_top_k_words = concept_model.interpret(
...     interpretation_method=TopKInputs,
...     activation_granularity=TopKInputs.activation_granularities.WORD,
...     k=2,
...     concepts_indices="all",
...     inputs=dataset,
...     latent_activations=activations,
... )
Source code in interpreto/concepts/interpretations/topk_inputs.py
def __init__(
    self,
    *,
    model_with_split_points: ModelWithSplitPoints,
    concept_model: ConceptModelProtocol,
    activation_granularity: ActivationGranularity = ActivationGranularity.WORD,
    split_point: str | None = None,
    k: int = 5,
    use_vocab: bool = False,
):
    super().__init__(
        model_with_split_points=model_with_split_points,
        concept_model=concept_model,
        split_point=split_point,
        activation_granularity=activation_granularity,
    )

    if activation_granularity not in (
        ActivationGranularity.TOKEN,
        ActivationGranularity.WORD,
        ActivationGranularity.SENTENCE,
        ActivationGranularity.SAMPLE,
    ):
        raise ValueError(
            f"The granularity {activation_granularity} is not supported. Supported `activation_granularities`: TOKEN, WORD, SENTENCE, and SAMPLE"
        )

    self.k = k
    self.use_vocab = use_vocab

interpret

Give the interpretation of the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them. The granularity of input examples is determined by the activation_granularity class attribute.

The returned inputs are the most activating inputs for the concepts.

If all activations are zero, the corresponding concept interpretation is set to None.

Parameters:

Name Type Description Default

concepts_indices

int | list[int]

The indices of the concepts to interpret.

required

inputs

list[str] | None

The inputs to use for the interpretation. Necessary if not use_vocab, as examples are extracted from the inputs.

None

latent_activations

Float[Tensor, 'nl d'] | None

The latent activations matching the inputs. If not provided, it is computed from the inputs.

None

concepts_activations

Float[Tensor, 'nl cpt'] | None

The concepts activations matching the inputs. If not provided, it is computed from the inputs or latent activations.

None

Returns:

Type Description
Mapping[int, Any]

Mapping[int, Any]: The interpretation of the concepts indices.

Source code in interpreto/concepts/interpretations/topk_inputs.py
def interpret(
    self,
    concepts_indices: int | list[int],
    inputs: list[str] | None = None,
    latent_activations: LatentActivations | None = None,
    concepts_activations: ConceptsActivations | None = None,
) -> Mapping[int, Any]:
    """
    Give the interpretation of the concepts dimensions in the latent space into a human-readable format.
    The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them.
    The granularity of input examples is determined by the `activation_granularity` class attribute.

    The returned inputs are the most activating inputs for the concepts.

    If all activations are zero, the corresponding concept interpretation is set to `None`.

    Args:
        concepts_indices (int | list[int]): The indices of the concepts to interpret.
        inputs (list[str] | None): The inputs to use for the interpretation.
            Necessary if not `use_vocab`, as examples are extracted from the inputs.
        latent_activations (Float[torch.Tensor, "nl d"] | None): The latent activations matching the inputs. If not provided, it is computed from the inputs.
        concepts_activations (Float[torch.Tensor, "nl cpt"] | None): The concepts activations matching the inputs. If not provided, it is computed from the inputs or latent activations.

    Returns:
        Mapping[int, Any]: The interpretation of the concepts indices.

    """
    # compute the concepts activations from the provided source, can also create inputs from the vocabulary
    if self.use_vocab:
        sure_inputs, sure_concepts_activations = self.concepts_activations_from_vocab()
        granular_inputs = sure_inputs
    else:
        if inputs is None:
            raise ValueError("Inputs must be provided when `use_vocab` is False.")
        sure_inputs = inputs
        sure_concepts_activations = self.concepts_activations_from_source(
            inputs=inputs,
            latent_activations=latent_activations,
            concepts_activations=concepts_activations,
        )
        granular_inputs, _ = self.get_granular_inputs(sure_inputs)

    concepts_indices = verify_concepts_indices(
        concepts_activations=sure_concepts_activations, concepts_indices=concepts_indices
    )
    verify_granular_inputs(
        granular_inputs=granular_inputs,
        sure_concepts_activations=sure_concepts_activations,
        latent_activations=latent_activations,
        concepts_activations=concepts_activations,
    )

    return self._topk_inputs_from_concepts_activations(
        inputs=granular_inputs,
        concepts_activations=sure_concepts_activations,
        concepts_indices=concepts_indices,
    )