Skip to content

interpreto_banner

Classification Concept-based Explanation Tutorial

Welcome to this tutorial, our will be to obtain concept-based explanations starting from the beginning.

For any precision, please refer to the Interpreto documentation.

There are five key steps for concepts based explanations:

  1. βž— Split your model in two parts
  2. 🚦 Compute a dataset of activations
  3. πŸ‹οΈβ€β™‚οΈ Fit a concept model on activations
  4. 🏷️ Interpret the concept dimensions
  5. 🌍 Find the globally important concepts

On which we add three bonus steps:

  1. πŸ“š Class-wise concepts and LLM label
  2. πŸ“ Locally important concepts
  3. βš–οΈ Evaluate concept-based explanations

Author: Antonin PochΓ©

import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. βž— Split your model in two parts

We choose a DistilBERT fine-tuned on the AG-News dataset and split it just before the classification head.

To split the model, we use the interpreto.ModelWithSplitPoints which wraps around the transformers model and allows the computation of activations at the specified split_points.

from transformers import AutoModelForSequenceClassification

from interpreto import ModelWithSplitPoints

model_with_split_points = ModelWithSplitPoints(
    model_or_repo_id="textattack/distilbert-base-uncased-ag-news",
    automodel=AutoModelForSequenceClassification,
    split_points=[5],  # split at the sixth layer
    device_map="cuda",
    batch_size=1024,
)

2. 🚦 Compute a datasets of activations

We load the first 10000 documents of the AG-News train set.

Then we extract the activations of the [CLS] token of each document.

> ➑️ Common practice > > In the literature, to train concepts for classification it is common to use the [CLS] just before the classification head. > > In fact, at this layer, it makes no sense to use other elements.

> ⚠️ Warning > > In this notebook, many things are specific to the use of the [CLS] token.

interpreto.ModelWithSplitPoints.get_activations()

from datasets import load_dataset

# load the AG-News dataset
dataset = load_dataset("fancyzhx/ag_news")
inputs = dataset["train"]["text"][:1000]  # here we use only 1000 examples to go faster, but the more, the better
classes_names = dataset["train"].features["label"].names

# Compute the [CLS] token activations
granularity = ModelWithSplitPoints.activation_granularities.CLS_TOKEN
activations = model_with_split_points.get_activations(
    inputs=inputs,
    activation_granularity=granularity,
    include_predicted_classes=True,
)

3. πŸ‹οΈβ€β™‚οΈ Fit a concept model on activations

With activations, we can train a concept model to find patterns (concepts).

The concept_model is an attribute of our concept explainer, similarly to the model_with_split_points. With these these two elements, we can go from inputs to concepts and from concepts to outputs.

In this tutorial, we use interpreto.concepts.ICAConcepts built upon the ICA (Independent Component Analysis) dimension reduction algorithm.

There are at least 15 others concept model available in interpreto. do not hesitate to explore them.

> πŸ”₯ Tip > > ICAConcepts is a good first candidate for classification. It has no requirements, is fast, and provide correct first results on most datasets. > > Well the SemiNMFConcepts used in the better concepts section is too.

from interpreto.concepts import ICAConcepts

# instantiate the concept explainer
concept_explainer = ICAConcepts(model_with_split_points, nb_concepts=50, device="cuda")

# fit the concept explainer on activations
concept_explainer.fit(activations)

4. 🏷️ Interpret the concept dimensions

We have our concepts and the link between concepts and classes. But now, we need to make sense of these concepts.

In this case, we will use the interpreto.concepts.interpretations.TopKInputs to find the 8 words which activates the most our concepts.

> ⚠️ Warning > > If the granularity specified to the interpretation method is not the same as the one used for activations, the results will be wrong.

from interpreto.concepts.interpretations import TopKInputs

# instantiate the interpretation method with the concept explainer
topk_inputs_method = TopKInputs(
    concept_explainer=concept_explainer,
    k=5,
    activation_granularity=granularity,
    use_unique_words=True,  # with the [CLS] token granularity, we are forced to use unique words
    unique_words_kwargs={
        "count_min_threshold": round(
            len(inputs) * 0.002
        ),  # appear in at least 0.2% of the samples | increase if random words appear and decrease if some words appear too often
        "lemmatize": True,
        "words_to_ignore": [],  # include noise words and punctuation
    },
)
# call the interpretation methods on the inputs
# we cannot give the previously computed activations because `use_unique_words=True` creates samples with a single word inside
topk_words = topk_inputs_method.interpret(
    inputs=inputs,
    concepts_indices="all",
)

5. 🌍 Find the globally important concepts

