Exploring a Pytorch Model's Latent Space¶
Atlas can be used to better understand your deep neural nets training and test characteristics. By interacting with your models embeddings (logits) during training and evaluation you can:
- Identify which classes, targets or concepts your model has ease/difficulty learning.
- Identify mislabeled datapoints.
- Spot bugs/errors in your model implementation.
Atlas has a Pytorch Lightning hook that you can plug straight into your pytorch lightning training scripts. This tutorial will take you through using it to visualize the training of a two layer neural network on MNIST.
!pip install pytorch-lightning torch torchvision torchmetrics
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from nomic.pl_callbacks import AtlasEmbeddingExplorer
import nomic
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
torch.manual_seed(0)
# #api key to a limited demo account. Make your own account at atlas.nomic.ai
nomic.login('7xDPkYXSYDc1_ErdTPIcoAR9RNd8YDlkS3nVNXcVoIMZ6')
The Lightning Module¶
class MNISTModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
self.l2 = torch.nn.Linear(10, 10)
def forward(self, x):
return torch.relu(self.l2(torch.relu(self.l1(x.view(x.size(0), -1)))))
def training_step(self, batch, batch_nb):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
prediction = torch.argmax(logits, dim=1)
#an image for each label
image_links = [f'https://s3.amazonaws.com/static.nomic.ai/mnist/eval/{label}/{batch_idx*BATCH_SIZE+idx}.jpg'
for idx, label in enumerate(y)]
metadata = {'label': y, 'prediction': prediction, 'url': image_links}
self.atlas.log(embeddings=logits, metadata=metadata)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
Training the model¶
mnist_model = MNISTModel()
# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)
# Initialize a trainer
max_epochs = 10
# Initialize the Embedding Explorer 🗺️ hook
embedding_explorer = AtlasEmbeddingExplorer(max_points=10_000,
name="MNIST Validation Latent Space",
description="MNIST Validation Latent Space",
overwrite_on_validation=True)
trainer = Trainer(
accelerator="auto",
devices=1 if torch.cuda.is_available() else None,
max_epochs=max_epochs,
check_val_every_n_epoch=10,
callbacks=[TQDMProgressBar(refresh_rate=20),
embedding_explorer],
)
# Train the model ⚡
trainer.fit(mnist_model, train_dataloaders=train_loader, val_dataloaders=test_loader)
Validate the model and log the embeddings¶
trainer.validate(mnist_model, test_loader)
View the map¶
embedding_explorer.map
Project: MNIST Validation Latent Space
Debugging the Latent Space¶
You can visually inspect your trained models decision boundaries. Points that are misclassied or that are hard to classify appear in-between embedding clusters. For example, hover over the region between the yellow and blue cluster to find points that are hard to discriminate between zero's and six's. Try modifying the Pytorch model definition to see how classification errors differ for models with better vision inductive biases like CNN's.