Skip to content

Probes (Supervised)

Probes are supervised concept methods: they require labeled data (concept annotations) to learn a mapping from activations to concept scores. This contrasts with the unsupervised Concept Spaces (ICA, NMF, SAEs, …) which discover concepts from unlabeled activations alone.

The probe workflow uses the same ModelWithSplitPoints activation extraction as unsupervised methods, but the fit step requires both activations and binary concept labels.

Usage Guide

Classification Model (SplitterForClassification)

The simplest setup for classification models. SplitterForClassification automatically detects the classification head and uses CLS-token activations.

from interpreto import SplitterForClassification
from interpreto.concepts import ProbeExplainer
from interpreto.concepts.probes import LinearRegressionProbe

# 1. Wrap your classification model
model = SplitterForClassification("textattack/bert-base-uncased-imdb")

# 2. Extract CLS-token activations — shape (n, d)
activations, predictions = model.get_activations(texts)

# 3. Instantiate probe and explainer
probe = LinearRegressionProbe()
explainer = ProbeExplainer(model, concept_model=probe)

# 4. Fit on activations + binary concept labels — labels shape (n, c)
explainer.fit(activations, labels)

# 5. Score new inputs
concept_scores = explainer.activations_to_concepts(new_activations)

Generation Model (SplitterForGeneration or ModelWithSplitPoints)

For generation models, you must choose how to aggregate the sequence of token-level activations into a fixed-size representation. Two common strategies:

Strategy A: Aggregate to one vector per sample

Use activation_granularity=SAMPLE to pool all tokens into one activation vector. This is appropriate when concepts are global properties of the input (e.g., topic, style).

from interpreto import ModelWithSplitPoints
from interpreto.concepts import ProbeExplainer
from interpreto.concepts.probes import CosineCentroidProbe

model = ModelWithSplitPoints(
    "gpt2",
    split_point="transformer.h.6",
    device_map="cuda",
)

# Aggregate all tokens into one vector per sample — shape (n, d)
activations, _ = model.get_activations(
    texts,
    activation_granularity=model.activation_granularities.SAMPLE,
    aggregation_strategy=model.aggregation_strategies.MEAN, # MAX and LAST are also often compared in the literature
)

probe = CosineCentroidProbe()
explainer = ProbeExplainer(model, concept_model=probe)
explainer.fit(activations, labels)

Strategy B: Per-token activations (flattened)

Use activation_granularity=TOKEN to get one activation per token (special tokens removed, then flattened). This is appropriate when concepts are local properties (e.g., named-entity type, part-of-speech) and labels are provided per-token.

This is the behavior of SplitterForGeneration, which can be seen as a special case of ModelWithSplitPoints.

# One vector per token, flattened across all samples — shape (n*l, d)
activations, _ = model.get_activations(
    texts,
    activation_granularity=model.activation_granularities.TOKEN,
)

# labels must also be flattened to match: shape (n*l, c)
probe = LinearRegressionProbe()
explainer = ProbeExplainer(model, concept_model=probe)
explainer.fit(activations, token_labels)

Choosing a granularity

  • SAMPLE / CLS_TOKEN: one score per input — good for document-level concepts.
  • TOKEN / WORD / SENTENCE: one score per unit — good for local/fine-grained concepts.

Using Normalizations

Normalizations standardize or decorrelate activations before the probe sees them. They are fitted jointly during probe.fit() and applied automatically at encode time.

from interpreto.concepts.probes import (
    LinearRegressionProbe,
    CosineCentroidProbe,
    Standardization,
    Whitening,
)

# Zero-mean, unit-variance per feature
probe = LinearRegressionProbe(normalization=Standardization())

# SVD-based whitening (full rank)
probe = CosineCentroidProbe(normalization=Whitening())

# Low-rank whitening — projects to top-128 principal components
probe = CosineCentroidProbe(normalization=Whitening(rank=128))

Using Bias Calibrators

Bias calibrators set the additive bias of a probe after fitting the weights/centroids. They control the decision threshold: a sample is considered positive for concept j when score_j + bias_j > 0.

from interpreto.concepts.probes import (
    LinearRegressionProbe,
    DotProductCentroidProbe,
    prevalence_bias,
    fpr_bias,
    midpoint_bias,
    bce_bias,
    lda_shared_var_bias,
)

# Set threshold based on class prevalence (logit of prior)
probe = DotProductCentroidProbe(bias_calibrator=prevalence_bias)

# Control false-positive rate at 1%
probe = LinearRegressionProbe(bias_calibrator=fpr_bias)

