Skip to content

interpreto_banner

Generation Concept-based Explanation Tutorial

Welcome to this tutorial, our will be to obtain concept-based explanations starting from the beginning.

We will start with a minimal example in section 1 and then go through the key steps of the concept-based:

  1. Split your model in two parts
  2. Minimal example: Top-k tokens for neurons
  3. 🚦 Compute a dataset of activations
  4. 🏋️‍♂️ Fit a concept model on activations
  5. 🏷️ Interpret the concept dimensions

On which we add two bonus steps present in most papers:

  1. 📍 Local concept analysis
  2. ⚖️ Evaluate concept-based explanations

Author: Antonin Poché

0. ➗ Split your model in two parts

Let's take a Qwen3-0.6B for both the minimal example and the detailed pipeline. But you can naturally use larger models.

Here we split at the 6 / 28 layers. But you can specify the module path to split at.

To split the model, we use the interpreto.ModelWithSplitPoints which wraps around the transformers model and allows the computation of activations at the specified split_points.

> ➡️ Note > > Interpreto's splitting based on nnsight, so depending on your use case, you might want to use nnsight directly.

from transformers import AutoModelForCausalLM

from interpreto import ModelWithSplitPoints

# 1. load and split the generation model
mwsp = ModelWithSplitPoints(
    "Qwen/Qwen3-0.6B",  # "gpt2",  # for a less compute heavy notebook
    split_points=[5],  # split at the 6th layer
    automodel=AutoModelForCausalLM,
    device_map="cuda",
    batch_size=2048,  # high for the minimal example where samples are a single token
)

1. ⏩ Minimal example: Top-k tokens for neurons

from interpreto.concepts import NeuronsAsConcepts, TopKInputs

# 2. No dataset of activation is needed as we consider the latent space as the concept space

# 3. No training neither
# Use `NeuronsAsConcepts` to use the concept-based pipeline with neurons
concept_explainer = NeuronsAsConcepts(mwsp)

# 4. Use `TopKInputs` to get the top-k tokens that maximally activate each neuron
method = TopKInputs(
    concept_explainer=concept_explainer,
    use_vocab=True,  # use the vocabulary of the model and test all tokens (~150k with Qwen-0.6B)
    k=10,  # get the top 10 tokens for each neuron
)
topk_tokens = method.interpret(
    concepts_indices=list(range(5)),  # interpret the five first neurons
)

# show some neurons' interpretations
for concept_idx, tokens in topk_tokens.items():
    print(f"Concept {concept_idx}: {list(tokens.keys())}")

del concept_explainer, method, topk_tokens
You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.

Concept 0: ['Ġpostal', 'Ġparliamentary', 'éĤ½', 'Ġgüncel', 'éį', 'Ġprosecuting', 'å¥ģ', 'çĴ', 'ä¸Ģç·Ĵ', '-business']
Concept 1: ['<|im_start|>', 'ĠзаÑĢегиÑģÑĤÑĢиÑĢова', 'ÙħÙ쨧ÙĪØ¶', 'ìĨĢ', 'ë¼ĺ', 'åħ·æľīæĪĺ士', 'ë§Ħ', '身åĪĽéĢłçļĦ', 'çķ°ãģªãĤĭ', 'تغÙĬر']
Concept 2: ['ðŁĴ¬', '!!!!!!!!', '^^^^', ';;;;;;;;;;;;;;;;', '··', 'MMMM', 'aaaaaaaa', 'âĶģâĶģ', 'ãĢĢ', 'qq']
Concept 3: ['<|im_start|>', 'åħ³ä¹İ', 'å°±æĦıåij³çĿĢ', 'تÙĥاÙħÙĦ', 'íķľëĭ¤ë©´', 'ת×ķש×ij×Ļ', '身åĪĽéĢłçļĦ', 'Ġunfavor', 'ĠзаÑĢегиÑģÑĤÑĢиÑĢова', 'çķ°ãģªãĤĭ']
Concept 4: ['¿', 'Ġ¿', 'æłĩé¢ĺ', 'æĸij', 'Title', 'åį±', 'Ġ¡', 'ðŁĴ¬', 'é¢Ĩ导人', 'Ĕ']

