Skip to content

interpreto_banner

Classification Concept-based Explanation Tutorial

Welcome to this tutorial, our will be to obtain concept-based explanations starting from the beginning.

For any precision, please refer to the Interpreto documentation.

There are five key steps for concepts based explanations:

  1. βž— Split your model in two parts
  2. 🚦 Compute a dataset of activations
  3. πŸ‹οΈβ€β™‚οΈ Fit a concept model on activations
  4. 🏷️ Interpret the concept dimensions
  5. 🌍 Find the globally important concepts

On which we add three bonus steps:

  1. πŸ“š Class-wise concepts and LLM label
  2. πŸ“ Locally important concepts
  3. βš–οΈ Evaluate concept-based explanations

Author: Antonin PochΓ©

import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. βž— Split your model in two parts

We choose a DistilBERT fine-tuned on the AG-News dataset and split it just before the classification head.

To split the model, we use the interpreto.ModelWithSplitPoints which wraps around the transformers model and allows the computation of activations at the specified split_points.

from transformers import AutoModelForSequenceClassification

from interpreto import ModelWithSplitPoints

model_with_split_points = ModelWithSplitPoints(
    model_or_repo_id="textattack/distilbert-base-uncased-ag-news",
    automodel=AutoModelForSequenceClassification,
    split_points=[5],  # split at the sixth layer
    device_map="cuda",
    batch_size=1024,
)

2. 🚦 Compute a datasets of activations

We load the first 10000 documents of the AG-News train set.

Then we extract the activations of the [CLS] token of each document.

> ➑️ Common practice > > In the literature, to train concepts for classification it is common to use the [CLS] just before the classification head. > > In fact, at this layer, it makes no sense to use other elements.

> ⚠️ Warning > > In this notebook, many things are specific to the use of the [CLS] token.

interpreto.ModelWithSplitPoints.get_activations()

from datasets import load_dataset

# load the AG-News dataset
dataset = load_dataset("fancyzhx/ag_news")
inputs = dataset["train"]["text"][:1000]  # here we use only 1000 examples to go faster, but the more, the better
classes_names = dataset["train"].features["label"].names

# Compute the [CLS] token activations
granularity = ModelWithSplitPoints.activation_granularities.CLS_TOKEN
activations = model_with_split_points.get_activations(
    inputs=inputs,
    activation_granularity=granularity,
    include_predicted_classes=True,
)

3. πŸ‹οΈβ€β™‚οΈ Fit a concept model on activations

With activations, we can train a concept model to find patterns (concepts).

The concept_model is an attribute of our concept explainer, similarly to the model_with_split_points. With these these two elements, we can go from inputs to concepts and from concepts to outputs.

In this tutorial, we use interpreto.concepts.ICAConcepts built upon the ICA (Independent Component Analysis) dimension reduction algorithm.

There are at least 15 others concept model available in interpreto. do not hesitate to explore them.

> πŸ”₯ Tip > > ICAConcepts is a good first candidate for classification. It has no requirements, is fast, and provide correct first results on most datasets. > > Well the SemiNMFConcepts used in the better concepts section is too.

from interpreto.concepts import ICAConcepts

# instantiate the concept explainer
concept_explainer = ICAConcepts(model_with_split_points, nb_concepts=50, device="cuda")

# fit the concept explainer on activations
concept_explainer.fit(activations)

4. 🏷️ Interpret the concept dimensions

We have our concepts and the link between concepts and classes. But now, we need to make sense of these concepts.

In this case, we will use the interpreto.concepts.interpretations.TopKInputs to find the 8 words which activates the most our concepts.

> ⚠️ Warning > > If the granularity specified to the interpretation method is not the same as the one used for activations, the results will be wrong.

from interpreto.concepts.interpretations import TopKInputs

# instantiate the interpretation method with the concept explainer
topk_inputs_method = TopKInputs(
    concept_explainer=concept_explainer,
    k=5,
    activation_granularity=granularity,
    use_unique_words=True,  # with the [CLS] token granularity, we are forced to use unique words
    unique_words_kwargs={
        "count_min_threshold": round(
            len(inputs) * 0.002
        ),  # appear in at least 0.2% of the samples | increase if random words appear and decrease if some words appear too often
        "lemmatize": True,
        "words_to_ignore": [],  # include noise words and punctuation
    },
)
# call the interpretation methods on the inputs
# we cannot give the previously computed activations because `use_unique_words=True` creates samples with a single word inside
topk_words = topk_inputs_method.interpret(
    inputs=inputs,
    concepts_indices="all",
)

5. 🌍 Find the globally important concepts

