Training Image Classification in PyTorch¶
Deeplake makes it easy to train image classification models by streaming data directly from managed tables into a PyTorch training loop. In this tutorial, you will ingest the Fashion MNIST dataset from HuggingFace into a Deeplake managed table, then train a ResNet18 model using a standard PyTorch DataLoader.
Objective¶
Ingest the Fashion MNIST dataset from HuggingFace into a Deeplake managed table, then train a ResNet18 model by streaming data with a PyTorch DataLoader.
Prerequisites¶
- Deeplake SDK:
pip install deeplake - PyTorch and torchvision:
pip install torch torchvision - A Deeplake API token.
Set credentials first
Complete Code¶
import io
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models, transforms
from PIL import Image
from deeplake import Client
# --- Configuration ---
TABLE_NAME = "fashion_mnist"
NUM_EPOCHS = 3
BATCH_SIZE = 32
LEARNING_RATE = 0.001
# --- 1. Ingest the Dataset from HuggingFace ---
client = Client()
client.ingest(TABLE_NAME, {"_huggingface": "fashion_mnist"})
# --- 2. Create the DataLoader ---
ds = client.open_table(TABLE_NAME)
tform = transforms.Compose([
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def apply_transform(sample):
d = sample.to_dict()
d["image"] = tform(Image.open(io.BytesIO(d["image"])))
return d
train_loader = DataLoader(
ds.pytorch(transform=apply_transform),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4,
)
# --- 3. Define the Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(weights="DEFAULT")
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10) # 10 Fashion MNIST classes
model = model.to(device)
# --- 4. Training Loop ---
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
for epoch in range(NUM_EPOCHS):
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, batch in enumerate(train_loader):
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
if i % 100 == 99:
print(f"Epoch {epoch+1}, Batch {i+1}: "
f"Loss={running_loss/100:.3f}, "
f"Acc={100.*correct/total:.1f}%")
running_loss = 0.0
print(f"Epoch {epoch+1} complete - Accuracy: {100.*correct/total:.1f}%")
# --- 5. Evaluate ---
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for batch in train_loader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
outputs = model(images)
_, predicted = outputs.max(1)
test_total += labels.size(0)
test_correct += predicted.eq(labels).sum().item()
print(f"Final Accuracy: {100.*test_correct/test_total:.1f}%")
Step-by-Step Breakdown¶
1. Ingest the Dataset¶
Deeplake can ingest datasets directly from HuggingFace with a single call. The _huggingface key tells the platform to pull the dataset by name, automatically mapping its columns (image, label) into a managed table.
If the table already exists, you can skip this step and go straight to open_table.
2. Create the DataLoader¶
Open the managed table and wrap it in a standard PyTorch DataLoader. The ds.pytorch() method returns a map-style dataset that streams data directly from Deeplake's storage engine. Pass a transform function that receives each sample as a Row object. Call .to_dict() to get a mutable dict, then decode image bytes with Image.open(io.BytesIO(...)).
ds = client.open_table(TABLE_NAME)
tform = transforms.Compose([
transforms.RandomRotation(20),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def apply_transform(sample):
d = sample.to_dict()
d["image"] = tform(Image.open(io.BytesIO(d["image"])))
return d
train_loader = DataLoader(
ds.pytorch(transform=apply_transform),
batch_size=32,
shuffle=True,
num_workers=4,
)
Each batch is a dictionary where keys match the table columns: batch["image"] contains the image tensors and batch["label"] contains the class indices.
3. Define the Model¶
We use a pretrained ResNet18 and adapt it for Fashion MNIST. Two modifications are needed: the first convolutional layer is changed from 3-channel RGB to 1-channel grayscale, and the final fully connected layer is resized to output 10 classes.
model = models.resnet18(weights="DEFAULT")
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
4. Training Loop¶
A standard PyTorch training loop with CrossEntropyLoss and SGD. Data streams from the managed table through the DataLoader exactly as it would from a local dataset. No special handling is required.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(NUM_EPOCHS):
model.train()
for batch in train_loader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
5. Evaluate¶
Switch the model to evaluation mode and run a pass over the data to measure accuracy. For a production setup, you would ingest a separate test split and evaluate against that.
model.eval()
with torch.no_grad():
for batch in train_loader:
images = batch["image"].to(device)
labels = batch["label"].to(device)
outputs = model(images)
_, predicted = outputs.max(1)
Why no REST API?¶
Streaming high-performance tensor data over standard REST endpoints introduces significant latency and CPU overhead due to HTTP headers and JSON serialization. For high-throughput training, the Python SDK is the only supported method as it uses optimized C++ streaming kernels.
What to try next¶
- GPU-Streaming Pipeline: learn more about direct-to-GPU data streaming.
- Massive Ingestion: prepare large-scale datasets for training.
- Semantic Search: search your dataset by content similarity.
- Reference: Querying: details on
open_table().