> ❓ The concepts are not interpretable? > > Well, this is surely due to superposition, this is why we use dictionary learning. > > Let's explore and solve this problem in the next sections.

2. 🚦 Compute a datasets of activations

We will use the IMDB dataset to build a dataset of activations.

> ⚠️ Warning > > The dataset used has a considerable impact on the concept-space obtained. Hence, in practice, we recommend to use a subset of the model training set. > > The concept model we train be it SAEs or others, find patterns in the dataset of activations, which explains the dependence of the concepts found on the activations dataset.

> 🔥 Tip > > The larger the dataset the better. But at the cost of computation time to get activations. > > Hence, the best dataset for SAEs would be one where each token is only seen once. But this is harder in practice, and such pipeline is not covered in this tutorial.

interpreto.ModelWithSplitPoints.get_activations()

interpreto.ModelWithSplitPoints.activation_granularities

from datasets import load_dataset

# take the whole dataset, more samples leads to better results, but some methods do not support too big datasets
imdb = load_dataset("stanfordnlp/imdb")["train"]["text"][:5000]

# ignore special tokens activations
TOKEN = ModelWithSplitPoints.activation_granularities.TOKEN

# compute the activations of the whole IMDB dataset
# activations are flattened between the n_sample and seq_len dimensions
# which leads us to more then 6 million tokens
# (n * l, d)
mwsp.batch_size = 8
activations_dict = mwsp.get_activations(
    inputs=imdb,
    activation_granularity=TOKEN,
    tqdm_bar=True,
)
# it is possible to compute activations for several split points
# hence, we need to extract the activations for the split point we are interested in
activations = mwsp.get_split_activations(activations_dict)

print(f"{activations.shape = }")
Computing activations: 100%|██████████| 625/625 [03:46<00:00,  2.77batch/s]

activations.shape = torch.Size([1474091, 1024])

3. 🏋️‍♂️ Fit a concept model on activations

Now we can fit a concept model on the activations. They exist more or less complex concept models. Here we use an SAE, so it is quite complex and has a lot more parameters than a simple concept model.

In particular, we use interpreto.concepts.BatchTopK.

The concept_explainer wraps around both mwsp, the model wrapper, and the concept_model.

> ➡️ Note > > Most of our concept models and all the SAEs implementations depend on the Overcomplete library. It is a library that provides a lot of concept models and optimization methods for concept extraction. So if you want to go deeper on these, we suggest digging there.

import torch

from interpreto.concepts.methods.overcomplete import BatchTopKSAEConcepts, DeadNeuronsReanimationLoss

top_k_individual = 10
concept_model_batch_size = 2048
epochs = 20

# instantiate the concept explainer with the splitted model
concept_explainer = BatchTopKSAEConcepts(
    mwsp,
    nb_concepts=1000,
    device="cuda",
    top_k=top_k_individual * concept_model_batch_size,
)