We have concept directions, it means that our model has access to them, but not that it uses them.

It is the same when you train a model on tabular data, not all features are used.

In this step, we use the ConceptAutoEncoderExplainer.concept_output_gradients to evaluate the importance of each concept with respect to the predicted classes.

> ➑️ Note > > All unsupervised concept-based explainers in Interpreto inherit from ConceptAutoEncoderExplainer.

> ➑️ Note 2 > > This step can be done prior to the interpretation, as the interpretation step can be compute heavy. Then specify using the concept_indices parameter. > Only interpreting the important concepts can be wise. (Here we only have 50 concepts, so it does not matter.)

import torch

# estimate the importance of concepts for each class using the gradient
gradients = concept_explainer.concept_output_gradient(
    inputs=inputs,
    targets=None,  # None means all classes
    activation_granularity=granularity,
    concepts_x_gradients=True,  # the concept to output gradients are multiplied by the concepts values, this is common practice in the literature
    batch_size=64,
)

# stack gradients on samples and average them over samples
mean_gradients = torch.stack(gradients).abs().squeeze().mean(0)  # (num_classes, num_concepts)

# for each class, sort the importance scores
order = torch.argsort(mean_gradients, descending=True)

# visualize the top 5 concepts for each class
for target in range(order.shape[0]):
    print(f"\nClass: {classes_names[target]}:")
    for i in range(5):
        concept_id = order[target, i].item()
        importance = mean_gradients[target, concept_id].item()
        words = list(topk_words.get(concept_id, None).keys())
        print(f"\tconcept id: {concept_id},\timportance: {round(importance, 3)},\ttopk words: {words}")

Class: World:
    concept id: 27, importance: 0.091,  topk words: ['serbia-montenegro', 'nato', 'liechtenstein', 'nikkei', 'ossetia']
    concept id: 45, importance: 0.078,  topk words: ['separatist', 'militia', 'militiaman', 'usatoday.com', 'inquirer']
    concept id: 11, importance: 0.07,   topk words: ['betting', 'saudi', 'gambler', 'shark', 'kidnapper']
    concept id: 31, importance: 0.064,  topk words: ['betting', 'fraud', 'holy', 'kmart', 'p.m.']
    concept id: 48, importance: 0.041,  topk words: ['afp', 'armed', 'hue', 'anarchist', 'naval']

Class: Sports:
    concept id: 7,  importance: 0.11,   topk words: ['phelps', '200-meter', 'batter', 'inning', 'homered']
    concept id: 49, importance: 0.092,  topk words: ['heat', '100-meter', 'fastest', '200-meter', '200m']
    concept id: 30, importance: 0.087,  topk words: ['200-meter', '100-meter', 'breaststroke', '400-meter', 'heat']
    concept id: 33, importance: 0.069,  topk words: ['phillies', 'mets', 'baltimore', 'sox', 'nl']
    concept id: 24, importance: 0.062,  topk words: ['stryker', 'nfl', 'armadillo', 'homer', 'autodesk']

Class: Business:
    concept id: 13, importance: 0.084,  topk words: ['pharmacare', 'procurement', 'costly', 'sector', 'marketer']
    concept id: 28, importance: 0.068,  topk words: ['vodafone', 'ipo.google.com', 'antitrust', 'verizon', 'lenovo']
    concept id: 19, importance: 0.046,  topk words: [';', 'ott', 'atp', 'fcc', '3g']
    concept id: 7,  importance: 0.042,  topk words: ['phelps', '200-meter', 'batter', 'inning', 'homered']
    concept id: 37, importance: 0.04,   topk words: ['eurozone', 'finance', 'imf', 'economy', 'balance']

Class: Sci/Tech:
    concept id: 7,  importance: 0.057,  topk words: ['phelps', '200-meter', 'batter', 'inning', 'homered']
    concept id: 49, importance: 0.053,  topk words: ['heat', '100-meter', 'fastest', '200-meter', '200m']
    concept id: 13, importance: 0.052,  topk words: ['pharmacare', 'procurement', 'costly', 'sector', 'marketer']
    concept id: 27, importance: 0.05,   topk words: ['serbia-montenegro', 'nato', 'liechtenstein', 'nikkei', 'ossetia']
    concept id: 30, importance: 0.046,  topk words: ['200-meter', '100-meter', 'breaststroke', '400-meter', 'heat']

from interpreto import plot_concepts

labels = {k: list(v.keys()) for k, v in topk_words.items()}

plot_concepts(
    classes_names=classes_names,
    concepts_importances=mean_gradients,
    concepts_labels=labels,
)

Classes