# Midpoint between positive and negative score means
probe = LinearRegressionProbe(bias_calibrator=midpoint_bias)

# Optimize BCE loss on the intercept only
probe = LinearRegressionProbe(bias_calibrator=bce_bias)

# Bayes-optimal threshold assuming Gaussian class-conditionals
probe = DotProductCentroidProbe(bias_calibrator=lda_shared_var_bias)

Combining Normalization + Bias Calibration

Both options compose naturally:

from interpreto.concepts.probes import (
    CosineCentroidProbe,
    Standardization,
    fpr_bias,
)

probe = CosineCentroidProbe(
    normalization=Standardization(),
    bias_calibrator=fpr_bias,
)
# The pipeline at fit/encode time is:
#   1. Standardize activations (fitted during probe.fit)
#   2. Compute cosine similarity to centroids
#   3. Add calibrated bias (threshold at 1% FPR)

ProbeExplainer

interpreto.concepts.ProbeExplainer

ProbeExplainer(splitter, concept_model)

Bases: ConceptEncoderExplainer[Probe]

Concept explainer backed by a [Probe][interpreto.concepts.probes.base.Probe].

Integrates any pre-instantiated torch probe into the concept explainer pipeline, connecting it to a [BaseSplitter][interpreto.concepts.splitters.base_splitter.BaseSplitter] for activation extraction.

The probe is provided already instantiated (unfitted or pre-fitted). Calling fit delegates to the probe's own fit method.

Parameters:

Name Type Description Default

splitter

BaseSplitter

Wrapped transformer model.

required

concept_model

Probe

An instantiated torch probe.

required

split_point

str | None

Layer name to extract activations from.

required

Example::

from interpreto.concepts import LinearRegressionProbe, ProbeExplainer

probe = LinearRegressionProbe()
explainer = ProbeExplainer(splitter, probe)
explainer.fit(activations, labels)
concepts = explainer.activations_to_concepts(activations)
Source code in interpreto/concepts/probes/base.py
def __init__(
    self,
    splitter: BaseSplitter,
    concept_model: Probe,
):
    if not isinstance(concept_model, Probe):
        raise TypeError(f"concept_model must be a Probe instance, got {type(concept_model).__name__}.")
    super().__init__(
        splitter=splitter,
        concept_model=concept_model,
    )

fit

Fit the probe on activations and multi-label targets.

Parameters:

Name Type Description Default

activations

LatentActivations

Latent activations (2D tensor).

required

labels

Float[Tensor, 'n c']

Binary multi-label targets of shape (n, c). This should be a matrix can be an extended vector (n, 1) for a single concept. However, we allow and recommend to train several probes simultaneously.

required
Source code in interpreto/concepts/probes/base.py
def fit(
    self,
    activations: LatentActivations,
    labels: Float[torch.Tensor, "n c"],
):
    """Fit the probe on activations and multi-label targets.

    Args:
        activations: Latent activations (2D tensor).
        labels: Binary multi-label targets of shape `(n, c)`.
            This should be a matrix can be an extended vector `(n, 1)` for a single concept.
            However, we allow and recommend to train several probes simultaneously.
    """
    if len(activations.shape) != 2:
        raise ValueError(f"Expected activations to be a 2D array, (n, d), got shape {activations.shape}")
    if activations.shape[0] != labels.shape[0]:
        raise ValueError(
            "Activations and labels must have the same number of samples, "
            f"got {activations.shape[0]} and {labels.shape[0]}."
        )

    self.concept_model.fit(activations, labels)

activations_to_concepts

activations_to_concepts(activations)

Encode activations into concept scores using the fitted probe.

Parameters:

Name Type Description Default

activations

LatentActivations

Latent activations of shape (n, d).

required

Returns:

Type Description
ConceptsActivations

Concept scores of shape (n, c).

Source code in interpreto/concepts/probes/base.py
@check_fitted
def activations_to_concepts(self, activations: LatentActivations) -> ConceptsActivations:
    """Encode activations into concept scores using the fitted probe.

    Args:
        activations: Latent activations of shape `(n, d)`.

    Returns:
        Concept scores of shape `(n, c)`.
    """
    # Use the _fitted_flag buffer (always present) to infer the probe's device.
    probe_device = self.concept_model._fitted_flag.device  # type: ignore
    if activations.device != probe_device:
        activations = activations.to(probe_device)  # type: ignore
    return self.concept_model.encode(activations)

get_inputs_to_concepts_model

