Motivation

Graph Neural Networks (GNNs) are powerful machine learning algorithms over graph-structured data. These models have seen success across several domains from Polypharmacy to Recommender Systems. However, GNNs lack rigorous uncertainty estimates, limiting their deployment in high-stake settings.

Huang et al's Uncertainty Quantification over Conformalized GNNs demonstrates, both theoretically and empirically, the validity of Conformal Prediction in the context of GNNs. We implement conformalized GNNs in the experimental version of the PyG Library as a general-purpose algorithm for classification and regression problems. This addresses the need for uncertainty estimates.

Conformal Prediction

Conformal prediction is a modern framework designed to quantify uncertainty by constructing prediction sets that contain the true outcome with a specified probability.

For concreteness, let \( (X_{i}, Y_{i}) \sim P \) be a feature and response pair from a distribution \( P \) over the Cartesian Product \( {X} \times \mathcal{Y} \). Let \( \alpha \in (0, 1) \) denote the miscoverage rate. Our task is to find a prediction set \( C \) such that the probability of observing a new feature and response pair in the band is bounded by the miscoverage rate.

\[ \mathbb{P} \left ( Y_{n + 1} \in \hat{C_{n}}(X_{n + 1}) \ge 1 - \alpha \right ) \]

Remarkably, this is achievable under mild conditions! We build intuition in the discussion that follows.

Toy Example

Consider an ordered list of integers spanning \( 1 \) through \( n \). Now pick a new number and insert it in the list. To produce a set of numbers such that the new number is in the set with a specified coverage, we pick the \( \lceil(n + 1)(1 - \alpha) \rceil \) smallest.

Your browser does not support SVG
Prediction Set in Discrete Setting

At this point, you might think we cheated. In particular, we added \( n + 1 \) to the list before constructing our set! The insight is that we didn't need to know the new point so long as it was equally likely to be placed anywhere. This condition, fancifully called exchangeability, is the only requirement for Conformal Prediction. In general, we define the quantile below.

\[ \hat{q} = \lceil (n + 1)(1 - \alpha) \rceil \hspace{0.5em} \text{lowest} \]

With this quantile on hand, we reason that the prediction set meets coverage with finite sample correction.

\[ \mathbb{P} \left ( Y_{n + 1} \le \hat{q} \right ) \in \left [ 1 - \alpha, 1 - \alpha + \frac{1}{n + 1} \right ] \]

While trivial, this example showcases the power of rank (order) statistics. If we can introduce an order to our observations, we can construct prediction sets with arbitrary coverage.

We extend these ideas to regression. Consider a dataset \( D \) partitioned into a disjoint training set \( D_{1} \) and calibration set \( D_{2} \). Let \( \hat{f}_{n} \) be a point-predictor trained on \( D_{1} \). We compute the residuals.

\[ R_{i} = \lvert \hat{f}_{n}(x_{i} - y_{i}) \rvert \hspace{0.5em} \text{for} \hspace{0.5em} i \in D_{2} \]

The residuals are an order statistic. As before we compute \( \hat{q}_{n_{2}} \) as the lowest among \( R_{i} \). The prediction set \( f_{n_{1}} (x_{i}) \pm \hat{q}_{n_{2}} \) has guaranteed coverage on \( D_{1} \) despite being calibrated on different data.

To flesh this out, we will construct a simple example. Consider a cubic spline \( f(x) \). For a collection of input, output pairs we will add Gaussian Noise drawn from \( \mathcal{N}(\mu, \sigma) \) for \( \mu = 0 \) and \( \sigma = 0.3 \). Running the algorithm above, yields the following plot.

Your browser does not support SVG
Conformal Prediction on Cubic Spline

Since our predictor is perfect (by construction) we expect the quantiles to capture the Gaussian Noise. In particular, to have coverage \( 1 - \alpha \) we would expect normalcdf\( (-\hat{q}, \hat{q}) \approx 1 - \alpha \) in general. We tabulate the empirical and expected quantiles below.

