Go to top

Deep text-pair classification with Quora's 2017 question dataset

Quora recently released the first dataset from their platform: a set of 400,000 question pairs, with annotations indicating whether the questions request the same information. This data set is large, real, and relevant — a rare combination. In this post, I’ll explain how to solve text-pair tasks with deep learning, using both new and established tips and technologies.

The Quora dataset is an example of an important type of Natural Language Processing problem: text-pair classification. This type of problem is challenging because you usually can’t solve it by looking at individual words. No single word is going to tell you whether two questions are duplicates, or whether some headline is a good match for a story, or whether a valid link is probably pointing to the wrong page. You have to look at both items together. That’s hard — but it’s also rewarding. And models that do this are starting to get pretty good.

Update (March 1, 2017)

Updated experiments on this task can be found in our follow-up post.

The Quora dataset

Recent approaches to text-pair classification have mostly been developed on the Stanford Natural Language Inference (SNLI) corpus, prepared by Sam Bowman as part of his graduate research. The corpus provides over 500,000 pairs of short sentences, with human annotations indicating whether an entailment, contradiction or neutral logical relationship holds between the sentences. The SNLI dataset is over 100x larger than previous similar resources, allowing current deep-learning models to be applied to the problem. However, the data is also quite artificial — the texts are quite unlike any you’re likely to find in your applications.

Examples from the Quora dataExamples from the SNLI corpus
Which is the best digital marketing institution in banglore?Which is the best digital marketing institute in Pune?0A person on horse jumps over a broken down airplane.A person is training his horse for a competition.0
What’s causing someone to be jealous?What can I do to avoid being jealous of someone?0People listening to a choir in a catholic church.Choir singing in church.1
What are some special cares for someone with a nose that gets stuffy during the night?How can I keep my nose from getting stuffy at night?1A person on a bike is waiting while the light is green.Bicyclists waiting at an intersection.0
How do you get deleted Instagram chats?How can I view deleted Instagram dms?1Bicyclists waiting at an intersection.The bicyclists ride through the mall on their bikes.-1

When I first used the SNLI data, I was concerned that the limited vocabulary and relatively literal sentences made the problem unrealistically easy. The Quora data gives us a fantastic chance to check our progress: are the models developed on the SNLI data really useful on the real world task, or did the artificial data lead us to draw incorrect conclusions about how to build this type of model?

The question of how idealised NLP experiments should be is not new. However, it’s rare to have such a good opportunity to examine the reliability of our methodologies. Was the SNLI too artificial? If so, it will have misled us on how we should solve a real task, such as the one posed by the Quora data. The Quora data is about the same size, and it comes at just the right time. It will be interesting to see how this looks over the next few months. So far, it seems like the conclusions from the SNLI corpus are holding up quite well.

A neural bag-of-words model for text-pair classification

When designing a neural network for a text-pair task, probably the most important decision is whether you want to represent the meanings of the texts independently, or jointly. An independent representation means that the network can read a text in isolation, and produce a vector representation for it. This is great if you know you’ll need to make lots of comparisons over the same texts, for instance if you want to find their pairwise-similarities. However, reading the sentences independently makes the text-pair task more difficult. Models which read the sentences together before reducing them to vectors have an accuracy advantage.

I’ve previously described a model that reads sentences jointly — Parikh et al.‘s decomposable attention model. In this post I’ll describe a very simple sentence encoding model, using a so-called “neural bag-of-words”. The model is implemented using Thinc, a small library of NLP-optimized machine learning functions being developed for use in spaCy. While Thinc isn’t yet fully stable, I’m already finding it quite productive, especially for small models that should run well on CPU.

Text-pair classification with Thinc

with model.define_operators({'>>': chain, '**': clone, '|': concatenate}):
sent2vec = (
flatten_add_lengths
>> with_getitem(0, StaticEmbed(get_vectors, width))
>> (mean_pool | max_pool)
)
model = (
((Arg(0) >> sent2vec) | (Arg(1) >> sent2vec))
>> Maxout(width) ** depth
>> Softmax(2)
)

First, we fetch a pre-trained “word embedding” vector for each word in the sentence. The static embeddings are quite long, and it’s useful to learn to reweight the dimensions — so we learn a projection matrix, that maps the embedded vectors down to length width.

This gives us two 2d arrays — one per sentence. We want to learn a single categorical label for the pair of questions, so we want to get a single vector for the pair of sentences. There are a variety of pooling operations that people use to do this. I find it works well to use multiple pooling methods, and concatenate the results. In the code above, I’m creating vectors for the elementwise averages and maximums (“mean pooling” and “max pooling” respectively), and concatenating the results.

We then create a vector for each sentence, and concatenate the results. This is then fed forward into a deep Maxout network, before a Softmax layer makes the prediction. The neural bag-of-words model produces the following accuracies on the two data sets:

Digression: Thinc, spaCy’s machine learning library

Thinc works a little differently from most neural network libraries. There’s no computational graph abstraction — we don’t compile your computations, we just execute them. To compute the backward pass, layers just return a callback. To illustrate, imagine we have the following implementation of an affine layer, as a closure:

Callbacks for backward pass in Thinc

