Probes (Supervised)¶
Probes are supervised concept methods: they require labeled data (concept annotations) to learn a mapping from activations to concept scores. This contrasts with the unsupervised Concept Spaces (ICA, NMF, SAEs, …) which discover concepts from unlabeled activations alone.
The probe workflow uses the same ModelWithSplitPoints activation extraction as unsupervised
methods, but the fit step requires both activations and binary concept labels.
Usage Guide¶
Classification Model (SplitterForClassification)¶
The simplest setup for classification models. SplitterForClassification automatically
detects the classification head and uses CLS-token activations.
from interpreto import SplitterForClassification
from interpreto.concepts import ProbeExplainer
from interpreto.concepts.probes import LinearRegressionProbe
# 1. Wrap your classification model
model = SplitterForClassification("textattack/bert-base-uncased-imdb")
# 2. Extract CLS-token activations — shape (n, d)
activations, predictions = model.get_activations(texts)
# 3. Instantiate probe and explainer
probe = LinearRegressionProbe()
explainer = ProbeExplainer(model, concept_model=probe)
# 4. Fit on activations + binary concept labels — labels shape (n, c)
explainer.fit(activations, labels)
# 5. Score new inputs
concept_scores = explainer.activations_to_concepts(new_activations)
Generation Model (SplitterForGeneration or ModelWithSplitPoints)¶
For generation models, you must choose how to aggregate the sequence of token-level activations into a fixed-size representation. Two common strategies:
Strategy A: Aggregate to one vector per sample¶
Use activation_granularity=SAMPLE to pool all tokens into one activation vector.
This is appropriate when concepts are global properties of the input (e.g., topic, style).
from interpreto import ModelWithSplitPoints
from interpreto.concepts import ProbeExplainer
from interpreto.concepts.probes import CosineCentroidProbe
model = ModelWithSplitPoints(
"gpt2",
split_point="transformer.h.6",
device_map="cuda",
)
# Aggregate all tokens into one vector per sample — shape (n, d)
activations, _ = model.get_activations(
texts,
activation_granularity=model.activation_granularities.SAMPLE,
aggregation_strategy=model.aggregation_strategies.MEAN, # MAX and LAST are also often compared in the literature
)
probe = CosineCentroidProbe()
explainer = ProbeExplainer(model, concept_model=probe)
explainer.fit(activations, labels)
Strategy B: Per-token activations (flattened)¶
Use activation_granularity=TOKEN to get one activation per token (special tokens
removed, then flattened). This is appropriate when concepts are local properties
(e.g., named-entity type, part-of-speech) and labels are provided per-token.
This is the behavior of SplitterForGeneration, which can be seen as a special case of ModelWithSplitPoints.
# One vector per token, flattened across all samples — shape (n*l, d)
activations, _ = model.get_activations(
texts,
activation_granularity=model.activation_granularities.TOKEN,
)
# labels must also be flattened to match: shape (n*l, c)
probe = LinearRegressionProbe()
explainer = ProbeExplainer(model, concept_model=probe)
explainer.fit(activations, token_labels)
Choosing a granularity
- SAMPLE / CLS_TOKEN: one score per input — good for document-level concepts.
- TOKEN / WORD / SENTENCE: one score per unit — good for local/fine-grained concepts.
Using Normalizations¶
Normalizations standardize or decorrelate activations before the probe sees them.
They are fitted jointly during probe.fit() and applied automatically at encode time.
from interpreto.concepts.probes import (
LinearRegressionProbe,
CosineCentroidProbe,
Standardization,
Whitening,
)
# Zero-mean, unit-variance per feature
probe = LinearRegressionProbe(normalization=Standardization())
# SVD-based whitening (full rank)
probe = CosineCentroidProbe(normalization=Whitening())
# Low-rank whitening — projects to top-128 principal components
probe = CosineCentroidProbe(normalization=Whitening(rank=128))
Using Bias Calibrators¶
Bias calibrators set the additive bias of a probe after fitting the weights/centroids.
They control the decision threshold: a sample is considered positive for concept j
when score_j + bias_j > 0.
from interpreto.concepts.probes import (
LinearRegressionProbe,
DotProductCentroidProbe,
prevalence_bias,
fpr_bias,
midpoint_bias,
bce_bias,
lda_shared_var_bias,
)
# Set threshold based on class prevalence (logit of prior)
probe = DotProductCentroidProbe(bias_calibrator=prevalence_bias)
# Control false-positive rate at 1%
probe = LinearRegressionProbe(bias_calibrator=fpr_bias)
# Midpoint between positive and negative score means
probe = LinearRegressionProbe(bias_calibrator=midpoint_bias)
# Optimize BCE loss on the intercept only
probe = LinearRegressionProbe(bias_calibrator=bce_bias)
# Bayes-optimal threshold assuming Gaussian class-conditionals
probe = DotProductCentroidProbe(bias_calibrator=lda_shared_var_bias)
Combining Normalization + Bias Calibration¶
Both options compose naturally:
from interpreto.concepts.probes import (
CosineCentroidProbe,
Standardization,
fpr_bias,
)
probe = CosineCentroidProbe(
normalization=Standardization(),
bias_calibrator=fpr_bias,
)
# The pipeline at fit/encode time is:
# 1. Standardize activations (fitted during probe.fit)
# 2. Compute cosine similarity to centroids
# 3. Add calibrated bias (threshold at 1% FPR)
ProbeExplainer¶
interpreto.concepts.ProbeExplainer
¶
ProbeExplainer(splitter, concept_model)
Bases: ConceptEncoderExplainer[Probe]
Concept explainer backed by a [Probe][interpreto.concepts.probes.base.Probe].
Integrates any pre-instantiated torch probe into the concept explainer pipeline, connecting it to a [BaseSplitter][interpreto.concepts.splitters.base_splitter.BaseSplitter] for activation extraction.
The probe is provided already instantiated (unfitted or pre-fitted).
Calling fit
delegates to the probe's own fit method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
BaseSplitter
|
Wrapped transformer model. |
required |
|
Probe
|
An instantiated torch probe. |
required |
|
str | None
|
Layer name to extract activations from. |
required |
Example::
from interpreto.concepts import LinearRegressionProbe, ProbeExplainer
probe = LinearRegressionProbe()
explainer = ProbeExplainer(splitter, probe)
explainer.fit(activations, labels)
concepts = explainer.activations_to_concepts(activations)
Source code in interpreto/concepts/probes/base.py
fit
¶
fit(activations, labels)
Fit the probe on activations and multi-label targets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
LatentActivations
|
Latent activations (2D tensor). |
required |
|
Float[Tensor, 'n c']
|
Binary multi-label targets of shape |
required |
Source code in interpreto/concepts/probes/base.py
activations_to_concepts
¶
activations_to_concepts(activations)
Encode activations into concept scores using the fitted probe.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
LatentActivations
|
Latent activations of shape |
required |
Returns:
| Type | Description |
|---|---|
ConceptsActivations
|
Concept scores of shape |
Source code in interpreto/concepts/probes/base.py
get_inputs_to_concepts_model
¶
Returns a model that maps raw inputs to concept activations.
The model can be passed to an attribution method, to obtain inputs to concepts attributions. Which are ways to interpret the concepts.
Returns:
| Name | Type | Description |
|---|---|---|
ModelForInputsToConcepts |
ModelForInputsToConcepts
|
A model that maps raw inputs to concept activations. |
Source code in interpreto/concepts/base.py
Probe Models¶
All probes follow the Probe interface and can be passed as concept_model
to ProbeExplainer.
Linear Probes¶
interpreto.concepts.probes.LinearRegressionProbe
¶
LinearRegressionProbe(l2=0.0, bias_calibrator=None, normalization=None)
Bases: BaseLinearProbe
Code: concepts/probes/linear.py
Multi-output linear regression probe with intercept.
This probe fits concept scores using ordinary least squares or ridge regression, with an optional unpenalized intercept term1.
Fits the linear model in closed form:
l2 == 0: OLS via pseudo-inverse.l2 > 0: Ridge regression (intercept is not penalized).
-
Hastie, T., Tibshirani, R., Friedman, J., The Elements of Statistical Learning. Springer, 2nd edition, 2009. ↩
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
L2 regularization strength (0 for OLS). |
0.0
|
|
BiasCalibrator | None
|
If provided, overrides the regression intercept with a calibrated bias. |
None
|
|
NormalizationBase | None
|
Optional input normalization. |
None
|
interpreto.concepts.probes.LogisticRegressionProbe
¶
LogisticRegressionProbe(lr=0.01, max_iter=20, l2=0.0, init_from_means_diff=True, init_bias_calibrator=prevalence_bias, normalization=None)
Bases: _GDLinearProbe
Code: concepts/probes/linear.py
Multi-label logistic regression probe (BCE loss, Adam optimizer).
Minimizes binary cross-entropy with logits, optionally with L2 weight regularization. Initialized from MeansDiffProbe by default.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Adam learning rate. |
0.01
|
|
int
|
Number of optimization steps. |
20
|
|
float
|
L2 regularization on weight (bias not penalized). |
0.0
|
|
bool
|
If |
True
|
|
BiasCalibrator | None
|
Bias calibrator for MeansDiff initialization. |
prevalence_bias
|
|
NormalizationBase | None
|
Optional input normalization. |
None
|
interpreto.concepts.probes.LinearSVMProbe
¶
LinearSVMProbe(lr=0.01, max_iter=20, l2=0.0, init_from_means_diff=True, init_bias_calibrator=prevalence_bias, normalization=None)
Bases: _GDLinearProbe
Multi-label linear SVM probe (hinge loss, Adam optimizer).
This is the linear model used in CAV1.
Targets are mapped to {-1, +1} and the loss is the mean of
max(0, 1 - y * logits). Optionally with L2 weight regularization.
Initialized from MeansDiffProbe by default.
-
Kim, B. et al., Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV). Proceedings of the 35th International Conference on Machine Learning, 2018. ↩
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Adam learning rate. |
0.01
|
|
int
|
Number of optimization steps. |
20
|
|
float
|
L2 regularization on weight (bias not penalized). |
0.0
|
|
bool
|
If |
True
|
|
BiasCalibrator | None
|
Bias calibrator for MeansDiff initialization. |
prevalence_bias
|
|
NormalizationBase | None
|
Optional input normalization. |
None
|
interpreto.concepts.probes.MeansDiffProbe
¶
MeansDiffProbe(bias_calibrator=None, normalization=None, eps=1e-08)
Bases: BaseLinearProbe
Code: concepts/probes/linear.py
Means-difference probe (multi-label, multi-output).
For each concept j, the weight vector is the difference between the mean activation of positive and negative samples::
$$w_j = mean(x | y_j=1) - mean(x | y_j=0)$$
This is equivalent to Fisher’s Linear Discriminant with shared identity covariance assumption12.
-
Fisher, R. A., The Use of Multiple Measurements in Taxonomic Problems. Annals of Eugenics, 7(2), 1936, pp. 179-188. ↩
-
Hastie, T., Tibshirani, R., Friedman, J., The Elements of Statistical Learning. Springer, 2nd edition, 2009. ↩
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
BiasCalibrator | None
|
Post-hoc bias calibration.
If |
None
|
|
NormalizationBase | None
|
Optional input normalization. |
None
|
|
float
|
Floor for sample count denominators. |
1e-08
|
Centroid Probes¶
interpreto.concepts.probes.CosineCentroidProbe
¶
CosineCentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08)
Bases: BaseCentroidProbe
Code: concepts/probes/centroid.py
Centroid probe using cosine similarity1.
score_ij = cosine(x_i, centroid_j) + bias_j
Centroids are L2-normalized after fitting, and inputs are normalized before scoring.
-
Manning, C. D., Raghavan, P., Schütze, H., Introduction to Information Retrieval. Cambridge University Press, 2008. ↩
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
NormalizationBase | None
|
Optional input normalization layer. |
None
|
|
BiasCalibrator | None
|
Optional post-hoc bias calibration function. |
None
|
|
float
|
Numerical stability floor. |
1e-08
|
interpreto.concepts.probes.DotProductCentroidProbe
¶
DotProductCentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08)
Bases: BaseCentroidProbe
Code: concepts/probes/centroid.py
Centroid probe using dot-product similarity.
score_ij = x_i · centroid_j + bias_j
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
NormalizationBase | None
|
Optional input normalization layer. |
None
|
|
BiasCalibrator | None
|
Optional post-hoc bias calibration function. |
None
|
|
float
|
Numerical stability floor. |
1e-08
|
interpreto.concepts.probes.SqL2CentroidProbe
¶
SqL2CentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08)
Bases: BaseCentroidProbe
Code: concepts/probes/centroid.py
Centroid probe using negative squared Euclidean distance.
score_ij = -||x_i - centroid_j||² + bias_j
Computed efficiently via the expansion:
dist² = (x·x) + (c·c) - 2(x·c)
Recommended normalization: Standardization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
NormalizationBase | None
|
Optional input normalization layer. |
None
|
|
BiasCalibrator | None
|
Optional post-hoc bias calibration function. |
None
|
|
float
|
Numerical stability floor. |
1e-08
|
interpreto.concepts.probes.SVDDCentroidProbe
¶
Bases: BaseCentroidProbe
Code: concepts/probes/centroid.py
Multi-label SVDD (Support Vector Data Description) probe.
For each concept, this probe fits a hypersphere around positive examples and scores activations by the margin between the learned radius and the squared distance to the learned center, following Support Vector Data Description1.
For each concept j, jointly optimizes
a center a_j and squared radius r²_j on positive samples by minimizing::
L_j = r²_j + C * mean_pos( relu(||x - a_j||² - r²_j) ) + 0.5*l2*||a_j||²
Encoding returns a margin score::
score_ij = r²_j - ||x_i - a_j||²
Positive scores indicate the sample is inside the concept sphere.
-
Tax, D. M. J., Duin, R. P. W., Support Vector Data Description. Machine Learning, 54, 2004, pp. 45-66. ↩
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
float
|
Adam learning rate for center/radius optimization. |
0.05
|
|
int
|
Number of optimization steps. |
20
|
|
float
|
Hinge loss penalty weight on outliers. |
1.0
|
|
float
|
L2 regularization on centroids. |
0.0
|
|
NormalizationBase | None
|
Optional input normalization. |
None
|
|
float
|
Numerical stability floor. |
1e-08
|
interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe
¶
DiagonalMahalanobisCentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08, shrinkage=1.0)
Bases: BaseCentroidProbe
Centroid probe using diagonal Mahalanobis distance.
This probe scores activations by a variance-weighted squared distance to each concept centroid. With pooled variance, this corresponds to a diagonal-covariance form of Gaussian discriminant analysis1.
score_ij = -(x_i - c_j)^T diag(1/var) (x_i - c_j) + bias_j
The variance matrix is controlled by the shrinkage parameter:
shrinkage = 0: class-wise diagonal variance(c, d)from positives only.0 < shrinkage < 1: convex combination of class-wise and pooled variance.shrinkage = 1: pooled diagonal variance(d,)from all samples.
Default normalization is Standardization if none is provided.
-
Bishop, C. M., Pattern Recognition and Machine Learning. Springer, 2006. ↩
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
NormalizationBase | None
|
Input normalization (defaults to Standardization). |
None
|
|
BiasCalibrator | None
|
Post-hoc bias calibration function. |
None
|
|
float
|
Numerical stability floor. |
1e-08
|
|
float
|
Shrinkage coefficient in [0, 1] for the variance estimate. |
1.0
|
Normalizations¶
Normalizations can be composed with any probe to standardize or whiten the input activations before probing.
interpreto.concepts.probes.Standardization
¶
Bases: NormalizationBase
Code: concepts/probes/normalizations.py
Per-feature zero-mean, unit-variance normalization.
z = (x - mean) / std
Attributes:
| Name | Type | Description |
|---|---|---|
mean |
Tensor
|
Feature means from training data, shape |
std |
Tensor
|
Feature standard deviations (clamped to eps), shape |
interpreto.concepts.probes.Whitening
¶
Bases: NormalizationBase
Code: concepts/probes/normalizations.py
SVD-based whitening normalization.
Whitening projects centered activations onto singular-vector directions and rescales them by the inverse singular values, as in PCA whitening1.
z = (X - mean) @ V_r * (sqrt(n) / S_r)
This produces decorrelated, unit-variance features in the rotated space.
-
Murphy, K. P., Machine Learning: A Probabilistic Perspective. MIT Press, 2012. Args: rank (int | None): If
None(default), full whitening (r = min(n, d)). If int, low-rank whitening keeping the top-r singular components. eps (float): Numerical stability floor. ↩
Attributes:
| Name | Type | Description |
|---|---|---|
mean |
Tensor
|
Feature means, shape |
V |
Tensor
|
Right singular vectors, shape |
inv_s |
Tensor
|
Scaling factors |
Bias Calibrators¶
Post-hoc functions to set the bias of a fitted probe based on different criteria.
interpreto.concepts.probes.bce_bias
¶
BCE-optimal bias via L-BFGS.
Fits a per-class intercept b to minimize
BCEWithLogitsLoss(scores + b, y) with scores held fixed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Float[Tensor, 'n c']
|
Raw concept scores (treated as fixed). |
required |
|
Float[Tensor, 'n c']
|
Binary multi-label targets. |
required |
|
int
|
Maximum L-BFGS iterations. |
50
|
|
float
|
Clamping for initial prevalence estimate. |
1e-06
|
Returns:
| Type | Description |
|---|---|
Float[Tensor, c]
|
Per-concept bias. |
interpreto.concepts.probes.fpr_bias
¶
fpr_bias(scores, y, target_fpr=0.01)
False-positive-rate controlled bias.
Sets the threshold at the (1 - target_fpr) quantile of the negative
score distribution, giving approximately target_fpr false positive rate.
threshold_j = quantile_{1 - target_fpr}(scores | y=0)
bias_j = -threshold_j
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Float[Tensor, 'n c']
|
Raw concept scores. |
required |
|
Float[Tensor, 'n c']
|
Binary multi-label targets. |
required |
|
float
|
Desired false positive rate (default 1%). |
0.01
|
Returns:
| Type | Description |
|---|---|
Float[Tensor, c]
|
Per-concept bias. |
interpreto.concepts.probes.prevalence_bias
¶
Prevalence-based bias: bias_j = logit(mean(y_j)).
Sets the decision threshold at the class prior, which is optimal under a uniform score distribution. Ignores the actual scores.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Float[Tensor, 'n c']
|
Raw concept scores (unused, kept for API consistency). |
required |
|
Float[Tensor, 'n c']
|
Binary multi-label targets. |
required |
|
float
|
Clamping value for numerical stability. |
1e-06
|
Returns:
| Type | Description |
|---|---|
Float[Tensor, c]
|
Per-concept bias. |
interpreto.concepts.probes.lda_shared_var_bias
¶
Closed-form 1-D LDA threshold with shared variance and empirical priors.
Computes the Bayes-optimal threshold assuming Gaussian class-conditional distributions with a shared (pooled) variance::
t = 0.5*(mu0 + mu1) + (var / (mu1 - mu0)) * log(pi0 / pi1)
bias = -t
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Float[Tensor, 'n c']
|
Raw concept scores. |
required |
|
Float[Tensor, 'n c']
|
Binary multi-label targets. |
required |
|
float
|
Floor for count and denominator stability. |
1e-12
|
|
float
|
Minimum variance to avoid division by zero. |
1e-06
|
Returns:
| Type | Description |
|---|---|
Float[Tensor, c]
|
Per-concept bias. |
interpreto.concepts.probes.midpoint_bias
¶
Midpoint bias between positive and negative score means.
threshold_j = 0.5 * (mean(score|y=1) + mean(score|y=0))
bias_j = -threshold_j
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Float[Tensor, 'n c']
|
Raw concept scores. |
required |
|
Float[Tensor, 'n c']
|
Binary multi-label targets. |
required |
|
float
|
Floor for count denominators. |
1e-12
|
Returns:
| Type | Description |
|---|---|
Float[Tensor, c]
|
Per-concept bias. |
Using other models with sklearn¶
interpreto.concepts.probes.sklearn.SklearnProbeExplainer
¶
SklearnProbeExplainer(splitter, sklearn_class=SVC, sklearn_kwargs={})
Bases: ConceptEncoderExplainer[SklearnProbe]
Concept explainer using a scikit-learn probe.
Integrates SklearnProbe into the concept explainer pipeline, connecting it to a [BaseSplitter][interpreto.concepts.splitters.base_splitter.BaseSplitter] for activation extraction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
BaseSplitter
|
Wrapped transformer model. |
required |
|
Any
|
Scikit-learn estimator class (default: |
SVC
|
|
dict[str, Any]
|
Arguments forwarded to the sklearn estimator. |
{}
|
Source code in interpreto/concepts/probes/sklearn.py
fit
¶
Fit the concept model.
Source code in interpreto/concepts/probes/sklearn.py
interpreto.concepts.probes.sklearn.SklearnProbe
¶
SklearnProbe(sklearn_class, sklearn_kwargs)
Probe wrapping a scikit-learn classifier with decision_function.
Satisfies [ConceptModelProtocol][interpreto.typing.ConceptModelProtocol] structurally. Currently limited to a single binary concept.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Scikit-learn estimator class (e.g. |
required |
|
dict[str, Any]
|
Keyword arguments forwarded to the estimator constructor. |
required |
Source code in interpreto/concepts/probes/sklearn.py
encode
¶
Encode the given activations using the concept model.