Self-assessing neural networks

Self-assessing neural networks

I invented a new type of layer for neural networks: It gives the network the ability to critically assess the reliability of its own features, to enable more informed decision making.

This invention is a spin-off of a private research project of mine. Preliminary results showed that using this layer lead to a minor improvement on the MNIST dataset. However, the improvement was too small to know for sure if it is useful.

Since this project does not go in the same direction as my primary research, I am publishing it here. Maybe someone else wants to have a look at it.

Even if it turns out that the improvement on MNIST is a fluke, there are a large number of future improvements that could make the algorithm more effective.

Also, if you have any practical problems or benchmarking testcases lying around, be sure to give it a try and report if it worked. It's just one line of code to swap a normal linear layer with a self-assessment layer.

Intuition

The intuition behind this new type of layer is as follows:

When a normal neural network is used to solve a regression problem, it does not provide a measure of how certain it is of the correctness of its output. If the network has no idea what the output should actually be, a well trained network will simply give the value that minimizes the expected loss over all training examples.

There is a difference between saying "I have checked every data source I can find and the answer is 0.3" and saying "I have no idea where to even start, but the average is 0.3 so let's go with that". Contemporary neural networks do not capture that difference.

It seems clear to me that this information ought to be useful to have. The question is: Is it useful enough to be worth the extra effort to calculate? I give a list of possible usecases at the end of this article, but for now let's try a simple implementation.

Idea

I introduce the Self-Assessment layer:

The self-assessment layer is a new type of layer for neural networks. It consists of two parts: The main component (which behaves like a normal layer), and the self-assessment component (which is trained to predict the usefulness of the main component).

There are many ways to implement the details. In the below example, it's implemented like this:

The main component is a linear layer. The self-assessment component operates like a normal linear layer on the forward pass, but uses a modified gradient in the backpropagation: When the main layer receives its gradient, the self-assessment layer's own gradient is ignored and overwritten: The self-assessment component's new gradient is computed through Minimum-least-squares loss, with the mean absolute of the gradient of the main neurons as the target.

The result of this is that the self-assessment component learns to predict the mean absolute of the main component's gradient. In other words, whenever the main component makes a perfect prediction, the self-assessment component should predict zero, and whenever the main component makes a poor prediction (so that it will receive a high gradient), the self-assessment component should predict a high value. The self-assessment component ends up as a measure of the uncertainty that is left after the other features have been calculated.

While it is already useful to know the reliability of your network for its own sake, to inform your decisions when designing a network architecture, this information can also be used to aid training:

By splitting a network layer into many parallel layers and learning a self-assessment for each of them, the network should gain valuable information to improve its own performance. Intuitively, if you imagine the features of a layer as people giving advise to you, then adding a self-assessment layer would be like knowing which of the advisors are actually confident about their advice, and which ones are not.

Implementation in pytorch

I have implemented the self-assessment layer in pytorch:

import math

import torch
from torch import autograd
import torch.nn as nn


