Skip to main content

Generative AI data interface

Learn how to use Atlas to evaluate your machine learning systems.

Use cases​

  • Visualize the output of your embedding models
  • Monitoring chatbots
  • Exploring a PyTorch model's latent space


The visualization of unstructured data and high-dimensional embeddings in Atlas is crucial for understanding hidden layers and opaque outputs of deep learning models. Atlas also provides visual access to enormous unstructured data streams, allowing you to detect anomalous patterns, and uncover underlying thematic structures, significantly improving the ability to refine natural language processing algorithms.

The integration of Atlas's semantic visualization elevates both the model development process and the interpretability of complex machine learning systems.

Monitoring chatbots

One of the hardest things about building a chatbot is monitoring how users are actually interacting with the model. Are users sending toxic messages? Is your bot producing toxic outputs?

In this tutorial, we'll show you how to use Langchain, AtlasDB, and Cohere to monitor your chatbot for toxic content. Try implementing it yourself or check out the πŸ““ Colab Notebook.

Installing dependencies​

!pip install langchain cohere datasets
!pip install --upgrade nomic

Data setup​

To mock a query stream that gets progressively more unsafe, we are going to use two datasets from Allen AI.

The first dataset, soda, contains several "normal" chatbot transcripts. The second dataset, prosocial-dialog, contains several "unsafe" transcripts.

We are going to build a data generator that randomly samples a datapoint from one of these datasets. As time goes on, it will become more likely that the generator selects a "unsafe" datum. This will mock the real world situation where your query stream becomes more unsafe over time.

import random
from datasets import load_dataset

def stream_data(T=5000):
safe = load_dataset('allenai/soda')['train']
unsafe = iter([e for e in load_dataset('allenai/prosocial-dialog')['train']
if e['safety_label'] == '__needs_intervention__'])

for t in range(T):
p_unsafe = t/(3*T) #slowly increase toxic content to 33% chance
is_unsafe = random.uniform(0, 1) <= p_unsafe
if is_unsafe:
x = next(unsafe)
yield 'unsafe', t, x['context'] + ' ' + x['response']
x = safe[t]
yield 'safe', t, x['dialogue'][0]

data = [e for e in stream_data(6000)]
train_data = data[:5000]
test_data = data[5000:]

Model setup​

We are going to use the AtlasDB langchain integration to store our chat data, and CohereEmbeddings to embed it.

import time
import numpy as np
from tqdm import tqdm
from pprint import pprint
from langchain.vectorstores import AtlasDB
from langchain.embeddings import CohereEmbeddings
# Create a CohereEmbeddings object to vectorize our text
embedding = CohereEmbeddings(cohere_api_key=COHERE_API_KEY, model='small')

# Create an AtlasDB object to store our vectors
db = AtlasDB(name='Observability Demo',

Add data at AtlasDB with the add_texts function. Each datapoint has text (the contents of the chat), a label and a timestamp.

batch_size = 1000
batched_texts = []
batched_metadatas = []

for datum in tqdm(train_data):

label, timestamp, text = datum[0], datum[1], datum[2]

#Batch data for faster adds. You can also add data one at a time if you like
batched_metadatas.append({'label': label, 'timestamp': timestamp})

if len(batched_texts) >= batch_size:
#Add data to database as it streams
refresh=False) #refresh=False indicates that we would not rebuild the database every time we add a text
batched_texts = []
batched_metadatas = []

Once we've added all of our text, we can create an index over our data. If you add more text after creating your index, you can refresh with db.project.rebuild_maps()

For a full list of parameters to create_index, see:

#Index our data in Atlas
db.create_index(name='Observability Demo',
colorable_fields=['timestamp', 'label'],

Visualize your chats​

Once the index completes, we can view it inline in our notebook!

#Wait for the index to build
with db.project.wait_for_project_lock():

Atlas automatically builds a topic model on your data, allowing you to understand what people are talking about with your bot at a glance.

Check out the bottom left corner of the map. The topics there (e.g. domestic violence, jokes, etc...) indicate that there is unsafe content in your query stream!

Getting topic information​

We can view the topics that AltasDB extracted programmatically in our code.

map = db.project.maps[0]
topic_data = map.topics.metadata #this is how we view all of the metadata associated with our topics

Using topics to filter the query stream​

By exploring the map, we can see that many of the unsafe topics share the top level label "Don't hurt people." Let's use AtlasDB to build a greylist so we can review detect utterances that may be unsafe.

# Prep the data from our test set
test_labels = [e[0] for e in test_data]
test_text = [e[2] for e in test_data]
test_embeddings = np.stack(embedding.embed_documents(test_text))
unsafe_topic = '1' #The topic id of our unsafe topic, inferred from topic_data[0]['topic']
batch_size = 100
topic_posteriors = []
for i in range(0, len(test_embeddings), batch_size):
batch = test_embeddings[i:i+batch_size]
cur_posteriors = map.topics.vector_search_topics(batch, k=32, depth=1)['topics'] #Use AtlasDB to infer the topics of new data

predicted_is_unsafe = [max(posterior, key=lambda topic: posterior[topic]) == unsafe_topic
for posterior in topic_posteriors]
#Analyze the results of our classifier on the test data
true_positive = 0
false_positive = 0
false_negative = 0
for i, label in enumerate(test_labels):
if predicted_is_unsafe[i]:
if label == 'unsafe':
true_positive += 1
false_positive += 1
if label == 'unsafe':
false_negative += 1

print('Precision: ', true_positive/(true_positive+false_positive))
print('Recall: ', true_positive/(true_positive+false_negative))
Exploring a PyTorch model's latent space

Atlas can be used to better understand your deep neural nets training and test characteristics. By interacting with your models embeddings (logits) during training and evaluation you can:

  • Identify which classes, targets or concepts your model has ease/difficulty learning.
  • Identify mislabeled datapoints.
  • Spot bugs/errors in your model implementation.

Atlas has a Pytorch Lightning hook that you can plug straight into your pytorch lightning training scripts. This tutorial will take you through using it to visualize the training of a two layer neural network on MNIST.

Follow along in πŸ““ Colab Notebook.

!pip install pytorch-lightning torch torchvision torchmetrics
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.nn import functional as F
from import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from nomic.pl_callbacks import AtlasEmbeddingExplorer
import nomic

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
# #api key to a limited demo account. Make your own account at

Lightning Module​

class MNISTModel(LightningModule):
def __init__(self):
self.l1 = torch.nn.Linear(28 * 28, 10)
self.l2 = torch.nn.Linear(10, 10)

def forward(self, x):
return torch.relu(self.l2(torch.relu(self.l1(x.view(x.size(0), -1)))))

def training_step(self, batch, batch_nb):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
prediction = torch.argmax(logits, dim=1)

#an image for each label
image_links = [f'{label}/{batch_idx*BATCH_SIZE+idx}.jpg'
for idx, label in enumerate(y)]
metadata = {'label': y, 'prediction': prediction, 'url': image_links}
self.atlas.log(embeddings=logits, metadata=metadata)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

Training the model​

mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
max_epochs = 10

# Initialize the Embedding Explorer πŸ—ΊοΈ hook
embedding_explorer = AtlasEmbeddingExplorer(max_points=10_000,
name="MNIST Validation Latent Space",
description="MNIST Validation Latent Space",
trainer = Trainer(
devices=1 if torch.cuda.is_available() else None,

# Train the model ⚑, train_dataloaders=train_loader, val_dataloaders=test_loader)

Validate model and log embeddings​

trainer.validate(mnist_model, test_loader)

View the map​

Debugging the latent space​

You can visually inspect your trained models decision boundaries. Points that are misclassied or that are hard to classify appear in-between embedding clusters. For example, hover over the region between the yellow and blue cluster to find points that are hard to discriminate between zero's and six's. Try modifying the Pytorch model definition to see how classification errors differ for models with better vision inductive biases, like CNNs.