© Freepik

Pseudo-rehearsal: A simple solution to catastrophic forgetting for NLP

by Matthew Honnibal on

Sometimes you want to fine-tune a pre-trained model to add a new label or correct some specific errors. This can introduce the "catastrophic forgetting" problem. Pseudo-rehearsal is a good solution: use the original model to label examples, and mix them through your fine-tuning updates.

The catastrophic forgetting problem occurs when you optimise two learning problems in succession, with the weights from the first problem used as part of the initialisation for the weights of the second problem. A lot of work has gone into designing optimisation algorithms that are less sensitive to initialisation. Ideally, our optimisers would be so good that they'd always find the same – optimal – solution to a given problem, no matter how the weights are initialised. This isn't true, but it's something we're aiming for. This means that if you optimise for two problems in succession, catastrophic forgetting is what should happen.

This point has been well made by Hal Daumé in a blog post, and reiterated more recently on Twitter by Jason Eisner. Yoav Goldberg also discusses the problem in his book, with more detail about smarter techniques for using pre-trained vectors.

Catastrophic forgetting problems have become more relevant for spaCy users lately, because spaCy v2's part-of-speech, named entity, syntactic dependency and sentence segmentation models all share an input representation, produced by a convolutional neural network. This lets the various models share most of their weights, making the total model very small – the latest release is only 18MB, while the previous linear model was almost 1GB. The multi-task input representation can also be used for other tasks, such as text classification and semantic similarity, via the doc.tensor attribute.spaCy v2.0.0a10To help you avoid the catastrophic forgetting problem, the latest spaCy v2.0 alpha model mixes the multi-task CNN with local CNNs, specific to each task. This lets you update a task in isolation, without writing to the shared component.

However, sharing the weights between all these models sets a subtle trap. Let's say you're parsing short commands, so you have a lot of examples where you know the first word is an imperative verb. The default spaCy model performs poorly on this type of input, so we want to update the model on some examples of the type of user-command text we'll be processing.

import spacy

nlp = spacy.load('en_core_web_sm')
doc = nlp(u'search for pictures of playful rodents')

This parse is wrong – it's analysed "search" as a noun, where it should be a verb. If all you know is that the first word of the sentence should be a verb, you can still use that to update spaCy's model. To update the model, we pass a Doc instance and a GoldParse instance to the nlp.update() method:

from spacy.gold import GoldParse

new_tags = [None] * len(doc)
new_tags[0] = 'VBP'
gold = GoldParse(doc, tags=new_tags)
nlp.update(doc, gold, update_shared=True)

The None values indicate there's no supervision on those tags, so gradients will be 0 for those predictions. There are also no labels for the dependency parse or the entity recognizer, so the weights for those models won't be updated. However, all models share the same input representation, so if that representation is updated, all models are potentially affected. To address that problem, spaCy v2.0.0a10 introduces a new flag update_shared. This flag is set to False by default.

If we make a few updates on this single example, we'll get a model that tags it correctly. However, from a single example, there's no way for the model to guess what level of generality it should learn at. Are all words now tagged VBP? All first-words of the sentence? All instances of search? We need to give the model more information about the solution we're looking for, or the learning problem will be too unconstrained, and we'll be unlikely to get the solution we want.

To make the "forgetting" metaphor explicit here, we could say that the overall multi-task model started out "knowing" how to tag entities and produce dependency parses for a range of genres of written English. Then we focussed on a few more specific corrections, but this caused the model to lose the more general capabilities. This metaphor makes the problem seem surprising: why should our AI be so stupid and brittle? This is the point at which the metaphor has outlived its usefulness, and we need to think more precisely about what's going on.

When we call nlp.update(), we ask the model to produce an analyis given its current weights. An error gradient is then calculated for each subtask, and the weights are updated via backpropagation. Essentially, we push the weights around until we get a set of weights that produce analyses where our error gradients are near zero. Any set of weights that produce zero loss is stable.

It's not necessarily useful to think in terms of the model "remembering" or "forgetting" things. It's simply optimising the function you tell it to optimise – sometimes well, sometimes poorly. Sometimes we have reason to believe that the solution that optimises one objective will be quite good at another objective. But if we don't encode this restriction explicitly, it's difficult to guarantee.