# train the SAE on the activations
log = concept_explainer.fit(
    activations=activations,
    criterion=DeadNeuronsReanimationLoss,  # set an MSE loss with dead neurons reanimation
    optimizer_class=torch.optim.Adam,
    scheduler_class=torch.optim.lr_scheduler.CosineAnnealingLR,
    scheduler_kwargs={"T_max": epochs, "eta_min": 1e-6},
    lr=1e-3,
    nb_epochs=epochs,
    batch_size=concept_model_batch_size,
    monitoring=1,
)
Epoch[1/20], Loss: 5.1928, R2: 0.9625, L0: 10.0041, Dead Features: 0.0%, Time: 6.8131 seconds
Epoch[2/20], Loss: 0.2371, R2: 0.9975, L0: 10.0041, Dead Features: 0.0%, Time: 7.2336 seconds
Epoch[3/20], Loss: 0.2229, R2: 0.9969, L0: 10.0041, Dead Features: 0.0%, Time: 8.4109 seconds
Epoch[4/20], Loss: 0.1567, R2: 0.9987, L0: 10.0041, Dead Features: 0.0%, Time: 5.9831 seconds
Epoch[5/20], Loss: 0.1692, R2: 0.9978, L0: 10.0041, Dead Features: 0.0%, Time: 6.2661 seconds
Epoch[6/20], Loss: 0.1423, R2: 0.9988, L0: 10.0041, Dead Features: 0.0%, Time: 6.5914 seconds
Epoch[7/20], Loss: 0.1761, R2: 0.9982, L0: 10.0041, Dead Features: 0.0%, Time: 6.8673 seconds
Epoch[8/20], Loss: 0.1214, R2: 0.9986, L0: 10.0041, Dead Features: 0.0%, Time: 7.2377 seconds
Epoch[9/20], Loss: 0.1291, R2: 0.9978, L0: 10.0041, Dead Features: 0.0%, Time: 8.4441 seconds
Epoch[10/20], Loss: 0.1210, R2: 0.9986, L0: 10.0041, Dead Features: 0.0%, Time: 5.8799 seconds
Epoch[11/20], Loss: 0.1330, R2: 0.9989, L0: 10.0041, Dead Features: 0.0%, Time: 6.0463 seconds
Epoch[12/20], Loss: 0.1247, R2: 0.9990, L0: 10.0041, Dead Features: 0.0%, Time: 6.4981 seconds
Epoch[13/20], Loss: 0.1147, R2: 0.9990, L0: 10.0041, Dead Features: 0.0%, Time: 6.7863 seconds
Epoch[14/20], Loss: 0.1145, R2: 0.9990, L0: 10.0041, Dead Features: 0.0%, Time: 7.0415 seconds
Epoch[15/20], Loss: 0.1299, R2: 0.9986, L0: 10.0041, Dead Features: 0.0%, Time: 8.2114 seconds
Epoch[16/20], Loss: 0.1117, R2: 0.9991, L0: 10.0041, Dead Features: 0.0%, Time: 5.7240 seconds
Epoch[17/20], Loss: 0.1571, R2: 0.9987, L0: 10.0041, Dead Features: 0.0%, Time: 6.0964 seconds
Epoch[18/20], Loss: 0.1076, R2: 0.9988, L0: 10.0041, Dead Features: 0.0%, Time: 6.4666 seconds
Epoch[19/20], Loss: 0.1072, R2: 0.9991, L0: 10.0041, Dead Features: 0.0%, Time: 6.6090 seconds
Epoch[20/20], Loss: 0.1079, R2: 0.9987, L0: 10.0041, Dead Features: 0.0%, Time: 7.0120 seconds

4. 🏷️ Interpret the concept dimensions

Once the concept directions are set, it is time to interpret them. We explore two possibilities here, as the second one requires an OpenAI api key.

4.1 Interpret concepts via top-k words

Via interpreto.concepts.TopKInputs, it is possible to extract the top-k words that acxtivates the most a concept. It could also be tokens, sentences, or samples.

from interpreto.concepts import TopKInputs

WORD = ModelWithSplitPoints.activation_granularities.WORD

interpretation_method = TopKInputs(
    concept_explainer=concept_explainer,
    activation_granularity=WORD,
    concept_encoding_batch_size=concept_model_batch_size,
    k=10,  # get the top 10 words for each concepts
    concept_model_device="cuda",
)
interpretations = interpretation_method.interpret(
    inputs=imdb[:100],  # larger datasets lead to better interpretations
    latent_activations=None,  # we use a different granularity hence we cannot reuse the activations
    concepts_indices="all",  # interpret all concepts
)

for concept_idx, words in list(interpretations.items())[:10]:
    print(f"Concept {concept_idx}: {list(words.keys()) if words else None}")