We have concept directions, it means that our model has access to them, but not that it uses them.

It is the same when you train a model on tabular data, not all features are used.

In this step, we use the ConceptAutoEncoderExplainer.concept_output_gradients to evaluate the importance of each concept with respect to the predicted classes.

> ➑️ Note > > All unsupervised concept-based explainers in Interpreto inherit from ConceptAutoEncoderExplainer.

> ➑️ Note 2 > > This step can be done prior to the interpretation, as the interpretation step can be compute heavy. Then specify using the concept_indices parameter. > Only interpreting the important concepts can be wise. (Here we only have 50 concepts, so it does not matter.)

import torch

# estimate the importance of concepts for each class using the gradient
gradients = concept_explainer.concept_output_gradient(
    inputs=inputs,
    targets=None,  # None means all classes
    activation_granularity=granularity,
    concepts_x_gradients=True,  # the concept to output gradients are multiplied by the concepts values, this is common practice in the literature
    batch_size=64,
)

# stack gradients on samples and average them over samples
mean_gradients = torch.stack(gradients).abs().squeeze().mean(0)  # (num_classes, num_concepts)

# for each class, sort the importance scores
order = torch.argsort(mean_gradients, descending=True)

# visualize the top 5 concepts for each class
for target in range(order.shape[0]):
    print(f"\nClass: {classes_names[target]}:")
    for i in range(5):
        concept_id = order[target, i].item()
        importance = mean_gradients[target, concept_id].item()
        words = list(topk_words.get(concept_id, None).keys())
        print(f"\tconcept id: {concept_id},\timportance: {round(importance, 3)},\ttopk words: {words}")

Class: World:
    concept id: 27, importance: 0.091,  topk words: ['serbia-montenegro', 'nato', 'liechtenstein', 'nikkei', 'ossetia']
    concept id: 45, importance: 0.078,  topk words: ['separatist', 'militia', 'militiaman', 'usatoday.com', 'inquirer']
    concept id: 11, importance: 0.07,   topk words: ['betting', 'saudi', 'gambler', 'shark', 'kidnapper']
    concept id: 31, importance: 0.064,  topk words: ['betting', 'fraud', 'holy', 'kmart', 'p.m.']
    concept id: 48, importance: 0.041,  topk words: ['afp', 'armed', 'hue', 'anarchist', 'naval']

Class: Sports:
    concept id: 7,  importance: 0.111,  topk words: ['phelps', '200-meter', 'batter', 'inning', 'homered']
    concept id: 49, importance: 0.092,  topk words: ['heat', '100-meter', 'fastest', '200-meter', '200m']
    concept id: 30, importance: 0.087,  topk words: ['200-meter', '100-meter', 'breaststroke', '400-meter', 'heat']
    concept id: 33, importance: 0.069,  topk words: ['phillies', 'mets', 'baltimore', 'sox', 'nl']
    concept id: 24, importance: 0.062,  topk words: ['stryker', 'nfl', 'armadillo', 'homer', 'autodesk']

Class: Business:
    concept id: 13, importance: 0.084,  topk words: ['pharmacare', 'procurement', 'costly', 'sector', 'marketer']
    concept id: 28, importance: 0.067,  topk words: ['vodafone', 'antitrust', 'ipo.google.com', 'verizon', 'lenovo']
    concept id: 19, importance: 0.046,  topk words: [';', 'ott', 'atp', 'fcc', '3g']
    concept id: 7,  importance: 0.042,  topk words: ['phelps', '200-meter', 'batter', 'inning', 'homered']
    concept id: 37, importance: 0.04,   topk words: ['eurozone', 'finance', 'imf', 'economy', 'balance']

Class: Sci/Tech:
    concept id: 7,  importance: 0.057,  topk words: ['phelps', '200-meter', 'batter', 'inning', 'homered']
    concept id: 49, importance: 0.053,  topk words: ['heat', '100-meter', 'fastest', '200-meter', '200m']
    concept id: 13, importance: 0.052,  topk words: ['pharmacare', 'procurement', 'costly', 'sector', 'marketer']
    concept id: 27, importance: 0.05,   topk words: ['serbia-montenegro', 'nato', 'liechtenstein', 'nikkei', 'ossetia']
    concept id: 30, importance: 0.046,  topk words: ['200-meter', '100-meter', 'breaststroke', '400-meter', 'heat']

> ❓ The concepts are not interpretable, what do I do? > > - Try to improve the concept-space: > - Increases the number of samples. You can artificially do so by splitting then by sentences (not included) > - Try different concept-models and parameters > - Try to compute concepts class-wise see next section > > - Improve the interpretation of concepts: > - Play with the parameters > - Try LLMLabels see next section > > - Try to evaluate the concepts, to automatically find the best methods. Check this other tutorial: TODO > > - Never forget the faithfulness-plausibility trade-off of explanations

