Skip to content

interpreto_banner

Classification Demonstration

This notebook show case what can be done with interpreto for classification.

author: Antonin Poché

0. Imports and model loading

import os

import datasets
import torch
import transformers

import interpreto
from interpreto import Lime, plot_attributions, plot_concepts
from interpreto.attributions.metrics import Insertion
from interpreto.concepts import LLMLabels, SemiNMFConcepts
from interpreto.concepts.interpretations import TopKInputs
from interpreto.concepts.metrics import FID, SparsityRatio
from interpreto.model_wrapping.llm_interface import OpenAILLM
# load dataset examples
dataset = datasets.load_dataset("dair-ai/emotion", "split")["train"]["text"][:1000]

classes_names = ["sadness", "joy", "love", "anger", "fear", "surprise"]
# class_names={0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("nateraw/bert-base-uncased-emotion", use_fast=True)
model = transformers.AutoModelForSequenceClassification.from_pretrained("nateraw/bert-base-uncased-emotion").cuda()

1. Attribution Demonstration

1.1 Obtain the attributions

# instantiate Lime
attribution_explainer = Lime(model, tokenizer)

# compute attributions
attributions = attribution_explainer(
    model_inputs="Love and hate are two sides of the same coin.",
    targets=torch.tensor([[0, 1, 2, 3, 4, 5]]),
)