class SelfAssessmentFunction(autograd.Function):
    """
    Implements two linear layers, one called 'main' and one called 'sass' (for 'self-assessment').

    The main layer behaves just like a regular linear layer.

    The sass layer behaves like a linear layer on the forward pass, but it uses a custom backward method to calculate its gradient:

    The gradient the sass layer receives from backpropagation is ignored completely. It is replaced with the result of a custom loss function: It uses MLEloss, with the absolute values of the gradient of the main layer as the target. This new loss function is then used to calculate the gradient of the sass layer.

    As a result, the sass layer learns to approximate the average absolute gradient of the main layer.

    Each neuron in the sass layer reacts to a different subset of the neurons in the main layer, which is controlled by output_to_sass_mean_compression.
    """

    @staticmethod
    def forward(ctx, input, weight_main, bias_main, weight_sass, bias_sass, output_to_sass_mean_compression):
        # Both feed-forward portions are just the result of applying the respective layer to the input
        output_main = input.mm(weight_main.t())
        output_main += bias_main.unsqueeze(0).expand_as(output_main)
        output_sass = input.mm(weight_sass.t())
        output_sass += bias_sass.unsqueeze(0).expand_as(output_sass)
        ctx.save_for_backward(input, weight_main, bias_main, weight_sass, bias_sass, output_sass, output_to_sass_mean_compression)
        return output_main, output_sass

    @staticmethod
    def backward(ctx, grad_main, grad_sass):
        input, weight_main, bias_main, weight_sass, bias_sass, output_sass, output_to_sass_mean_compression = ctx.saved_tensors
        grad_input = grad_weight_main = grad_bias_main = grad_weight_sass = grad_bias_sass = grad_output_to_sass_mean_compression = None
        # Perform normal gradient calculations on the main layer
        grad_weight_main = grad_main.t().mm(input)
        grad_bias_main = grad_main.sum(0)
        # For the sass layer, ignore the grad_sass and recompute it:
        # The grad_sass is computed through MLELoss, with the absolute of the gradient of the main neurons as the target.
        # Each neuron in sass measures a subset of the main neurons. This mapping is done by output_to_sass_mean_compression.
        target = grad_main.abs().mm(output_to_sass_mean_compression)
        grad_sass = (output_sass - target) * 2
        # Apply this new gradient
        grad_weight_sass = grad_sass.t().mm(input)
        grad_bias_sass = grad_sass.sum(0)
        # Calculate the gradient for the input
        grad_input = grad_main.mm(weight_main)
        return grad_input, grad_weight_main, grad_bias_main, grad_weight_sass, grad_bias_sass, grad_output_to_sass_mean_compression


class SelfAssessment(nn.Module):
    """
    Implements a linear layer as well as a self-assessment layer, which is a second linear layer that is trained to predict the gradient of the first linear layer.

    If add_sass_features_to_output=True, the results of both layers are combined into a single tensor. Otherwise they are returned separately.

    The parameter sass_features specifies the number of different self-assessment neurons, which must be a multiple of out_features.

    For example:

    If you set sass_features=1 and add_sass_features_to_output=False, this module behaves like a normal linear layer, but you will get a secondary output that predicts the average absolute gradient of the entire layer.

    If you set out_features=200, sass_features=10, add_sass_features_to_output=True, you end up with an output vector of size 210. 200 of those are normal results of the linear layer, while the remaining 10 are predictions of the mean absolute gradient of the 10 blocks of 20 neurons. These 10 additional predictions may improve network performance because subsequent layers can use them as an estimate for how reliable the 10*20 main neurons are for the given example.

    Note for interpretation: A higher value of the self-assessment means that the network is less sure of itself. (The value measures how much the network expects to learn from the new data.)
    """

    def __init__(self, in_features, out_features, sass_features=1, add_sass_features_to_output=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.sass_features = sass_features
        self.add_sass_features_to_output = add_sass_features_to_output
        if float(out_features) % float(sass_features) != 0:
            raise ValueError("The number of output features (out_features) must be a multiple of the number of self-assessment features (sass_features).")
        # Create one layer for the calculation itself, and another for the self assessment
        self.weight_main = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias_main = nn.Parameter(torch.Tensor(out_features))
        self.weight_sass = nn.Parameter(torch.Tensor(sass_features, in_features))
        self.bias_sass = nn.Parameter(torch.Tensor(sass_features))
        self.reset_parameters()
        # Create a mapping that compresses features from the main layer by taking the means of a subset of them.
        # This is needed to calculate the gradient of MSE-loss for all sass features at the same time.
        self.output_to_sass_mean_compression = torch.zeros(out_features, sass_features, requires_grad=False)
        main_per_sass = out_features / sass_features
        for i in range(out_features):
            j = int(i / main_per_sass)
            self.output_to_sass_mean_compression[i,j] = 1.0 / main_per_sass
        # Create mappings that combine output_main and output_sass into a single tensor
        self.output_combiner_main = torch.zeros(out_features, out_features + sass_features, requires_grad=False)
        self.output_combiner_sass = torch.zeros(sass_features, out_features + sass_features, requires_grad=False)
        for i in range(out_features):
            self.output_combiner_main[i,i] = 1.0
        for i in range(sass_features):
            self.output_combiner_sass[i,out_features+i] = 1.0

    def reset_parameters(self):
        # main
        nn.init.kaiming_uniform_(self.weight_main, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_main)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias_main, -bound, bound)
        # sass
        nn.init.kaiming_uniform_(self.weight_sass, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_sass)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias_sass, -bound, bound)

    def forward(self, input):
        output_main, output_sass = SelfAssessmentFunction.apply(input, self.weight_main, self.bias_main, self.weight_sass, self.bias_sass, self.output_to_sass_mean_compression)
        if self.add_sass_features_to_output:
            combined_output = output_main.mm(self.output_combiner_main) + output_sass.mm(self.output_combiner_sass)
            return combined_output
        else:
            return output_main, output_sass

    def extra_repr(self):
        return 'in_features={}, out_features={}, sass_features={}'.format(
            self.in_features, self.out_features, self.sass_features
        )

