Skip to content

Concept Interpretations

interpreto.concepts.interpretations.TopKInputs

Bases: BaseConceptInterpretationMethod

Code concepts/interpretations/topk_inputs.py

Implementation of the Top-K Inputs concept interpretation method also called MaxAct, or CMAW. It associate to each concept the inputs that activates it the most. It is the most natural way to interpret a concept, as it is the most natural way to explain a concept. Hence several papers used it without describing it. Nonetheless, we can reference Bricken et al. (2023) 1 from Anthropic for their post on transformer-circuits.


  1. Trenton Bricken, Adly Templeton, Joshua Batson, Brian Chen, Adam Jermyn*, Tom Conerly, Nicholas L Turner, Cem Anil, Carson Denison, Amanda Askell, Robert Lasenby, Yifan Wu, Shauna Kravec, Nicholas Schiefer, Tim Maxwell, Nicholas Joseph, Alex Tamkin, Karina Nguyen, Brayden McLean, Josiah E Burke, Tristan Hume, Shan Carter, Tom Henighan, Chris Olah Towards Monosemanticity: Decomposing Language Models With Dictionary Learning Transformer Circuits, 2023. 

Parameters:

Name Type Description Default

concept_explainer

ConceptEncoderExplainer

The concept explainer built on top of a ModelWithSplitPoints.

required

activation_granularity

ActivationGranularity

The granularity at which the interpretation is computed. Allowed values are CLS_TOKEN, TOKEN, WORD, SENTENCE, and SAMPLE. Ignored when use_vocab=True.

WORD

concept_encoding_batch_size

int

The batch size to use for the concept encoding.

1024

k

int

The number of inputs to use for the interpretation.

5

use_vocab

bool

If True, the interpretation will be computed from the vocabulary of the model.

False

use_unique_words

bool

If True, the interpretation will be computed from the unique words of the inputs. Incompatible with use_vocab=True. Default unique words selects all different word from the input. It can be tuned through the unique_words_kwargs argument.

False

unique_words_kwargs

dict

The kwargs to pass to the extract_unique_words function. See extract_unique_words for more details. Possible arguments are count_min_threshold, lemmatize, words_to_ignore.

{}

concept_model_device

device | str | None

The device to use for the concept model forward pass. If None, does not change the device.

None

Examples:

Minimal example, finding the topk tokens activating a neuron:

>>> from transformers import AutoModelForCausalLM
>>>
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import NeuronsAsConcepts, TopKInputs
>>>
>>> # load and split the the GPT2 model
>>> mwsp = ModelWithSplitPoints(
...     "gpt2",
...     split_points=[11],           # split at the 12th layer
...     automodel=AutoModelForCausalLM,
...     device_map="auto",
...     batch_size=2048,
... )
>>>
>>> # Use `NeuronsAsConcepts` to use the concept-based pipeline with neurons
>>> concept_explainer = NeuronsAsConcepts(mwsp)
>>>
>>> method = TopKInputs(
...     concept_explainer=concept_explainer,
...     use_vocab=True,             # use the vocabulary of the model and test all tokens (50257 with GPT2)
...     k=10,                       # get the top 10 tokens for each neuron
... )
>>>
>>> topk_tokens = method.interpret(
...     concepts_indices="all",     # interpret the three first neurons of the 7th layer
... )
>>>
>>> print(list(topk_tokens[1].keys()))
['hostages', 'choke', 'infring', 'herpes', 'nuns', 'phylogen', 'watched', 'alitarian', 'tattoos', 'fisher']
>>> # Results are not interpretable, due to superposition and such.
>>> # This is why we use dictionary to find concept direction!

Classification example, we should fit concepts on the [CLS] token activations, then use TopKInputs with use_unique_words=True and activation_granularity=CSL_TOKEN:

>>> from datasets import load_dataset
>>> from transformers import AutoModelForSequenceClassification
>>>
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import ICAConcepts, TopKInputs
>>>
>>> CLS_TOKEN = ModelWithSplitPoints.activation_granularities.CLS_TOKEN
>>>
>>> # load and split an IMDB classification model
>>> mwsp = ModelWithSplitPoints(
...     "textattack/bert-base-uncased-imdb",
...     split_points=[11],              # split at the last layer
...     automodel=AutoModelForSequenceClassification,
...     device_map="cuda",
...     batch_size=64,
... )
>>>
>>> # load the IMDB dataset and compute a dataset of [CLS] token activations
>>> imdb = load_dataset("stanfordnlp/imdb", split="train")["text"][:1000]
>>> activations = mwsp.get_activations(imdb, activation_granularity=CLS_TOKEN)
>>>
>>> # Load an fit a concept-based explainer
>>> concept_explainer = ICAConcepts(mwsp, nb_concepts=20)
>>> concept_explainer.fit(activations)
>>>
>>> method = TopKInputs(
...     concept_explainer=concept_explainer,
...     activation_granularity=CLS_TOKEN,
...     k=5,                            # get the top 10 tokens for each concept
...     use_unique_words=True,          # necessary to get topk words on the [CLS] token
...     unique_words_kwargs={
...         "count_min_threshold": 5,   # only consider words that appear at least 5 times in the dataset
...         "lemmatize": True,          # group words by their lemma (e.g., "bad" and "badly" are grouped together)
...     }
... )
>>>
>>> topk_words = method.interpret(
...     inputs=imdb,
...     concepts_indices="all",     # interpret the three first neurons of the 7th layer
... )
>>>
>>> print(list(topk_words[1].keys()))
['bad', 'bad.', 'hackneyed', 'clichéd', 'cannibal']

Generation example, use either TOKEN or WORD granularity for activations. WORD allows to select the topk words for each concept without recomputing the activations.

>>> from datasets import load_dataset
>>> from transformers import AutoModelForCausalLM
>>>
>>> from interpreto import ModelWithSplitPoints
>>> from interpreto.concepts import ICAConcepts, TopKInputs
>>>
>>> WORD = ModelWithSplitPoints.activation_granularities.WORD
>>>
>>> # load and split the the GPT2 model
>>> mwsp = ModelWithSplitPoints(
...     "Qwen/Qwen3-0.6B",
...     split_points=[9],              # split at the 10th layer
...     automodel=AutoModelForCausalLM,
...     device_map="auto",
...     batch_size=16,
... )
>>>
>>> # load the IMDB dataset and compute a dataset of words activations
>>> imdb = load_dataset("stanfordnlp/imdb", split="train")["text"][:1000]
>>> activations = mwsp.get_activations(imdb, activation_granularity=WORD)
>>>
>>> # Load an fit a concept-based explainer
>>> concept_explainer = ICAConcepts(mwsp, nb_concepts=10)
>>> concept_explainer.fit(activations)
>>>
>>> method = TopKInputs(
...     concept_explainer=concept_explainer,
...     activation_granularity=WORD,    # we want the topk words for each concept
...     k=10,                           # get the top 10 words for each concept
...     device="cuda",
... )
>>>
>>> topk_tokens = method.interpret(
...     concepts_indices="all",     # interpret the three first neurons of the 7th layer
...     inputs=imdb,
...     latent_activations=activations, # use previously computed activations (same granularity)
... )
Source code in interpreto/concepts/interpretations/topk_inputs.py
def __init__(
    self,
    *,
    concept_explainer: ConceptEncoderExplainer,
    activation_granularity: ActivationGranularity = ActivationGranularity.WORD,
    concept_encoding_batch_size: int = 1024,
    k: int = 5,
    use_vocab: bool = False,
    use_unique_words: bool = False,
    unique_words_kwargs: dict = {},
    concept_model_device: torch.device | str | None = None,
):
    super().__init__(
        concept_explainer=concept_explainer,
        activation_granularity=activation_granularity,
        concept_encoding_batch_size=concept_encoding_batch_size,
        use_vocab=use_vocab,
        use_unique_words=use_unique_words,
        unique_words_kwargs=unique_words_kwargs,
        concept_model_device=concept_model_device,
    )

    self.k = k