Miscoverage Empirical Expected
\( 0.10 \) \( 0.461 \) \( 0.493 \)
\( 0.20 \) \( 0.366 \) \( 0.384 \)
\( 0.30 \) \( 0.291 \) \( 0.311 \)
\( 0.40 \) \( 0.222 \) \( 0.252 \)
\( 0.50 \) \( 0.163 \) \( 0.202 \)
Comparison of Quantiles

The table confirms our expectation. It is worth pointing out that as the miscoverage decreases the prediction set grows larger to accommdate.

General Approach

A popular tutorial on Conformal Prediction summarizes the steps.

  1. Identify an uncertainty heuristic using the pre-trained model.
  2. Define the score function \( s(x, y) \in \mathbb{R} \)
  3. Compute \( \hat{q} \) as the \( \frac{\lceil(n + 1)(1 - \alpha) \rceil}{n} \) quantile of calibration scores
  4. Form prediction sets \( \mathcal{C}(X) = \{y : s(x, y) \le \hat{q} \} \)

While any score function can guarantee coverage, improving it reduces the length of the prediction set.

\[ \text{len} = \int \int_{C} \mathrm{d}\mu(y) \mathrm{d}(P_{X}) \hspace{0.5em} \text{and} \hspace{0.5em} \text{cov} = \int \int_{C} \mathrm{d}P_{Y \vert X} \mathrm{d}(P_{X}) \]

If we increase the probability of a response given a feature, namely \( P_{Y \vert X} \), then we can meet coverage without inflating the length. The argument is due to Tibshirani.

In our assessment, the key challenge of conformal prediction is computing a reasonable score function. The next sections overview two approaches.

Prediction Strategy

A prediction strategy leverages the pre-trained heuristic to score uncertainty. The strategy depends on the task at hand.

Adaptive Prediction Sets (APS) is a prediction strategy for the classification problems. We define the following score function where \( \pi \) is the permutation that ranks the classes from most to least likely.

\[ s(x, y) = \sum_{j = 1}^{k} \hat{f}(x)_{\pi_j(x)} \quad \text{where} \quad y = \pi_k(x) \]

Intuitively, we want to include top-rank classes until the cumulative sum of their probabilities meets the desired coverage. By greedily doing this from most likely to least likely, we ensure the prediction set is minimal. For this reason, the prediction sets are feature adaptive.

Conformal Quantile Regression (CQR) is a prediction strategy for regression problems. We introduce learned functions \( \hat{t}_{\alpha/2}(x) \) and \( \hat{t}_{1 - \alpha/2}(x) \) that estimate the upper and lower bound. They are trained on the following objective known as the Quantile (Pinball) Loss.

\[ L_{\tau}(\hat{t}, y) = \begin{cases} (y - \hat{t}) \tau & \text{if } y > \hat{t}, \\ (\hat{t} - y)(1 - \tau) & \text{if } y \leq \hat{t}. \end{cases} \]

With these functions on hand, we construct the score function. Intuitively, we want to center our prediction sets on the known label and move up and down such that the desired coverage is met.

\[ s(x, y) = \max \left( \hat{t}_{\alpha/2}(x) - y, \; y - \hat{t}_{1 - \alpha/2}(x) \right) \]

Computing \( \hat{q} \) as before, we arrive at the following prediction band.

\[ \mathcal{C}(x) = \left[ \hat{t}_{\alpha/2}(x) - \hat{q}, \; \hat{t}_{1 - \alpha/2}(x) + \hat{q} \right]. \]

Calibration Strategy

A calibration strategy modifies the score function by adjusting the heuristic itself. We introduce four calibration strategies.

Temperature Scaling (TS) scales the logits \( z \) by a learned temperature. The (negative) log-likelihood is minimized.

\[ \min_T \, -\frac{1}{N} \sum_{i=1}^N \log p_{y_i}^{(T)} \hspace{0.5em} \text{for} \hspace{0.5em} p_{y_i}^{(T)} = \text{softmax} \left( \frac{z_i}{T} \right)_{y_i} \]

There is clear intuition. A temperature greater than \( 1 \) smooths the probabilities, reducing overconfidence. A temperature less than \( 1 \) sharpens the probabilities, reducing underconfidence.

