
Classification Concept-based Explanation Tutorial¶
Welcome to this tutorial, our will be to obtain concept-based explanations starting from the beginning.
For any precision, please refer to the Interpreto documentation.
There are five key steps for concepts based explanations:
- β Split your model in two parts
- π¦ Compute a dataset of activations
- ποΈββοΈ Fit a concept model on activations
- π·οΈ Interpret the concept dimensions
- π Find the globally important concepts
On which we add three bonus steps:
- π Class-wise concepts and LLM label
- π Locally important concepts
- βοΈ Evaluate concept-based explanations
Author: Antonin PochΓ©
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1. β Split your model in two parts ¶
We choose a DistilBERT fine-tuned on the AG-News dataset and split it just before the classification head.
To split the model, we use the interpreto.ModelWithSplitPoints which wraps around the transformers model and allows the computation of activations at the specified split_points.
from transformers import AutoModelForSequenceClassification
from interpreto import ModelWithSplitPoints
model_with_split_points = ModelWithSplitPoints(
model_or_repo_id="textattack/distilbert-base-uncased-ag-news",
automodel=AutoModelForSequenceClassification,
split_points=[5], # split at the sixth layer
device_map="cuda",
batch_size=1024,
)
2. π¦ Compute a datasets of activations ¶
We load the first 10000 documents of the AG-News train set.
Then we extract the activations of the [CLS] token of each document.
> β‘οΈ Common practice > > In the literature, to train concepts for classification it is common to use the [CLS] just before the classification head. > > In fact, at this layer, it makes no sense to use other elements.
> β οΈ Warning > > In this notebook, many things are specific to the use of the [CLS] token.
from datasets import load_dataset
# load the AG-News dataset
dataset = load_dataset("fancyzhx/ag_news")
inputs = dataset["train"]["text"][:1000] # here we use only 1000 examples to go faster, but the more, the better
classes_names = dataset["train"].features["label"].names
# Compute the [CLS] token activations
granularity = ModelWithSplitPoints.activation_granularities.CLS_TOKEN
activations = model_with_split_points.get_activations(
inputs=inputs,
activation_granularity=granularity,
include_predicted_classes=True,
)
3. ποΈββοΈ Fit a concept model on activations ¶
With activations, we can train a concept model to find patterns (concepts).
The concept_model is an attribute of our concept explainer, similarly to the model_with_split_points. With these these two elements, we can go from inputs to concepts and from concepts to outputs.
In this tutorial, we use interpreto.concepts.ICAConcepts built upon the ICA (Independent Component Analysis) dimension reduction algorithm.
There are at least 15 others concept model available in interpreto. do not hesitate to explore them.
> π₯ Tip
>
> ICAConcepts is a good first candidate for classification. It has no requirements, is fast, and provide correct first results on most datasets.
>
> Well the SemiNMFConcepts used in the better concepts section is too.
from interpreto.concepts import ICAConcepts
# instantiate the concept explainer
concept_explainer = ICAConcepts(model_with_split_points, nb_concepts=50, device="cuda")
# fit the concept explainer on activations
concept_explainer.fit(activations)
4. π·οΈ Interpret the concept dimensions ¶
We have our concepts and the link between concepts and classes. But now, we need to make sense of these concepts.
In this case, we will use the interpreto.concepts.interpretations.TopKInputs to find the 8 words which activates the most our concepts.
> β οΈ Warning
>
> If the granularity specified to the interpretation method is not the same as the one used for activations, the results will be wrong.
from interpreto.concepts.interpretations import TopKInputs
# instantiate the interpretation method with the concept explainer
topk_inputs_method = TopKInputs(
concept_explainer=concept_explainer,
k=5,
activation_granularity=granularity,
use_unique_words=True, # with the [CLS] token granularity, we are forced to use unique words
unique_words_kwargs={
"count_min_threshold": round(
len(inputs) * 0.002
), # appear in at least 0.2% of the samples | increase if random words appear and decrease if some words appear too often
"lemmatize": True,
"words_to_ignore": [], # include noise words and punctuation
},
)
# call the interpretation methods on the inputs
# we cannot give the previously computed activations because `use_unique_words=True` creates samples with a single word inside
topk_words = topk_inputs_method.interpret(
inputs=inputs,
concepts_indices="all",
)
5. π Find the globally important concepts ¶
We have concept directions, it means that our model has access to them, but not that it uses them.
It is the same when you train a model on tabular data, not all features are used.
In this step, we use the ConceptAutoEncoderExplainer.concept_output_gradients to evaluate the importance of each concept with respect to the predicted classes.
> β‘οΈ Note
>
> All unsupervised concept-based explainers in Interpreto inherit from ConceptAutoEncoderExplainer.
> β‘οΈ Note 2
>
> This step can be done prior to the interpretation, as the interpretation step can be compute heavy. Then specify using the concept_indices parameter.
> Only interpreting the important concepts can be wise. (Here we only have 50 concepts, so it does not matter.)
import torch
# estimate the importance of concepts for each class using the gradient
gradients = concept_explainer.concept_output_gradient(
inputs=inputs,
targets=None, # None means all classes
activation_granularity=granularity,
concepts_x_gradients=True, # the concept to output gradients are multiplied by the concepts values, this is common practice in the literature
batch_size=64,
)
# stack gradients on samples and average them over samples
mean_gradients = torch.stack(gradients).abs().squeeze().mean(0) # (num_classes, num_concepts)
# for each class, sort the importance scores
order = torch.argsort(mean_gradients, descending=True)
# visualize the top 5 concepts for each class
for target in range(order.shape[0]):
print(f"\nClass: {classes_names[target]}:")
for i in range(5):
concept_id = order[target, i].item()
importance = mean_gradients[target, concept_id].item()
words = list(topk_words.get(concept_id, None).keys())
print(f"\tconcept id: {concept_id},\timportance: {round(importance, 3)},\ttopk words: {words}")
> β The concepts are not interpretable, what do I do?
>
> - Try to improve the concept-space:
> - Increases the number of samples. You can artificially do so by splitting then by sentences (not included)
> - Try different concept-models and parameters
> - Try to compute concepts class-wise see next section
>
> - Improve the interpretation of concepts:
> - Play with the parameters
> - Try LLMLabels see next section
>
> - Try to evaluate the concepts, to automatically find the best methods. Check this other tutorial: TODO
>
> - Never forget the faithfulness-plausibility trade-off of explanations
6. π Better concepts with class-wise concepts and LLM labels ¶
This section aims at improving the concepts learned by the model. We try three different approaches:
- Training class-wise concepts
- Using another concept model:
SemiNMFConcepts - Using
interpreto.concepts.LLMLabelsto interpret the concepts
When a single concept-space is defined for all classes, concepts tend to correspond to the classes themselves. In particular, when the concept-space is built upon on the latent space just before the classification head.
In this section, we will learn a concept space for each class separately. Thus, the class-wise concept explainers will only see examples from a single class (based on the predictions).
> β οΈ Warning
>
> The following cell and several others will not work with an OpenAI API key. As LLMLabels requires a LLMInterface, and we chose the OpenAILLM one.
>
> What you can do to make it work:
> - Get an OpenAI API key and set it as an environment variable OPENAI_API_KEY.
> - Use TopKInputs to replace LLMLabels.
> - Branch your own LLMInterface and use it instead of OpenAILLM. See last section for an example.
import os
from interpreto.concepts import LLMLabels, SemiNMFConcepts
from interpreto.model_wrapping.llm_interface import OpenAILLM
# Load API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError(
"An API key is required to use `OpenAILLM` interface. ",
"Cannot use LLMLabels without an LLM interface. ",
"See last section for an example of how to branch one.",
)
# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-nano")
concept_explainers = {}
concept_interpretations = {}
concept_importances = {}
# iterate over classes
for target, class_name in enumerate(classes_names):
# ----------------------------------------------------------------------------------------------
# 2. construct the dataset of activations (extract the ones related to the class)
indices = (activations["predictions"] == target).nonzero(as_tuple=True)[0]
class_wise_inputs = [inputs[i] for i in indices]
class_wise_activations = {k: v[indices] for k, v in activations.items()}
# ----------------------------------------------------------------------------------------------
# 3. train concept model
concept_explainers[target] = SemiNMFConcepts(model_with_split_points, nb_concepts=20, device="cuda")
concept_explainers[target].fit(class_wise_activations)
# ----------------------------------------------------------------------------------------------
# 5. compute concepts importance (before interpretations to limit the number of concepts interpreted)
gradients = concept_explainers[target].concept_output_gradient(
inputs=class_wise_inputs,
targets=[target],
activation_granularity=granularity,
concepts_x_gradients=True,
batch_size=64,
)
# stack gradients on samples and average them over samples
concept_importances[target] = torch.stack(gradients, axis=0).squeeze().abs().mean(dim=0) # (num_concepts,)
# for each class, sort the importance scores
important_concept_indices = torch.argsort(concept_importances[target], descending=True).tolist()
# ----------------------------------------------------------------------------------------------
# 4. interpret the important concepts concepts
llm_labels_method = LLMLabels(
concept_explainer=concept_explainers[target],
activation_granularity=granularity,
llm_interface=llm_interface,
k_examples=20,
)
concept_interpretations[target] = llm_labels_method.interpret(
inputs=class_wise_inputs,
concepts_indices=important_concept_indices, # only the top 5 concepts for this class
)
print(f"\nClass: {class_name}")
for concept_id in important_concept_indices[:5]:
label = concept_interpretations[target].get(concept_id, None)
importance = concept_importances[target][concept_id].item()
if label is not None:
print(f"\timportance: {round(importance, 3)},\t{label}")
6. π Locally important concepts ¶
We got which concept are important for the classes globally. However, the concepts are not all present in each sample and the model might rely on a specific concept for a specific sample. Let's look at locally important concepts, meaning, the concepts the model used in a specific sample.
> β‘οΈ Note > > We cannot look at which word activates which concept without doing a forward pass for each word individually, because we use the [CLS] token. You could do the following cell, iterate on words and replace the example by the word.
example = "Bio-engineered shoes will revolutionized running throughout the world, as they cost only 50 dollars."
pred = model_with_split_points.get_activations(
inputs=[example],
activation_granularity=granularity,
include_predicted_classes=True,
)["predictions"].item()
print(f"Example: {example}")
print(f"Predicted class: {classes_names[pred]}")
# compute local concepts importance for the class
# we use the class-wise
local_importance = concept_explainers[pred].concept_output_gradient(
inputs=[example],
targets=[pred],
activation_granularity=granularity,
concepts_x_gradients=True,
tqdm_bar=False,
)[0] # there is only one sample
# importance of shape (t, g, c) -> (c,)
# with t the target, here 1 as we have a single target
# g the number of activation granularity elements in one input, here 1 as we focus on the CLS_TOKEN
# c the number of concepts
local_importance = local_importance.squeeze()
# order the concepts by importance
ordered_indices = torch.argsort(local_importance, descending=True)
# print top 5 concepts
print("\nlocal importance\t| concept label", "------------------------+--------------", sep="\n")
for concept_id in ordered_indices[:5]:
importance = local_importance[concept_id]
label = concept_interpretations[target][concept_id.item()]
print(f"{round(importance.item(), 3)}\t\t\t| {label}")
7. βοΈ Evaluate concept-based explanations ¶
We take back the ICAConcepts explainer and evaluate it on new samples.
test_inputs = dataset["test"]["text"][:1000] # let's take one thousand test samples
# Compute the [CLS] token activations
test_activations = model_with_split_points.get_activations(
inputs=test_inputs,
activation_granularity=granularity,
include_predicted_classes=True,
)
7.1 π Evaluate the concept-space from the third part¶
> β οΈ Warning: > > These metrics should only be used to compare the concept-space trained in similar contexts, same model, split point, activation dataset...
Reconstruction error¶
from interpreto.concepts.metrics import FID, MSE
mse = MSE(concept_explainer).compute(test_activations)
fid = FID(concept_explainer).compute(test_activations)
print(f"MSE: {round(mse, 3)}, FID: {round(fid, 3)}")
> β‘οΈ Note > > Alone these values are useless, they should be compared between several concept explainers.
from interpreto.concepts.metrics import Sparsity, SparsityRatio
sparsity = Sparsity(concept_explainer).compute(test_activations)
ratio = SparsityRatio(concept_explainer).compute(test_activations)
print(f"Sparsity: {round(sparsity, 3)}, Sparsity ratio: {round(ratio, 3)}")
Dictionary metrics¶
The Stability metric requires two concept explainer, hence our first step step will be to train a new ICAConcepts with the same model, split, dataset, and hyper-parameters as the original ICAConcepts. However, to get a statistically robust metric score, one should compare more than just two instances of the same explainer.
from interpreto.concepts.metrics import Stability
# instantiate and train a second concept explainer
second_explainer = ICAConcepts(model_with_split_points, nb_concepts=50, device="cuda")
second_explainer.fit(activations)
stability = Stability(concept_explainer, second_explainer).compute()
del second_explainer
print(f"Stability: {round(stability, 3)}")
7.2 π Evaluate the concepts-interpretations from the fourth step¶
# Work in progress, coming soon
7.3 βοΈ Evaluate the whole concept-based explanations with ConSim¶
ConSim is a metric evaluating the whole concept-based explanations in an end-to-end manner. Indeed, this metric, evaluates to which extend the provided concept-based explanations help a meta-predictor to predict what the studied model would have predicted. The idea is that is a meta-predictor understands the model, it is able to predict what the model would have predicted on new samples.
> β‘οΈ Note > > For significant scores, we iterate on 10 different seeds. (5 were used in the paper).
from interpreto.concepts.metrics.consim import ConSim, PromptTypes
# convert step 5 global concept importances to a dictionary
global_importances = {
class_name: dict(enumerate(importances))
for class_name, importances in zip(classes_names, mean_gradients, strict=True)
}
# Load API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError(
"An API key is required to use `OpenAILLM` interface. ",
"Cannot use LLMLabels without an LLM interface. ",
"See last section for an example of how to branch one.",
)
# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-nano")
# Initialize the ConSim with the model with split points and the user-llm
# Therefore, a given ConSim metric can be used on different explainers for cleaner comparison
con_sim = ConSim(model_with_split_points, llm_interface, classes=classes_names, activation_granularity=granularity)
baseline_list = []
ica_score_list = []
for seed in range(10):
# Select examples for evaluation
samples, labels, predictions = con_sim.select_examples(
inputs=dataset["train"]["text"][:5000],
labels=torch.tensor(dataset["train"]["label"][:5000]).cuda(),
seed=seed,
)
# Compute a baseline and ConSim score to give sense to the explainer ConSim score
baseline = con_sim.evaluate(
interesting_samples=samples, predictions=predictions, prompt_type=PromptTypes.L2_baseline_with_lp
)
if baseline is None:
continue
# Compute the ConSim score for an explainer
ica_score = con_sim.evaluate(
interesting_samples=samples,
predictions=predictions,
concept_explainer=concept_explainer,
concepts_interpretation=topk_words,
global_importances=global_importances,
prompt_type=PromptTypes.E2_global_concepts_with_lp,
)
if ica_score is None:
continue
baseline_list.append(baseline)
ica_score_list.append(ica_score)
print(f"Baseline: {round(sum(baseline_list) / 10, 2)}, ICA: {round(sum(ica_score_list) / 10, 2)}")
> β‘οΈ Note > > We evaluated the first ICA explainer here. Concepts where not really interpretable, ConSim agrees.
8. Using your own LLM interface ¶
from interpreto.model_wrapping.llm_interface import LLMInterface, Role
class GeminiLLM(LLMInterface):
def __init__(self, api_key: str, model: str = "gemini-1.5-flash", num_try: int = 5):
try:
import google.generativeai as genai # noqa: PLC0415 # ruff: disable=import-outside-toplevel
except ImportError as e:
raise ImportError("Install google-generativeai to use Google Gemini API.") from e
self.genai = genai
self.genai.configure(api_key=api_key)
self.model = model
self.num_try = num_try
def generate(self, prompt: list[tuple[Role, str]]) -> str | None:
# Build system instruction and chat history for Gemini
system_messages: list[str] = []
contents: list[dict] = []
for role, content in prompt:
if role == Role.SYSTEM:
system_messages.append(content)
elif role == Role.USER:
contents.append(
{
"role": "user",
"parts": [{"text": content}],
}
)
elif role == Role.ASSISTANT:
contents.append(
{
"role": "model",
"parts": [{"text": content}],
}
)
else:
raise ValueError(f"Unknown role for google gemini api: {role}")
system_instruction: str | None = "\n".join(system_messages) if system_messages else None
label: str | None = None
for _ in range(self.num_try):
try:
model = self.genai.GenerativeModel(
model_name=self.model,
system_instruction=system_instruction,
)
response = model.generate_content(contents) # type: ignore[arg-type]
# google-generativeai exposes the main text as .text
label = response.text
break
except Exception as e: # noqa: BLE001
print(e)
return label