To get a self-assessment at the end of a regression task, just use a layer like this:

self.final_layer = SelfAssessment(500, num_outputs, sass_features=1)

and call it like this:

output, self_assessment = self.final_layer(x)

And to use the self-assessment to improve training, just replace any code that looks like this

self.fc1 = nn.Linear(200, 300)
with this:
self.fc1 = SelfAssessment(200, 270, sass_features=30, add_sass_features_to_output=True)

By the way: Since the sass component's output has a different scale than the output of the main component, I would recommend using a hyperparameter optimizer that uses a different learning rate for each parameter, such as the Adam optimizer.

Experiments

Experiment 1

To test the self-assessment layer, I defined a simple test function: It has three inputs and one output. The output is the sum of the squares of the inputs, plus a noise. The noise depends on the absolute value of the first input.

This function is easy to predict when the first input is close to zero, but difficult when the first input is far from zero. Varying the other inputs changes the result, but does not make the result harder to predict.

I built a simple neural network, put a self-assessment layer at the end of the network, and trained it on random samples of this function.

Would it learn the relation between the three inputs and the output? I.e., does it correctly recognize that its own output can not be trusted if and only if the first input is far from zero?

Result 1

The self-assessment layer at the end of the network did correctly learn to recognize that only the absolute value of the first input matters for the reliability of the network.

When tracking the change of the weights and outputs of the layers over time, it turns out that the self-assessment component converges only slightly slower than the rest of the network. A slight delay is expected, since the prediction-target of the self-assessment layer changes over time, as the rest of the network converges. As a result, the self-assessment can not finish training until the rest of the network has already finished training.

Experiment 2

To test whether or not the self assessment layer improves performance, I used a standard MNIST-solving network from the internet and simply replaced one of the linear layers with a self-assessment layer. I deliberately did not change anything else and ran the program only once, so that I wouldn't accidentally bias myself.

Result 2

The self-assessment layer led to a minor improvement.

With the standard linear layer: loss: 0.0381, accuracy: 0.9883

With the self-assessment linear layer: loss: 0.0365, accuracy: 0.9892

While I am happy to see an improvement, the improvement is so small that it might be chance. Additionally, the self-assessment layer required a bit of additional time to train, which may make it prohibitively expensive to use in practice. Further experimentation is required.

I would like to get a second opinion on this idea before I invest more time in it. Any and all feedback is welcome!

Experiment 3

The test network I was using did not use Batch Normalization. Because the sass component has very different outputs than the main component, I decided to run the experiment again with a Batch Normalization added behind the SelfAssessmentLayer.

Result 3

The performance was increased further.

With the standard linear layer: loss: 0.0426, accuracy: 0.9869

With the self-assessment linear layer: loss: 0.0337, accuracy: 0.9900

