Feeds:
Posts
Comments

Posts Tagged ‘Metric Learning’

An informal summary of a recent project I had some involvement in.

Motivation: Why care about Metric Learning?

In many machine learning algorithms, such as k-means, Support Vector Machines, k-Nearest Neighbour based classification, kernel regression, methods based on Gaussian Processes etc etc – there is a fundamental reliance, that is to be able to measure dissimilarity between two examples. Usually this is done by using the Euclidean distance between points (i.e. points that are closer in this sense are considered more similar), which is usually suboptimal in the sense that will be explained below. Being able to compare examples and decide if they are similar or dissimilar or return a measure of similarity is one of the most fundamental problems in machine learning. Ofcourse a related question is: What does mean by “similar” afterall?

To illustrate the above let us work with k-Nearest Neighbour classification. Before starting, let us just illustrate the really simple idea (of kNN classification) by an example: Consider the following points in \mathbb{R}^2, with the classes marked by different colours.

2DPoints

Now suppose we have a new point – marked with black – whose class is unknown. We assign it a class by looking at the nearest neighbors and taking the majority vote:

kNN

Some notes on kNN:

A brief digression first before moving on the problem in the above (what is nearest?). kNN classifiers are very simple and yet in many cases they can give excellent performance. For example, consider the performance on the MNIST dataset, it is clear that kNN can give competitive performance as compared to other more complicated models.

MNIST

Moreover, they are simple to implement, use local information and hence are inherently nonlinear. The biggest advantage in my opinion is that it is easy to add new classes (since no retraining from scratch is required) and since we average across points, kNN is also relatively robust to label noise. It also has some attractive theoretical properties: for example kNN is universally consistent (as the number of points approaches infinity, with appropriate choice of k, the kNN error will approach the Bayes Risk).

Notion of “Nearest”:

At the same time, kNN classifiers also have their disadvantages. One is related to the notion of “nearest” (which falls back on what was talked about at the start) i.e. how does one decide what points are “nearest”. Usually such points are decided on the basis of the Euclidean distance on the native feature space which usually has shortfalls. Why? Because nearness in the Euclidean space may not correspond to nearness in the label space i.e. points that might be far off in the Euclidean space may have similar labels. In such cases, clearly the notion of “near” using the euclidean distance is suboptimal. This is illustrated by a set of figures below (adapted from slides by Kilian Weinberger):

An Illustration:

Consider the image of this lady – now how do we decide what is more similar to it?

 Lady-Who

Someone might mean similar on the basis of the gender:

Lady-GenderOr on the basis of age:

Lady-Age

Or on the basis of the hairstyle!

Lady-Hair

Similarity depends on the context! Something that the euclidean distance in the native feature space would fail to capture. This context is provided by labels.

Distance Metric Learning:

The goal of Metric Learning is to learn a distance metric, so that the above label information is incorporated in the notion of distance i.e. points that are semantically similar are now closer in the new space. The idea is to take the original or native feature space, use the label information and then amplify directions that are more informative and squish directions that are not. This is illustrated in this figure – notice that the point marked in black would be incorrectly classified in the native feature space, however under the learnt metric it would be correctly classified.

MetricLearningAmp

It is worthwhile to have a brief look at what this means. The Euclidean distance (with x_i \in \mathbb{R}^d) is defined by

\sqrt{(x_i - x_j)^T (x_i - x_j)}

as also was evident in the above figure, this corresponds to the following euclidean ball in 2-D

EucBall

A family of distance measure may be defined using an inner product matrix. These are called the Mahalanobis metrics.

\sqrt{(x_i - x_j)^T \mathbf{W}(x_i - x_j)}

Mahal-Ball

The learnt metric affects a rescaling and rotation of the original space. The goal is now to learn this \mathbf{W} \succeq 0 using the label information so that the new distances correspond better to the semantic context. It is easy to see that when \mathbf{W} \succeq 0, the above is still a distance metric.

Learning \mathbf{W}:

Usually the real motivation for metric learning is to optimize for the kNN objective i.e. learn the matrix \mathbf{W} \succeq 0 so that the kNN error is reduced. But note that directly optimizing for the kNN loss is intractable because of the combinatorial nature of the optimization (we’ll see this in a bit), so instead, \mathbf{W} is learnt as follows:

1. Define a set of “good” neighbors for each point. The definition of “good” is usually some combination of proximity to the query point and label agreement between the points.

2. Define a set of “bad” neighbours for each point. This might be a set of points that are “close” to the query point but disagree on the label (i.e. inspite of being close to the query point they might give a wrong classification if they were chosen to classify the query point).

3. Set up the optimization problem for \mathbf{W} such that for each query point, “good” neighbours are pulled closer to it while “bad” neighbours are pushed farther away, and thus learn \mathbf{W} so as to minimize the leave one out kNN error.

The exact formulation of “good” and “bad” varies from method to method. Here are some examples:

In one of the earliest papers on distance metric learning by Xing, Ng, Jordan and Russell (2002) – good neighbors are similarly labeled k points. The optimization is done so that each class is mapped into a ball of fixed radius. However no separation is enforced between the classes. This is illustrated in the following figure (the query point is marked with an X, similarly labeled k points are moved into a ball of a fixed radius):

XingNgJordan

One problem with the above is that the kNN objective does not really require that similarly labeled points are clustered together, hence in a way it optimizes for a harder objective. This is remedied by the LMNN described briefly below.

One of the more famous Metric Learning papers is the Large Margin Nearest Neighbors by Weinberger and Saul (2006). Here good neighbors are similarly labeled k points (and the circle around x is the distance of the farthest of the good neighbours) and “worst offenders” or “bad” neighbours are points that are of a different class but still in the nearest neighbors of the query point. The optimization is basically a semidefinite program that works to pull the good neighbours towards the query point and a margin is enforced by pushing the offending points out of this circle. Thus in a way, the goal in LMNN is to deform the metric in such a way that the neighbourhood for each point is “pure.

LMNN

There are many approaches to the metric learning problem, however a few more notable ones are:

1. Neighbourhood Components Analysis (Goldberger, Roweis, Hinton and Salakhutdinov, 2004): Here the piecewise constant error of the kNN rule is replaced by a soft version. This leads to a non-convex objective that can be optimized by gradient descent. Basically, NCA tries to optimize for the choice  of neighbour at the price of losing convexity.

2. Collapsing Classes (Globerson and Roweis, 2006): This method attempts to remedy the non-convexity above by optimizing a similar stochastic rule while attempting to collapse each class to one point, making the problem convex.

3. Metric Learning to Rank (McFee and Lankriet, 2010): This paper takes a different take on metric learning, treating it as a ranking problem. Note that given a fixed p.s.d matrix \mathbf{W} a query point induces a permutation on the training set (in order of increasing distance). The idea thus is to optimize the metric for some ranking measure (such as precision@k). But note that this is not necessarily the same as requiring correct classification.

Neighbourhood Gerrymandering:

As a motivation we can look at the cartoon above for LMNN. Since we are looking to optimize for the kNN objective, the requirement to learn the metric should just be correct classification. Thus, we should need to push the points to ensure the same. Thus we can have the circle around x as simply the distance of the farthest point in the k nearest neighbours (irrespective of class). Now, we would like to deform the metric such that enough points are pulled in and pushed out of this circle so as to ensure correct classification. This is illustrated below.

MLNG

This method is akin to the common practice of Gerrymandering, in drawing up borders of election districts so as to provide advantages to desired political parties. This is done by concentrating voters from a particular party and/or by spreading out voters from other parties. In the above, the “districts” are cells in the Voronoi diagram defined by the Mahalanobis metric and “parties” are class labels voted for by each neighbour.

 Motivations and Intuition:

Now we can step back a little from the survey above, and think a bit about the kNN problem in somewhat more precise terms so that the above approach can be motivated better.

For kNN, given a query point and a fixed metric, there is an implicit latent variable: The choice of the k “neighbours”.