> ❓ The concepts are not interpretable, what do I do? > > - Try to improve the concept-space: > - Increases the number of samples. You can artificially do so by splitting then by sentences (not included) > - Try different concept-models and parameters > - Try to compute concepts class-wise see next section > > - Improve the interpretation of concepts: > - Play with the parameters > - Try LLMLabels see next section > > - Try to evaluate the concepts, to automatically find the best methods. Check this other tutorial: TODO > > - Never forget the faithfulness-plausibility trade-off of explanations

6. πŸ“š Better concepts with class-wise concepts and LLM labels

This section aims at improving the concepts learned by the model. We try three different approaches:

When a single concept-space is defined for all classes, concepts tend to correspond to the classes themselves. In particular, when the concept-space is built upon on the latent space just before the classification head.

In this section, we will learn a concept space for each class separately. Thus, the class-wise concept explainers will only see examples from a single class (based on the predictions).

> ⚠️ Warning > > The following cell and several others will not work with an OpenAI API key. As LLMLabels requires a LLMInterface, and we chose the OpenAILLM one. > > What you can do to make it work: > - Get an OpenAI API key and set it as an environment variable OPENAI_API_KEY. > - Use TopKInputs to replace LLMLabels. > - Branch your own LLMInterface and use it instead of OpenAILLM. See last section for an example.

import os

from interpreto.concepts import LLMLabels, SemiNMFConcepts
from interpreto.model_wrapping.llm_interface import OpenAILLM

# Load API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
    raise ValueError(
        "An API key is required to use `OpenAILLM` interface. ",
        "Cannot use LLMLabels without an LLM interface. ",
        "See last section for an example of how to branch one.",
    )

# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-nano")

concept_explainers = {}
concept_interpretations = {}
concept_importances = {}

# iterate over classes
for target, class_name in enumerate(classes_names):
    # ----------------------------------------------------------------------------------------------
    # 2. construct the dataset of activations (extract the ones related to the class)
    indices = (activations["predictions"] == target).nonzero(as_tuple=True)[0]
    class_wise_inputs = [inputs[i] for i in indices]
    class_wise_activations = {k: v[indices] for k, v in activations.items()}

    # ----------------------------------------------------------------------------------------------
    # 3. train concept model
    concept_explainers[target] = SemiNMFConcepts(model_with_split_points, nb_concepts=20, device="cuda")
    concept_explainers[target].fit(class_wise_activations)

    # ----------------------------------------------------------------------------------------------
    # 5. compute concepts importance (before interpretations to limit the number of concepts interpreted)
    gradients = concept_explainers[target].concept_output_gradient(
        inputs=class_wise_inputs,
        targets=[target],
        activation_granularity=granularity,
        concepts_x_gradients=True,
        batch_size=64,
    )

    # stack gradients on samples and average them over samples
    concept_importances[target] = torch.stack(gradients, axis=0).squeeze().abs().mean(dim=0)  # (num_concepts,)

    # for each class, sort the importance scores
    important_concept_indices = torch.argsort(concept_importances[target], descending=True).tolist()

    # ----------------------------------------------------------------------------------------------
    # 4. interpret the important concepts concepts
    llm_labels_method = LLMLabels(
        concept_explainer=concept_explainers[target],
        activation_granularity=granularity,
        llm_interface=llm_interface,
        k_examples=20,
    )

    concept_interpretations[target] = llm_labels_method.interpret(
        inputs=class_wise_inputs,
        concepts_indices=important_concept_indices,
    )

    print(f"\nClass: {class_name}")
    for concept_id in important_concept_indices[:5]:
        label = concept_interpretations[target].get(concept_id, None)
        importance = concept_importances[target][concept_id].item()
        if label is not None:
            print(f"\timportance: {round(importance, 3)},\t{label}")

Class: World
    importance: 0.129,  Concise pattern: Formal, informational language with emphasis on factual reporting and descriptive detail.
    importance: 0.111,  Event-focused, concise, multi-word summaries highlighting key entities, actions, or dynamics.
    importance: 0.086,  Event-focused, descriptive summaries emphasizing struggle, health, and rituals.
    importance: 0.081,  Repetitive political and event reporting
    importance: 0.072,  Highly satirical, exaggerated, humorous patterns.

Class: Sports
    importance: 0.094,  Event reporting emphasizes action and outcomes, often highlighting surprises, victories, or controversies, with a focus on dynamics and conflicts.
    importance: 0.088,  Consistent references to named entities, events, and competitions with minimal descriptive language.
    importance: 0.085,  Event summaries with specific details and recent developments.
    importance: 0.084,  Concise pattern: Event descriptions with structured reporting, including participants, outcomes, and sometimes specific details or quotes.
    importance: 0.076,  Event-focused, sports-related, specific terminology.

