Skip to content
Home » All Posts » How to Build a Fast PyTorch Mixed Precision Training Loop (Step-by-Step)

How to Build a Fast PyTorch Mixed Precision Training Loop (Step-by-Step)

Introduction: Why PyTorch Mixed Precision Training Matters

When I first started training larger deep learning models in PyTorch, the bottleneck wasn’t the model architecture, it was the GPU time and memory. That’s exactly where PyTorch mixed precision training comes in: it lets me use lower-precision math (usually FP16 or bfloat16) where it’s safe, and full FP32 precision only where it’s really needed.

Mixed precision means a single training step uses multiple numeric precisions. In practice, that usually looks like model weights stored in FP32 for stability, while most matrix multiplications and convolutions run in FP16/bfloat16 on Tensor Cores. PyTorch’s AMP (Automatic Mixed Precision) system automates this, so I don’t have to hand‑tune every operation.

In my experience on modern NVIDIA GPUs, this often delivers:

  • 2x or more speedup for many vision and NLP models
  • Lower GPU memory usage, allowing bigger batch sizes or larger models
  • Better hardware utilization, especially Tensor Cores designed for low‑precision math

Before I adopted mixed precision, I regularly hit out‑of‑memory errors and had to shrink batch sizes. After switching, I could fit deeper models and more data per step without buying new hardware. In the rest of this guide, I’ll walk through how to build a fast, stable PyTorch mixed precision training loop that actually works in day‑to‑day ML engineering.

Prerequisites and Setup for PyTorch Mixed Precision Training

Hardware and GPU Requirements

Before I even think about enabling PyTorch mixed precision training, I double-check the GPU. For FP16 with Tensor Cores, you’ll want a Volta or newer NVIDIA GPU (V100, T4, RTX 20xx/30xx/40xx, A100, etc.). Older cards can still run FP16, but the speedups are usually much smaller and sometimes not worth the complexity.

If I’m using bfloat16 (bf16), I make sure the GPU explicitly supports it (for example, A100 or newer data-center GPUs). On laptops or consumer cards, FP16 is still the default choice for most mixed precision setups.

PyTorch, CUDA, and Driver Versions

In my own projects, I try to stick to a recent stable PyTorch release, because AMP APIs have matured a lot over time. As a rule of thumb, I:

  • Use a recent PyTorch (e.g., 2.x+) with built-in torch.cuda.amp.
  • Match the CUDA toolkit version that the PyTorch build expects.
  • Keep NVIDIA drivers up to date so the GPU can fully leverage Tensor Cores.

The easiest way I’ve found to avoid version mismatch is to install PyTorch directly from the official selector, then verify the install and CUDA device like this:

import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

Creating a Clean Environment

Mixed precision bugs can be subtle, so I like to start from a clean, isolated environment. For Python projects, I typically use venv or Conda, then install PyTorch and essential packages there:

# example with conda
conda create -n mp-pytorch python=3.10 -y
conda activate mp-pytorch

# install a GPU-enabled PyTorch build from the official instructions
# (choose the right CUDA version from the PyTorch site)

Once the environment is set, I run a tiny script to do a forward and backward pass under AMP just to confirm everything works before wiring it into a full training loop. That quick smoke test has saved me a lot of time compared to debugging mixed precision issues in a huge codebase. Official PyTorch Get Started Guide

Understanding Autocast and GradScaler in PyTorch

What autocast Does in Mixed Precision

When I first started using PyTorch mixed precision training, torch.cuda.amp.autocast was the feature that made everything click. Autocast is a context manager that automatically chooses the numeric precision (FP16/bf16 vs FP32) for each operation. Heavy tensor ops like convolutions and matrix multiplies run in lower precision for speed, while numerically sensitive ops (like softmax or reductions) stay in FP32 for stability.

In practice, I just wrap the forward pass in an autocast block and let PyTorch handle the rest:

from torch.cuda.amp import autocast

for inputs, targets in dataloader:
    inputs, targets = inputs.cuda(), targets.cuda()

    optimizer.zero_grad(set_to_none=True)

    with autocast(dtype=torch.float16):
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    # backward + optimizer step will use scaled gradients (with GradScaler)

That one change is usually enough to tap into Tensor Cores without rewriting the model.

Why GradScaler Is Essential for Stability

The other half of the story is torch.cuda.amp.GradScaler. Early on, I learned the hard way that FP16 gradients can underflow to zero, especially with small losses. GradScaler fixes this by multiplying the loss by a scale factor before backprop, making gradients larger and less likely to vanish in FP16. After the backward pass, it unscales the gradients, checks for NaNs/Infs, and only then applies the optimizer step.

from torch.cuda.amp import GradScaler

scaler = GradScaler()