def Affine(n_out, n_in):
scale = 6. / np.sqrt(n_out + n_in)
W = np.random.uniform(0., scale, (n_out, n_in))
b = np.zeros((n_out,))
def forward(inputs):
outputs = W.dot(inputs) + b
def backward(d_outputs, optimizer):
d_inputs = np.outer(d_outputs, W.T)
d_W = np.dot(d_outputs, inputs)
d_b = d_outputs.sum()
optimizer(W, d_W)
optimizer(b, d_b)
return d_inputs
return outputs, backward
return forward

The weights of the layer, W and b, are private — they’re internal details of the layer, that sit in the function’s outer scope. The layer returns its forward function, which references the enclosed weights. The forward function returns an output, and the callback backward. The callback can then be used to complete the backward pass:

affine = Affine(10, 5)
X2, bp_X2 = affine(X1)
# Later, once we have gradient of X2
d_X1 = bp_X2(d_X2, adam_solver)

This design allows all layers to have the same simple signature, which makes it easy to write helper functions to compose the layers in various ways. This makes it easy to define custom data flows — you can have whatever types you want flowing through the model, so long as you define both the forward and backward pass.

Results

The neural bag-of-words isn’t the most satisfying model, but it’s a good baseline to compute — and as always, it’s important to steel-man the baseline, and compute the best version of the idea possible. I recommend always trying the mean and max pooling trick — I’ve yet to find a task where it doesn’t perform at least as well as mean or max pooling alone, and it usually does at least a little better.

 MeanMaxMean and Max
Accuracy Quora80.982.382.8
Accuracy SNLI (2-class)85.188.688.5

Width was set to 128, and depth was set to 1 (i.e. only one Maxout layer was used before the Softmax). I didn’t use dropout because there are so few parameters in the model — the model being trained is less than 1mb, because we’re not updating the vectors. Batch size was set to 1 initially, and increased by 0.1% each iteration to a maximum of 256. I’m planning to write this trick up in a subsequent post — it’s been working quite well.

Negative result: Maxout Window Encoding

Update (March 1, 2017)

The negative result here turned out to be due to a bug. In updated experiments the Maxout Window Encoding helps as expected.

I also tried models which encoded a limited amount of positional information, using a convolutional layer. There have been many proposals for this sort of “poor man’s” BiLSTM lately. My new go-to solution along these lines is a layer I call Maxout Window Encoding (MWE). It’s very simple: for each word i in the sentence, we form a new vector, by concatenating the vectors for (i-1, i, i+1). If our sentence was N words long and our vectors were M wide, this step would take in an (N, M) matrix and return an (N, M*3) matrix. We then use a maxout layer to map the concatenated, 3*M-length vectors back down to M-length vectors.

The MWE layer has the same aim as the BiLSTM: extract better word features. Most NLP neural networks start with an embedding layer. After this layer, your word features are position-independent: the vector for the word “duck” is always the same, no matter what words surround it. We know this is bad — we know the meaning of the word “duck” does change depending on its context. There’s clearly an opportunity to improve our features here — to feed better information about the input upwards into the next layer.

The figure above shows how a single MWE block rewrites the vector for each word given evidence for the two words immediately surrounding it. You can think of the output as trigram vectors — they’re built on the information from a three-word window. By simply adding another layer, we’ll get vectors computed from 5-grams — the receptive field widens with each layer we go deeper.

For the MWE unit to work, it needs to learn a non-linear mapping from a trigram down to a shorter vector. You could use any non-linearity here, but I’ve found maxout to work quite well. The logic is that adding capacity to the layer by increasing the width M is quite expensive, because our weights layers will be (M, 3*M). The maxout unit instead lets us add capacity by adding another dimension instead. I usually use two or three pieces.

The CNN tagger example in the Thinc repository provides a simple proof of concept. The example is a straight-forward tagging model, trained and evaluated on the Ancora Spanish corpus. The model receives only word IDs as input — no sub-word features — and words with frequency below 10 are labelled unknown. No pre-trained vectors are used.

 Depth 012345
Accuracy83.893.293.994.193.993.9
Train (seconds)9144608091118
Run (words/second)1,800,0001,300,000900,000720,000650,000330,000

At depth 0, the model can only learn one tag per word type — it has no contextual information. Each layer of depth makes the model sensitive to a wider field of context, leading to small improvements in accuracy that plateau at depth 3.

However, what worked for tagging and intent detection proved surprisingly ineffective at text-pair classification. This matches previous reports I’ve heard about BiLSTM being relatively ineffective in various models developed for the SNLI task. I still don’t have a good intuition for why this might be so.

 Depth 0Depth 1Depth 2Depth 3
Accuracy Quora82.882.682.882.6
Accuracy SNLI (2-class)88.586.986.586.8

Summary

A lot of interesting functionality can be implemented using text-pair classification models. The technology is still quite young, so the applications haven’t been explored well yet. We’ve had good techniques for classifying single texts for some time — but the ability to accurately model the relationships between texts is fairly new. I’m looking forward to seeing what people build with this.

In the meantime, we’re working on an interactive demo to explore different models trained on the Quora data set and the SNLI corpus.

Resources

Bibliography