Skip to content

TopKInputs

TopKInputs identifies the most activating inputs (tokens, words, sentences, or samples) for each concept globally. It provides a global interpretation by finding which elements in your dataset best characterize each concept direction.

Quick Example

from interpreto.concepts.interpretations import TopKInputs
from interpreto.concepts.splitters.model_with_split_points import ActivationGranularity

topk = TopKInputs(
    concept_explainer=concept_explainer,
    k=5,
    activation_granularity=ActivationGranularity.CLS_TOKEN,
    use_unique_words=3,  # consider all 3-grams as unique words
)

topk_words = topk.interpret(inputs=dataset, concepts_indices="all")

API Reference

interpreto.concepts.interpretations.TopKInputs

Bases: BaseConceptInterpretationMethod

Code concepts/interpretations/topk_inputs.py

Implementation of the Top-K Inputs concept interpretation method also called MaxAct, or CMAW. It associate to each concept the inputs that activates it the most. It is the most natural way to interpret a concept, as it is the most natural way to explain a concept. Hence several papers used it without describing it. Nonetheless, we can reference Bricken et al. (2023) 1 from Anthropic for their post on transformer-circuits.


  1. Trenton Bricken, Adly Templeton, Joshua Batson, Brian Chen, Adam Jermyn*, Tom Conerly, Nicholas L Turner, Cem Anil, Carson Denison, Amanda Askell, Robert Lasenby, Yifan Wu, Shauna Kravec, Nicholas Schiefer, Tim Maxwell, Nicholas Joseph, Alex Tamkin, Karina Nguyen, Brayden McLean, Josiah E Burke, Tristan Hume, Shan Carter, Tom Henighan, Chris Olah Towards Monosemanticity: Decomposing Language Models With Dictionary Learning Transformer Circuits, 2023. 

Parameters:

Name Type Description Default

concept_explainer

ConceptEncoderExplainer

The concept explainer built on top of a ModelWithSplitPoints.

required

activation_granularity

ActivationGranularity

The granularity of the activations to use for the interpretation. See :method:interpreto.concepts.splitters.model_with_split_points.ModelWithSplitPoints.get_activations for more details.

None

aggregation_strategy

GranularityAggregationStrategy

The aggregation strategy to use for the activations. See :method:interpreto.concepts.splitters.model_with_split_points.ModelWithSplitPoints.get_activations for more details.

MEAN

concept_encoding_batch_size

int

The batch size to use for the concept encoding.

1024

k

int

The number of inputs to use for the interpretation.

5

use_vocab

bool

If True, the interpretation will be computed from the vocabulary of the model.

False

use_unique_words

bool

If True, the interpretation will be computed from the unique words of the inputs. Incompatible with use_vocab=True. Default unique words selects all different word from the input. It can be tuned through the unique_words_kwargs argument.

0

unique_words_kwargs

dict

The kwargs to pass to the extract_ngrams function. See extract_ngrams for more details. Possible arguments are count_min_threshold, lemmatize, words_to_ignore.

{}

Examples:

Minimal example, finding the topk tokens activating a neuron:

>>> from transformers import AutoModelForCausalLM
>>>
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import NeuronsAsConcepts, TopKInputs
>>>
>>> # load and split the the GPT2 model
>>> mwsp = ModelWithSplitPoints(
...     "gpt2",
...     split_point=11,           # split at the 12th layer
...     automodel=AutoModelForCausalLM,
...     device_map="auto",
...     batch_size=2048,
... )
>>>
>>> # Use `NeuronsAsConcepts` to use the concept-based pipeline with neurons
>>> concept_explainer = NeuronsAsConcepts(mwsp)
>>>
>>> method = TopKInputs(
...     concept_explainer=concept_explainer,
...     use_vocab=True,             # use the vocabulary of the model and test all tokens (50257 with GPT2)
...     k=10,                       # get the top 10 tokens for each neuron
... )
>>>
>>> topk_tokens = method.interpret(
...     concepts_indices="all",     # interpret the three first neurons of the 7th layer
... )
>>>
>>> print(list(topk_tokens[1].keys()))
['hostages', 'choke', 'infring', 'herpes', 'nuns', 'phylogen', 'watched', 'alitarian', 'tattoos', 'fisher']
>>> # Results are not interpretable, due to superposition and such.
>>> # This is why we use dictionary to find concept direction!

Classification example, we should fit concepts on the [CLS] token activations, then use TopKInputs with use_unique_words=True and activation_granularity=CSL_TOKEN:

>>> from datasets import load_dataset
>>> from transformers import AutoModelForSequenceClassification
>>>
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import ICAConcepts, TopKInputs
>>>
>>> CLS_TOKEN = ModelWithSplitPoints.activation_granularities.CLS_TOKEN
>>>
>>> # load and split an IMDB classification model
>>> mwsp = ModelWithSplitPoints(
...     "textattack/bert-base-uncased-imdb",
...     split_point=11,              # split at the last layer
...     automodel=AutoModelForSequenceClassification,
...     device_map="cuda",
...     batch_size=64,
... )
>>>
>>> # load the IMDB dataset and compute a dataset of [CLS] token activations
>>> imdb = load_dataset("stanfordnlp/imdb", split="train")["text"][:1000]
>>> activations, _ = mwsp.get_activations(imdb, activation_granularity=CLS_TOKEN)
>>>
>>> # Load an fit a concept-based explainer
>>> concept_explainer = ICAConcepts(mwsp, nb_concepts=20)
>>> concept_explainer.fit(activations)
>>>
>>> method = TopKInputs(
...     concept_explainer=concept_explainer,
...     activation_granularity=CLS_TOKEN,
...     k=5,                            # get the top 10 tokens for each concept
...     use_unique_words=True,          # necessary to get topk words on the [CLS] token
...     unique_words_kwargs={
...         "count_min_threshold": 5,   # only consider words that appear at least 5 times in the dataset
...         "lemmatize": True,          # group words by their lemma (e.g., "bad" and "badly" are grouped together)
...     }
... )
>>>
>>> topk_words = method.interpret(
...     inputs=imdb,
...     concepts_indices="all",     # interpret the three first neurons of the 7th layer
... )
>>>
>>> print(list(topk_words[1].keys()))
['bad', 'bad.', 'hackneyed', 'clichéd', 'cannibal']

Generation example, use either TOKEN or WORD granularity for activations. WORD allows to select the topk words for each concept without recomputing the activations.

>>> from datasets import load_dataset
>>> from transformers import AutoModelForCausalLM
>>>
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import ICAConcepts, TopKInputs
>>>
>>> WORD = ModelWithSplitPoints.activation_granularities.WORD
>>>
>>> # load and split the the GPT2 model
>>> mwsp = ModelWithSplitPoints(
...     "Qwen/Qwen3-0.6B",
...     split_point=9,              # split at the 10th layer
...     automodel=AutoModelForCausalLM,
...     device_map="auto",
...     batch_size=16,
... )
>>>
>>> # load the IMDB dataset and compute a dataset of words activations
>>> imdb = load_dataset("stanfordnlp/imdb", split="train")["text"][:1000]
>>> activations, _ = mwsp.get_activations(imdb, activation_granularity=WORD)
>>>
>>> # Load an fit a concept-based explainer
>>> concept_explainer = ICAConcepts(mwsp, nb_concepts=10)
>>> concept_explainer.fit(activations)
>>>
>>> method = TopKInputs(
...     concept_explainer=concept_explainer,
...     activation_granularity=WORD,    # we want the topk words for each concept
...     k=10,                           # get the top 10 words for each concept
... )
>>>
>>> topk_tokens = method.interpret(
...     concepts_indices="all",     # interpret the three first neurons of the 7th layer
...     inputs=imdb,
...     latent_activations=activations, # use previously computed activations (same granularity)
... )
Source code in interpreto/concepts/interpretations/topk_inputs.py
def __init__(
    self,
    *,
    concept_explainer: ConceptEncoderExplainer,
    activation_granularity: ActivationGranularity | None = None,
    aggregation_strategy: GranularityAggregationStrategy = GranularityAggregationStrategy.MEAN,
    concept_encoding_batch_size: int = 1024,
    k: int = 5,
    use_vocab: bool = False,
    use_unique_words: bool | int = 0,
    unique_words_kwargs: dict = {},
):
    super().__init__(
        concept_explainer=concept_explainer,
        activation_granularity=activation_granularity,
        aggregation_strategy=aggregation_strategy,
        concept_encoding_batch_size=concept_encoding_batch_size,
        use_vocab=use_vocab,
        use_unique_words=use_unique_words,
        unique_words_kwargs=unique_words_kwargs,
    )

    self.k = k

interpret

Give the interpretation of the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them. The granularity of input examples is determined by the activation_granularity class attribute.

The returned inputs are the most activating inputs for the concepts.

If all activations are zero, the corresponding concept interpretation is set to None.

Parameters:

Name Type Description Default

concepts_indices

int | list[int] | Literal['all']

The indices of the concepts to interpret. If "all", all concepts are interpreted.

'all'

inputs

list[str] | None

The inputs to use for the interpretation. Necessary if not use_vocab,as examples are extracted from the inputs.

None

latent_activations

Float[Tensor, 'nl d'] | None

The latent activations matching the inputs. If not provided, it is computed from the inputs.

None

concepts_activations

Float[Tensor, 'nl cpt'] | None

The concepts activations matching the inputs. If not provided, it is computed from the inputs or latent activations.

None

Returns:

Type Description
Mapping[int, Any]

Mapping[int, Any]: The interpretation of the concepts indices.

