Skip to content

Embeddings: 2D and Latent

Atlas stores, manages and generates embeddings for your unstructured data.

You can access Atlas latent embedding (e.g. high dimensional) or their two-dimensional projected representations.

from nomic import AtlasProject

map = AtlasProject(name='My Project').maps[0]

projected_embeddings = map.embeddings.projected

latent_embeddings = map.embeddings.latent

print(f"The datapoint with id {projected_embeddings['id'][0]} is located at ({projected_embeddings['x'][0]}, {projected_embeddings['y'][0]}) with latent embedding {latent_embeddings[0]}")

AtlasMapEmbeddings

Atlas Embeddings State

Access latent (high-dimensional) and projected (two-dimensional) embeddings of your datapoints.

Two-dimensional projected embeddings
from nomic import AtlasProject

project = AtlasProject(name='My Project')
map = project.maps[0]
print(map.embeddings)
      id_          x          y
0      0A  -6.164423  21.517719
1      0g  -6.606402  -5.601104
2      0Q  -9.206946   7.448542
...   ...        ...        ...
9998  JZQ   2.110881 -12.937058
9999  JZU   7.865006  -6.876243
High dimensional latent embeddings
from nomic import AtlasProject

project = AtlasProject(name='My Project')
map = project.maps[0]
embeddings = map.embeddings.latent
print(embeddings.shape)
[10000, 384]

High dimensional embeddings

High dimensional embeddings are not immediately downloaded when you access the embeddings attribute - you must explicitly call map.embeddings.latent. Once downloaded, subsequent calls will reference your downloaded local copy.

