
Generation Demonstration¶
This notebook show case what can be done with interpreto for generation.
author: Antonin Poché
0. Imports and model loading¶
import os
import datasets
import transformers
import interpreto
from interpreto import KernelShap, plot_attributions
from interpreto.attributions.metrics import Deletion
from interpreto.concepts import LLMLabels, SemiNMFConcepts
from interpreto.concepts.metrics import MSE, Sparsity
from interpreto.model_wrapping.llm_interface import OpenAILLM
# load dataset examples
dataset = datasets.load_dataset("dair-ai/emotion", "split")["train"]["text"][:1000]
# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
model = transformers.AutoModelForCausalLM.from_pretrained("gpt2").cuda()
1. Attribution Demonstration¶
1.1 Obtain attribution¶
# instantiate Lime
attribution_explainer = KernelShap(model, tokenizer)
# compute attributions
attributions = attribution_explainer(
model_inputs="Alice and Bob enter the bar, ",
targets="then Alice offers a drink to Bob.",
)
# visualize attributions
plot_attributions(attributions[0])
1.2 Evaluate attributions¶
# use the same model as for the explainer
metric = Deletion(model, tokenizer)
# compute scores on attributions
auc, detailed_scores = metric.evaluate(attributions)
print(f"Deletion AUC: {round(auc, 3)} (higher is better)")
2. Concepts Demonstration¶
2.1 Split model and get a dataset of activations¶
# split the model
split_model = interpreto.ModelWithSplitPoints(model, tokenizer=tokenizer, split_points=[8])
# compute the token activations
granularity = split_model.activation_granularities.TOKEN
activations = split_model.get_activations(
inputs=dataset,
activation_granularity=granularity,
)
2.2 Learn concepts as patterns in the activations¶
# create an explainer around the split model
concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20)
# train the concept model of the explainer
concept_explainer.fit(activations)
2.3 Interpret the concepts¶
llm_interface = OpenAILLM(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4.1-nano")
# instantiate the interpretation method with the concept explainer
interpretation_method = LLMLabels(
concept_explainer=concept_explainer,
activation_granularity=granularity,
llm_interface=llm_interface,
k_examples=20,
)
# interpret the concepts via top-k words
llm_labels = interpretation_method.interpret(
inputs=dataset,
latent_activations=activations,
concepts_indices="all",
)
print("\n".join(llm_labels.values()))
2.4 Evaluate concepts¶
test_activations = activations # in practice, these should be different
mse = MSE(concept_explainer).compute(test_activations)
sparsity = Sparsity(concept_explainer).compute(test_activations)
print(f"MSE: {round(mse, 3)}, Sparsity: {round(sparsity, 3)}")