interpret

Give the interpretation of the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them. The granularity of input examples is determined by the activation_granularity class attribute.

The returned inputs are the most activating inputs for the concepts.

If all activations are zero, the corresponding concept interpretation is set to None.

Parameters:

Name Type Description Default

concepts_indices

int | list[int] | Literal['all']

The indices of the concepts to interpret. If "all", all concepts are interpreted.

required

inputs

list[str] | None

The inputs to use for the interpretation. Necessary if not use_vocab,as examples are extracted from the inputs.

None

latent_activations

dict[str, Tensor] | Float[Tensor, 'nl d'] | None

The latent activations matching the inputs. If not provided, it is computed from the inputs.

None

concepts_activations

Float[Tensor, 'nl cpt'] | None

The concepts activations matching the inputs. If not provided, it is computed from the inputs or latent activations.

None

Returns:

Type Description
Mapping[int, Any]

Mapping[int, Any]: The interpretation of the concepts indices.

Source code in interpreto/concepts/interpretations/topk_inputs.py
def interpret(
    self,
    concepts_indices: int | list[int] | Literal["all"],
    inputs: list[str] | None = None,
    latent_activations: dict[str, torch.Tensor] | LatentActivations | None = None,
    concepts_activations: ConceptsActivations | None = None,
) -> Mapping[int, Any]:
    """
    Give the interpretation of the concepts dimensions in the latent space into a human-readable format.
    The interpretation is a mapping between the concepts indices and a list of inputs allowing to interpret them.
    The granularity of input examples is determined by the `activation_granularity` class attribute.

    The returned inputs are the most activating inputs for the concepts.

    If all activations are zero, the corresponding concept interpretation is set to `None`.

    Args:
        concepts_indices (int | list[int] | Literal["all"]):
            The indices of the concepts to interpret. If "all", all concepts are interpreted.

        inputs (list[str] | None):
            The inputs to use for the interpretation.
            Necessary if not `use_vocab`,as examples are extracted from the inputs.

        latent_activations (dict[str, torch.Tensor] | Float[torch.Tensor, "nl d"] | None):
            The latent activations matching the inputs. If not provided,
            it is computed from the inputs.

        concepts_activations (Float[torch.Tensor, "nl cpt"] | None):
            The concepts activations matching the inputs. If not provided,
            it is computed from the inputs or latent activations.

    Returns:
        Mapping[int, Any]: The interpretation of the concepts indices.

    """
    sure_concepts_indices, granular_inputs, sure_concepts_activations, _ = (
        self.get_granular_inputs_and_concept_activations(
            concepts_indices=concepts_indices,
            inputs=inputs,
            latent_activations=latent_activations,
            concepts_activations=concepts_activations,
        )
    )
    sure_concepts_indices: list[int]
    granular_inputs: list[str]
    sure_concepts_activations: torch.Tensor

    return self._topk_inputs_from_concepts_activations(
        inputs=granular_inputs,
        concepts_activations=sure_concepts_activations,
        concepts_indices=sure_concepts_indices,
    )

interpreto.concepts.interpretations.LLMLabels

Bases: BaseConceptInterpretationMethod

Code concepts/interpretations/llm_labels.py

Implement the automatic labeling method using a language model (LLM) to provide a short textual description given some examples of what activate the concept. This method was first introduced in 1, we implement here the step 1 of the method.


  1. Steven Bills, Nick Cammarata, Dan Mossing, Henk Tillman, Leo Gao, Gabriel Goh, Ilya Sutskever, Jan Leike, Jeff Wu, William Saunders* Language models can explain neurons in language models 2023. 

Parameters:

Name Type Description Default

concept_explainer

ConceptEncoderExplainer