get_inputs_to_concepts_model()

Returns a model that maps raw inputs to concept activations.

The model can be passed to an attribution method, to obtain inputs to concepts attributions. Which are ways to interpret the concepts.

Returns:

Name Type Description
ModelForInputsToConcepts ModelForInputsToConcepts

A model that maps raw inputs to concept activations.

Source code in interpreto/concepts/base.py
def get_inputs_to_concepts_model(self) -> ModelForInputsToConcepts:
    """Returns a model that maps raw inputs to concept activations.

    The model can be passed to an attribution method,
    to obtain inputs to concepts attributions.
    Which are ways to interpret the concepts.

    Returns:
        ModelForInputsToConcepts: A model that maps raw inputs to concept activations.
    """
    return ModelForInputsToConcepts(self)

Probe Models

All probes follow the Probe interface and can be passed as concept_model to ProbeExplainer.

Linear Probes

interpreto.concepts.probes.LinearRegressionProbe

LinearRegressionProbe(l2=0.0, bias_calibrator=None, normalization=None)

Bases: BaseLinearProbe

Code: concepts/probes/linear.py

Multi-output linear regression probe with intercept.

This probe fits concept scores using ordinary least squares or ridge regression, with an optional unpenalized intercept term1.

Fits the linear model in closed form:

  • l2 == 0: OLS via pseudo-inverse.
  • l2 > 0: Ridge regression (intercept is not penalized).

  1. Hastie, T., Tibshirani, R., Friedman, J., The Elements of Statistical Learning. Springer, 2nd edition, 2009. 

Parameters:

Name Type Description Default

l2

float

L2 regularization strength (0 for OLS).

0.0

bias_calibrator

BiasCalibrator | None

If provided, overrides the regression intercept with a calibrated bias.

None

normalization

NormalizationBase | None

Optional input normalization.

None

interpreto.concepts.probes.LogisticRegressionProbe

LogisticRegressionProbe(lr=0.01, max_iter=20, l2=0.0, init_from_means_diff=True, init_bias_calibrator=prevalence_bias, normalization=None)

Bases: _GDLinearProbe

Code: concepts/probes/linear.py

Multi-label logistic regression probe (BCE loss, Adam optimizer).

Minimizes binary cross-entropy with logits, optionally with L2 weight regularization. Initialized from MeansDiffProbe by default.

Parameters:

Name Type Description Default

lr

float

Adam learning rate.

0.01

max_iter

int

Number of optimization steps.

20

l2

float

L2 regularization on weight (bias not penalized).

0.0

init_from_means_diff

bool

If True, initialize weight/bias from MeansDiffProbe.

True

init_bias_calibrator

BiasCalibrator | None

Bias calibrator for MeansDiff initialization.

prevalence_bias

normalization

NormalizationBase | None

Optional input normalization.

None

interpreto.concepts.probes.LinearSVMProbe

Bases: _GDLinearProbe

Multi-label linear SVM probe (hinge loss, Adam optimizer).

This is the linear model used in CAV1.

Targets are mapped to {-1, +1} and the loss is the mean of max(0, 1 - y * logits). Optionally with L2 weight regularization. Initialized from MeansDiffProbe by default.


  1. Kim, B. et al., Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV). Proceedings of the 35th International Conference on Machine Learning, 2018. 

Parameters:

Name Type Description Default

lr

float

Adam learning rate.

0.01

max_iter

int

Number of optimization steps.

20

l2

float

L2 regularization on weight (bias not penalized).

0.0

init_from_means_diff

bool

If True, initialize weight/bias from MeansDiffProbe.

True

init_bias_calibrator

BiasCalibrator | None

Bias calibrator for MeansDiff initialization.

prevalence_bias

normalization

NormalizationBase | None

Optional input normalization.

None

interpreto.concepts.probes.MeansDiffProbe

MeansDiffProbe(bias_calibrator=None, normalization=None, eps=1e-08)

Bases: BaseLinearProbe

Code: concepts/probes/linear.py

Means-difference probe (multi-label, multi-output).

For each concept j, the weight vector is the difference between the mean activation of positive and negative samples::

$$w_j = mean(x | y_j=1) - mean(x | y_j=0)$$

This is equivalent to Fisher’s Linear Discriminant with shared identity covariance assumption12.


  1. Fisher, R. A., The Use of Multiple Measurements in Taxonomic Problems. Annals of Eugenics, 7(2), 1936, pp. 179-188. 

  2. Hastie, T., Tibshirani, R., Friedman, J., The Elements of Statistical Learning. Springer, 2nd edition, 2009. 