Source code in interpreto/concepts/interpretations/topk_inputs.py
def interpret(
    self,
    concepts_indices: int | list[int] | Literal["all"] = "all",
    inputs: list[str] | None = None,
    latent_activations: LatentActivations | None = None,
    concepts_activations: ConceptsActivations | None = None,
) -> Mapping[int, Any]:
    """
    Give the interpretation of the concepts dimensions in the latent space into a human-readable format.
    The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them.
    The granularity of input examples is determined by the `activation_granularity` class attribute.

    The returned inputs are the most activating inputs for the concepts.

    If all activations are zero, the corresponding concept interpretation is set to `None`.

    Args:
        concepts_indices (int | list[int] | Literal["all"]):
            The indices of the concepts to interpret. If "all", all concepts are interpreted.

        inputs (list[str] | None):
            The inputs to use for the interpretation.
            Necessary if not `use_vocab`,as examples are extracted from the inputs.

        latent_activations (Float[torch.Tensor, "nl d"] | None):
            The latent activations matching the inputs. If not provided,
            it is computed from the inputs.

        concepts_activations (Float[torch.Tensor, "nl cpt"] | None):
            The concepts activations matching the inputs. If not provided,
            it is computed from the inputs or latent activations.

    Returns:
        Mapping[int, Any]: The interpretation of the concepts indices.

    """
    sure_concepts_indices, granular_inputs, sure_concepts_activations, _ = (
        self.get_granular_inputs_and_concept_activations(
            concepts_indices=concepts_indices,
            inputs=inputs,
            latent_activations=latent_activations,
            concepts_activations=concepts_activations,
        )
    )
    sure_concepts_indices: list[int]
    granular_inputs: list[str]
    sure_concepts_activations: torch.Tensor

    return self._topk_inputs_from_concepts_activations(
        inputs=granular_inputs,
        concepts_activations=sure_concepts_activations,
        concepts_indices=sure_concepts_indices,
    )

interpreto.concepts.interpretations.extract_ngrams

extract_ngrams(inputs, n=1, count_min_threshold=1, return_counts=False, lemmatize=False, words_to_ignore=None)

Extract n-grams (from 1-gram up to n-gram of words) from a list of texts.

If n=3, it extracts 1-grams, 2-grams, and 3-grams.

Parameters:

Name Type Description Default

inputs

Iterable[str]

The texts to extract n-grams from.

required

n

int

The maximum n-gram size. All sizes from 1 to n are extracted.

1

count_min_threshold

int

The minimum total number of occurrences of an n-gram in the whole inputs.

1

return_counts

bool

Whether to return the counts of each n-gram. Defaults to False.

False

lemmatize

bool

Whether to lemmatize words before counting.

False

words_to_ignore

list[str] | None

A list of words to ignore (applied to individual tokens before forming n-grams).

None

Returns:

Type Description
list[str] | Counter[str]

list[str] | Counter[str]: The list of unique n-grams or the counts of each n-gram.

Source code in interpreto/concepts/interpretations/base.py
@jaxtyped(typechecker=beartype)
def extract_ngrams(
    inputs: Iterable[str],
    n: int = 1,
    count_min_threshold: int = 1,
    return_counts: bool = False,
    lemmatize: bool = False,
    words_to_ignore: list[str] | None = None,
) -> list[str] | Counter[str]:
    """
    Extract n-grams (from 1-gram up to n-gram of words) from a list of texts.

    If n=3, it extracts 1-grams, 2-grams, and 3-grams.

    Args:
        inputs (Iterable[str]):
            The texts to extract n-grams from.

        n (int):
            The maximum n-gram size. All sizes from 1 to n are extracted.

        count_min_threshold (int, optional):
            The minimum total number of occurrences of an n-gram in the whole `inputs`.

        return_counts (bool, optional):
            Whether to return the counts of each n-gram.
            Defaults to False.

        lemmatize (bool, optional):
            Whether to lemmatize words before counting.

        words_to_ignore (list[str] | None, optional):
            A list of words to ignore (applied to individual tokens before forming n-grams).

    Returns:
        list[str] | Counter[str]:
            The list of unique n-grams or the counts of each n-gram.
    """
    _ensure_nltk_resources(lemmatize=lemmatize)

    if lemmatize:
        lemmatizer = WordNetLemmatizer()

    tuple_ngram_counts: Counter[tuple[str]] = Counter()

    for text in inputs:
        tokens = word_tokenize(text)

        # preprocess tokens
        processed = []
        for word in tokens:
            if lemmatize:
                word = lemmatizer.lemmatize(word.lower())  # noqa: PLW2901  # type: ignore  (ignore possibly unbound)
            if words_to_ignore is not None and word in words_to_ignore:
                continue
            processed.append(word)
            tuple_ngram_counts[(word,)] += 1  # unigram tuple

        for size in range(2, n + 1):  # skips size 1 as covered over
            for i in range(len(processed) - size + 1):
                tuple_ngram_counts[tuple(processed[i : i + size])] += 1  # >1-gram tuples

    str_ngram_counts: Counter[str] = Counter(
        {
            " ".join(key): count  # convert ngram tuples to strings
            for key, count in tuple_ngram_counts.items()
            if count >= count_min_threshold  # filter too rare n-grams
        }
    )

    if return_counts:
        return str_ngram_counts

    return list(str_ngram_counts.keys())