Go to top

Fast transformer inference with Metal Performance Shaders

We are happy to introduce support for Metal Performance Shaders in Thinc PyTorch layers. This makes it possible to run spaCy transformer-based pipelines on GPU on Apple Silicon Macs and improves inference speed up to 4.7 times.

In this post, we will discuss the hardware acceleration facilities of Apple Silicon Macs and how spaCy can use them to accelerate transformer models. We will wrap up the post with benchmarks that show what kind of acceleration you can expect on various Apple Silicon Mac models.

Large transformer models are well known to be computationally expensive. This is due to the quadratic computational complexity of the self-attention mechanism, as well as the sheer size of many transformer models. For example, the widely-used BERT, RoBERTa and XLM-R base models use 12 hidden layers, hidden representations of 768 dimensions, and representations of 3072 dimensions in their feed-forward blocks.

The next figure shows the five largest cost centers when annotating German text with the de_dep_news_trf spaCy transformer model whilst using a specially-compiled version of PyTorch that uses the CPU cores of a Mac M1 Max CPU using generic ARM64 NEON-optimized kernels:

Transformer CPU profile

The runtime of the transformer model is dominated by matrix multiplication — bli_sgemm_armv8a_asm_8x12 is a single-precision matrix-multiplication kernel. sgemm is a standardized matrix-multiplication function provided by linear algebra libraries that implement the BLAS interface. This should not be surprising since matrix multiplication is one of the main operations used by transformer models, e.g: to compute the pairwise attention scores in the attention blocks and the linear projections in the feed-forward blocks.

Besides matrix multiplication, the GeLU and Softmax non-linearities that are used respectively in the feed-forward and attention blocks are also visible in the profile through their use of the erff, expf and exp elementary functions. Together, these non-linearities account for 19% of the runtime.

To speed up transformer inference, we can take three different approaches:

  1. Replacing the self-attention mechanism with something that has a better time complexity than O(N^2). For example, the Longformer’s attention mechanism has a time complexity of O(N). If we would like to continue using existing pre-trained transformer models, we can also place an upper bound on N.
  2. Speed up matrix multiplication.
  3. Speed up the calculation of non-linearities.

The impact of the quadratic attention mechanism is already limited in spaCy transformer pipelines by placing an upper bound on N. Each document is processed in strides. By default, spaCy transformers process 96 tokens at a time, using windows of 128 tokens to create overlapping contextual representations. This puts the upper bound of the time complexity of the attention mechanism at N=128.

In this post, we will focus on the two other approaches — speeding up matrix multiplications and the calculation of non-linearities by using the specialized hardware of Apple Silicon Macs. We will first look at the AMX matrix multiplication blocks that are part of Apple Silicon CPUs. After that, we will explore running compute kernels on the GPU.

AMX matrix multiplication blocks

All Apple M CPUs have at least one matrix-multiplication co-processor called an ‘AMX block’. AMX is largely undocumented. For instance, it is unknown if the energy-efficient core clusters of Apple M CPUs have their own AMX block. However, we can infer various properties of the AMX blocks by benchmarking them. The table below lists the matrix multiplication performance of 768x768 matrices in TFLOPS (trillion floating point operations per second) on various CPUs, measured with gemm-benchmark:

ThreadsM1M2M1 Pro/MaxM1 UltraRyzen 5950X
11.31.52.12.20.1
21.21.62.63.40.3
41.01.72.73.80.6
81.31.62.54.31.0
121.21.52.44.31.6
161.21.42.44.41.9
Largest speedup compared to M11.01.32.13.41.5

We can glean various bits of interesting information from these numbers:

  • The performance does not increase with the number of threads. So, the AMX blocks are not part of the individual CPU cores.
  • The M1, M1 Pro, and M1 Ultra have 1, 2, and 4 performance core clusters respectively. The matrix multiplication performance increases with the number of performance core clusters (see the Largest speedup compared to M1 row). This suggests that each performance cluster has an AMX block.
  • AMX blocks are fast. A single AMX block has the same matrix multiplication performance as 9 Ryzen 5950X cores.