6. πŸ“š Better concepts with class-wise concepts and LLM labels

This section aims at improving the concepts learned by the model. We try three different approaches:

When a single concept-space is defined for all classes, concepts tend to correspond to the classes themselves. In particular, when the concept-space is built upon on the latent space just before the classification head.

In this section, we will learn a concept space for each class separately. Thus, the class-wise concept explainers will only see examples from a single class (based on the predictions).

> ⚠️ Warning > > The following cell and several others will not work with an OpenAI API key. As LLMLabels requires a LLMInterface, and we chose the OpenAILLM one. > > What you can do to make it work: > - Get an OpenAI API key and set it as an environment variable OPENAI_API_KEY. > - Use TopKInputs to replace LLMLabels. > - Branch your own LLMInterface and use it instead of OpenAILLM. See last section for an example.

import os

from interpreto.concepts import LLMLabels, SemiNMFConcepts
from interpreto.model_wrapping.llm_interface import OpenAILLM

# Load API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
    raise ValueError(
        "An API key is required to use `OpenAILLM` interface. ",
        "Cannot use LLMLabels without an LLM interface. ",
        "See last section for an example of how to branch one.",
    )

# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-nano")

concept_explainers = {}
concept_interpretations = {}
concept_importances = {}

# iterate over classes
for target, class_name in enumerate(classes_names):
    # ----------------------------------------------------------------------------------------------
    # 2. construct the dataset of activations (extract the ones related to the class)
    indices = (activations["predictions"] == target).nonzero(as_tuple=True)[0]
    class_wise_inputs = [inputs[i] for i in indices]
    class_wise_activations = {k: v[indices] for k, v in activations.items()}

    # ----------------------------------------------------------------------------------------------
    # 3. train concept model
    concept_explainers[target] = SemiNMFConcepts(model_with_split_points, nb_concepts=20, device="cuda")
    concept_explainers[target].fit(class_wise_activations)

    # ----------------------------------------------------------------------------------------------
    # 5. compute concepts importance (before interpretations to limit the number of concepts interpreted)
    gradients = concept_explainers[target].concept_output_gradient(
        inputs=class_wise_inputs,
        targets=[target],
        activation_granularity=granularity,
        concepts_x_gradients=True,
        batch_size=64,
    )

    # stack gradients on samples and average them over samples
    concept_importances[target] = torch.stack(gradients, axis=0).squeeze().abs().mean(dim=0)  # (num_concepts,)

    # for each class, sort the importance scores
    important_concept_indices = torch.argsort(concept_importances[target], descending=True).tolist()

    # ----------------------------------------------------------------------------------------------
    # 4. interpret the important concepts concepts
    llm_labels_method = LLMLabels(
        concept_explainer=concept_explainers[target],
        activation_granularity=granularity,
        llm_interface=llm_interface,
        k_examples=20,
    )

    concept_interpretations[target] = llm_labels_method.interpret(
        inputs=class_wise_inputs,
        concepts_indices=important_concept_indices,  # only the top 5 concepts for this class
    )

    print(f"\nClass: {class_name}")
    for concept_id in important_concept_indices[:5]:
        label = concept_interpretations[target].get(concept_id, None)
        importance = concept_importances[target][concept_id].item()
        if label is not None:
            print(f"\timportance: {round(importance, 3)},\t{label}")

Class: World
    importance: 0.088,  Diverse topics unified by factual reporting and event-focused language
    importance: 0.086,  Consistent focus on named entities, events, and quotations.
    importance: 0.084,  Consistently presents factual info with emphasis on notable events or figures.
    importance: 0.082,  Recurring topics and formal reporting language.
    importance: 0.078,  Concise reporting style with specific names, events, and dates.

Class: Sports
    importance: 0.112,  Patterns involve sports, achievements, and event summaries with emphasis on names, scores, and outcomes.
    importance: 0.105,  Structured sports and news summaries emphasize game events, scores, and highlights, often including player performances and outcomes in concise, formulaic language. The pattern involves brief, factual narration focusing on key actions and results.
    importance: 0.082,  Focus on key entities, events, and record-breaking achievements.
    importance: 0.076,  Consistent references to sports events, players, and results.
    importance: 0.07,   Concise patterns: factual summaries with focus on specific event outcomes.