The fitted concept explainer used for encoding activations.

required

activation_granularity

ActivationGranularity

The granularity at which the interpretation is computed. Allowed values are CLS_TOKEN, TOKEN, WORD, SENTENCE, and SAMPLE. Ignored when use_vocab=True.

TOKEN

llm_interface

LLMInterface

The LLM interface to use for the interpretation.

required

concept_encoding_batch_size

int

The batch size to use for the concept encoding.

1024

sampling_method

SAMPLING_METHOD

The method to use for sampling the inputs provided to the LLM.

TOP

k_examples

int

The number of inputs to use for the interpretation.

30

k_context

int

The number of context tokens to use around the concept tokens.

0

use_vocab

bool

If True, the interpretation will be computed from the vocabulary of the model.

False

use_unique_words

bool

If True, the interpretation will be computed from the unique words of the inputs. Incompatible with use_vocab=True. Default unique words selects all different word from the input. It can be tuned through the unique_words_kwargs argument.

False

unique_words_kwargs

dict

The kwargs to pass to the extract_unique_words function. See extract_unique_words for more details. Possible arguments are count_min_threshold, lemmatize, words_to_ignore.

{}

k_quantile

int

The number of quantiles to use for sampling the inputs, if sampling_method is QUANTILE.

5

system_prompt

str | None

The system prompt to use for the LLM. If None, a default prompt is used.

None

concept_model_device

device | str | None

The device to use for the concept model forward pass. If None, does not change the device.

None
Source code in interpreto/concepts/interpretations/llm_labels.py
def __init__(
    self,
    *,
    concept_explainer: ConceptEncoderExplainer,
    activation_granularity: ActivationGranularity = ActivationGranularity.TOKEN,
    llm_interface: LLMInterface,
    concept_encoding_batch_size: int = 1024,
    sampling_method: SamplingMethod = SamplingMethod.TOP,
    k_examples: int = 30,
    k_context: int = 0,
    use_vocab: bool = False,
    use_unique_words: bool = False,
    unique_words_kwargs: dict = {},
    k_quantile: int = 5,
    system_prompt: str | None = None,
    concept_model_device: torch.device | str | None = None,
):
    super().__init__(
        concept_explainer=concept_explainer,
        activation_granularity=activation_granularity,
        concept_encoding_batch_size=concept_encoding_batch_size,
        use_vocab=use_vocab,
        use_unique_words=use_unique_words,
        unique_words_kwargs=unique_words_kwargs,
        concept_model_device=concept_model_device,
    )

    self.llm_interface = llm_interface
    self.sampling_method = sampling_method
    self.k_examples = k_examples
    self.k_context = k_context
    self.k_quantile = k_quantile

    if system_prompt is None:
        if self.k_context > 0:
            self.system_prompt = SYSTEM_PROMPT_WITH_CONTEXT
        else:
            self.system_prompt = SYSTEM_PROMPT_WITHOUT_CONTEXT
    else:
        self.system_prompt = system_prompt

interpret

Give the interpretation of the concepts dimensions in the latent space into a human-readable format. The interpretation is a mapping between the concepts indices and a short textual description. The granularity of input examples is determined by the activation_granularity class attribute.

Parameters:

Name Type Description Default

concepts_indices

int | list[int] | Literal['all']

The indices of the concepts to interpret. If "all", all concepts are interpreted.

required

inputs

list[str] | None

The inputs to use for the interpretation. Necessary if not use_vocab,as examples are extracted from the inputs.

None

latent_activations

dict[str, Tensor] | Float[Tensor, 'nl d'] | None

The latent activations matching the inputs. If not provided, it is computed from the inputs.

None

concepts_activations

Float[Tensor, 'nl cpt'] | None

The concepts activations matching the inputs. If not provided, it is computed from the inputs or latent activations.

None

Returns:

Type Description
Mapping[int, str | None]

Mapping[int, str | None]: The textual labels of the concepts indices.

