
Classification Demonstration¶
This notebook show case what can be done with interpreto for classification.
author: Antonin Poché
0. Imports and model loading¶
import os
import datasets
import torch
import transformers
import interpreto
from interpreto import Lime, plot_attributions, plot_concepts
from interpreto.attributions.metrics import Insertion
from interpreto.concepts import LLMLabels, SemiNMFConcepts
from interpreto.concepts.interpretations import TopKInputs
from interpreto.concepts.metrics import FID, SparsityRatio
from interpreto.model_wrapping.llm_interface import OpenAILLM
# load dataset examples
dataset = datasets.load_dataset("dair-ai/emotion", "split")["train"]["text"][:1000]
classes_names = ["sadness", "joy", "love", "anger", "fear", "surprise"]
# class_names={0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("nateraw/bert-base-uncased-emotion", use_fast=True)
model = transformers.AutoModelForSequenceClassification.from_pretrained("nateraw/bert-base-uncased-emotion").cuda()
1. Attribution Demonstration¶
1.1 Obtain the attributions¶
# instantiate Lime
attribution_explainer = Lime(model, tokenizer)
# compute attributions
attributions = attribution_explainer(
model_inputs="Love and hate are two sides of the same coin.",
targets=torch.tensor([[0, 1, 2, 3, 4, 5]]),
)
# visualize attributions
plot_attributions(attributions[0], classes_names=classes_names)
1.2 Evaluate the attributions¶
# use the same model as for the explainer
metric = Insertion(model, tokenizer)
# compute scores on attributions
auc, detailed_scores = metric.evaluate(attributions)
print(f"Insertion AUC: {round(auc, 3)} (lower is better)")
2. Concepts Demonstration¶
2.1 Split model and get a dataset of activations¶
# split the model
split_model = interpreto.ModelWithSplitPoints(model, tokenizer=tokenizer, split_points=[11])
# compute the [CLS] token activations
granularity = split_model.activation_granularities.CLS_TOKEN
activations = split_model.get_activations(
inputs=dataset,
activation_granularity=granularity,
include_predicted_classes=True,
)
2.2 Learn concepts as patterns in the activations¶
# create an explainer around the split model
concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20)
# train the concept model of the explainer
concept_explainer.fit(activations)
2.3 Interpret the concepts¶
# instantiate the interpretation method with the concept explainer
interpretation_method = TopKInputs(
concept_explainer=concept_explainer,
k=5,
activation_granularity=granularity,
use_unique_words=True,
unique_words_kwargs={"count_min_threshold": 20, "lemmatize": True},
)
# interpret the concepts via top-k words
topk_words = interpretation_method.interpret(
inputs=dataset,
concepts_indices="all",
)
2.4 Estimate concepts importance to classes¶
# estimate the importance of concepts for each class using the gradient
gradients = concept_explainer.concept_output_gradient(
inputs=dataset,
activation_granularity=granularity,
batch_size=32,
)
# stack gradients on samples and average them over samples
mean_gradients = torch.stack(gradients).abs().squeeze().mean(0) # (num_classes, num_concepts)
# for each class, sort the importance scores
order = torch.argsort(mean_gradients, descending=True)
2.5 Visualize the most important concepts for each class¶
# visualize the top 5 concepts for each class
for target in range(order.shape[0]):
print(f"\nClass: {classes_names[target]}:")
for i in range(5):
concept_id = order[target, i].item()
importance = mean_gradients[target, concept_id].item()
words = list(topk_words.get(concept_id, None).keys())
print(f"\tconcept id: {concept_id},\timportance: {round(importance, 3)},\ttopk words: {words}")
plot_concepts(
classes_names=classes_names,
concepts_importances=mean_gradients,
concepts_labels={k: list(v.keys()) for k, v in topk_words.items()},
)
> Tip: The above visualization is interactive, you should click on the classes.
2.6 Evaluate concepts¶
test_activations = activations # in practice, these should be different
fid = FID(concept_explainer).compute(test_activations)
ratio = SparsityRatio(concept_explainer).compute(test_activations)
print(f"FID: {round(fid, 3)}, Sparsity ratio: {round(ratio, 3)}")
2.7 Class-wise and LLM-labels¶
# set the LLM interface used to generate labels based on the constructed prompts
llm_interface = OpenAILLM(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4.1-nano")
# iterate over classes
concepts_interpretations = {}
concepts_importances = {}
for target, class_name in enumerate(classes_names):
# ----------------------------------------------------------------------------------------------
# 1. construct the dataset of activations (extract the ones related to the class)
indices = (activations["predictions"] == target).nonzero(as_tuple=True)[0]
class_wise_inputs = [dataset[i] for i in indices]
class_wise_activations = {k: v[indices] for k, v in activations.items()}
# ----------------------------------------------------------------------------------------------
# 2. train concept model
concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20, device="cuda")
concept_explainer.fit(class_wise_activations)
# ----------------------------------------------------------------------------------------------
# 4. compute concepts importance (before interpretations to limit the number of concepts interpreted)
gradients = concept_explainer.concept_output_gradient(
inputs=class_wise_inputs,
targets=[target],
activation_granularity=granularity,
concepts_x_gradients=True,
batch_size=32,
)
# stack gradients on samples and average them over samples
concepts_importances[target] = torch.stack(gradients, axis=0).squeeze().abs().mean(dim=0) # (num_concepts,)
# for each class, sort the importance scores
important_concept_indices = torch.argsort(concepts_importances[target], descending=True).tolist()
# ----------------------------------------------------------------------------------------------
# 3. interpret the important concepts concepts
llm_labels_method = LLMLabels(
concept_explainer=concept_explainer,
activation_granularity=granularity,
llm_interface=llm_interface,
k_examples=20,
)
concepts_interpretations[target] = llm_labels_method.interpret(
inputs=class_wise_inputs,
concepts_indices=important_concept_indices, # only the top 5 concepts for this class
)
print(f"\nClass: {class_name}")
for concept_id in important_concept_indices[:5]:
label = concepts_interpretations[target].get(concept_id, None)
importance = concepts_importances[target][concept_id].item()
if label is not None:
print(f"\timportance: {round(importance, 3)},\t{label}")
plot_concepts(
classes_names=classes_names,
concepts_importances=concepts_importances,
concepts_labels=concepts_interpretations,
)