Monitoring Chatbots with Atlas, Langchain and Cohere¶
One of the hardest things about building a chatbot is monitoring how users are actually interacting with the model. Are users sending toxic messages? Is your bot producing toxic outputs?
In this tutorial, we'll show you how to use Langchain, AtlasDB, and Cohere to monitor your chatbot for toxic content.
Installing Dependencies¶
!pip install langchain cohere datasets
!pip install --upgrade nomic
COHERE_API_KEY = ''
ATLAS_API_KEY = ''
Data Setup¶
To mock a query stream that gets progressively more unsafe, we are going to use two datasets from Allen AI.
The first dataset, soda, contains several "normal" chatbot transcripts. The second dataset, prosocial-dialog, contains several "unsafe" transcripts.
We are going to build a data generator that randomly samples a datapoint from one of these datasets. As time goes on, it will become more likely that the generator selects a "unsafe" datum. This will mock the real world situation where your query stream becomes more unsafe over time.
import random
from datasets import load_dataset
def stream_data(T=5000):
safe = load_dataset('allenai/soda')['train']
unsafe = iter([e for e in load_dataset('allenai/prosocial-dialog')['train']
if e['safety_label'] == '__needs_intervention__'])
for t in range(T):
p_unsafe = t/(3*T) #slowly increase toxic content to 33% chance
is_unsafe = random.uniform(0, 1) <= p_unsafe
if is_unsafe:
x = next(unsafe)
yield 'unsafe', t, x['context'] + ' ' + x['response']
else:
x = safe[t]
yield 'safe', t, x['dialogue'][0]
data = [e for e in stream_data(6000)]
train_data = data[:5000]
test_data = data[5000:]
Model Setup¶
We are going to use the AtlasDB langchain integration to store our chat data, and CohereEmbeddings to embed it.
import time
import numpy as np
from tqdm import tqdm
from pprint import pprint
from langchain.vectorstores import AtlasDB
from langchain.embeddings import CohereEmbeddings
# Create a CohereEmbeddings object to vectorize our text
embedding = CohereEmbeddings(cohere_api_key=COHERE_API_KEY, model='small')
# Create an AtlasDB object to store our vectors
db = AtlasDB(name='Observability Demo',
api_key=ATLAS_API_KEY,
embedding_function=embedding)
Add data at AtlasDB with the add_texts function. Each datapoint has text (the contents of the chat), a label and a timestamp.
batch_size = 1000
batched_texts = []
batched_metadatas = []
for datum in tqdm(train_data):
label, timestamp, text = datum[0], datum[1], datum[2]
#Batch data for faster adds. You can also add data one at a time if you like
batched_texts.append(text)
batched_metadatas.append({'label': label, 'timestamp': timestamp})
if len(batched_texts) >= batch_size:
#Add data to database as it streams
db.add_texts(texts=batched_texts,
metadatas=batched_metadatas,
refresh=False) #refresh=False indicates that we would not rebuild the database every time we add a text
batched_texts = []
batched_metadatas = []
Once we've added all of our text, we can create an index over our data.
If you add more text after creating your index, you can refresh with db.project.rebuild_maps()
For a full list of parameters to create_index, see: https://docs.nomic.ai/atlas_api.html#nomic.project.AtlasProject.create_index
#Index our data in Atlas
db.create_index(name='Observability Demo',
colorable_fields=['timestamp', 'label'],
build_topic_model=True,
topic_label_field='text')
2023-03-13 12:48:24.578 | INFO | nomic.project:wait_for_project_lock:1057 - Observability Demo: Project lock is released. 2023-03-13 12:48:26.518 | INFO | nomic.project:create_index:1256 - Created map `Observability Demo` in project `Observability Demo`: https://atlas.nomic.ai/map/545dd10e-65a9-4095-976d-2b16abd889e2/b3d4ea89-680a-4c75-833b-595cd80fedf3
Visualize your chats¶
Once the index completes, we can view it inline in our notebook!
#Wait for the index to build
with db.project.wait_for_project_lock():
time.sleep(1)
db.project
A description for your project 5000 datums inserted.
1 index built.
Projections
- Observability Demo. Status Completed. view online
Projection ID: b3d4ea89-680a-4c75-833b-595cd80fedf3
Atlas automatically builds a topic model on your data, allowing you to understand what people are talking about with your bot at a glance.
Check out the bottom left corner of the map. The topics there (e.g. domestic violence, jokes, etc...) indicate that there is unsafe content in your query stream!
You can click on a topic label to filter the map to only things of that topic
The slider tool can be used to visualize data within certain time ranges. Using it, we can see that our query stream becomes more toxic over time.
The aesthetic tool can be used to change the colors on the map. Coloring by the label confirms that our map places most of the toxic content in the same region of the map.
The pencil tool can be used to select regions of the map, which can be bulk labeled
Getting Topic Information¶
We can view the topics that AltasDB extracted programmatically in our code.
map = db.project.maps[0]
topic_data = map.topics.metadata #this is how we view all of the metadata associated with our topics
pprint(topic_data[:10])
[{'_topic_depth_1': 'Relationships', 'depth': 1, 'topic': 1, 'topic_description': 'her/they/is/should/she/him/would/them/their/will/not/If/his/You/joke', 'topic_short_description': 'Relationships'}, {'_topic_depth_1': 'math anxiety', 'depth': 1, 'topic': 2, 'topic_description': "don't/know/sorry/feel/job/just/can't/what/didn't/do/math/like/lately/anymore/test", 'topic_short_description': 'math anxiety'}, {'_topic_depth_1': "How's Sarah doing today?", 'depth': 1, 'topic': 3, 'topic_description': "Sarah/Hey/What's//How/doing/up/are/How's/today/Sarah!/going/what's/Hi/you", 'topic_short_description': "How's Sarah doing today?"}, {'_topic_depth_1': 'Thank you for the amazing party!', 'depth': 1, 'topic': 4, 'topic_description': 'Wow/really/so/Thank/much/excited/party/book/amazing/beautiful/looking/appreciate/car/new/amazing!', 'topic_short_description': 'Thank you for the amazing party!'}, {'_topic_depth_1': 'Talking to your boss', 'depth': 1, 'topic': 5, 'topic_description': 'talk/boss/Hey/Hi//something/glad/wanted/meet/minute/to/you/ask/here/Good', 'topic_short_description': 'Talking to your boss'}, {'_topic_depth_1': 'Family', 'depth': 1, 'topic': 6, 'topic_description': 'Mom/Hey/Dad/dad//talk/Mom!/something/college/dishes/father/home/Do/minute/Can', 'topic_short_description': 'Family'}, {'_topic_depth_1': 'Feeling sick', 'depth': 1, 'topic': 7, 'topic_description': 'okay/Doctor/feeling/Are/stomach/headache/throat/pain/well/hurts/Hi/nauseous/doctor/body/sick', 'topic_short_description': 'Feeling sick'}, {'_topic_depth_1': 'Hey, How, are, doing,', 'depth': 1, 'topic': 8, 'topic_description': "Hey/How//up/are/doing/today/what's/What's/how's/How's/coach/seen/Hi/how", 'topic_short_description': 'Hey, How, are, doing,'}, {'_topic_depth_1': 'Relationships', '_topic_depth_2': 'Parenting', 'depth': 2, 'topic': 9, 'topic_description': "her/she/baby/year/old/children/child/she's/sleep/adult/suicidal/age/drinking/drugs/daughter", 'topic_short_description': 'Parenting'}, {'_topic_depth_1': 'math anxiety', '_topic_depth_2': 'My new shirt', 'depth': 2, 'topic': 10, 'topic_description': 'car/my/new/shirt/wear/it/was/the/clean/clothes/old/driving/on/floor/dress', 'topic_short_description': 'My new shirt'}]
Based on this topic information, we can see several clusters of unsafe data we want to flag - things like "stealing", "violence", "addiction".
Furthermore, we see that each entry may have several entries under _topic_depth. The Atlas topic model is hierarchical, and entries with multiple depths trace the path down our topic tree. For instance,
{'_topic_depth_1': 'You Should Not Do That If You',
'_topic_depth_2': 'Addiction',
'depth': 2,
'topic': 17,
'topic_description': 'drugs/smoking/doctor/medication/addicted/pills/meds/cocaine/medicine/heroin/sleeping/sugar/quit/take/suddenly',
'topic_short_description': 'Addiction'},
Is a depth 2 topic, whose parent topic is named "You should not do that if you", and whose distinguishing words are "drugs, smoking, doctor, medication, addicted, etc..."
Using Topics to Filter the Query Stream¶
By exploring the map, we can see that many of the unsafe topics share the top level label "Don't hurt people." Let's use AtlasDB to build a greylist so we can review detect utterances that may be unsafe.
# Prep the data from our test set
test_labels = [e[0] for e in test_data]
test_text = [e[2] for e in test_data]
test_embeddings = np.stack(embedding.embed_documents(test_text))
unsafe_topic = '1' #The topic id of our unsafe topic, inferred from topic_data[0]['topic']
batch_size = 100
topic_posteriors = []
for i in range(0, len(test_embeddings), batch_size):
batch = test_embeddings[i:i+batch_size]
cur_posteriors = map.topics.vector_search_topics(batch, k=32, depth=1)['topics'] #Use AtlasDB to infer the topics of new data
topic_posteriors.extend(cur_posteriors)
predicted_is_unsafe = [max(posterior, key=lambda topic: posterior[topic]) == unsafe_topic
for posterior in topic_posteriors]
#Analyze the results of our classifier on the test data
true_positive = 0
false_positive = 0
false_negative = 0
for i, label in enumerate(test_labels):
if predicted_is_unsafe[i]:
if label == 'unsafe':
true_positive += 1
else:
false_positive += 1
else:
if label == 'unsafe':
false_negative += 1
print('Precision: ', true_positive/(true_positive+false_positive))
print('Recall: ', true_positive/(true_positive+false_negative))
Precision: 0.9893617021276596 Recall: 0.5923566878980892
By just using the topics that AtlasDB automatically generated for us, we were able to identify the unseen harmful data with a high precision!