Parameters:

Name Type Description Default

bias_calibrator

BiasCalibrator | None

Post-hoc bias calibration. If None, bias is zero.

None

normalization

NormalizationBase | None

Optional input normalization.

None

eps

float

Floor for sample count denominators.

1e-08

Centroid Probes

interpreto.concepts.probes.CosineCentroidProbe

CosineCentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08)

Bases: BaseCentroidProbe

Code: concepts/probes/centroid.py

Centroid probe using cosine similarity1.

score_ij = cosine(x_i, centroid_j) + bias_j

Centroids are L2-normalized after fitting, and inputs are normalized before scoring.


  1. Manning, C. D., Raghavan, P., Schütze, H., Introduction to Information Retrieval. Cambridge University Press, 2008. 

Parameters:

Name Type Description Default

normalization

NormalizationBase | None

Optional input normalization layer.

None

bias_calibrator

BiasCalibrator | None

Optional post-hoc bias calibration function.

None

eps

float

Numerical stability floor.

1e-08

interpreto.concepts.probes.DotProductCentroidProbe

DotProductCentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08)

Bases: BaseCentroidProbe

Code: concepts/probes/centroid.py

Centroid probe using dot-product similarity.

score_ij = x_i · centroid_j + bias_j

Parameters:

Name Type Description Default

normalization

NormalizationBase | None

Optional input normalization layer.

None

bias_calibrator

BiasCalibrator | None

Optional post-hoc bias calibration function.

None

eps

float

Numerical stability floor.

1e-08

interpreto.concepts.probes.SqL2CentroidProbe

SqL2CentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08)

Bases: BaseCentroidProbe

Code: concepts/probes/centroid.py

Centroid probe using negative squared Euclidean distance.

score_ij = -||x_i - centroid_j||² + bias_j

Computed efficiently via the expansion: dist² = (x·x) + (c·c) - 2(x·c)

Recommended normalization: Standardization.

Parameters:

Name Type Description Default

normalization

NormalizationBase | None

Optional input normalization layer.

None

bias_calibrator

BiasCalibrator | None

Optional post-hoc bias calibration function.

None

eps

float

Numerical stability floor.

1e-08

interpreto.concepts.probes.SVDDCentroidProbe

SVDDCentroidProbe(lr=0.05, max_iter=20, C=1.0, l2=0.0, normalization=None, eps=1e-08)

Bases: BaseCentroidProbe

Code: concepts/probes/centroid.py

Multi-label SVDD (Support Vector Data Description) probe.

For each concept, this probe fits a hypersphere around positive examples and scores activations by the margin between the learned radius and the squared distance to the learned center, following Support Vector Data Description1.

For each concept j, jointly optimizes a center a_j and squared radius r²_j on positive samples by minimizing::

L_j = r²_j + C * mean_pos( relu(||x - a_j||² - r²_j) ) + 0.5*l2*||a_j||²

Encoding returns a margin score::

score_ij = r²_j - ||x_i - a_j||²

Positive scores indicate the sample is inside the concept sphere.


  1. Tax, D. M. J., Duin, R. P. W., Support Vector Data Description. Machine Learning, 54, 2004, pp. 45-66. 

Parameters:

Name Type Description Default

lr

float

Adam learning rate for center/radius optimization.

0.05

max_iter

int

Number of optimization steps.

20

C

float

Hinge loss penalty weight on outliers.

1.0

l2

float

L2 regularization on centroids.

0.0

normalization

NormalizationBase | None

Optional input normalization.

None

eps

float

Numerical stability floor.

1e-08

fit

fit(x, y)

Fit SVDD centers and radii via Adam optimization.

Parameters:

Name Type Description Default

x

Float[Tensor, 'n d']

Raw activations.

required

y

Float[Tensor, 'n c']

Binary multi-label targets.

required

interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe

DiagonalMahalanobisCentroidProbe(normalization=None, bias_calibrator=None, eps=1e-08, shrinkage=1.0)

Bases: BaseCentroidProbe

Centroid probe using diagonal Mahalanobis distance.

This probe scores activations by a variance-weighted squared distance to each concept centroid. With pooled variance, this corresponds to a diagonal-covariance form of Gaussian discriminant analysis1.

score_ij = -(x_i - c_j)^T diag(1/var) (x_i - c_j) + bias_j