One way to preserve the previous behaviour is to encode a bias against changing the parameters too much. However, this type of regularisation penalty isn't always a good approximation of what we want. In a deep neural network, the relationship between the model's weights and its prediction behaviours is non-linear. Very deep networks may be downright chaotic. What we actually care about are the outputs, not the parameter values – so that's how we should frame our objective. As models become more complex and less linear, it's better to avoid trying to guess what the parameters ought to look like.Regularization is still good for embeddingsEmbedding tables define a vector space, so there's a linear relationship between changes to the parameter values and changes to the solution. In this situation, it makes sense to penalize the L2 norm of the divergence from the initial values.

All this leads to a very simple recommendation to address the "catastrophic forgetting" problem. When we start fine-tuning the model, we're hoping to get a solution that's correct on the new training examples, while producing output that's similar to the original. This is easy: we can generate as much of the original output as we want. We just have to create some mixture of that original output and the new examples. Unsurprisingly, this is not a new suggestion.

Pseudo-rehearsalrevision_data = []

# Apply the initial model to raw examples. You'll want to experiment
# with finding a good number of revision texts. It can also help to
# filter out some data.
for doc in nlp.pipe(revision_texts):
    tags = [w.tag_ for w in doc]
    heads = [w.head.i for w in doc]
    deps = [w.dep_ for w in doc]
    entities = [(e.start_char, e.end_char, e.label_) for e in doc.ents]
    revision_data.append((doc, GoldParse(doc, tags=doc_tags, heads=heads,
                                         deps=deps, entities=entities)))

# Now shuffle the previous behaviour into the new fine-tuning data, and
# update with them together. You might want to upsample the fine-tuning
# examples (e.g. include 5 copies of it). This lets you use a better
# variety of revision data without changing the ratio of revision : tuning
# data.
n_epoch = 5
batch_size = 32
for i in range(n_epoch):
    examples = revision_data + fine_tune_data
    losses = {}
    for batch in partition_all(batch_size, examples):
        docs, golds = zip(*batch)
        nlp.update(docs, golds, losses=losses)

A crucial detail in this process is that the "revision exercises" that you're mixing into the new material must not be produced by the weights you're currently optimising. You should keep the model that generates the revision material static. Otherwise, the model can stabilise on trivial solutions. If you're streaming the examples, you'll need to hold two copies of the model in memory. Alternatively, you can pre-parse a batch of text, and then use the annotations to stabilise your fine-tuning.

There's one improvement to this recipe that's still pending. At the moment spaCy treats the analyses provided by the teaching model the same as any other type of gold-standard data. This seems unideal, because the models use log-loss. For the part-of-speech tagger, this means that an original prediction of "80% confidence tag is 'NN'" gets converted into "100% confidence tag is 'NN'". It would be better to either supervise with the distribution returned by the teaching model, or to use a log-loss.

It's common to use pre-trained models for computer vision and natural language processing. Image, video, text and audio inputs have rich internal structure, that can be learned from large training samples and generalised across tasks. These pre-trained models are particularly useful if they can be "fine-tuned" on the specific problem of interest. However, the fine-tuning process can introduce the problem of "catastrophic forgetting": a solution is found that optimises the specific fine-tuning data, and the generalisation is lost.

Some people suggest regularisation penalties to address this problem. However, this encodes a preference for solutions that are close to the previous model in parameter-space, when what we really want are solutions that are close to the previous model in output space. Pseudo-rehearsal is a good way to achieve that: predict a number of examples with the initial model, and mix them through the fine-tuning data. This represents an objective for a model that behaves similarly to the pre-trained one, except on the fine-tuning data.

Matthew Honnibal
About the Author

Matthew Honnibal

Matthew is a leading expert in AI technology, known for his research, software and writings. He completed his PhD in 2009, and spent a further 5 years publishing research on state-of-the-art natural language understanding systems. Anticipating the AI boom, he left academia in 2014 to develop spaCy, an open-source library for industrial-strength NLP.

Read more