Sometimes you want to fine-tune a pre-trained model to add a new label or correct some specific errors. This can introduce the “catastrophic forgetting” problem. Pseudo-rehearsal is a good solution: use the original model to label examples, and mix them through your fine-tuning updates.
The catastrophic forgetting problem occurs when you optimise two learning problems in succession, with the weights from the first problem used as part of the initialisation for the weights of the second problem. A lot of work has gone into designing optimisation algorithms that are less sensitive to initialisation. Ideally, our optimisers would be so good that they’d always find the same – optimal – solution to a given problem, no matter how the weights are initialised. This isn’t true, but it’s something we’re aiming for. This means that if you optimise for two problems in succession, catastrophic forgetting is what should happen.
This point has been well made by Hal Daumé in a blog post, and reiterated more recently on Twitter by Jason Eisner. Yoav Goldberg also discusses the problem in his book, with more detail about smarter techniques for using pre-trained vectors.
Multi-task learning in spaCy
Catastrophic forgetting problems have become more relevant for
spaCy users lately, because spaCy v2’s part-of-speech, named
entity, syntactic dependency and sentence segmentation models all share an input
representation, produced by a convolutional neural network. This lets the
various models share most of their weights, making the total model very small –
the
latest release
is only 18MB, while the previous linear model was almost 1GB. The multi-task
input representation can also be used for other tasks, such as text
classification and semantic similarity, via the doc.tensor
attribute.
However, sharing the weights between all these models sets a subtle trap. Let’s say you’re parsing short commands, so you have a lot of examples where you know the first word is an imperative verb. The default spaCy model performs poorly on this type of input, so we want to update the model on some examples of the type of user-command text we’ll be processing.
import spacy
nlp = spacy.load("en_core_web_sm")doc = nlp(u"search for pictures of playful rodents")spacy.displacy.serve(doc)
This parse is wrong – it’s analysed “search” as a noun, where it should be a
verb. If all you know is that the first word of the sentence should be a verb,
you can still use that to update spaCy’s model. To update the model, we pass a
Doc
instance and a GoldParse
instance to the nlp.update()
method:
from spacy.gold import GoldParse
new_tags = [None] * len(doc)new_tags[0] = "VBP"gold = GoldParse(doc, tags=new_tags)nlp.update(doc, gold, update_shared=True)
The None
values indicate there’s no supervision on those tags, so gradients
will be 0
for those predictions. There are also no labels for the dependency
parse or the entity recognizer, so the weights for those models won’t be
updated. However, all models share the same input representation, so if that
representation is updated, all models are potentially affected. To address that
problem,
spaCy v2.0.0a10
introduces a new flag update_shared
. This flag is set to False
by default.
If we make a few updates on this single example, we’ll get a model that tags it
correctly. However, from a single example, there’s no way for the model to guess
what level of generality it should learn at. Are all words now tagged VBP
? All
first-words of the sentence? All instances of search? We need to give the
model more information about the solution we’re looking for, or the learning
problem will be too unconstrained, and we’ll be unlikely to get the solution we
want.
Beyond the metaphor
To make the “forgetting” metaphor explicit here, we could say that the overall multi-task model started out “knowing” how to tag entities and produce dependency parses for a range of genres of written English. Then we focussed on a few more specific corrections, but this caused the model to lose the more general capabilities. This metaphor makes the problem seem surprising: why should our AI be so stupid and brittle? This is the point at which the metaphor has outlived its usefulness, and we need to think more precisely about what’s going on.
When we call nlp.update()
, we ask the model to produce an analyis given its
current weights. An error gradient is then calculated for each subtask, and the
weights are updated via backpropagation. Essentially, we push the weights around
until we get a set of weights that produce analyses where our error gradients
are near zero. Any set of weights that produce zero loss is stable.
It’s not necessarily useful to think in terms of the model “remembering” or “forgetting” things. It’s simply optimising the function you tell it to optimise – sometimes well, sometimes poorly. Sometimes we have reason to believe that the solution that optimises one objective will be quite good at another objective. But if we don’t encode this restriction explicitly, it’s difficult to guarantee.
One way to preserve the previous behaviour is to encode a bias against changing the parameters too much. However, this type of regularisation penalty isn’t always a good approximation of what we want. In a deep neural network, the relationship between the model’s weights and its prediction behaviours is non-linear. Very deep networks may be downright chaotic. What we actually care about are the outputs, not the parameter values – so that’s how we should frame our objective. As models become more complex and less linear, it’s better to avoid trying to guess what the parameters ought to look like.
Pseudo-rehearsal
All this leads to a very simple recommendation to address the “catastrophic forgetting” problem. When we start fine-tuning the model, we’re hoping to get a solution that’s correct on the new training examples, while producing output that’s similar to the original. This is easy: we can generate as much of the original output as we want. We just have to create some mixture of that original output and the new examples. Unsurprisingly, this is not a new suggestion.
Pseudo-rehearsal
revision_data = []# Apply the initial model to raw examples. You'll want to experiment# with finding a good number of revision texts. It can also help to# filter out some data.for doc in nlp.pipe(revision_texts):tags = [w.tag_ for w in doc]heads = [w.head.i for w in doc]deps = [w.dep_ for w in doc]entities = [(e.start_char, e.end_char, e.label_) for e in doc.ents]revision_data.append((doc, GoldParse(doc, tags=doc_tags, heads=heads,deps=deps, entities=entities)))# Now shuffle the previous behaviour into the new fine-tuning data, and# update with them together. You might want to upsample the fine-tuning# examples (e.g. include 5 copies of it). This lets you use a better# variety of revision data without changing the ratio of revision : tuning# data.n_epoch = 5batch_size = 32for i in range(n_epoch):examples = revision_data + fine_tune_datalosses = {}random.shuffle(examples)for batch in partition_all(batch_size, examples):docs, golds = zip(*batch)nlp.update(docs, golds, losses=losses)
A crucial detail in this process is that the “revision exercises” that you’re mixing into the new material must not be produced by the weights you’re currently optimising. You should keep the model that generates the revision material static. Otherwise, the model can stabilise on trivial solutions. If you’re streaming the examples, you’ll need to hold two copies of the model in memory. Alternatively, you can pre-parse a batch of text, and then use the annotations to stabilise your fine-tuning.
There’s one improvement to this recipe that’s still pending. At the moment spaCy treats the analyses provided by the teaching model the same as any other type of gold-standard data. This seems unideal, because the models use log-loss. For the part-of-speech tagger, this means that an original prediction of “80% confidence tag is ‘NN’” gets converted into “100% confidence tag is ‘NN’“. It would be better to either supervise with the distribution returned by the teaching model, or to use a log-loss.
Summary
It’s common to use pre-trained models for computer vision and natural language processing. Image, video, text and audio inputs have rich internal structure, that can be learned from large training samples and generalised across tasks. These pre-trained models are particularly useful if they can be “fine-tuned” on the specific problem of interest. However, the fine-tuning process can introduce the problem of “catastrophic forgetting”: a solution is found that optimises the specific fine-tuning data, and the generalisation is lost.
Some people suggest regularisation penalties to address this problem. However, this encodes a preference for solutions that are close to the previous model in parameter-space, when what we really want are solutions that are close to the previous model in output space. Pseudo-rehearsal is a good way to achieve that: predict a number of examples with the initial model, and mix them through the fine-tuning data. This represents an objective for a model that behaves similarly to the pre-trained one, except on the fine-tuning data.