The variance matrix is controlled by the shrinkage parameter:

  • shrinkage = 0: class-wise diagonal variance (c, d) from positives only.
  • 0 < shrinkage < 1: convex combination of class-wise and pooled variance.
  • shrinkage = 1: pooled diagonal variance (d,) from all samples.

Default normalization is Standardization if none is provided.


  1. Bishop, C. M., Pattern Recognition and Machine Learning. Springer, 2006. 

Parameters:

Name Type Description Default

normalization

NormalizationBase | None

Input normalization (defaults to Standardization).

None

bias_calibrator

BiasCalibrator | None

Post-hoc bias calibration function.

None

eps

float

Numerical stability floor.

1e-08

shrinkage

float

Shrinkage coefficient in [0, 1] for the variance estimate.

1.0

Normalizations

Normalizations can be composed with any probe to standardize or whiten the input activations before probing.

interpreto.concepts.probes.Standardization

Standardization(eps=1e-08)

Bases: NormalizationBase

Code: concepts/probes/normalizations.py

Per-feature zero-mean, unit-variance normalization.

z = (x - mean) / std

Attributes:

Name Type Description
mean Tensor

Feature means from training data, shape (d,).

std Tensor

Feature standard deviations (clamped to eps), shape (d,).

interpreto.concepts.probes.Whitening

Whitening(rank=None, eps=1e-08)

Bases: NormalizationBase

Code: concepts/probes/normalizations.py

SVD-based whitening normalization.

Whitening projects centered activations onto singular-vector directions and rescales them by the inverse singular values, as in PCA whitening1.

z = (X - mean) @ V_r * (sqrt(n) / S_r)

This produces decorrelated, unit-variance features in the rotated space.


  1. Murphy, K. P., Machine Learning: A Probabilistic Perspective. MIT Press, 2012. Args: rank (int | None): If None (default), full whitening (r = min(n, d)). If int, low-rank whitening keeping the top-r singular components. eps (float): Numerical stability floor. 

Attributes:

Name Type Description
mean Tensor

Feature means, shape (d,).

V Tensor

Right singular vectors, shape (d, r).

inv_s Tensor

Scaling factors sqrt(n) / s_i, shape (r,).

Bias Calibrators

Post-hoc functions to set the bias of a fitted probe based on different criteria.

interpreto.concepts.probes.bce_bias

bce_bias(scores, y, max_iter=50, eps=1e-06)

BCE-optimal bias via L-BFGS.

Fits a per-class intercept b to minimize BCEWithLogitsLoss(scores + b, y) with scores held fixed.

Parameters:

Name Type Description Default

scores

Float[Tensor, 'n c']

Raw concept scores (treated as fixed).

required

y

Float[Tensor, 'n c']

Binary multi-label targets.

required

max_iter

int

Maximum L-BFGS iterations.

50

eps

float

Clamping for initial prevalence estimate.

1e-06

Returns:

Type Description
Float[Tensor, c]

Per-concept bias.

interpreto.concepts.probes.fpr_bias

fpr_bias(scores, y, target_fpr=0.01)

False-positive-rate controlled bias.

Sets the threshold at the (1 - target_fpr) quantile of the negative score distribution, giving approximately target_fpr false positive rate.

threshold_j = quantile_{1 - target_fpr}(scores | y=0) bias_j = -threshold_j

Parameters:

Name Type Description Default

scores

Float[Tensor, 'n c']

Raw concept scores.

required

y

Float[Tensor, 'n c']

Binary multi-label targets.

required

target_fpr

float

Desired false positive rate (default 1%).

0.01

Returns:

Type Description
Float[Tensor, c]

Per-concept bias.

interpreto.concepts.probes.prevalence_bias

prevalence_bias(scores, y, eps=1e-06)

Prevalence-based bias: bias_j = logit(mean(y_j)).

Sets the decision threshold at the class prior, which is optimal under a uniform score distribution. Ignores the actual scores.

Parameters:

Name Type Description Default

scores

Float[Tensor, 'n c']

Raw concept scores (unused, kept for API consistency).

required

y

Float[Tensor, 'n c']

Binary multi-label targets.

required

eps

float

Clamping value for numerical stability.

1e-06

Returns:

Type Description
Float[Tensor, c]

Per-concept bias.

interpreto.concepts.probes.lda_shared_var_bias

lda_shared_var_bias(scores, y, eps=1e-12, var_floor=1e-06)

Closed-form 1-D LDA threshold with shared variance and empirical priors.

Computes the Bayes-optimal threshold assuming Gaussian class-conditional distributions with a shared (pooled) variance::