Given this latent variable – inference of the label for the query point is trivial – since it is just the majority vote. But notice that for any given query point, there can exist a very large number of  choices of k points that may correspond to correct classification (basically any set of points with majority of correct class will work). Now we basically want to learn a metric so that we prefer one of these sets over any set of k neighbours which would vote for a wrong class. In particular, from the sets that affects correct classification we would like to pick the set that is on average most similar to the query point.

We can write kNN prediction as an inference problem with a structured latent variable being the choice of k neighbours.

The learning then corresponds to  minimizing a sum of structured latent hinge loss and a regularizer. Computing the latent hinge loss involves loss-augmented inference which is basically looking for the worst offending k points (points that have high average similarity with the query point, yet correspond to a high loss). Given the combinatorial nature of the problem, efficient inference and loss-augmented inference is key. Optimization can basically be just gradient descent on the surrograte loss. To make this a bit more clear, the setup is described below:

Problem Setup:

Suppose we are given N training examples that are represented by a “native” feature map, \mathbf{X} = \{x_1, \dots, x_N\} with x_i \in \mathbb{R}^d with class labels \mathbf{y} = [y_1, \dots, y_N]^T with y_i \in [\mathbf{R}], where [\mathbf{R}] stands for the set \{1, \dots, \mathbf{R}\}.

Suppose are also provided with a loss matrix \Lambda with \Lambda(r,r') being the loss incurred by predicting r' when the correct class is r. We assume that \Lambda(r,r) = 0 and \forall (r,r'), \Lambda(r,r') \geq 0.

Now let h \subset \mathbf{X} be a set of examples in \mathbf{X}.

As stated earlier, we are interested in the Mahalanobis metrics:

D_W(x,x_i) = (x-x_i)^T W (x-x_i)

For a fixed W we may define the distance of h with respect to a point x as:

\displaystyle S_W(x,h) - \sum_{x_j \in h} D_W(x, x_j)

Therefore, the set of k-Nearest Neighbours of x in \mathbf{X} is:

h_W(x ) = \arg\max_{|h|=k} S_W(x,h)

For any set h of k examples from \mathbf{X} we can predict the label of x by a simple majority vote.

\hat{y}(h) = majority\{y_j: x_j \in h\}

The kNN classifier therefore predicts \hat{y}(h_W(x)).

Thus, the classification loss incurred using the set h can be defined as:

\Delta(y,h) = \Lambda(y,\hat{y}(h))

Learning and Inference:

One might want to learn W so as to minimize the training loss:

\displaystyle \sum_i \Delta(y_i, h_W(x_i))

However as mentioned in passing above, this fails because of the intractable nature of  the classification loss \Delta. Thus we’d have to resort to the usual remedy: define a tractable surrograte loss.

It must be stressed again that the output of prediction is a structured object h_W. The loss in structured prediction penalizes the gap between score of the correct structured output and the score of the “worst offending” incorrect output. This leads to the following definition of the surrogate:

L(x,y,W) = \max_h [S_W(x,h) + \Delta(y,h)] - \max_{h: \Delta(y,h) = 0} S_W(x,h)

This corresponds to our earlier intuition on wanting to learn W such that the gap between the “good neighbours” and “worst offenders” is increased.

So, although the loss above was arrived at by intuitive arguments, it turns out that our problem is an instance of a familiar type of problem: Latent Structured Prediction and hence the machinery for optimization there can be used here as well. The objective for us corresponds to:

\displaystyle \min_W \| W\|^2_{F} + C \sum_i (L(x_i, y_i,W))

Where \| \cdot \|_F is the Frobenius norm.

Note that the regularizer is convex, but the loss is not convex to the subtraction of the max term i.e. now it is a difference of convex functions which means the concave convex procedure may be used for optimization (although we just use stochastic gradient descent). Also note that the optimization at each step needs an efficient subroutine to determine the correct structured output (inference of the best set of neighbours) and the worst offending incorrect structured output (loss augmented inference i.e. finding the worst set of neighbors). Turns out that for this problem this is possible (although not presented here).

It is interesting to think about how this approach extends to regression and to see how it works when the embeddings learnt are not linear.

Read Full Post »