Training & Alignment

Mixed Precision Training

By Arpit Tripathi, Founder

Mixed precision training is a technique that performs most neural network computations in 16-bit floating point (FP16 or BF16) while keeping a high-precision copy of weights, cutting memory use and speeding up training on modern hardware without sacrificing accuracy.

What is Mixed Precision Training?

Mixed precision training is a technique for training neural networks that runs most operations in a 16-bit floating point format, such as FP16 or BF16, while keeping certain values in 32-bit floating point (FP32) for numerical stability. The approach was introduced by Micikevicius and colleagues at NVIDIA and Baidu in the 2017 paper 'Mixed Precision Training.' By using fewer bits for the bulk of the math, it reduces memory consumption and increases throughput on hardware with dedicated 16-bit units, such as NVIDIA Tensor Cores.

The word mixed is central: the method does not run everything in low precision. It keeps an FP32 master copy of the weights and accumulates certain reductions, such as large sums in normalization or softmax, in FP32. Forward and backward passes use the 16-bit format for matrix multiplications and convolutions, which are where most of the compute and memory bandwidth go. This split preserves the accuracy of full-precision training while capturing most of the speed and memory benefits.

Mixed precision is standard practice for training large models, including transformers and large language models, because it roughly halves the memory needed for activations and weights and can substantially increase training speed. Both PyTorch and TensorFlow provide automatic mixed precision (AMP) utilities that apply it with minimal code changes.

  • Runs most math in FP16 or BF16, keeping select values in FP32.
  • Introduced by Micikevicius et al. (NVIDIA, Baidu) in 2017.
  • Targets matrix multiplications and convolutions for speed and memory savings.
  • Roughly halves memory for weights and activations.
  • Standard for transformer and large language model training.

FP16 vs BF16: the two 16-bit formats

FP16 (half precision) uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. Its limited 5-bit exponent gives it a narrow dynamic range, so very small gradient values can underflow to zero and very large values can overflow to infinity. FP16 has more mantissa bits than BF16, so it represents numbers within its range more precisely, but its range is the constraint that requires extra care during training.

BF16 (bfloat16) uses 1 sign bit, 8 exponent bits, and 7 mantissa bits. The 8-bit exponent matches FP32's dynamic range, so BF16 can represent the same magnitude of numbers as full precision while using half the storage. The trade-off is fewer mantissa bits and therefore lower precision within that range. Because BF16 shares FP32's range, it usually avoids the underflow problems that FP16 faces, which simplifies training.

The practical consequence is that FP16 training typically needs loss scaling to keep small gradients representable, while BF16 training often does not. BF16 is widely available on recent accelerators and has become the preferred format for training large models, whereas FP16 remains common where BF16 hardware support is limited.

  • FP16: 5 exponent bits, 10 mantissa bits; narrow range, higher in-range precision.
  • BF16: 8 exponent bits, 7 mantissa bits; FP32-equivalent range, lower precision.
  • FP16's narrow range causes underflow, which loss scaling addresses.
  • BF16 shares FP32's dynamic range, so it usually avoids loss scaling.
  • BF16 is the preferred format for large-model training where supported.

Master weights and loss scaling

Two techniques from the original paper make FP16 training stable. The first is keeping an FP32 master copy of the weights. Each optimizer step computes updates in FP32 and applies them to the master copy, then casts those weights down to FP16 for the next forward and backward pass. This prevents the situation where a small update is lost because it is tiny relative to the weight when both are stored in FP16.

The second technique is loss scaling, which addresses gradient underflow. Many gradient values in FP16 are so small that they round to zero, losing information. Loss scaling multiplies the loss by a large factor S before the backward pass, which shifts all gradients up into the representable FP16 range, and then divides the gradients by S before the optimizer applies them. Modern AMP implementations use dynamic loss scaling, which automatically raises S when training is stable and lowers it when overflow is detected.

BF16 changes the picture. Because BF16 shares FP32's dynamic range, gradient underflow is rarely an issue, so BF16 training usually does not need loss scaling. An FP32 master copy of weights is still commonly used, however, to preserve the precision of small parameter updates accumulated over many steps.

  • Keep an FP32 master copy of weights; apply updates in FP32, then cast to 16-bit.
  • Loss scaling multiplies the loss by S to lift small gradients into range.
  • Gradients are divided by S before the optimizer step.
  • Dynamic loss scaling adjusts S automatically based on overflow.
  • BF16 usually skips loss scaling but still benefits from FP32 master weights.

Mixed precision in PyTorch

PyTorch's automatic mixed precision (AMP) applies the technique with two components: autocast, which automatically chooses the right precision per operation, and a gradient scaler, which implements dynamic loss scaling for FP16. The example below shows a standard FP16 AMP training step.

python
import torch

model = torch.nn.Linear(1024, 1024).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scaler = torch.amp.GradScaler("cuda")  # dynamic loss scaling for FP16

for inputs, targets in dataloader:
    inputs, targets = inputs.cuda(), targets.cuda()
    optimizer.zero_grad()

    # autocast runs eligible ops in FP16, keeping reductions in FP32
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        outputs = model(inputs)
        loss = torch.nn.functional.mse_loss(outputs, targets)

    scaler.scale(loss).backward()  # scale loss, then backprop
    scaler.step(optimizer)         # unscale + optimizer step
    scaler.update()                # adjust the scale factor

# For BF16, use dtype=torch.bfloat16 and the GradScaler is typically unnecessary.
FP16 mixed precision training in PyTorch using autocast and a dynamic GradScaler.

Key takeaways

  • Mixed precision training runs most math in 16-bit (FP16 or BF16) while keeping select values in FP32, cutting memory and speeding up training.
  • FP16 has a narrow dynamic range and needs loss scaling; BF16 matches FP32's range and usually does not.
  • An FP32 master copy of weights preserves small parameter updates that would be lost in pure 16-bit storage.
  • Loss scaling multiplies the loss before backprop to keep tiny gradients representable, then divides them back out before the optimizer step.
  • PyTorch AMP (autocast plus GradScaler) and TensorFlow's mixed precision API make the technique easy to apply.

Frequently asked questions

Mixed precision training runs most neural network operations in a 16-bit floating point format such as FP16 or BF16 while keeping certain values, like a master copy of weights, in 32-bit FP32. This reduces memory use and increases training throughput on hardware with 16-bit units, while preserving the accuracy of full-precision training.
FP16 uses 5 exponent bits and 10 mantissa bits, giving high in-range precision but a narrow dynamic range that can cause underflow. BF16 uses 8 exponent bits and 7 mantissa bits, matching FP32's range at lower precision. Because BF16 shares FP32's range, it usually avoids the loss scaling that FP16 needs.
In FP16, many gradient values are too small to represent and round to zero, losing information. Loss scaling multiplies the loss by a large factor before the backward pass, shifting gradients into the representable range, then divides them by the same factor before the optimizer step. BF16 rarely needs it because of its wider dynamic range.
When done correctly, no. Keeping an FP32 master copy of weights and using loss scaling (for FP16) lets mixed precision match full FP32 accuracy on many models, as shown in the original 2017 paper. The technique is widely used to train large models with no meaningful loss in final quality.
Using 16-bit instead of 32-bit roughly halves the memory needed to store weights and activations, since each value occupies two bytes instead of four. The exact savings depend on the optimizer state and whether an FP32 master copy is kept, but the reduction in activation and weight memory is substantial and often enables larger batch sizes.

Put the idea into practice

MemX is an AI memory app built on these ideas: store anything, skip the folders, and find it again by asking in plain English.

Try MemX Free