However, while the performance of the SelfAssessmentLayer did indeed improve, adding the Batch Normalization actually decreased the performance of the linear layer. Since Batch Normalization is normally a good thing, this indicates that the experiment is not large enough to be reliable and trustworthy. I need to find a larger, more thorough test set to run better experiments.

(Any tips for a good benchmarking problem I should use are appreciated)

Planned Improvements

There are a large number of potential improvements to be made to the self-assessment logic described above. Even if it turns out that the improvement on MNIST was a fluke, these improvements may turn it into something useful:

  • Experiment with other types of operations for both parts of the SelfAssessmentLayer layer (the main component and the sass component). Currently, both parts are just linear and have no activation function.

  • Experiment with different ways of measuring the self-assessment. Currently, it measures the MLELoss of the mean of the absolute of the gradients. Both the loss function and the aggregator can be varied.

  • Alter how much of the gradient of the self-assessment is backpropagated, relative to the gradient of the main layer. In the above implementation, the self-assessment is not backpropagated at all.

    Choosing a good factor here is tricky because the main component and the sass component use different units and can vary greatly in magnitude.

  • Related to that last point: Find a good way to normalize the self-assessment relative to the other neurons in the layer. As experiment 3 showed, Batch Normalization helps with this, but some fine tuning of the running averages may increase the performance further because the training of the self-assessment component lags slightly behind the training of the main component.

  • In the SelfAssessmentLayer, the matrix output_to_sass_mean_compression is currently hardcoded, which means that the neurons in the main layer are randomly grouped together. It would likely lead to better performance if there was a way to ensure that the neurons get grouped together in such a way that the neurons in each group tend to be reliable or unreliable in the same situations.

    Note that this only works well if the gradients' magnitude correlates consistently throughout the training process, and won't change its behavior over time. I have no idea how likely that is, and it sounds like an interesting topic to research in its own right. Intuitively, when gradients stop correlating it should indicate that the neurons have diverged in meaning and now capture different features. Recognizing this could be useful for a variety of purposes.

  • The purpose of the SelfAssessmentLayer is to allow later stages of the network to quickly decide to discard some features that are not reliable in the current context. Given that this is the goal, why not modify the SelfAssessmentLayer to take this into account directly:

    During the forward pass, the activation of each neuron in the main component is dampened depending on the self-assessment component for that neuron.

    This should simulate ensemble learning similar to the way Dropout layers do. Only instead of randomly reducing features to zero, we systematically shrink features that we believe won't be as relevant.

    Addendum (2019-05-18): I tested this idea, and did not notice any consistent improvements. It sometimes improved and sometimes hurt performance, and I could not detect any pattern in this.

  • Building on the previous point: Modify the learning rate based on the self-assessment.

    When the SelfAssessmentLayer thinks that a feature is unreliable, then there is no point in updating and the learning rate for that particular neuron should be small. When the SelfAssessmentLayer is very sure of its correctness, then any feedback is surprising and the learning rate should be increased.

    Addendum (2019-05-18): I tested this idea, and did not notice any consistent improvements. It sometimes improved and sometimes hurt performance, and I could not detect any pattern in this.

Possible Usecases

Self-assessment has a number of additional potential usecases that may be worth investigating:

  • With self-assessment, a neural network has the ability to raise a warning when it isn't sure about its estimate. Automated processes can react to this and try to get a prediction from some other source instead.

  • Identify the most surprising inputs: Correct predictions when self-assessment assigns low confidence, or incorrect predictions when self-assessment assigns high confidence.

    By looking for surprising data, one can also detect data entry errors.

  • Search for correlated clusters in the inputs that have good self-assessment but low actual predictive accuracy / high error. Each cluster found represents a qualitatively new concept that the AI hasn't learned yet.

  • Self-assessment can be used to improve ensembles of methods.

    Instead of giving each network a fixed weight, let the weight of the network depend on the input: It scales with the self-assessment the network gives itself for the given input.

    This may not work very well, because building a stacking ensemble would likely perform better than this, but it may be worth investigating to draw a comparison.