for inputs, targets in dataloader:
    inputs, targets = inputs.cuda(), targets.cuda()
    optimizer.zero_grad(set_to_none=True)

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    # scale the loss, then backward
    scaler.scale(loss).backward()

    # unscale gradients, check for overflow, then step
    scaler.step(optimizer)
    scaler.update()

This pattern has become my default template because it gives me most of the speed benefits of FP16 while keeping training as stable as full precision.

How autocast and GradScaler Work Together

In a full training loop, autocast and GradScaler complement each other: autocast chooses the right precision for each op, and GradScaler protects the gradients from underflow and overflow. When I compare mixed precision runs with and without GradScaler, the difference in stability is obvious—loss curves are smoother, and I rarely see sudden NaNs derailing a long run.

One habit that has helped me is to first get a model converging in pure FP32, then drop in autocast + GradScaler and confirm the mixed precision run reaches a similar final accuracy. That quick A/B check gives me confidence that the speedup isn’t quietly hurting model quality. Automatic Mixed Precision package – torch.amp — PyTorch documentation

Step-by-Step: Converting a Standard Training Loop to Mixed Precision

1. Start from a Plain FP32 Training Loop

Whenever I introduce PyTorch mixed precision training into a project, I begin from a clean, working FP32 loop. If full precision doesn’t converge, mixed precision won’t magically fix it. Here’s a minimal baseline I often use:

import torch
from torch import nn, optim

model = MyModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train_epoch(dataloader):
    model.train()
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad(set_to_none=True)

        outputs = model(inputs)          # FP32 forward
        loss = criterion(outputs, targets)

        loss.backward()                  # FP32 backward
        optimizer.step()                 # FP32 step

I always run at least one epoch with this loop to confirm loss is decreasing and metrics look sane before touching precision.

2. Add autocast Around the Forward and Loss

The first mixed precision change I make is wrapping the forward and loss in autocast. This tells PyTorch which ops can safely run in lower precision for speed:

from torch.cuda.amp import autocast

def train_epoch(dataloader):
    model.train()
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad(set_to_none=True)

        # Only the forward + loss go inside autocast
        with autocast(dtype=torch.float16):
            outputs = model(inputs)      # many ops now run in FP16 on Tensor Cores
            loss = criterion(outputs, targets)

        loss.backward()                  # still FP32 gradients here (no scaling yet)
        optimizer.step()

At this point you’ll see some speedup, but without gradient scaling, FP16 underflow can still bite you, especially on deeper models.

3. Introduce GradScaler for Stable Backprop

The next step is bringing in GradScaler. In my experience, this is what makes mixed precision as robust as full precision for most workloads:

from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()   # create once and reuse every epoch

def train_epoch(dataloader):
    model.train()
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()

        optimizer.zero_grad(set_to_none=True)

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # scale loss before backward
        scaler.scale(loss).backward()

        # unscale gradients, check for NaNs/Infs, then step
        scaler.step(optimizer)
        scaler.update()

One thing I learned the hard way was to call scaler.step(optimizer) instead of optimizer.step() directly; otherwise, you skip the overflow checks and lose a lot of the safety benefits.

4. Complete Mixed Precision Loop with Best Practices

Once this skeleton is working, I tighten it up with a few best practices I now use by default: set_to_none=True for cheaper gradient zeroing, optional gradient clipping, and explicit non_blocking=True on data transfers when the dataloader uses pinned memory. Here’s a more production-ready version:

import torch
from torch import nn, optim
from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scaler = GradScaler()


def train_epoch(dataloader):
    model.train()
    for inputs, targets in dataloader:
        inputs = inputs.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        # mixed precision forward + loss
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # backward with scaled gradients
        scaled_loss = scaler.scale(loss)
        scaled_loss.backward()

        # optional: gradient clipping on unscaled grads
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # optimizer step and scaler update
        scaler.step(optimizer)
        scaler.update()

        # you can still log the original FP32 loss
        print(f"loss: {loss.item():.4f}")

In my own workflows, this pattern has become the default template: start from a stable FP32 loop, wrap the forward in autocast, add GradScaler, then layer in extras like clipping or custom schedulers only after the mixed precision core is solid.

Step-by-Step: Converting a Standard Training Loop to Mixed Precision - image 1

Maximizing GPU Utilization with Mixed Precision Training

Tuning Batch Size and Memory Usage

One of the biggest wins I’ve seen with PyTorch mixed precision training is the ability to push much larger batch sizes before hitting out-of-memory. Because activations and many intermediate tensors are stored in FP16/bf16, they use roughly half the memory of FP32. I usually start from the stable FP32 batch size, then increase it step by step under mixed precision until I’m just below the OOM boundary.