Even though the instructions to dispatch calculations to the AMX blocks are not documented by Apple, third-party applications can use the AMX blocks through Apple’s Accelerate framework, which implements the industry-standard BLAS interface. So, BLAS matrix-multiplication functions - such as the sgemm function that we saw in the profile - are automatically accelerated.

Since the transformer uses matrix multiplication as its primary operation, AMX units provide a considerable speedup to transformers. PyTorch uses Accelerate on Apple platforms for matrix multiplication, so PyTorch uses the AMX blocks by default.

Metal Performance Shaders

Even though the AMX blocks show impressive speed when handling matrix multiplication throughput, Apple Silicon Macs have two other subsystems for compute, namely the Apple Neural Engine (ANE) and the GPU. The ANE is fairly limited in that it needs to run computation graphs that are defined through Core ML, but the GPU can run user-defined compute kernels (so-called ‘shaders’). This makes the GPU flexible enough to run a large variety of machine learning models.

The 8 core GPU of the M1 has a compute performance of 2.6 TFLOPS, which would provide roughly double the performance of an M1 AMX unit. Furthermore, the GPU scales all the way up to 64 cores in the M1 Ultra, giving a theoretical peak performance of 20.8 TFLOPS. Thus, the Apple Silicon GPUs could push transformer performance beyond what is offered by the AMX blocks.

PyTorch recently introduced support for Apple M GPUs through Apple’s Metal API. Various PyTorch operations have been implemented as custom Metal shaders and using Apple’s own collection of Metal shaders that are included in the Metal Performance Shaders framework. For supported operations, using Apple Silicon GPUs in PyTorch is as simple as placing tensors or modules on the new mps device. For example, matrix multiplication can be done on the GPU cores in the following manner:

>>> import torch
>>> u = torch.rand((10, 20), dtype=torch.float,
device=torch.device("mps"))
>>> v = torch.rand((20, 10), dtype=torch.float,
device=torch.device("mps"))
>>> torch.matmul(u, v).device
device(type='mps', index=0)

Some operations have not been implemented yet at the time of writing, but in such cases PyTorch will fall back to CPU kernels when the environment variable PYTORCH_ENABLE_MPS_FALLBACK is set to 1.

Metal Performance Shaders in spaCy and Thinc

spaCy uses Thinc as its machine learning library. Thinc is a lightweight deep learning library that also supports layers defined in other frameworks such as PyTorch and Tensorflow. The spacy-transformers package uses this interoperability of Thinc to make Huggingface PyTorch transformer models usable in spaCy pipelines. And now since PyTorch supports Apple Silicon GPUs, the transformer model in a transformer-based spaCy pipeline could in principle be executed on the GPU cores of Apple Silicon machines.

Unfortunately, Thinc versions prior to 8.1 used a deprecated PyTorch device management feature that made it impossible to support new Torch devices like mps. Thinc implements its own operations in various device-specific Ops classes. Prior to Thinc 8.1, the following Ops implementations were available:

  • NumpyOps: executes operations on the CPU. Uses NumPy and additional C++ kernels.
  • CupyOps: executes operations on a CUDA-capable GPU. Uses CuPy and additional CUDA C++ kernels.
  • AppleOps: inherits from NumpyOps, overriding matrix multiplication to run on AMX block by leveraging Apple’s Accelerate framework.
  • BigEndianOps: inherits from NumpyOps, overriding specific operations to support big-endian platforms.

Each Thinc layer is associated with an instance of one of the Ops classes. The layer uses the Ops instance to allocate parameters, perform computations, etc. PyTorch Thinc layers are different than regular Thinc layers in that they use PyTorch’s own operations instead of the associated Ops instance . However, when we wrap a PyTorch layer while using CupyOps, we want the PyTorch layer to run on a CUDA device rather than the default cpu device. Thinc used to accomplish this by using the now-deprecated torch.set_default_tensor_type function to set the default tensor type to torch.cuda.FloatTensor or torch.FloatTensor, depending on the active Ops instance.

However, the set_default_tensor_type function does not allow us to set the default device to mps. So, for this reason (amongst others) we had to replace this mechanism with something that uses Torch device identifiers like in the matrix multiplication example above. Starting with Thinc 8.1, the PyTorch wrapper adds a keyword argument to specify the Torch device that the layer should be placed on. If this argument is not specified, Thinc will use the appropriate device for the currently active Ops.