Class: Business
    importance: 0.117,  Concise pattern: Economic focus on data-driven, factual reporting.
    importance: 0.105,  Diverse topics centered on socio-cultural, economic, and urban themes, with emphasis on local identities and modernization.
    importance: 0.099,  Diverse topics and sources, variable detail, consistent news reporting style.
    importance: 0.094,  Topic-specific keywords.
    importance: 0.085,  Consistent focus on finance, investments, or economic topics.

Class: Sci/Tech
    importance: 0.13,   Data-driven scientific advancements and technical developments
    importance: 0.116,  Technology focus on communication, regulation, security.
    importance: 0.102,  Concise, global phenomena or scientific explanations.
    importance: 0.102,  Diverse technical and scientific descriptions with consistent focus on innovations, mechanisms, and discoveries.
    importance: 0.064,  Concise pattern: Commercial language with factual tone.

plot_concepts(
    classes_names=classes_names,
    concepts_importances=concept_importances,
    concepts_labels=concept_interpretations,
)

Classes

6. πŸ“ Locally important concepts

We got which concept are important for the classes globally. However, the concepts are not all present in each sample and the model might rely on a specific concept for a specific sample. Let's look at locally important concepts, meaning, the concepts the model used in a specific sample.

> ➑️ Note > > We cannot look at which word activates which concept without doing a forward pass for each word individually, because we use the [CLS] token. You could do the following cell, iterate on words and replace the example by the word.

example = "Bio-engineered shoes will revolutionized running throughout the world, as they cost only 50 dollars."

activations_dict = model_with_split_points.get_activations(
    inputs=[example],
    activation_granularity=granularity,
    include_predicted_classes=True,
)
pred = activations_dict.pop("predictions").item()
local_activations = next(iter(activations_dict.values()))
concepts_activations = concept_explainers[pred].encode_activations(local_activations)

print(f"Example: {example}")
print(f"Predicted class: {classes_names[pred]}")

# compute local concepts importance for the class
# we use the class-wise
local_importance = concept_explainers[pred].concept_output_gradient(
    inputs=[example],
    activation_granularity=granularity,
    concepts_x_gradients=True,
    tqdm_bar=False,
)[0]  # there is only one sample

plot_concepts(
    sample=[example],
    classes_names=classes_names,
    concepts_activations=concepts_activations,
    concepts_importances=local_importance.squeeze(),  # importance of shape (t, g, c) -> (t, c)
    concepts_labels=concept_interpretations,
)
Example: Bio-engineered shoes will revolutionized running throughout the world, as they cost only 50 dollars.
Predicted class: Sci/Tech

Classes

Concepts

Sample

7. βš–οΈ Evaluate concept-based explanations

We take back the ICAConcepts explainer and evaluate it on new samples.

test_inputs = dataset["test"]["text"][:1000]  # let's take one thousand test samples

# Compute the [CLS] token activations
test_activations = model_with_split_points.get_activations(
    inputs=test_inputs,
    activation_granularity=granularity,
    include_predicted_classes=True,
)

7.1 🌐 Evaluate the concept-space from the third part

> ⚠️ Warning: > > These metrics should only be used to compare the concept-space trained in similar contexts, same model, split point, activation dataset...

from interpreto.concepts.metrics import FID, MSE

mse = MSE(concept_explainer).compute(test_activations)
fid = FID(concept_explainer).compute(test_activations)

print(f"MSE: {round(mse, 3)}, FID: {round(fid, 3)}")
MSE: 165.919, FID: 0.041

> ➑️ Note > > Alone these values are useless, they should be compared between several concept explainers.

from interpreto.concepts.metrics import Sparsity, SparsityRatio

sparsity = Sparsity(concept_explainer).compute(test_activations)
ratio = SparsityRatio(concept_explainer).compute(test_activations)

print(f"Sparsity: {round(sparsity, 3)}, Sparsity ratio: {round(ratio, 3)}")
Sparsity: 1.0, Sparsity ratio: 0.02

Dictionary metrics

The Stability metric requires two concept explainer, hence our first step step will be to train a new ICAConcepts with the same model, split, dataset, and hyper-parameters as the original ICAConcepts. However, to get a statistically robust metric score, one should compare more than just two instances of the same explainer.

from interpreto.concepts.metrics import Stability

# instantiate and train a second concept explainer
second_explainer = ICAConcepts(model_with_split_points, nb_concepts=50, device="cuda")
second_explainer.fit(activations)

