class AtlasMapEmbeddings:
"""
Atlas Embeddings State
Access latent (high-dimensional) and projected (two-dimensional) embeddings of your datapoints.
## Two-dimensional projected embeddings
=== "Accessing 2D Embeddings Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.embeddings)
```
=== "Output"
```
id_ x y
0 0A -6.164423 21.517719
1 0g -6.606402 -5.601104
2 0Q -9.206946 7.448542
... ... ... ...
9998 JZQ 2.110881 -12.937058
9999 JZU 7.865006 -6.876243
```
## High dimensional latent embeddings
=== "Accessing Latent Embeddings Example"
``` py
from nomic import AtlasProject
project = AtlasProject(name='My Project')
map = project.maps[0]
embeddings = map.embeddings.latent
print(embeddings.shape)
```
=== "Output"
```
[10000, 384]
```
!!! warning "High dimensional embeddings"
High dimensional embeddings are not immediately downloaded when you access the embeddings attribute - you must explicitly call `map.embeddings.latent`. Once downloaded, subsequent calls will reference your downloaded local copy.
"""
def __init__(self, projection: "AtlasProjection"):
self.projection = projection
self.id_field = self.projection.project.id_field
self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, 'x', 'y'])
self.project = projection.project
self._latent = None
@property
def df(self):
"""
Pandas dataframe containing information about embeddings of your datapoints.
Includes only the two-dimensional embeddings
"""
return self.tb.to_pandas()
@property
def tb(self) -> pa.Table:
"""
Pyarrow table containing two-dimensional embeddings of each of your data points.
This table is memmapped from the underlying files and is the most efficient way to
access embedding information.
Does not include high-dimensional embeddings.
"""
return self._tb
@property
def projected(self) -> pd.DataFrame:
"""
Two-dimensional embeddings.
These are the points you see in your web browser.
Returns:
Pandas dataframe mapping your datapoints to their two-dimensional embeddings.
"""
return self.df
@property
def latent(self) -> np.array:
"""
High dimensional embeddings.
Returns:
A memmapped numpy array where each row contains the latent embedding of the corresponding datapoint in the same order as `map.embeddings.projected`.
"""
if self._latent is not None:
return self._latent
root_embedding = self.projection.tile_destination / "0/0/0-0.embeddings.feather"
# Not the most complete check, hence the warning below.
if not root_embedding.exists():
self._download_latent()
all_embeddings = []
for path in self.projection._tiles_in_order():
# double with-suffix to remove '.embeddings.feather'
files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather")
# Should there be more than 10, we need to sort by int values, not string values
sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]))
if len(sortable) == 0:
raise FileNotFoundError("Could not find any embeddings for tile {}".format(path) +
" If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'.")
for file in sortable:
tb = feather.read_table(file, memory_map=True)
dims = tb['_embeddings'].type.list_size
all_embeddings.append(pc.list_flatten(tb['_embeddings']).to_numpy().reshape(-1, dims))
return np.vstack(all_embeddings)
def _download_latent(self):
"""
Downloads the latent embeddings one file at a time.
"""
logger.warning("Downloading latent embeddings of all datapoints.")
limit = 10_000
route = self.projection.project.atlas_api_path + '/v1/project/data/get/embedding/paged'
last = None
with tqdm(total=self.project.total_datums//limit) as pbar:
while True:
params = {
'projection_id': self.projection.id,
"last_file": last,
"page_size": limit
}
r = requests.post(route, headers=self.projection.project.header, json=params)
if r.status_code == 204:
# Download complete!
break
fin = BytesIO(r.content)
tb = feather.read_table(fin, memory_map=True)
tilename = tb.schema.metadata[b'tile'].decode("utf-8")
dest = (self.projection.tile_destination / tilename).with_suffix(".embeddings.feather")
dest.parent.mkdir(parents=True, exist_ok=True)
feather.write_feather(tb, dest)
last = tilename
pbar.update(1)
def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int = 5) -> Dict[str, List]:
'''
Performs semantic vector search over data points on your map.
If ids is specified, receive back the most similar data ids in latent vector space to your input ids.
If queries is specified, receive back the data ids with representations most similar to the query vectors.
You should not specify both queries and ids.
Args:
queries: a 2d numpy array where each row corresponds to a query vector
ids: a list of ids
k: the number of closest data points (neighbors) to return for each input query/data id
Returns:
A tuple with two elements containing the following information:
neighbors: A set of ids corresponding to the nearest neighbors of each query
distances: A set of distances between each query and its neighbors
'''
if queries is None and ids is None:
raise ValueError('You must specify either a list of datum `ids` or numpy array of `queries` but not both.')
max_k = 128
max_queries = 256
if k > max_k:
raise Exception(f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower")
if ids is not None:
if len(ids) > max_queries:
raise Exception(f"Max ids per query is {max_queries}. You sent {len(ids)}.")
if queries is not None:
if not isinstance(queries, np.ndarray):
raise Exception("`queries` must be an instance of np.array.")
if queries.shape[0] > max_queries:
raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.")
if queries is not None:
if queries.ndim != 2:
raise ValueError(
'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).'
)
bytesio = io.BytesIO()
np.save(bytesio, queries)
if queries is not None:
response = requests.post(
self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding",
headers=self.projection.project.header,
json={
'atlas_index_id': self.projection.atlas_index_id,
'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'),
'k': k,
},
)
else:
response = requests.post(
self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_id",
headers=self.projection.project.header,
json={'atlas_index_id': self.projection.atlas_index_id, 'datum_ids': ids, 'k': k},
)
if response.status_code == 500:
raise Exception('Cannot perform vector search on your map at this time. Try again later.')
if response.status_code != 200:
raise Exception(response.text)
response = response.json()
return response['neighbors'], response['distances']
def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]:
'''
Deprecated in favor of `map.embeddings.latent`.
Iterate through embeddings of your datums.
Returns:
An iterable mapping datum ids to their embeddings.
'''
raise DeprecationWarning("Deprecated as of June 2023. Iterate `map.embeddings.latent`.")
if self.project.is_locked:
raise Exception('Project is locked! Please wait until the project is unlocked to download embeddings')
offset = 0
limit = EMBEDDING_PAGINATION_LIMIT
while True:
response = requests.get(
self.atlas_api_path
+ f"/v1/project/data/get/embedding/{self.project.id}/{self.projection.atlas_index_id}/{offset}/{limit}",
headers=self.header,
)
if response.status_code != 200:
raise Exception(response.text)
content = response.json()
if len(content['datum_ids']) == 0:
break
offset += len(content['datum_ids'])
yield content['datum_ids'], content['embeddings']
def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool:
'''
Deprecated in favor of `map.embeddings.latent`.
Downloads embeddings to the specified save_directory.
Args:
save_directory: The directory to save your embeddings.
Returns:
True on success
'''
raise DeprecationWarning("Deprecated as of June 2023. Use `map.embeddings.latent`.")
self.project._latest_project_state()
total_datums = self.project.total_datums
if self.project.is_locked:
raise Exception('Project is locked! Please wait until the project is unlocked to download embeddings')
offset = 0
limit = EMBEDDING_PAGINATION_LIMIT
def download_shard(offset, check_access=False):
response = requests.get(
self.project.atlas_api_path
+ f"/v1/project/data/get/embedding/{self.project.id}/{self.projection.atlas_index_id}/{offset}/{limit}",
headers=self.project.header,
)
if response.status_code != 200:
raise Exception(response.text)
if check_access:
return
shard_name = '{}_{}_{}.feather'.format(self.projection.atlas_index_id, offset, offset + limit)
shard_path = os.path.join(save_directory, shard_name)
try:
content = response.content
is_arrow_format = content[:6] == b"ARROW1" and content[-6:] == b"ARROW1"
if not is_arrow_format:
raise Exception('Expected response to be in Arrow IPC format')
with open(shard_path, 'wb') as f:
f.write(content)
except Exception as e:
logger.error('Shard {} download failed with error: {}'.format(shard_name, e))
download_shard(0, check_access=True)
with tqdm(total=total_datums // limit) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(download_shard, cur_offset): cur_offset
for cur_offset in range(0, total_datums, limit)
}
for future in concurrent.futures.as_completed(futures):
_ = future.result()
pbar.update(1)
return True
def __repr__(self) -> str:
return str(self.df)