Concept 0: [' worry', ' if']
Concept 1: [' so', ' So', 'So', ' such', ' thus', ' therefore', ' much', ' a', 'Such', ' is']
Concept 2: [' was', ' a', 'is', ' horribly', ' meat', ' still']
Concept 3: ['What', 'The', 'Its', 'This', 'THE', 'Three', 'Lifetime', 'Who', 'It', 'this']
Concept 4: [' as', ' far', ' Whatever', ' for', ' whatever', ' the', ' whoever', ' a', 'whatever', ' to']
Concept 5: [' group', ' set', ' cast', ' gang', ' team', ' clan', ' Gang', ' sets', ' of', ' family']
Concept 6: [' King', ' was']
Concept 7: [' atrocity']
Concept 8: [' shocked', ' amazed', ' disappointed', ' angry', ' surprised', ' delighted', ' furious', ' impressed', ' confused', ' bored']
Concept 9: None

> ➡️ Note > > While not the most interpretable concepts, this is still better than the minimal example. > > Let's improve interpretations and see what happens.

4.2 Interpret concepts with an LLM labels

Here we use interpreto.concepts.LLMLabels, it uses LLM to label the concepts based on examples activating the concept.

> ⚠️ Warning > > The following cell is skipped if you do not set an OpenAI API key. As LLMLabels requires a LLMInterface, and we chose the OpenAILLM one. > > What you can do to make it work: > - Get an OpenAI API key and set it as an environment variable OPENAI_API_KEY. > - Branch your own LLMInterface and use it instead of OpenAILLM. See last section for an example.

import os

from interpreto.concepts import LLMLabels
from interpreto.model_wrapping.llm_interface import OpenAILLM

api_key = os.getenv("OPENAI_API_KEY")
if api_key is not None:
    # set the LLM interface used to generate labels based on the constructed prompts
    llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-mini")

    mwsp.to("cpu")
    interpretation_method = LLMLabels(
        concept_explainer=concept_explainer,
        activation_granularity=TOKEN,  # we suggest to use the same as the activations
        llm_interface=llm_interface,
        k_examples=20,  # number of examples in each concept
        k_context=5,  # number of tokens before and after the maximally activating one to give context
        concept_encoding_batch_size=concept_model_batch_size,
    )

    # compute the labels for an arbitrary subset of the concepts
    interpretations = interpretation_method.interpret(
        inputs=imdb,
        latent_activations=activations,  # as the inputs and granularity are the same, we can reuse the activations
        concepts_indices=list(
            range(10)
        ),  # we could put `"all"` but that would take lon and cost a bit through the API
    )

    for concept_id, label in interpretations.items():
        print(f"Concept {concept_id}: {label if label is not None else 'None'}")
Concept 0: function words
Concept 1: intensifier "so"
Concept 2: Tokens
Concept 3: Interjection
Concept 4: Comparative phrases
Concept 5: Groups or collections
Concept 6: prepositions and determiners
Concept 7: Location name "ro"
Concept 8: negative emotions
Concept 9: word fragment

> ➡️ Note 1 > > The labels highly depend on the system prompt provided to the LLM interface. > > For labels formulation to better align with what you expect, you can set the system_prompt argument of the LLMLabels. > > Here is the default system prompt:

SYSTEM_PROMPT_WITH_CONTEXT = """Your role is to label the concepts/patterns present in the different examples.

You will be given a list of text examples on which special tokens are selected and between delimiters like &lt;<this>&gt;.
How important each token is for the behavior is listed after each example in parentheses, with importance from 0 to 10.

Hard rules:
- The label should summarize the concept linking the examples together. Give a single label describing all examples highlighted tokens.
- The label should be between 1 and 5 words long.
- The shorter the label the better. The best is a word.
- Do not make a sentence.
- The label should be the most precise possible. The goal is to be able to differentiate between concepts.
- The label should encompass most examples. But you can ignore the non-informative ones.
- Do not mention the marker tokens (&lt;&lt; &gt;&gt;) in your explanation. Nor refer to the importance.
- Only focus on the content and the label.
- Never ever give labels with more than 5 words, they would be cut out.

Some examples: 'blue', 'positive sentiment and enthusiasm', 'legal entities', 'medical places', 'hate', 'noun phrase', 'ion or iou sounds', 'questions' final words'...
"""

> ➡️ Note 2 > > Having our concepts and their interpretation is great. But what is really useful is to know: > - When are each concept used? (see next section) > - If the concepts are pertinent. (see final section)