Source code in interpreto/concepts/interpretations/llm_labels.py
def interpret(
    self,
    concepts_indices: int | list[int] | Literal["all"],
    inputs: list[str] | None = None,
    latent_activations: dict[str, torch.Tensor] | LatentActivations | None = None,
    concepts_activations: ConceptsActivations | None = None,
) -> Mapping[int, str | None]:
    """
    Give the interpretation of the concepts dimensions in the latent space into a human-readable format.
    The interpretation is a mapping between the concepts indices and a short textual description.
    The granularity of input examples is determined by the `activation_granularity` class attribute.


    Args:
        concepts_indices (int | list[int] | Literal["all"]):
            The indices of the concepts to interpret. If "all", all concepts are interpreted.

        inputs (list[str] | None):
            The inputs to use for the interpretation.
            Necessary if not `use_vocab`,as examples are extracted from the inputs.

        latent_activations (dict[str, torch.Tensor] | Float[torch.Tensor, "nl d"] | None):
            The latent activations matching the inputs. If not provided,
            it is computed from the inputs.

        concepts_activations (Float[torch.Tensor, "nl cpt"] | None):
            The concepts activations matching the inputs. If not provided,
            it is computed from the inputs or latent activations.

    Returns:
        Mapping[int, str | None]: The textual labels of the concepts indices.
    """
    sure_concepts_indices: list[int]
    granular_inputs: list[str]
    sure_concepts_activations: Float[torch.Tensor, "nl cpt"]
    granular_sample_ids: list[int]
    sure_concepts_indices, granular_inputs, sure_concepts_activations, granular_sample_ids = (
        self.get_granular_inputs_and_concept_activations(
            concepts_indices=concepts_indices,
            inputs=inputs,
            latent_activations=latent_activations,
            concepts_activations=concepts_activations,
        )
    )

    labels: Mapping[int, str | None] = {}
    for concept_idx in sure_concepts_indices:
        example_idx = self.sampling_method.sample_examples(
            concept_activations=sure_concepts_activations[:, concept_idx],
            k_examples=self.k_examples,
            k_quantile=self.k_quantile,
        )
        examples = _format_examples(
            example_ids=example_idx,
            inputs=granular_inputs,
            concept_activations=sure_concepts_activations[:, concept_idx],
            sample_ids=granular_sample_ids,
            k_context=self.k_context,
        )
        example_prompt = _build_example_prompt(examples)
        prompt: list[tuple[Role, str]] = [
            (Role.SYSTEM, self.system_prompt),
            (Role.USER, example_prompt),
            (Role.ASSISTANT, ""),
        ]
        label = self.llm_interface.generate(prompt)
        labels[concept_idx] = label
    return labels

interpreto.model_wrapping.llm_interface.LLMInterface

Bases: ABC

generate abstractmethod

generate(prompt)
Source code in interpreto/model_wrapping/llm_interface.py
@abstractmethod
def generate(self, prompt: list[tuple[Role, str]]) -> str | None:
    pass

interpreto.concepts.interpretations.extract_unique_words

extract_unique_words(inputs, count_min_threshold=1, return_counts=False, lemmatize=False, words_to_ignore=None)

Extract words from a text.

Depending on parameters, it may select a subset of words or return the counts of each word.

Parameters:

Name Type Description Default

inputs

str

The text to extract words from.

required

count_min_threshold

float

The minimum total number of a occurrence of a word in the whole inputs.

1

return_counts

bool

Whether to return the counts of each word. Defaults to False.

False

words_to_ignore

list[str] | None

(list[str], optional): A list of words to ignore.

None

Examples:

Fastest version as used in TopKInputs.

>>> extract_unique_words(["Interpreto is the latin for 'to interpret'.", "interpreto is magic"])
["interpreto", "is", "the", "latin", "for", "to", "'", "interpret", ".", "magic"]

More complex use:

>>> import nltk
>>> from datasets import load_dataset
>>> from nltk.corpus import stopwords
>>>
>>> from interpreto.concepts.interpretations import extract_unique_words
>>>
>>> nltk.download("stopwords")
>>>
>>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
>>> extract_unique_words(
...     inputs=dataset,
...     count_min_threshold=20,
...     return_counts=True,
...     lemmatize=True,
...     words_to_ignore=stopwords.words("english") + [".", ",", "'s", "n't", "--", "``", "'"],
... )
Counter({'film': 1402,
         'movie': 1243,
         'one': 594,
         'like': 574,
         'ha': 563,
         'make': 437,
         'story': 417,
...
         'pop': 20,
         'college': 20,
         'bear': 20,
         'plain': 20,
         'generic': 20})

Returns:

Type Description
list[str] | Counter[str]

list[str] | Counter[str]: The list of unique words or the counts of each word.

Raises:

Type Description
ValueError

If the input is not a list of strings.

Source code in interpreto/concepts/interpretations/base.py
@jaxtyped(typechecker=beartype)
def extract_unique_words(
    inputs: list[str],
    count_min_threshold: int = 1,
    return_counts: bool = False,
    lemmatize: bool = False,
    words_to_ignore: list[str] | None = None,
) -> list[str] | Counter[str]:
    """
    Extract words from a text.

    Depending on parameters, it may select a subset of words or return the counts of each word.

    Args:
        inputs (str):
            The text to extract words from.

        count_min_threshold (float, optional):
            The minimum total number of a occurrence of a word in the whole `inputs`.

        return_counts (bool, optional):
            Whether to return the counts of each word.
            Defaults to False.

        words_to_ignore: (list[str], optional):
            A list of words to ignore.

    Examples:
        Fastest version as used in `TopKInputs`.
        >>> extract_unique_words(["Interpreto is the latin for 'to interpret'.", "interpreto is magic"])
        ["interpreto", "is", "the", "latin", "for", "to", "'", "interpret", ".", "magic"]

        More complex use:
        >>> import nltk
        >>> from datasets import load_dataset
        >>> from nltk.corpus import stopwords
        >>>
        >>> from interpreto.concepts.interpretations import extract_unique_words
        >>>
        >>> nltk.download("stopwords")
        >>>
        >>> dataset = load_dataset("cornell-movie-review-data/rotten_tomatoes")["train"]["text"]
        >>> extract_unique_words(
        ...     inputs=dataset,
        ...     count_min_threshold=20,
        ...     return_counts=True,
        ...     lemmatize=True,
        ...     words_to_ignore=stopwords.words("english") + [".", ",", "'s", "n't", "--", "``", "'"],
        ... )
        Counter({'film': 1402,
                 'movie': 1243,
                 'one': 594,
                 'like': 574,
                 'ha': 563,
                 'make': 437,
                 'story': 417,
        ...
                 'pop': 20,
                 'college': 20,
                 'bear': 20,
                 'plain': 20,
                 'generic': 20})

    Returns:
        list[str] | Counter[str]:
            The list of unique words or the counts of each word.

    Raises:
        ValueError:
            If the input is not a list of strings.
    """
    # ensure NLTK resources are downloaded
    _ensure_nltk_resources(lemmatize=lemmatize)

    if lemmatize:
        lemmatizer = WordNetLemmatizer()

    # counter both list unique words and counts of each word
    words_count = Counter()

    for text in inputs:
        for word in word_tokenize(text):
            # lemmatize words
            if lemmatize:
                word = lemmatizer.lemmatize(word.lower())  # noqa: PLW2901  # type: ignore  (ignore possibly unbound)

            # ignore words
            if words_to_ignore is not None and word in words_to_ignore:
                continue

            # add word to counter
            words_count[word] += 1

    # filter too rare words
    if count_min_threshold > 1:
        words_count = Counter({key: count for key, count in words_count.items() if count >= count_min_threshold})

    if return_counts:
        return words_count

    return list(words_count.keys())