# visualize attributions
plot_attributions(attributions[0], classes_names=classes_names)
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`

Classes

Inputs

1.2 Evaluate the attributions

# use the same model as for the explainer
metric = Insertion(model, tokenizer)

# compute scores on attributions
auc, detailed_scores = metric.evaluate(attributions)

print(f"Insertion AUC: {round(auc, 3)} (lower is better)")
Insertion AUC: 0.317 (lower is better)

2. Concepts Demonstration

2.1 Split model and get a dataset of activations

# split the model
split_model = interpreto.ModelWithSplitPoints(model, tokenizer=tokenizer, split_points=[11])

# compute the [CLS] token activations
granularity = split_model.activation_granularities.CLS_TOKEN
activations = split_model.get_activations(
    inputs=dataset,
    activation_granularity=granularity,
    include_predicted_classes=True,
)

2.2 Learn concepts as patterns in the activations

# create an explainer around the split model
concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20)

# train the concept model of the explainer
concept_explainer.fit(activations)

2.3 Interpret the concepts

# instantiate the interpretation method with the concept explainer
interpretation_method = TopKInputs(
    concept_explainer=concept_explainer,
    k=5,
    activation_granularity=granularity,
    use_unique_words=True,
    unique_words_kwargs={"count_min_threshold": 20, "lemmatize": True},
)

# interpret the concepts via top-k words
topk_words = interpretation_method.interpret(
    inputs=dataset,
    concepts_indices="all",
)

2.4 Estimate concepts importance to classes

# estimate the importance of concepts for each class using the gradient
gradients = concept_explainer.concept_output_gradient(
    inputs=dataset,
    activation_granularity=granularity,
    batch_size=32,
)

# stack gradients on samples and average them over samples
mean_gradients = torch.stack(gradients).abs().squeeze().mean(0)  # (num_classes, num_concepts)

# for each class, sort the importance scores
order = torch.argsort(mean_gradients, descending=True)

2.5 Visualize the most important concepts for each class

# visualize the top 5 concepts for each class
for target in range(order.shape[0]):
    print(f"\nClass: {classes_names[target]}:")
    for i in range(5):
        concept_id = order[target, i].item()
        importance = mean_gradients[target, concept_id].item()
        words = list(topk_words.get(concept_id, None).keys())
        print(f"\tconcept id: {concept_id},\timportance: {round(importance, 3)},\ttopk words: {words}")

Class: sadness:
    concept id: 5,  importance: 0.121,  topk words: ['feel', 'feeling', 'want', 'it', 'how']
    concept id: 10, importance: 0.095,  topk words: ['love', 'do', 'pretty', 'again', 'for']
    concept id: 8,  importance: 0.092,  topk words: ['need', 'feeling', 'how', 'want', 'be']
    concept id: 3,  importance: 0.085,  topk words: ['you', 'little', 'because', 'really', 'been']
    concept id: 6,  importance: 0.084,  topk words: ['feel', 'feeling', 'right', 'one', 'more']

Class: joy:
    concept id: 4,  importance: 0.104,  topk words: ['will', 'help', 'day', 'little', 'need']
    concept id: 14, importance: 0.1,    topk words: ['pretty', 'going', 'really', 'there', 'where']
    concept id: 0,  importance: 0.08,   topk words: ['still', 'around', 'way', 'own', 'always']
    concept id: 18, importance: 0.079,  topk words: ['what', 'when', 'time', 'day', 'year']
    concept id: 9,  importance: 0.076,  topk words: ['really', 'm', 'it', 'she', 'other']

Class: love:
    concept id: 11, importance: 0.147,  topk words: ['feeling', 'love', 'feel', 'been', 'about']
    concept id: 6,  importance: 0.104,  topk words: ['feel', 'feeling', 'right', 'one', 'more']
    concept id: 5,  importance: 0.099,  topk words: ['feel', 'feeling', 'want', 'it', 'how']
    concept id: 19, importance: 0.087,  topk words: ['feeling', 'what', 'how', 'little', 'myself']
    concept id: 7,  importance: 0.072,  topk words: ['could', 'feeling', 'no', 'think', 'been']

Class: anger:
    concept id: 13, importance: 0.115,  topk words: ['good', 'you', 'more', 'no', 'help']
    concept id: 11, importance: 0.091,  topk words: ['feeling', 'love', 'feel', 'been', 'about']
    concept id: 15, importance: 0.084,  topk words: ['pretty', 'our', 'we', 'being', 'very']
    concept id: 5,  importance: 0.08,   topk words: ['feel', 'feeling', 'want', 'it', 'how']
    concept id: 2,  importance: 0.074,  topk words: ['really', 'just', 'm', 'for', 'pretty']

Class: fear:
    concept id: 17, importance: 0.113,  topk words: ['for', 'pretty', 'feel', 'your', 'feeling']
    concept id: 9,  importance: 0.105,  topk words: ['really', 'm', 'it', 'she', 'other']
    concept id: 3,  importance: 0.092,  topk words: ['you', 'little', 'because', 'really', 'been']
    concept id: 13, importance: 0.086,  topk words: ['good', 'you', 'more', 'no', 'help']
    concept id: 18, importance: 0.084,  topk words: ['what', 'when', 'time', 'day', 'year']

Class: surprise:
    concept id: 1,  importance: 0.169,  topk words: ['how', 'about', 'being', 'always', 'some']
    concept id: 0,  importance: 0.113,  topk words: ['still', 'around', 'way', 'own', 'always']
    concept id: 17, importance: 0.105,  topk words: ['for', 'pretty', 'feel', 'your', 'feeling']
    concept id: 13, importance: 0.098,  topk words: ['good', 'you', 'more', 'no', 'help']
    concept id: 4,  importance: 0.095,  topk words: ['will', 'help', 'day', 'little', 'need']

plot_concepts(
    classes_names=classes_names,
    concepts_importances=mean_gradients,
    concepts_labels={k: list(v.keys()) for k, v in topk_words.items()},
)

Classes

> Tip: The above visualization is interactive, you should click on the classes.

2.6 Evaluate concepts

test_activations = activations  # in practice, these should be different

fid = FID(concept_explainer).compute(test_activations)
ratio = SparsityRatio(concept_explainer).compute(test_activations)

print(f"FID: {round(fid, 3)}, Sparsity ratio: {round(ratio, 3)}")
FID: 0.04, Sparsity ratio: 0.05

2.7 Class-wise and LLM-labels

# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4.1-nano")

# iterate over classes
concepts_interpretations = {}
concepts_importances = {}
for target, class_name in enumerate(classes_names):
    # ----------------------------------------------------------------------------------------------
    # 1. construct the dataset of activations (extract the ones related to the class)
    indices = (activations["predictions"] == target).nonzero(as_tuple=True)[0]
    class_wise_inputs = [dataset[i] for i in indices]
    class_wise_activations = {k: v[indices] for k, v in activations.items()}

    # ----------------------------------------------------------------------------------------------
    # 2. train concept model
    concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20, device="cuda")
    concept_explainer.fit(class_wise_activations)

    # ----------------------------------------------------------------------------------------------
    # 4. compute concepts importance (before interpretations to limit the number of concepts interpreted)
    gradients = concept_explainer.concept_output_gradient(
        inputs=class_wise_inputs,
        targets=[target],
        activation_granularity=granularity,
        concepts_x_gradients=True,
        batch_size=32,
    )

    # stack gradients on samples and average them over samples
    concepts_importances[target] = torch.stack(gradients, axis=0).squeeze().abs().mean(dim=0)  # (num_concepts,)

    # for each class, sort the importance scores
    important_concept_indices = torch.argsort(concepts_importances[target], descending=True).tolist()

    # ----------------------------------------------------------------------------------------------
    # 3. interpret the important concepts concepts
    llm_labels_method = LLMLabels(
        concept_explainer=concept_explainer,
        activation_granularity=granularity,
        llm_interface=llm_interface,
        k_examples=20,
    )

    concepts_interpretations[target] = llm_labels_method.interpret(
        inputs=class_wise_inputs,
        concepts_indices=important_concept_indices,  # only the top 5 concepts for this class
    )

    print(f"\nClass: {class_name}")
    for concept_id in important_concept_indices[:5]:
        label = concepts_interpretations[target].get(concept_id, None)
        importance = concepts_importances[target][concept_id].item()
        if label is not None:
            print(f"\timportance: {round(importance, 3)},\t{label}")

Class: sadness
    importance: 0.089,  Expressions of emotional and physical states related to fatigue and mood.
    importance: 0.085,  Expression of personal emotional and physical states.
    importance: 0.081,  Confessions of social and personal inadequacy.
    importance: 0.081,  Expressing pervasive emotional distress and helplessness.
    importance: 0.071,  Expressions of emotional detachment and internal conflict.

Class: joy
    importance: 0.092,  Expressions of subjective emotional states focused on confidence, reassurance, and self-perception.
    importance: 0.084,  Personal feelings of spirituality, vitality, and positivity.
    importance: 0.074,  Expressing safety and comfort through emotional reassurance
    importance: 0.073,  Expressive first-person emotional references
    importance: 0.069,  Explicit expression of personal sensation or emotional state.

Class: love
    importance: 0.108,  Expressive, tender focus on emotional and physical intimacy
    importance: 0.099,  Expressive first-person declarations of strong or shifting emotions
    importance: 0.096,  Emotional expressions linked to subtle sensations
    importance: 0.086,  Subjective sensing and emotional states linked to physical sensations and nostalgia.
    importance: 0.085,  Focus on subjective expressions of loyalty, tenderness, and romantic feelings.

Class: anger
    importance: 0.124,  Expressive negative emotion words and variations of "feeling" linked with mild to strong anger or frustration.
    importance: 0.103,  Personal emotional expressions centered on conflict, pain, and contradictory feelings.
    importance: 0.081,  Subjective "I" statements expressing personal feelings.
    importance: 0.077,  Expressions of intense emotion and physical sensation.
    importance: 0.071,  Self-directed emotional reflection and judgment.

Class: fear
    importance: 0.098,  Expressive emotional descriptors with emphasis on vulnerability and agitation.
    importance: 0.09,   Uncertainty and discomfort in social and emotional context
    importance: 0.078,  Vivid emotional projection and self-reference
    importance: 0.078,  Subjective, emotional intensity cues
    importance: 0.076,  Expressive emotional states.

Class: surprise
    importance: 0.119,  Expressive first-person emotional and physical states.
    importance: 0.119,  Expressing personal emotional and perceptual states.
    importance: 0.111,  Expressive self-awareness and emotional states
    importance: 0.089,  Expressive emotional awareness
    importance: 0.067,  Expressions of personal emotion and subjective experience.

plot_concepts(
    classes_names=classes_names,
    concepts_importances=concepts_importances,
    concepts_labels=concepts_interpretations,
)

Classes