A simple pattern I use is to write a tiny script that tries a range of batch sizes and reports the maximum that fits. Once I’ve found that sweet spot, I often get both better GPU utilization (more work per step) and slightly smoother gradients due to larger effective batch sizes. If memory is still tight, gradient accumulation (simulating a larger batch over multiple steps) can be combined with mixed precision for even more flexibility.

Optimizing Data Loading and Host–Device Transfers

Mixed precision won’t help if the GPU is starved for data. In practice, I’ve seen more than one setup where utilization hovered around 40% just because the dataloader was too slow. To avoid that, I:

  • Use multiple workers in DataLoader (e.g., 4–8 as a starting point).
  • Enable pinned memory so host-to-device copies can be asynchronous.
  • Move tensors with non_blocking=True to overlap transfers and compute.

Here’s a configuration that has worked well for me on many projects:

from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=bs,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    persistent_workers=True,
)

for inputs, targets in train_loader:
    inputs = inputs.cuda(non_blocking=True)
    targets = targets.cuda(non_blocking=True)
    # mixed precision forward/backward here

Once I started treating the dataloader as part of the performance-critical path, my GPU utilization graphs in tools like nvidia-smi became much flatter and closer to 90–100%.

Profiling and Identifying Bottlenecks

I learned quickly that intuition about performance is often wrong; profiling is what actually reveals where time is going. With mixed precision enabled, I use torch.profiler or lightweight CUDA timers to check whether my run is compute-bound or input-bound.

import torch
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    with record_function("train_step"):
        train_epoch(train_loader)

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15))

If I see a lot of time in data loading or CPU-side augmentation, I move more work onto the GPU or simplify transforms. If most time is in a particular kernel, I might try different batch sizes, sequence lengths, or fused operations (e.g., using torch.nn.functional variants that play nicer with AMP). One habit that’s helped me is to profile a short mixed precision run every time I make a major architectural or input pipeline change; that way I catch regressions early instead of wondering why training became mysteriously slower. Performance Tuning Guide — PyTorch Tutorials

Avoiding CUDA Out-of-Memory and Stability Issues in Mixed Precision

Handling CUDA Out-of-Memory (OOM) Errors

When I first switched to PyTorch mixed precision training, I expected OOM errors to disappear; instead, I sometimes hit them faster because I cranked the batch size too aggressively. My first step is always to confirm that memory is really the issue by watching nvidia-smi during training and checking that usage climbs to the GPU limit before the crash.

In practice, I use a simple checklist:

  • Gradually reduce batch size until the OOM disappears.
  • Make sure I’m calling optimizer.zero_grad(set_to_none=True) every step.
  • Delete large tensors I no longer need and occasionally call torch.cuda.empty_cache() between runs, especially in notebooks.

On very tight setups, I’ve also used gradient accumulation so I can keep the effective batch size large while fitting smaller micro-batches in memory.

Dealing with NaNs, Infs, and Diverging Loss

NaNs and exploding loss were the first real pain points I hit with mixed precision. In my experience, the most common culprits are missing GradScaler, too high a learning rate, or numerically unstable operations in FP16. I usually start by confirming that loss is stable in pure FP32; if it isn’t, I fix that before blaming mixed precision.

Once FP32 is solid, I double-check that I’m following the full AMP pattern:

  • Forward and loss inside autocast only.
  • Use scaler.scale(loss).backward() instead of loss.backward().
  • Call scaler.step(optimizer) then scaler.update().

If NaNs still pop up, I’ll temporarily lower the learning rate (even by 2–4x), and if needed, disable autocast around particularly sensitive layers (e.g., custom normalization) so they run in FP32:

with autocast():
    x = model.backbone(inputs)

# force FP32 for a numerically delicate head
with torch.cuda.amp.autocast(enabled=False):
    x = model.head(x).float()
    loss = criterion(x, targets)

This targeted override has saved me more than once on models with tricky custom components.

Debugging Mixed Precision Issues Systematically

When things still feel unstable, I fall back on a simple but effective debugging recipe I’ve used in several projects:

  1. Baseline FP32: run a few hundred steps in full precision, log loss, and confirm smooth convergence.
  2. Add autocast only: enable autocast but keep FP32 backward; compare loss curves to FP32 baseline.
  3. Add GradScaler: plug in GradScaler and verify that final accuracy and loss are within noise of FP32.

If mixed precision diverges from FP32 at a specific point, I’ll log intermediate activations or disable autocast for suspected layers until I isolate the problem. That methodical comparison against a known-good FP32 run has been the most reliable way I’ve found to make mixed precision both fast and trustworthy in real training pipelines.

Avoiding CUDA Out-of-Memory and Stability Issues in Mixed Precision - image 1

End-to-End Example: PyTorch Mixed Precision Training Script

Script Overview and Structure

