
Generation Concept-based Explanation Tutorial¶
Welcome to this tutorial, our will be to obtain concept-based explanations starting from the beginning.
We will start with a minimal example in section 1 and then go through the key steps of the concept-based:
- ➗ Split your model in two parts
- ⏩ Minimal example: Top-k tokens for neurons
- 🚦 Compute a dataset of activations
- 🏋️♂️ Fit a concept model on activations
- 🏷️ Interpret the concept dimensions
On which we add two bonus steps present in most papers:
Author: Antonin Poché
0. ➗ Split your model in two parts ¶
Let's take a Qwen3-0.6B for both the minimal example and the detailed pipeline. But you can naturally use larger models.
Here we split at the 6 / 28 layers. But you can specify the module path to split at.
To split the model, we use the interpreto.ModelWithSplitPoints which wraps around the transformers model and allows the computation of activations at the specified split_points.
> ➡️ Note
>
> Interpreto's splitting based on nnsight, so depending on your use case, you might want to use nnsight directly.
from transformers import AutoModelForCausalLM
from interpreto import ModelWithSplitPoints
# 1. load and split the generation model
mwsp = ModelWithSplitPoints(
"Qwen/Qwen3-0.6B", # "gpt2", # for a less compute heavy notebook
split_points=[5], # split at the 6th layer
automodel=AutoModelForCausalLM,
device_map="cuda",
batch_size=2048, # high for the minimal example where samples are a single token
)
1. ⏩ Minimal example: Top-k tokens for neurons ¶
from interpreto.concepts import NeuronsAsConcepts, TopKInputs
# 2. No dataset of activation is needed as we consider the latent space as the concept space
# 3. No training neither
# Use `NeuronsAsConcepts` to use the concept-based pipeline with neurons
concept_explainer = NeuronsAsConcepts(mwsp)
# 4. Use `TopKInputs` to get the top-k tokens that maximally activate each neuron
method = TopKInputs(
concept_explainer=concept_explainer,
use_vocab=True, # use the vocabulary of the model and test all tokens (~150k with Qwen-0.6B)
k=10, # get the top 10 tokens for each neuron
)
topk_tokens = method.interpret(
concepts_indices=list(range(5)), # interpret the five first neurons
)
# show some neurons' interpretations
for concept_idx, tokens in topk_tokens.items():
print(f"Concept {concept_idx}: {list(tokens.keys())}")
del concept_explainer, method, topk_tokens
> ❓ The concepts are not interpretable? > > Well, this is surely due to superposition, this is why we use dictionary learning. > > Let's explore and solve this problem in the next sections.
2. 🚦 Compute a datasets of activations ¶
We will use the IMDB dataset to build a dataset of activations.
> ⚠️ Warning > > The dataset used has a considerable impact on the concept-space obtained. Hence, in practice, we recommend to use a subset of the model training set. > > The concept model we train be it SAEs or others, find patterns in the dataset of activations, which explains the dependence of the concepts found on the activations dataset.
> 🔥 Tip > > The larger the dataset the better. But at the cost of computation time to get activations. > > Hence, the best dataset for SAEs would be one where each token is only seen once. But this is harder in practice, and such pipeline is not covered in this tutorial.
from datasets import load_dataset
# take the whole dataset, more samples leads to better results, but some methods do not support too big datasets
imdb = load_dataset("stanfordnlp/imdb")["train"]["text"][:5000]
# ignore special tokens activations
TOKEN = ModelWithSplitPoints.activation_granularities.TOKEN
# compute the activations of the whole IMDB dataset
# activations are flattened between the n_sample and seq_len dimensions
# which leads us to more then 6 million tokens
# (n * l, d)
mwsp.batch_size = 8
activations_dict = mwsp.get_activations(
inputs=imdb,
activation_granularity=TOKEN,
tqdm_bar=True,
)
# it is possible to compute activations for several split points
# hence, we need to extract the activations for the split point we are interested in
activations = mwsp.get_split_activations(activations_dict)
print(f"{activations.shape = }")
3. 🏋️♂️ Fit a concept model on activations ¶
Now we can fit a concept model on the activations. They exist more or less complex concept models. Here we use an SAE, so it is quite complex and has a lot more parameters than a simple concept model.
In particular, we use interpreto.concepts.BatchTopK.
The concept_explainer wraps around both mwsp, the model wrapper, and the concept_model.
> ➡️ Note > > Most of our concept models and all the SAEs implementations depend on the Overcomplete library. It is a library that provides a lot of concept models and optimization methods for concept extraction. So if you want to go deeper on these, we suggest digging there.
import torch
from interpreto.concepts.methods.overcomplete import BatchTopKSAEConcepts, DeadNeuronsReanimationLoss
top_k_individual = 10
concept_model_batch_size = 2048
epochs = 20
# instantiate the concept explainer with the splitted model
concept_explainer = BatchTopKSAEConcepts(
mwsp,
nb_concepts=1000,
device="cuda",
top_k=top_k_individual * concept_model_batch_size,
)
# train the SAE on the activations
log = concept_explainer.fit(
activations=activations,
criterion=DeadNeuronsReanimationLoss, # set an MSE loss with dead neurons reanimation
optimizer_class=torch.optim.Adam,
scheduler_class=torch.optim.lr_scheduler.CosineAnnealingLR,
scheduler_kwargs={"T_max": epochs, "eta_min": 1e-6},
lr=1e-3,
nb_epochs=epochs,
batch_size=concept_model_batch_size,
monitoring=1,
)
4. 🏷️ Interpret the concept dimensions ¶
Once the concept directions are set, it is time to interpret them. We explore two possibilities here, as the second one requires an OpenAI api key.
4.1 Interpret concepts via top-k words¶
Via interpreto.concepts.TopKInputs, it is possible to extract the top-k words that acxtivates the most a concept. It could also be tokens, sentences, or samples.
from interpreto.concepts import TopKInputs
WORD = ModelWithSplitPoints.activation_granularities.WORD
interpretation_method = TopKInputs(
concept_explainer=concept_explainer,
activation_granularity=WORD,
concept_encoding_batch_size=concept_model_batch_size,
k=10, # get the top 10 words for each concepts
concept_model_device="cuda",
)
interpretations = interpretation_method.interpret(
inputs=imdb[:100], # larger datasets lead to better interpretations
latent_activations=None, # we use a different granularity hence we cannot reuse the activations
concepts_indices="all", # interpret all concepts
)
for concept_idx, words in list(interpretations.items())[:10]:
print(f"Concept {concept_idx}: {list(words.keys()) if words else None}")
> ➡️ Note > > While not the most interpretable concepts, this is still better than the minimal example. > > Let's improve interpretations and see what happens.
4.2 Interpret concepts with an LLM labels¶
Here we use interpreto.concepts.LLMLabels, it uses LLM to label the concepts based on examples activating the concept.
> ⚠️ Warning
>
> The following cell is skipped if you do not set an OpenAI API key. As LLMLabels requires a LLMInterface, and we chose the OpenAILLM one.
>
> What you can do to make it work:
> - Get an OpenAI API key and set it as an environment variable OPENAI_API_KEY.
> - Branch your own LLMInterface and use it instead of OpenAILLM. See last section for an example.
import os
from interpreto.concepts import LLMLabels
from interpreto.model_wrapping.llm_interface import OpenAILLM
api_key = os.getenv("OPENAI_API_KEY")
if api_key is not None:
# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-mini")
mwsp.to("cpu")
interpretation_method = LLMLabels(
concept_explainer=concept_explainer,
activation_granularity=TOKEN, # we suggest to use the same as the activations
llm_interface=llm_interface,
k_examples=20, # number of examples in each concept
k_context=5, # number of tokens before and after the maximally activating one to give context
concept_encoding_batch_size=concept_model_batch_size,
)
# compute the labels for an arbitrary subset of the concepts
interpretations = interpretation_method.interpret(
inputs=imdb,
latent_activations=activations, # as the inputs and granularity are the same, we can reuse the activations
concepts_indices=list(
range(10)
), # we could put `"all"` but that would take lon and cost a bit through the API
)
for concept_id, label in interpretations.items():
print(f"Concept {concept_id}: {label if label is not None else 'None'}")
> ➡️ Note 1
>
> The labels highly depend on the system prompt provided to the LLM interface.
>
> For labels formulation to better align with what you expect, you can set the system_prompt argument of the LLMLabels.
>
> Here is the default system prompt:
SYSTEM_PROMPT_WITH_CONTEXT = """Your role is to label the concepts/patterns present in the different examples.
You will be given a list of text examples on which special tokens are selected and between delimiters like <<this>>.
How important each token is for the behavior is listed after each example in parentheses, with importance from 0 to 10.
Hard rules:
- The label should summarize the concept linking the examples together. Give a single label describing all examples highlighted tokens.
- The label should be between 1 and 5 words long.
- The shorter the label the better. The best is a word.
- Do not make a sentence.
- The label should be the most precise possible. The goal is to be able to differentiate between concepts.
- The label should encompass most examples. But you can ignore the non-informative ones.
- Do not mention the marker tokens (<< >>) in your explanation. Nor refer to the importance.
- Only focus on the content and the label.
- Never ever give labels with more than 5 words, they would be cut out.
Some examples: 'blue', 'positive sentiment and enthusiasm', 'legal entities', 'medical places', 'hate', 'noun phrase', 'ion or iou sounds', 'questions' final words'...
"""
> ➡️ Note 2 > > Having our concepts and their interpretation is great. But what is really useful is to know: > - When are each concept used? (see next section) > - If the concepts are pertinent. (see final section)
5. 📍 Local concept analysis ¶
To see the important concepts in a sample, there are two complementary steps: - Find the most important concepts for the predicted tokens. - Highlight the input tokens activating this concept.
5.0 Create a random sample and decompose it into tokens¶
# create a sample, let's take a review about "Avatar: the Way of Water"
sample = ["Visually beautiful but excessively long and boring. Three hours is too long."]
# from text to ids back to tokens
sample_token_ids = mwsp.tokenizer(sample, return_tensors="pt")
sample_tokens = TOKEN.value.get_decomposition(sample_token_ids, tokenizer=mwsp.tokenizer, return_text=True)[0]
print(f"The sample is: {sample[0]}\n\nIt is decomposed in {len(sample_tokens)} tokens:\n{sample_tokens}")
5.1 Important concepts for text (with respect to the output)¶
Here we use the gradient of the concept-to-output function to estimate the importance of each concept for the outputs of the model.
interpreto.concepts.ConceptAutoEncoderExplainer.concept_output_gradient()
mwsp.to("cuda")
# (seq_len (out), seq_len (in), nb_concepts)
local_importances = concept_explainer.concept_output_gradient(
inputs=sample,
targets=None, # all predicted tokens
activation_granularity=TOKEN, # same throughout the notebook
concepts_x_gradients=False, # based theoretically
normalization=False, # for each output token, the sum of the absolute values of the importance is equal to 1
batch_size=64,
)[0] # only one sample
mwsp.to("cpu")
# we take the sum over the input sequence dimension, has we focus on the concept-output relationship
# (seq_len (out), nb_concepts)
local_importances = local_importances.abs().sum(dim=1)
print(f"{local_importances.shape = }")
5.2 Input tokens activating the concepts¶
This step just looks at the concepts activation on the input tokens. It is a simple way to check if the concepts are activated on the input tokens.
# compute the latent activations
# (seq_len, d_model)
local_activations = mwsp.get_split_activations(mwsp.get_activations(sample, TOKEN))
# and the concepts activations
# (nb_concepts, seq_len)
concepts_activations = concept_explainer.encode_activations(local_activations).T
print(f"{concepts_activations.shape = }")
5.3 Visualization¶
There are several steps to visualize local concept-based explanations for generation.
- Select the output token as we only explain one prediction. Here we will take the last one "long".
- Extract the most important concepts for the output token.
- Visualize the concepts in a heatmap.
> ⚠️ Warning > > The following cell is a workaround with actual features to visualize something. > > But we plan to add a better visualization in the future. The cell will be included in the visualization and choices would be clicks on the interactive explanation.
from interpreto import AttributionVisualization
from interpreto.attributions.base import AttributionOutput, ModelTask
# select the last output token
last_token_concepts = local_importances[
-2
] # TODO: see if I take the mean/max an not just the last token, but I need to update the description
# top 5 concepts
top5_concepts = last_token_concepts.argsort(descending=True)[:5]
top5_concepts_activations = concepts_activations[top5_concepts]
# get additional interpretations
missing = [c.item() for c in top5_concepts if c.item() not in interpretations.keys()]
if len(missing) > 0:
interpretations.update(interpretation_method.interpret(missing, imdb, activations))
# create an attribution output [temporary workaround]
attr = AttributionOutput(
elements=sample_tokens,
attributions=top5_concepts_activations,
model_task=ModelTask.MULTI_CLASS_CLASSIFICATION,
model_inputs_to_explain=None,
targets=None,
)
# visualize the concepts
AttributionVisualization(
attr, class_names={i: interpretations[c.item()] for i, c in enumerate(top5_concepts)}
).display()