t = 0.5*(mu0 + mu1) + (var / (mu1 - mu0)) * log(pi0 / pi1)
bias = -t

Parameters:

Name Type Description Default

scores

Float[Tensor, 'n c']

Raw concept scores.

required

y

Float[Tensor, 'n c']

Binary multi-label targets.

required

eps

float

Floor for count and denominator stability.

1e-12

var_floor

float

Minimum variance to avoid division by zero.

1e-06

Returns:

Type Description
Float[Tensor, c]

Per-concept bias.

interpreto.concepts.probes.midpoint_bias

midpoint_bias(scores, y, eps=1e-12)

Midpoint bias between positive and negative score means.

threshold_j = 0.5 * (mean(score|y=1) + mean(score|y=0)) bias_j = -threshold_j

Parameters:

Name Type Description Default

scores

Float[Tensor, 'n c']

Raw concept scores.

required

y

Float[Tensor, 'n c']

Binary multi-label targets.

required

eps

float

Floor for count denominators.

1e-12

Returns:

Type Description
Float[Tensor, c]

Per-concept bias.

Using other models with sklearn

interpreto.concepts.probes.sklearn.SklearnProbeExplainer

SklearnProbeExplainer(splitter, sklearn_class=SVC, sklearn_kwargs={})

Bases: ConceptEncoderExplainer[SklearnProbe]

Concept explainer using a scikit-learn probe.

Integrates SklearnProbe into the concept explainer pipeline, connecting it to a [BaseSplitter][interpreto.concepts.splitters.base_splitter.BaseSplitter] for activation extraction.

Parameters:

Name Type Description Default

splitter

BaseSplitter

Wrapped transformer model.

required

sklearn_class

Any

Scikit-learn estimator class (default: SVC).

SVC

sklearn_kwargs

dict[str, Any]

Arguments forwarded to the sklearn estimator.

{}
Source code in interpreto/concepts/probes/sklearn.py
def __init__(
    self,
    splitter: BaseSplitter,
    sklearn_class: Any = SVC,
    sklearn_kwargs: dict[str, Any] = {},
):
    self.concept_model: SklearnProbe
    concept_model = SklearnProbe(sklearn_class, sklearn_kwargs)
    super().__init__(
        splitter=splitter,
        concept_model=concept_model,
    )

is_fitted property

is_fitted

Delegates to the probe's fitted flag.

fit

fit(activations, labels)

Fit the concept model.

Source code in interpreto/concepts/probes/sklearn.py
def fit(
    self,
    activations: LatentActivations,
    labels: torch.Tensor,
):
    """Fit the concept model."""
    if len(activations.shape) != 2:
        raise ValueError(f"Expected activations to be a 2D array, (n, d), got shape {activations.shape}")
    if activations.shape[0] != labels.shape[0]:
        raise ValueError(
            "Expected activations and labels to have the same number of rows, "
            f"got {activations.shape[0]} and {labels.shape[0]}"
        )

    self.concept_model.fit(activations, labels)

interpreto.concepts.probes.sklearn.SklearnProbe

Probe wrapping a scikit-learn classifier with decision_function.

Satisfies [ConceptModelProtocol][interpreto.typing.ConceptModelProtocol] structurally. Currently limited to a single binary concept.

Parameters:

Name Type Description Default

sklearn_class

Any

Scikit-learn estimator class (e.g. SVC).

required

sklearn_kwargs

dict[str, Any]

Keyword arguments forwarded to the estimator constructor.

required
Source code in interpreto/concepts/probes/sklearn.py
def __init__(self, sklearn_class: Any, sklearn_kwargs: dict[str, Any]):
    self.model = sklearn_class(**sklearn_kwargs)
    self.fitted = False

encode

encode(X)

Encode the given activations using the concept model.

Source code in interpreto/concepts/probes/sklearn.py
@assert_fitted
def encode(self, X: Float[torch.Tensor, "n d"]) -> Float[torch.Tensor, "n 1"]:
    """Encode the given activations using the concept model."""
    np_X = np.array(X)
    np_y = self.model.decision_function(np_X)
    return torch.from_numpy(np_y).unsqueeze(1)

fit

fit(X, y)

Fit the concept model.

Source code in interpreto/concepts/probes/sklearn.py
def fit(self, X: Float[torch.Tensor, "n d"], y: Float[torch.Tensor, "n"]):
    """Fit the concept model."""
    np_X, np_y = np.array(X), np.array(y)
    self.model.fit(np_X, np_y)
    self.fitted = True