To support Apple Silicon GPUs, we added a new Ops implementation, MPSOps, which defaults to the mps device for Torch layers. MPSOps is automatically used when you install Thinc 8.1 and ask Thinc or spaCy to use a GPU.

How fast is it? ⏱️

With Thinc 8.1 and PyTorch 1.13 all the pieces fall into place, and we can perform transformer inference on Apple Silicon GPUs. The following table shows the speed in words per second (WPS) of annotating German text with the de_dep_news_trf transformer model on various Apple Silicon Macs:

MachineCPU coresGPU coresAMX (WPS)GPU (WPS)Speedup
Mac Mini M14P/4E8118022021.9
MacBook Air M24P/4E10124233622.7
MacBook Pro 14” M1 Pro6P/2E14163146612.9
MacBook Pro 14” M1 Max8P/2E32182186484.7
Mac Studio M1 Ultra16P/4E482197120735.5
Ryzen 5950X + RTX 309016328 (Tensor cores)1879 (CPU)1884510.0

The benchmark shows a marked speedup when using Apple Silicon GPUs compared to AMX blocks, reaching up to 8648 words per second on the GPU compared to 1821 words per second on the AMX blocks on an M1 Max. The inference performance of the M1 Max is almost half that of an NVIDIA RTX 3090.

The compute performance of 8 M1 GPU cores is estimated to be approximately two times of that of an AMX block, but it turns out inference is more than twice as fast on the M1 Pro, even though this particular model only has two performance clusters with AMX blocks and 14 GPU cores. The reason is that AMX only accelerates matrix multiplication while the GPU accelerates other kernels as well, including the GELU and Softmax non-linearities. The following image shows the five largest cost centers when accelerating inference with the AMX blocks:

Transformer AMX profile

Since AMX only accelerates matrix multiplication, the computation of non-linearities has become the largest cost center. This is not an issue with GPU inference since the non-linearities are computed in parallel on the GPU.

Another interesting question is whether the improved throughput of the Apple Silicon GPUs comes at the cost of more power use. The following table shows the average power use in watts during the benchmarks. Running spaCy transformers on the GPU provides much more performance per watt.

MachineCPU coresGPU coresAMX (W)GPU (W)
Mac Mini M14P/4E81110
MacBook Air M24P/4E10139
MacBook Pro 14” M1 Pro6P/2E141617
MacBook Pro 14” M1 Max8P/2E321731
Mac Studio M1 Ultra16P/4E483470

Trying out spaCy transformer pipelines on Apple Silicon GPUs

Support for Apple Silicon GPUs is available in Thinc 8.1.0, spaCy 3.4.2 and spacy-transformers 1.1.8 or later versions. To use the support for Apple Silicon GPUs, first make sure that you have PyTorch 1.13 or later installed:

pip install spacy "torch>=1.13.0"

You can then install the transformer model that you want to use, this will also install the spacy-transformers package:

spacy download de_dep_news_trf

You can then use spaCy as you would normally after switching to use the GPU with the require_gpu function:

>>> import spacy
>>> spacy.require_gpu()
>>> nlp = spacy.load('de_dep_news_trf')
>>> docs = list(nlp.pipe(["Marbach am Neckar ist eine Stadt etwa 20 Kilometer nördlich von Stuttgart."]))
>>> [(t.text, t.pos_) for t in docs[0]]
[('Marbach', 'PROPN'), ('am', 'ADP'), ('Neckar', 'PROPN'), ('ist', 'AUX'), ('eine', 'DET'), ('Stadt', 'NOUN'), ('etwa', 'ADV'), ('20', 'NUM'), ('Kilometer', 'NOUN'), ('nördlich', 'ADV'), ('von', 'ADP'), ('Stuttgart', 'PROPN'), ('.', 'PUNCT')]

If you would like to verify that the GPU is indeed used, you can check that the currently active Ops is MPSOps:

>>> from thinc.api import get_current_ops
>>> get_current_ops()
<thinc.backends.mps_ops.MPSOps object at 0x1010b6e90>

To track updates to the support for Apple Silicon GPUs, you can follow our tracking issue in the Thinc repository.