Class: Business
    importance: 0.092,  Concise patterns involve financial, infrastructural, and cultural references.
    importance: 0.083,  Structured information emphasis, financial and infrastructural keywords.
    importance: 0.081,  Topical focus, political or economic narratives, and authoritative tone.
    importance: 0.074,  Concise patterns include focus on current events, economic issues, and industry developments.
    importance: 0.072,  Focus on comparative categories and rankings across regions.

Class: Sci/Tech
    importance: 0.121,  Language features that explicitly reference entities, events, or dates.
    importance: 0.101,  Consistent focus on environmental impacts, species, and conservation initiatives.
    importance: 0.095,  Scientific observations of natural phenomena and technological descriptions.
    importance: 0.093,  Factual reporting with scientific, environmental, and technological focus, structured as objective summaries.
    importance: 0.081,  Factual, technical, and headline-like language with specific details

6. πŸ“ Locally important concepts

We got which concept are important for the classes globally. However, the concepts are not all present in each sample and the model might rely on a specific concept for a specific sample. Let's look at locally important concepts, meaning, the concepts the model used in a specific sample.

> ➑️ Note > > We cannot look at which word activates which concept without doing a forward pass for each word individually, because we use the [CLS] token. You could do the following cell, iterate on words and replace the example by the word.

example = "Bio-engineered shoes will revolutionized running throughout the world, as they cost only 50 dollars."

pred = model_with_split_points.get_activations(
    inputs=[example],
    activation_granularity=granularity,
    include_predicted_classes=True,
)["predictions"].item()

print(f"Example: {example}")
print(f"Predicted class: {classes_names[pred]}")

# compute local concepts importance for the class
# we use the class-wise
local_importance = concept_explainers[pred].concept_output_gradient(
    inputs=[example],
    targets=[pred],
    activation_granularity=granularity,
    concepts_x_gradients=True,
    tqdm_bar=False,
)[0]  # there is only one sample

# importance of shape (t, g, c) -> (c,)
# with t the target, here 1 as we have a single target
# g the number of activation granularity elements in one input, here 1 as we focus on the CLS_TOKEN
# c the number of concepts
local_importance = local_importance.squeeze()

# order the concepts by importance
ordered_indices = torch.argsort(local_importance, descending=True)

# print top 5 concepts
print("\nlocal importance\t| concept label", "------------------------+--------------", sep="\n")
for concept_id in ordered_indices[:5]:
    importance = local_importance[concept_id]
    label = concept_interpretations[target][concept_id.item()]
    print(f"{round(importance.item(), 3)}\t\t\t| {label}")
Example: Bio-engineered shoes will revolutionized running throughout the world, as they cost only 50 dollars.
Predicted class: Sci/Tech

local importance    | concept label
------------------------+--------------
0.123           | Factual, technical, and headline-like language with specific details
0.115           | Structured informational content with consistent sentence patterns and varied topics.
0.076           | Scientific observations of natural phenomena and technological descriptions.
0.064           | Factual reporting with scientific, environmental, and technological focus, structured as objective summaries.
0.049           | Structured, factual reporting with explicit event details and dates.

7. βš–οΈ Evaluate concept-based explanations

We take back the ICAConcepts explainer and evaluate it on new samples.

test_inputs = dataset["test"]["text"][:1000]  # let's take one thousand test samples

# Compute the [CLS] token activations
test_activations = model_with_split_points.get_activations(
    inputs=test_inputs,
    activation_granularity=granularity,
    include_predicted_classes=True,
)

7.1 🌐 Evaluate the concept-space from the third part

> ⚠️ Warning: > > These metrics should only be used to compare the concept-space trained in similar contexts, same model, split point, activation dataset...

from interpreto.concepts.metrics import FID, MSE

mse = MSE(concept_explainer).compute(test_activations)
fid = FID(concept_explainer).compute(test_activations)

print(f"MSE: {round(mse, 3)}, FID: {round(fid, 3)}")
MSE: 165.919, FID: 0.041

> ➑️ Note > > Alone these values are useless, they should be compared between several concept explainers.

from interpreto.concepts.metrics import Sparsity, SparsityRatio

sparsity = Sparsity(concept_explainer).compute(test_activations)
ratio = SparsityRatio(concept_explainer).compute(test_activations)

print(f"Sparsity: {round(sparsity, 3)}, Sparsity ratio: {round(ratio, 3)}")
Sparsity: 1.0, Sparsity ratio: 0.02

Dictionary metrics

The Stability metric requires two concept explainer, hence our first step step will be to train a new ICAConcepts with the same model, split, dataset, and hyper-parameters as the original ICAConcepts. However, to get a statistically robust metric score, one should compare more than just two instances of the same explainer.

from interpreto.concepts.metrics import Stability

