Skip to content

SquareGrad

Bases: MultitaskExplainerMixin, AttributionExplainer

SquareGrad is a gradient-based attribution method that computes the variance of input gradients under random perturbations. Unlike methods that average gradients (e.g., SmoothGrad), SquareGrad averages the square of the gradient.

Procedure:

  • Generate multiple perturbed versions of the input by adding noise (Gaussian) to the input embeddings.
  • For each noisy input, compute the gradient of the output with respect to the embeddings.
  • Average the square of the gradients across all samples.
  • Aggregate the result per token (e.g., by norm with the input) to get the final attribution scores.

Reference: Hooker et al. (2019). A Benchmark for Interpretability Methods in Deep Neural Networks. Paper

Examples:

>>> from interpreto import SquareGrad
>>> method = SquareGrad(model, tokenizer, batch_size=4,
>>>                     n_perturbations=50, noise_std=0.01)
>>> explanations = method.explain(text)

Parameters:

Name Type Description Default

model

PreTrainedModel

model to explain

required

tokenizer

PreTrainedTokenizer

Hugging Face tokenizer associated with the model

required

batch_size

int

batch size for the attribution method

4

granularity

Granularity

The level of granularity for the explanation. Options are: ALL_TOKENS, TOKEN, WORD, or SENTENCE. Defaults to Granularity.WORD. To obtain it, from interpreto import Granularity then Granularity.WORD.

WORD

granularity_aggregation_strategy

GranularityAggregationStrategy

how to aggregate token-level attributions into granularity scores. Options are: MEAN, MAX, MIN, SUM, and SIGNED_MAX. Ignored for granularity set to ALL_TOKENS or TOKEN.

MEAN

device

device

device on which the attribution method will be run

None

inference_mode

Callable[[Tensor], Tensor]

The mode used for inference. It can be either one of LOGITS, SOFTMAX, or LOG_SOFTMAX. Use InferenceModes to choose the appropriate mode.

LOGITS

input_x_gradient

bool

If True, multiplies the input embeddings with their gradients before aggregation. Defaults to True.

True

n_perturbations

int

the number of interpolations to generate

10

noise_std

float

standard deviation of the Gaussian noise to add to the inputs

0.1