At this point, I like to bring everything together into a single, runnable script. When I’m starting a new project, I often clone a template like this and just swap in my own dataset and model. The goal is a clean, minimal example of PyTorch mixed precision training that still follows good habits: clear setup, data loading, AMP, GradScaler, and basic logging.

The script below trains a simple classifier on an image dataset (you can replace it with your own). It uses torch.cuda.amp.autocast for the forward pass, GradScaler for safe backprop, and a separate evaluation loop so you can confirm mixed precision doesn’t break accuracy.

Full Mixed Precision Training Script

import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.cuda.amp import autocast, GradScaler


def get_dataloaders(data_dir, batch_size=128, num_workers=4):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_ds = datasets.FakeData(
        size=2000,
        image_size=(3, 224, 224),
        num_classes=10,
        transform=transform,
    )
    val_ds = datasets.FakeData(
        size=500,
        image_size=(3, 224, 224),
        num_classes=10,
        transform=transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=num_workers > 0,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=num_workers > 0,
    )

    return train_loader, val_loader


def create_model(num_classes=10):
    # lightweight backbone; swap for your own model as needed
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


def train_one_epoch(model, loader, optimizer, criterion, scaler, device, epoch):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0

    for step, (images, targets) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        # autocast for mixed precision forward + loss
        with autocast(dtype=torch.float16):
            outputs = model(images)
            loss = criterion(outputs, targets)

        # backward with scaled loss
        scaler.scale(loss).backward()

        # unscale before optional gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        total += targets.size(0)
        correct += preds.eq(targets).sum().item()

        if (step + 1) % 20 == 0:
            avg_loss = running_loss / ((step + 1) * loader.batch_size)
            acc = 100.0 * correct / total
            print(f"Epoch {epoch} | Step {step + 1}/{len(loader)} | "
                  f"loss: {avg_loss:.4f} | acc: {acc:.2f}%")

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    # In eval I usually keep autocast for speed but disable GradScaler
    with torch.no_grad():
        with autocast(dtype=torch.float16):
            for images, targets in loader:
                images = images.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)

                outputs = model(images)
                loss = criterion(outputs, targets)

                running_loss += loss.item() * images.size(0)
                _, preds = outputs.max(1)
                total += targets.size(0)
                correct += preds.eq(targets).sum().item()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    batch_size = 128
    num_workers = 4
    epochs = 5
    lr = 3e-4

    train_loader, val_loader = get_dataloaders(
        data_dir="./data", batch_size=batch_size, num_workers=num_workers
    )

    model = create_model(num_classes=10).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scaler = GradScaler()

    best_val_acc = 0.0

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, optimizer, criterion, scaler, device, epoch
        )
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        print(
            f"Epoch {epoch} done | "
            f"train_loss: {train_loss:.4f}, train_acc: {train_acc:.2f}% | "
            f"val_loss: {val_loss:.4f}, val_acc: {val_acc:.2f}%"
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            os.makedirs("checkpoints", exist_ok=True)
            torch.save(model.state_dict(), "checkpoints/best_mixed_precision.pt")
            print(f">> Saved new best model with val_acc={val_acc:.2f}%")


if __name__ == "__main__":
    main()

In my own projects, I treat this kind of script as a living template: once it runs cleanly and converges with mixed precision, I swap in a real dataset, a larger model, and more advanced logging or schedulers. Because the AMP and GradScaler pattern is already baked in, I can focus on modeling decisions instead of constantly debugging low-level training boilerplate.

End-to-End Example: PyTorch Mixed Precision Training Script - image 1

Conclusion and Next Steps for PyTorch Mixed Precision Training

Recap of the Mixed Precision Training Workflow

Looking back at the full loop, the practical recipe for PyTorch mixed precision training is surprisingly compact. I start from a stable FP32 loop, wrap the forward and loss in torch.cuda.amp.autocast, plug in GradScaler around backward and optimizer steps, then tune batch size and data loading so the GPU stays busy. In my experience, once that core is solid and numerically stable, mixed precision feels almost “set-and-forget” compared to the early days of manual FP16.

Where to Go Next: Scaling Up and Advanced Use Cases

The same AMP building blocks extend naturally to bigger setups. The next steps I usually explore are:

  • Distributed training with DistributedDataParallel + AMP to scale across multiple GPUs or nodes.
  • bf16 training on newer hardware, which often gives FP32-like stability with mixed precision speed.
  • Integrating AMP into LLM-scale or vision foundation model training stacks that rely on sharding and activation checkpointing.

Once you’re comfortable with the single-GPU pattern, layering on distributed training, gradient checkpointing, and more sophisticated profilers is a natural progression. PyTorch Distributed Overview — PyTorch Tutorials

Join the conversation

Your email address will not be published. Required fields are marked *