# instantiate and train a second concept explainer
second_explainer = ICAConcepts(model_with_split_points, nb_concepts=50, device="cuda")
second_explainer.fit(activations)

stability = Stability(concept_explainer, second_explainer).compute()
del second_explainer

print(f"Stability: {round(stability, 3)}")
Stability: 1.0

7.2 πŸ’­ Evaluate the concepts-interpretations from the fourth step

# Work in progress, coming soon

7.3 ↔️ Evaluate the whole concept-based explanations with ConSim

ConSim is a metric evaluating the whole concept-based explanations in an end-to-end manner. Indeed, this metric, evaluates to which extend the provided concept-based explanations help a meta-predictor to predict what the studied model would have predicted. The idea is that is a meta-predictor understands the model, it is able to predict what the model would have predicted on new samples.

> ➑️ Note > > For significant scores, we iterate on 10 different seeds. (5 were used in the paper).

from interpreto.concepts.metrics.consim import ConSim, PromptTypes

# convert step 5 global concept importances to a dictionary
global_importances = {
    class_name: dict(enumerate(importances))
    for class_name, importances in zip(classes_names, mean_gradients, strict=True)
}

# Load API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
    raise ValueError(
        "An API key is required to use `OpenAILLM` interface. ",
        "Cannot use LLMLabels without an LLM interface. ",
        "See last section for an example of how to branch one.",
    )

# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-nano")

# Initialize the ConSim  with the model with split points and the user-llm
# Therefore, a given ConSim metric can be used on different explainers for cleaner comparison
con_sim = ConSim(model_with_split_points, llm_interface, classes=classes_names, activation_granularity=granularity)

baseline_list = []
ica_score_list = []
for seed in range(10):
    # Select examples for evaluation
    samples, labels, predictions = con_sim.select_examples(
        inputs=dataset["train"]["text"][:5000],
        labels=torch.tensor(dataset["train"]["label"][:5000]).cuda(),
        seed=seed,
    )

    # Compute a baseline and ConSim score to give sense to the explainer ConSim score
    baseline = con_sim.evaluate(
        interesting_samples=samples, predictions=predictions, prompt_type=PromptTypes.L2_baseline_with_lp
    )

    if baseline is None:
        continue

    # Compute the ConSim score for an explainer
    ica_score = con_sim.evaluate(
        interesting_samples=samples,
        predictions=predictions,
        concept_explainer=concept_explainer,
        concepts_interpretation=topk_words,
        global_importances=global_importances,
        prompt_type=PromptTypes.E2_global_concepts_with_lp,
    )

    if ica_score is None:
        continue

    baseline_list.append(baseline)
    ica_score_list.append(ica_score)

print(f"Baseline: {round(sum(baseline_list) / 10, 2)}, ICA: {round(sum(ica_score_list) / 10, 2)}")
Baseline: 0.28, ICA: 0.29

> ➑️ Note > > We evaluated the first ICA explainer here. Concepts where not really interpretable, ConSim agrees.

8. Using your own LLM interface

from interpreto.model_wrapping.llm_interface import LLMInterface, Role


class GeminiLLM(LLMInterface):
    def __init__(self, api_key: str, model: str = "gemini-1.5-flash", num_try: int = 5):
        try:
            import google.generativeai as genai  # noqa: PLC0415  # ruff: disable=import-outside-toplevel
        except ImportError as e:
            raise ImportError("Install google-generativeai to use Google Gemini API.") from e

        self.genai = genai
        self.genai.configure(api_key=api_key)
        self.model = model
        self.num_try = num_try

    def generate(self, prompt: list[tuple[Role, str]]) -> str | None:
        # Build system instruction and chat history for Gemini
        system_messages: list[str] = []
        contents: list[dict] = []

        for role, content in prompt:
            if role == Role.SYSTEM:
                system_messages.append(content)
            elif role == Role.USER:
                contents.append(
                    {
                        "role": "user",
                        "parts": [{"text": content}],
                    }
                )
            elif role == Role.ASSISTANT:
                contents.append(
                    {
                        "role": "model",
                        "parts": [{"text": content}],
                    }
                )
            else:
                raise ValueError(f"Unknown role for google gemini api: {role}")

        system_instruction: str | None = "\n".join(system_messages) if system_messages else None

        label: str | None = None
        for _ in range(self.num_try):
            try:
                model = self.genai.GenerativeModel(
                    model_name=self.model,
                    system_instruction=system_instruction,
                )
                response = model.generate_content(contents)  # type: ignore[arg-type]
                # google-generativeai exposes the main text as .text
                label = response.text
                break
            except Exception as e:  # noqa: BLE001
                print(e)
        return label