stability = Stability(concept_explainer, second_explainer).compute()
del second_explainer

print(f"Stability: {round(stability, 3)}")
Stability: 1.0

7.2 πŸ’­ Evaluate the concepts-interpretations from the fourth step

# Work in progress, coming soon

7.3 ↔️ Evaluate the whole concept-based explanations with ConSim

ConSim is a metric evaluating the whole concept-based explanations in an end-to-end manner. Indeed, this metric, evaluates to which extend the provided concept-based explanations help a meta-predictor to predict what the studied model would have predicted. The idea is that is a meta-predictor understands the model, it is able to predict what the model would have predicted on new samples.

> ➑️ Note > > For significant scores, we iterate on 10 different seeds. (5 were used in the paper).

from interpreto.concepts.metrics.consim import ConSim, PromptTypes

# convert step 5 global concept importances to a dictionary
global_importances = {
    class_name: dict(enumerate(importances))
    for class_name, importances in zip(classes_names, mean_gradients, strict=True)
}

# Load API key from environment variable
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
    raise ValueError(
        "An API key is required to use `OpenAILLM` interface. ",
        "Cannot use LLMLabels without an LLM interface. ",
        "See last section for an example of how to branch one.",
    )

# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=api_key, model="gpt-4.1-nano")

# Initialize the ConSim  with the model with split points and the user-llm
# Therefore, a given ConSim metric can be used on different explainers for cleaner comparison
con_sim = ConSim(model_with_split_points, llm_interface, classes=classes_names, activation_granularity=granularity)

baseline_list = []
ica_score_list = []
for seed in range(10):
    # Select examples for evaluation
    samples, labels, predictions = con_sim.select_examples(
        inputs=dataset["train"]["text"][:5000],
        labels=torch.tensor(dataset["train"]["label"][:5000]).cuda(),
        seed=seed,
    )

    # Compute a baseline and ConSim score to give sense to the explainer ConSim score
    baseline = con_sim.evaluate(
        interesting_samples=samples, predictions=predictions, prompt_type=PromptTypes.L2_baseline_with_lp
    )

    if baseline is None:
        continue

    # Compute the ConSim score for an explainer
    ica_score = con_sim.evaluate(
        interesting_samples=samples,
        predictions=predictions,
        concept_explainer=concept_explainer,
        concepts_interpretation=topk_words,
        global_importances=global_importances,
        prompt_type=PromptTypes.E2_global_concepts_with_lp,
    )

    if ica_score is None:
        continue

    baseline_list.append(baseline)
    ica_score_list.append(ica_score)

print(f"Baseline: {round(sum(baseline_list) / 10, 2)}, ICA: {round(sum(ica_score_list) / 10, 2)}")
Baseline: 0.29, ICA: 0.28

> ➑️ Note > > We evaluated the first ICA explainer here. Concepts where not really interpretable, ConSim agrees.

8. Using your own LLM interface

from interpreto.model_wrapping.llm_interface import LLMInterface, Role


class GeminiLLM(LLMInterface):
    def __init__(self, api_key: str, model: str = "gemini-1.5-flash", num_try: int = 5):
        try:
            import google.generativeai as genai  # noqa: PLC0415  # ruff: disable=import-outside-toplevel
        except ImportError as e:
            raise ImportError("Install google-generativeai to use Google Gemini API.") from e

        self.genai = genai
        self.genai.configure(api_key=api_key)
        self.model = model
        self.num_try = num_try

    def generate(self, prompt: list[tuple[Role, str]]) -> str | None:
        # Build system instruction and chat history for Gemini
        system_messages: list[str] = []
        contents: list[dict] = []

        for role, content in prompt:
            if role == Role.SYSTEM:
                system_messages.append(content)
            elif role == Role.USER:
                contents.append(
                    {
                        "role": "user",
                        "parts": [{"text": content}],
                    }
                )
            elif role == Role.ASSISTANT:
                contents.append(
                    {
                        "role": "model",
                        "parts": [{"text": content}],
                    }
                )
            else:
                raise ValueError(f"Unknown role for google gemini api: {role}")

        system_instruction: str | None = "\n".join(system_messages) if system_messages else None

        label: str | None = None
        for _ in range(self.num_try):
            try:
                model = self.genai.GenerativeModel(
                    model_name=self.model,
                    system_instruction=system_instruction,
                )
                response = model.generate_content(contents)  # type: ignore[arg-type]
                # google-generativeai exposes the main text as .text
                label = response.text
                break
            except Exception as e:  # noqa: BLE001
                print(e)
        return label