Source code in nomic/data_operations.py
class AtlasMapEmbeddings:
    """
    Atlas Embeddings State

    Access latent (high-dimensional) and projected (two-dimensional) embeddings of your datapoints.

    ## Two-dimensional projected embeddings

    === "Accessing 2D Embeddings Example"
        ``` py
        from nomic import AtlasProject

        project = AtlasProject(name='My Project')
        map = project.maps[0]
        print(map.embeddings)
        ```
    === "Output"
        ```
              id_          x          y
        0      0A  -6.164423  21.517719
        1      0g  -6.606402  -5.601104
        2      0Q  -9.206946   7.448542
        ...   ...        ...        ...
        9998  JZQ   2.110881 -12.937058
        9999  JZU   7.865006  -6.876243
        ```

    ## High dimensional latent embeddings


    === "Accessing Latent Embeddings Example"
        ``` py
        from nomic import AtlasProject

        project = AtlasProject(name='My Project')
        map = project.maps[0]
        embeddings = map.embeddings.latent
        print(embeddings.shape)
        ```
    === "Output"
        ```
        [10000, 384]
        ```


    !!! warning "High dimensional embeddings"
        High dimensional embeddings are not immediately downloaded when you access the embeddings attribute - you must explicitly call `map.embeddings.latent`. Once downloaded, subsequent calls will reference your downloaded local copy.

    """

    def __init__(self, projection: "AtlasProjection"):
        self.projection = projection
        self.id_field = self.projection.project.id_field
        self._tb: pa.Table = projection._fetch_tiles().select([self.id_field, 'x', 'y'])
        self.project = projection.project
        self._latent = None

    @property
    def df(self):
        """
        Pandas dataframe containing information about embeddings of your datapoints.

        Includes only the two-dimensional embeddings
        """
        return self.tb.to_pandas()

    @property
    def tb(self) -> pa.Table:
        """
        Pyarrow table containing two-dimensional embeddings of each of your data points.
        This table is memmapped from the underlying files and is the most efficient way to
        access embedding information.

        Does not include high-dimensional embeddings.
        """
        return self._tb

    @property
    def projected(self) -> pd.DataFrame:
        """
        Two-dimensional embeddings.

        These are the points you see in your web browser.

        Returns:
            Pandas dataframe mapping your datapoints to their two-dimensional embeddings.
        """
        return self.df

    @property
    def latent(self) -> np.array:
        """
        High dimensional embeddings.

        Returns:
            A memmapped numpy array where each row contains the latent embedding of the corresponding datapoint in the same order as `map.embeddings.projected`.
        """
        if self._latent is not None:
            return self._latent

        root_embedding = self.projection.tile_destination / "0/0/0-0.embeddings.feather"
        # Not the most complete check, hence the warning below.
        if not root_embedding.exists():
            self._download_latent()
        all_embeddings = []


        for path in self.projection._tiles_in_order():
            # double with-suffix to remove '.embeddings.feather'
            files = path.parent.glob(path.with_suffix("").stem + "-*.embeddings.feather")
            # Should there be more than 10, we need to sort by int values, not string values
            sortable = sorted(files, key=lambda x: int(x.with_suffix("").stem.split("-")[-1]))
            if len(sortable) == 0:
                raise FileNotFoundError("Could not find any embeddings for tile {}".format(path) + 
                " If you possibly downloaded only some of the embeddings, run '[map_name].download_latent()'.")
            for file in sortable:
                tb = feather.read_table(file, memory_map=True)
                dims = tb['_embeddings'].type.list_size
                all_embeddings.append(pc.list_flatten(tb['_embeddings']).to_numpy().reshape(-1, dims))
        return np.vstack(all_embeddings)

    def _download_latent(self):
        """
        Downloads the latent embeddings one file at a time.
        """
        logger.warning("Downloading latent embeddings of all datapoints.")
        limit = 10_000
        route = self.projection.project.atlas_api_path + '/v1/project/data/get/embedding/paged'
        last = None

        with tqdm(total=self.project.total_datums//limit) as pbar:
            while True:
                params = {
                    'projection_id': self.projection.id,
                    "last_file": last,
                    "page_size": limit
                }
                r = requests.post(route, headers=self.projection.project.header, json=params)
                if r.status_code == 204:
                    # Download complete!
                    break
                fin = BytesIO(r.content)
                tb = feather.read_table(fin, memory_map=True)

                tilename = tb.schema.metadata[b'tile'].decode("utf-8")
                dest = (self.projection.tile_destination / tilename).with_suffix(".embeddings.feather")
                dest.parent.mkdir(parents=True, exist_ok=True)
                feather.write_feather(tb, dest)
                last = tilename
                pbar.update(1)


    def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int = 5) -> Dict[str, List]:
        '''
        Performs semantic vector search over data points on your map.
        If ids is specified, receive back the most similar data ids in latent vector space to your input ids.
        If queries is specified, receive back the data ids with representations most similar to the query vectors.

        You should not specify both queries and ids.

        Args:
            queries: a 2d numpy array where each row corresponds to a query vector
            ids: a list of ids
            k: the number of closest data points (neighbors) to return for each input query/data id
        Returns:
            A tuple with two elements containing the following information:
                neighbors: A set of ids corresponding to the nearest neighbors of each query
                distances: A set of distances between each query and its neighbors
        '''

        if queries is None and ids is None:
            raise ValueError('You must specify either a list of datum `ids` or numpy array of `queries` but not both.')

        max_k = 128
        max_queries = 256
        if k > max_k:
            raise Exception(f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower")

        if ids is not None:
            if len(ids) > max_queries:
                raise Exception(f"Max ids per query is {max_queries}. You sent {len(ids)}.")
        if queries is not None:
            if not isinstance(queries, np.ndarray):
                raise Exception("`queries` must be an instance of np.array.")
            if queries.shape[0] > max_queries:
                raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.")

        if queries is not None:
            if queries.ndim != 2:
                raise ValueError(
                    'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).'
                )

            bytesio = io.BytesIO()
            np.save(bytesio, queries)

        if queries is not None:
            response = requests.post(
                self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding",
                headers=self.projection.project.header,
                json={
                    'atlas_index_id': self.projection.atlas_index_id,
                    'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'),
                    'k': k,
                },
            )
        else:
            response = requests.post(
                self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_id",
                headers=self.projection.project.header,
                json={'atlas_index_id': self.projection.atlas_index_id, 'datum_ids': ids, 'k': k},
            )

        if response.status_code == 500:
            raise Exception('Cannot perform vector search on your map at this time. Try again later.')

        if response.status_code != 200:
            raise Exception(response.text)

        response = response.json()

        return response['neighbors'], response['distances']

    def _get_embedding_iterator(self) -> Iterable[Tuple[str, str]]:
        '''
        Deprecated in favor of `map.embeddings.latent`.

        Iterate through embeddings of your datums.

        Returns:
            An iterable mapping datum ids to their embeddings.

        '''

        raise DeprecationWarning("Deprecated as of June 2023. Iterate `map.embeddings.latent`.")

        if self.project.is_locked:
            raise Exception('Project is locked! Please wait until the project is unlocked to download embeddings')

        offset = 0
        limit = EMBEDDING_PAGINATION_LIMIT
        while True:
            response = requests.get(
                self.atlas_api_path
                + f"/v1/project/data/get/embedding/{self.project.id}/{self.projection.atlas_index_id}/{offset}/{limit}",
                headers=self.header,
            )
            if response.status_code != 200:
                raise Exception(response.text)

            content = response.json()
            if len(content['datum_ids']) == 0:
                break
            offset += len(content['datum_ids'])

            yield content['datum_ids'], content['embeddings']

    def _download_embeddings(self, save_directory: str, num_workers: int = 10) -> bool:
        '''
        Deprecated in favor of `map.embeddings.latent`.

        Downloads embeddings to the specified save_directory.

        Args:
            save_directory: The directory to save your embeddings.
        Returns:
            True on success


        '''
        raise DeprecationWarning("Deprecated as of June 2023. Use `map.embeddings.latent`.")
        self.project._latest_project_state()

        total_datums = self.project.total_datums
        if self.project.is_locked:
            raise Exception('Project is locked! Please wait until the project is unlocked to download embeddings')

        offset = 0
        limit = EMBEDDING_PAGINATION_LIMIT

        def download_shard(offset, check_access=False):
            response = requests.get(
                self.project.atlas_api_path
                + f"/v1/project/data/get/embedding/{self.project.id}/{self.projection.atlas_index_id}/{offset}/{limit}",
                headers=self.project.header,
            )

            if response.status_code != 200:
                raise Exception(response.text)

            if check_access:
                return

            shard_name = '{}_{}_{}.feather'.format(self.projection.atlas_index_id, offset, offset + limit)
            shard_path = os.path.join(save_directory, shard_name)
            try:
                content = response.content
                is_arrow_format = content[:6] == b"ARROW1" and content[-6:] == b"ARROW1"

                if not is_arrow_format:
                    raise Exception('Expected response to be in Arrow IPC format')

                with open(shard_path, 'wb') as f:
                    f.write(content)

            except Exception as e:
                logger.error('Shard {} download failed with error: {}'.format(shard_name, e))

        download_shard(0, check_access=True)

        with tqdm(total=total_datums // limit) as pbar:
            with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
                futures = {
                    executor.submit(download_shard, cur_offset): cur_offset
                    for cur_offset in range(0, total_datums, limit)
                }
                for future in concurrent.futures.as_completed(futures):
                    _ = future.result()
                    pbar.update(1)

        return True

    def __repr__(self) -> str:
        return str(self.df)
df property

Pandas dataframe containing information about embeddings of your datapoints.

Includes only the two-dimensional embeddings

latent: np.array property

High dimensional embeddings.

Returns:

  • array

    A memmapped numpy array where each row contains the latent embedding of the corresponding datapoint in the same order as map.embeddings.projected.

projected: pd.DataFrame property

Two-dimensional embeddings.

These are the points you see in your web browser.

Returns:

  • DataFrame

    Pandas dataframe mapping your datapoints to their two-dimensional embeddings.

tb: pa.Table property

Pyarrow table containing two-dimensional embeddings of each of your data points. This table is memmapped from the underlying files and is the most efficient way to access embedding information.

Does not include high-dimensional embeddings.

Performs semantic vector search over data points on your map. If ids is specified, receive back the most similar data ids in latent vector space to your input ids. If queries is specified, receive back the data ids with representations most similar to the query vectors.

You should not specify both queries and ids.

Parameters:

  • queries (array, default: None ) –

    a 2d numpy array where each row corresponds to a query vector

  • ids (List[str], default: None ) –

    a list of ids

  • k (int, default: 5 ) –

    the number of closest data points (neighbors) to return for each input query/data id

Returns: A tuple with two elements containing the following information: neighbors: A set of ids corresponding to the nearest neighbors of each query distances: A set of distances between each query and its neighbors

Source code in nomic/data_operations.py
def vector_search(self, queries: np.array = None, ids: List[str] = None, k: int = 5) -> Dict[str, List]:
    '''
    Performs semantic vector search over data points on your map.
    If ids is specified, receive back the most similar data ids in latent vector space to your input ids.
    If queries is specified, receive back the data ids with representations most similar to the query vectors.

    You should not specify both queries and ids.

    Args:
        queries: a 2d numpy array where each row corresponds to a query vector
        ids: a list of ids
        k: the number of closest data points (neighbors) to return for each input query/data id
    Returns:
        A tuple with two elements containing the following information:
            neighbors: A set of ids corresponding to the nearest neighbors of each query
            distances: A set of distances between each query and its neighbors
    '''

    if queries is None and ids is None:
        raise ValueError('You must specify either a list of datum `ids` or numpy array of `queries` but not both.')

    max_k = 128
    max_queries = 256
    if k > max_k:
        raise Exception(f"Cannot query for more than {max_k} nearest neighbors. Set `k` to {max_k} or lower")

    if ids is not None:
        if len(ids) > max_queries:
            raise Exception(f"Max ids per query is {max_queries}. You sent {len(ids)}.")
    if queries is not None:
        if not isinstance(queries, np.ndarray):
            raise Exception("`queries` must be an instance of np.array.")
        if queries.shape[0] > max_queries:
            raise Exception(f"Max vectors per query is {max_queries}. You sent {queries.shape[0]}.")

    if queries is not None:
        if queries.ndim != 2:
            raise ValueError(
                'Expected a 2 dimensional array. If you have a single query, we expect an array of shape (1, d).'
            )

        bytesio = io.BytesIO()
        np.save(bytesio, queries)

    if queries is not None:
        response = requests.post(
            self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_embedding",
            headers=self.projection.project.header,
            json={
                'atlas_index_id': self.projection.atlas_index_id,
                'queries': base64.b64encode(bytesio.getvalue()).decode('utf-8'),
                'k': k,
            },
        )
    else:
        response = requests.post(
            self.projection.project.atlas_api_path + "/v1/project/data/get/nearest_neighbors/by_id",
            headers=self.projection.project.header,
            json={'atlas_index_id': self.projection.atlas_index_id, 'datum_ids': ids, 'k': k},
        )

    if response.status_code == 500:
        raise Exception('Cannot perform vector search on your map at this time. Try again later.')

    if response.status_code != 200:
        raise Exception(response.text)

    response = response.json()

    return response['neighbors'], response['distances']