I’m pleased to announce the 1.0 release of spaCy, the fastest NLP library in the world. By far the best part of the 1.0 release is a new system for integrating custom models into spaCy. This post introduces you to the changes, and shows you how to use the new custom pipeline functionality to add a Keras-powered LSTM sentiment analysis model into a spaCy pipeline.
The spaCy user survey has been full of great feedback about the library. The clearest finding has been the need for more tutorials. We’re currently working on a new and improved tutorials section for the site. We’re also prioritising tutorials for the new 1.0 functionality – like the new rule-based, entity-aware matcher, the model training APIs and the custom pipelines.
The custom pipelines are particularly exciting, because they let you hook your own deep learning models into spaCy. So, without further ado, here’s how to use Keras to train an LSTM sentiment analysis model and use the resulting annotations with spaCy.
How to add sentiment analysis to spaCy with an LSTM model using Keras
There are lots of great open-source libraries for researching, training and evaluating neural networks. However, the concerns of these libraries usually end at the point where you have an evaluation score and a model file. spaCy has always been designed to orchestrate multiple textual annotation models and help you use them together in your application. spaCy 1.0 now makes it much easier to calculate those annotations using your own custom models.
In this tutorial, we’ll be using Keras, as it’s the most popular deep learning library for Python. Let’s assume you’ve written a custom sentiment analysis model that predicts whether a document is positive or negative. Now you want to find which entities are commonly associated with positive or negative documents. Here’s a quick example of how that can look at runtime.
Runtime usage
def count_entity_sentiment(nlp, texts):'''Compute the net document sentiment for each entity in the texts.'''entity_sentiments = collections.Counter(float)for doc in nlp.pipe(texts, batch_size=1000, n_threads=4):for ent in doc.ents:entity_sentiments[ent.text] += doc.sentimentreturn entity_sentimentsdef load_nlp(lstm_path, lang_id='en'):def create_pipeline(nlp):return [nlp.tagger, nlp.entity, SentimentAnalyser.load(lstm_path, nlp)]return spacy.load(lang_id, create_pipeline=create_pipeline)
All you have to do is pass a create_pipeline
callback function to
spacy.load()
. The function should take a spacy.language.Language
object as
its only argument, and return a sequence of callables. Each callable should
accept a Doc
object, modify it in place, and return None
.
Of course, operating on single documents is inefficient, especially for deep
learning models. Usually we want to annotate many texts, and we want to process
them in parallel. You should therefore ensure that your model component also
supports a .pipe()
method. The .pipe()
method should be a well-behaved
generator function that operates on arbitrarily large sequences. It should
consume a small buffer of documents, work on them in parallel, and yield them
one-by-one.
Custom Annotator Class
class SentimentAnalyser(object):@classmethoddef load(cls, path, nlp):with (path / 'config.json').open() as file_:model = model_from_json(file_.read())with (path / 'model').open('rb') as file_:lstm_weights = pickle.load(file_)embeddings = get_embeddings(nlp.vocab)model.set_weights([embeddings] + lstm_weights)return cls(model)def __init__(self, model):self._model = modeldef __call__(self, doc):X = get_features([doc], self.max_length)y = self._model.predict(X)self.set_sentiment(doc, y)def pipe(self, docs, batch_size=1000, n_threads=2):for minibatch in cytoolz.partition_all(batch_size, docs):Xs = get_features(minibatch)ys = self._model.predict(X)for i, doc in enumerate(minibatch):doc.sentiment = ys[i]def set_sentiment(self, doc, y):doc.sentiment = float(y[0])# Sentiment has a native slot for a single float.# For arbitrary data storage, there's:# doc.user_data['my_data'] = ydef get_features(docs, max_length):Xs = numpy.zeros((len(docs), max_length), dtype='int32')for i, doc in enumerate(minibatch):for j, token in enumerate(doc[:max_length]):Xs[i, j] = token.rank if token.has_vector else 0return Xs
By default, spaCy 1.0 downloads and uses the 300-dimensional
GloVe common crawl vectors. It’s also
easy to replace these vectors with ones you’ve trained yourself, or to disable
the word vectors entirely. If you’ve installed your word vectors into spaCy’s
Vocab
object, here’s how to use them in a Keras model:
Training with Keras
def train(train_texts, train_labels, dev_texts, dev_labels,lstm_shape, lstm_settings, lstm_optimizer, batch_size=100, nb_epoch=5):nlp = spacy.load('en', parser=False, tagger=False, entity=False)embeddings = get_embeddings(nlp.vocab)model = compile_lstm(embeddings, lstm_shape, lstm_settings)train_X = get_features(nlp.pipe(train_texts))dev_X = get_features(nlp.pipe(dev_texts))model.fit(train_X, train_labels, validation_data=(dev_X, dev_labels),nb_epoch=nb_epoch, batch_size=batch_size)return modeldef compile_lstm(embeddings, shape, settings):model = Sequential()model.add(Embedding(embeddings.shape[1],embeddings.shape[0],input_length=shape['max_length'],trainable=False,weights=[embeddings]))model.add(Bidirectional(LSTM(shape['nr_hidden'])))model.add(Dropout(settings['dropout']))model.add(Dense(shape['nr_class'], activation='sigmoid'))model.compile(optimizer=Adam(lr=settings['lr']), loss='binary_crossentropy',metrics=['accuracy'])return modeldef get_embeddings(vocab):max_rank = max(lex.rank for lex in vocab if lex.has_vector)vectors = numpy.ndarray((max_rank+1, vocab.vectors_length), dtype='float32')for lex in vocab:if lex.has_vector:vectors[lex.rank] = lex.vectorreturn vectorsdef get_features(docs, max_length):Xs = numpy.zeros(len(list(docs)), max_length, dtype='int32')for i, doc in enumerate(docs):for j, token in enumerate(doc[:max_length]):Xs[i, j] = token.rank if token.has_vector else 0return Xs
For most applications, I recommend using pre-trained word embeddings without “fine-tuning”. This means that you’ll use the same embeddings across different models, and avoid learning adjustments to them on your training data. The embeddings table is large, and the values provided by the pre-trained vectors are already pretty good. Fine-tuning the embeddings table is therefore a waste of your “parameter budget”. It’s usually better to make your network larger some other way, e.g. by adding another LSTM layer, using attention mechanism, using character features, etc.
Attribute hooks (experimental)
Earlier, we saw how to store data in the new generic user_data
dict. This
generalises well, but it’s not terribly satisfying. Ideally, we want to let the
custom data drive more “native” behaviours. For instance, consider the
.similarity()
methods provided by spaCy’s Doc
, Token
and Span
objects:
Polymorphic similarity example
span.similarity(doc)token.similarity(span)doc1.similarity(doc2)
By default, this just averages the vectors for each document, and computes their
cosine. Obviously, spaCy should make it easy for you to install your own
similarity model. This introduces a tricky design challenge. The current
solution is to add three more dicts to the Doc
object:
Name | Description |
---|---|
user_hooks | Customise behaviour of doc.vector , doc.has_vector , doc.vector_norm or doc.sents |
user_token_hooks | Customise behaviour of token.similarity , token.vector , token.has_vector , token.vector_norm or token.conjuncts |
user_span_hooks | Customise behaviour of span.similarity , span.vector , span.has_vector , span.vector_norm or span.root |
To sum up, here’s an example of hooking in custom .similarity()
methods:
Add custom similarity hooks
class SimilarityModel(object):def __init__(self, model):self._model = modeldef __call__(self, doc):doc.user_hooks['similarity'] = self.similaritydoc.user_span_hooks['similarity'] = self.similaritydoc.user_token_hooks['similarity'] = self.similaritydef similarity(self, obj1, obj2):y = self._model([obj1.vector, obj2.vector])return float(y[0])
What’s next?
The attribute hooks are likely to evolve slightly, and will certainly need a little bit of tweaking to get fully consistent. I’m also looking forward to shipping improved models for the tagger, parser and entity recogniser. Over the last twelve months, research has shown that bidirectional LSTM models are a simple and effective approach for these tasks. The resulting models should also be significantly smaller in memory.