There are variants. Vector Scaling (VS) generalizes temperature scaling with a temperature for each class and a bias term. Ensemble Temperature Scaling (ETS) solves a constrained optimization problem that weights the contributions of different scales.

While the methods above are useful, they omit the underlying graph structure. Huang et al introduce a Calibration Attention Layer to account for this. Taking the original setup and drawing on GATConv along with some tweaks, we compute a node temperature.

\[ T_i = \text{ReLU}(\mathbf{W} \cdot \text{GATConv}_{\theta}(\mathcal{G})_i + b) \]

By incorporating information from neighboring nodes, with attention weights for each contribution, we can improve performance.

Your browser does not support SVG
Message Passing under Graph Attention Scheme

Limitation

Despite the ad-hoc fixes listed above, Conformal Prediction has its limitations. For one, Conformal Prediction does not guarantee conditional coverage. On average, you would expect the desired coverage. However, upon conditioning on a particular class, you might find that the coverage is much lower (or higher) than expected.

Your browser does not support SVG
Marginal vs Condional Coverage

The figure above, adapted from a popular tutorial, visually summarizes the salient idea. This bias may be unacceptable in certain settings.

For another, we demand exchangeability. That said, Weighted Conformal Prediction is an active and promising direction that accounts for covariate shift.

Implementation

Our implementation of the ideas above was designed with ease-of-use and integration with PyG in mind. We introduce an Uncertain layer under torch.NN with three files. Each file maps directly to a step in the Conformal Prediction pipeline.

Your browser does not support SVG
Directory Structure of Conformal

We took care to include existing PyG functionality where applicable. Notably, we use GATConv to build the Calibration Attention Layer.

Interface

The interface allows users to rapidly apply Conformal Prediction to an existing model. Borrowing from sklearn we use a fit and predict pattern for familarity and convenience.

                
                    from uncertain import Conformal

                    # Create conformal instance
                    conformal = Conformal(
                        model=model,
                        prediction='APS', # APS, RAPS, TPS
                        calibration='TS', # TS, VS, ETS, GATS
                    )

                    # Call Fit and Predict on instance
                    conformal.fit()
                    conformal.predict()
                
            

The fit step sets up initial state and optionally fits the specified calibrator. The predict step runs conformal prediction and returns a ConformalPrediction object. This object contains the prediction sets along with useful statistics that the user can inspect.

There are a couple "gotchas" that should be explicitly mentioned. First, the task is inferred from the prediction method. Second, the user must specify any train, test, and validation data. Handling the split properly is essential for statistically guaranteed coverage as noted above.

Features

We provide the user with several ready-made prediction and calibration strategies. The full list is shared in the interface.

The prediction strategies are generally compact and readable. We include a simplified snippet that captures the essence of APS. This is consistent with the description above.

                
                    def aps(smx, labels, n, alpha):
                        # Sort scores descending order
                        pi = smx.argsort(1)[:, ::-1]
                        srt = np.take_along_axis(smx, pi, axis=1).cumsum(axis=1)
                        
                        # Get cumulative probability up to true label
                        scores = np.take_along_axis(srt, pi.argsort(axis=1), axis=1)
                        scores = scores[range(n), labels]
                        
                        # Find threshold for coverage
                        qhat = np.quantile(
                            scores, 
                            np.ceil((n + 1) * (1 - alpha)) / n, 
                            interpolation="higher"
                        )
                
            

The calibration strategies are standard. We include the forward pass of TS that scales the logits by the temperature. This is consistent with the description above.

                
                    def forward(self, x, edge_index):
                        logits = self.model(x, edge_index)
                        return logits / self.temperature.expand(logits.size())
                
            

These implementations are housed in the prediction and calibration files respectively. The conformal file defines the wrapper which applies the specified methods.

Future Directions

Future PRs should add the following features. Functions to directly optimize desirable properties such as efficiency or conditional coverage. Weighted Conformal Prediction to hold exchangeability under covariate shift.

Comment Section