5. 📍 Local concept analysis

To see the important concepts in a sample, there are two complementary steps: - Find the most important concepts for the predicted tokens. - Highlight the input tokens activating this concept.

5.0 Create a random sample and decompose it into tokens

# create a sample, let's take a review about "Avatar: the Way of Water"
sample = ["Visually beautiful but excessively long and boring. Three hours is too long."]

# from text to ids back to tokens
sample_token_ids = mwsp.tokenizer(sample, return_tensors="pt")
sample_tokens = TOKEN.value.get_decomposition(sample_token_ids, tokenizer=mwsp.tokenizer, return_text=True)[0]

print(f"The sample is: {sample[0]}\n\nIt is decomposed in {len(sample_tokens)} tokens:\n{sample_tokens}")
The sample is: Visually beautiful but excessively long and boring. Three hours is too long.

It is decomposed in 15 tokens:
['Vis', 'ually', ' beautiful', ' but', ' excessively', ' long', ' and', ' boring', '.', ' Three', ' hours', ' is', ' too', ' long', '.']

5.1 Important concepts for text (with respect to the output)

Here we use the gradient of the concept-to-output function to estimate the importance of each concept for the outputs of the model.

interpreto.concepts.ConceptAutoEncoderExplainer.concept_output_gradient()

mwsp.to("cuda")
# (seq_len (out), seq_len (in), nb_concepts)
local_importances = concept_explainer.concept_output_gradient(
    inputs=sample,
    targets=None,  # all predicted tokens
    activation_granularity=TOKEN,  # same throughout the notebook
    concepts_x_gradients=False,  # based theoretically
    normalization=False,  # for each output token, the sum of the absolute values of the importance is equal to 1
    batch_size=64,
)[0]  # only one sample
mwsp.to("cpu")

# we take the sum over the input sequence dimension, has we focus on the concept-output relationship
# (seq_len (out), nb_concepts)
local_importances = local_importances.abs().sum(dim=1)
print(f"{local_importances.shape = }")
local_importances.shape = torch.Size([15, 1000])

5.2 Input tokens activating the concepts

This step just looks at the concepts activation on the input tokens. It is a simple way to check if the concepts are activated on the input tokens.

# compute the latent activations
# (seq_len, d_model)
local_activations = mwsp.get_split_activations(mwsp.get_activations(sample, TOKEN))

# and the concepts activations
# (nb_concepts, seq_len)
concepts_activations = concept_explainer.encode_activations(local_activations).T
print(f"{concepts_activations.shape = }")
concepts_activations.shape = torch.Size([1000, 15])

5.3 Visualization

There are several steps to visualize local concept-based explanations for generation.

  • Select the output token as we only explain one prediction. Here we will take the last one "long".
  • Extract the most important concepts for the output token.
  • Visualize the concepts in a heatmap.

> ⚠️ Warning > > The following cell is a workaround with actual features to visualize something. > > But we plan to add a better visualization in the future. The cell will be included in the visualization and choices would be clicks on the interactive explanation.

from interpreto import AttributionVisualization
from interpreto.attributions.base import AttributionOutput, ModelTask

# select the last output token
last_token_concepts = local_importances[
    -2
]  # TODO: see if I take the mean/max an not just the last token, but I need to update the description

# top 5 concepts
top5_concepts = last_token_concepts.argsort(descending=True)[:5]
top5_concepts_activations = concepts_activations[top5_concepts]

# get additional interpretations
missing = [c.item() for c in top5_concepts if c.item() not in interpretations.keys()]
if len(missing) &gt; 0:
    interpretations.update(interpretation_method.interpret(missing, imdb, activations))

# create an attribution output [temporary workaround]
attr = AttributionOutput(
    elements=sample_tokens,
    attributions=top5_concepts_activations,
    model_task=ModelTask.MULTI_CLASS_CLASSIFICATION,
    model_inputs_to_explain=None,
    targets=None,
)

# visualize the concepts
AttributionVisualization(
    attr, class_names={i: interpretations[c.item()] for i, c in enumerate(top5_concepts)}
).display()

Classes

Inputs

6. ⚖️ Evaluate concept-based explanations

Let's compare the minimal example and the Batch TopK SAE explainers.

test_inputs = load_dataset("stanfordnlp/imdb")["test"]["text"][:100]  # let's take one thousand test samples

# Compute the [CLS] token activations
mwsp.to("cuda")
test_activations = mwsp.get_activations(
    inputs=test_inputs,
    activation_granularity=TOKEN,
)
_ = mwsp.to("cpu")

6.1 🌐 Evaluate the concept-space from the third part

> ⚠️ Warning: > > These metrics should only be used to compare the concept-space trained in similar contexts, same model, split point, activation dataset...

Reconstruction error

Sparsity

Dictionary metrics

> ➡️ Note > > We do not apply the stability metric here because it requires to retrain the concept space several times and compare them. It would be too long for the tutorial. Nonetheless, SAEs in general are not really stable.

from interpreto.concepts.metrics import FID, MSE, Sparsity, SparsityRatio

# BatchTopKSAEConcepts
mse = MSE(concept_explainer).compute(test_activations)
fid = FID(concept_explainer).compute(test_activations)
sparsity = Sparsity(concept_explainer).compute(test_activations)
ratio = SparsityRatio(concept_explainer).compute(test_activations)

print(f"MSE: {round(mse, 3)}, FID: {round(fid, 3)}, Sparsity: {round(sparsity, 3)}, Sparsity ratio: {round(ratio, 3)}")

# NeuronsAsConcepts
identity_explainer = NeuronsAsConcepts(mwsp)
mse = MSE(identity_explainer).compute(test_activations)
fid = FID(identity_explainer).compute(test_activations)
sparsity = Sparsity(identity_explainer).compute(test_activations)
ratio = SparsityRatio(identity_explainer).compute(test_activations)

print(f"MSE: {round(mse, 3)}, FID: {round(fid, 3)}, Sparsity: {round(sparsity, 3)}, Sparsity ratio: {round(ratio, 3)}")
MSE: 1844.932, FID: 0.104, Sparsity: 0.01, Sparsity ratio: 0.0
MSE: 0.0, FID: 0.0, Sparsity: 1.0, Sparsity ratio: 0.001

6.2 💭 Evaluate the concepts-interpretations from the fourth step

# Work in progress, coming soon

7. Using your own LLM interface

from interpreto.model_wrapping.llm_interface import LLMInterface, Role


class GeminiLLM(LLMInterface):
    def __init__(self, api_key: str, model: str = "gemini-1.5-flash", num_try: int = 5):
        try:
            import google.generativeai as genai  # noqa: PLC0415  # ruff: disable=import-outside-toplevel
        except ImportError as e:
            raise ImportError("Install google-generativeai to use Google Gemini API.") from e

        self.genai = genai
        self.genai.configure(api_key=api_key)
        self.model = model
        self.num_try = num_try

    def generate(self, prompt: list[tuple[Role, str]]) -&gt; str | None:
        # Build system instruction and chat history for Gemini
        system_messages: list[str] = []
        contents: list[dict] = []

        for role, content in prompt:
            if role == Role.SYSTEM:
                system_messages.append(content)
            elif role == Role.USER:
                contents.append(
                    {
                        "role": "user",
                        "parts": [{"text": content}],
                    }
                )
            elif role == Role.ASSISTANT:
                contents.append(
                    {
                        "role": "model",
                        "parts": [{"text": content}],
                    }
                )
            else:
                raise ValueError(f"Unknown role for google gemini api: {role}")

        system_instruction: str | None = "\n".join(system_messages) if system_messages else None

        label: str | None = None
        for _ in range(self.num_try):
            try:
                model = self.genai.GenerativeModel(
                    model_name=self.model,
                    system_instruction=system_instruction,
                )
                response = model.generate_content(contents)  # type: ignore[arg-type]
                # google-generativeai exposes the main text as .text
                label = response.text
                break
            except Exception as e:  # noqa: BLE001
                print(e)
        return label