Skip to content

Module vexpresso.daft.utils#

View Source
from types import MethodType

from typing import Any, Iterable, List, Optional

import daft

from daft import col

from vexpresso.retrievers import Retriever

from vexpresso.utils import DataType, ResourceRequest, get_batch_size

@daft.udf(return_dtype=DataType.int64())

def indices(columnn):

    return list(range(len(columnn.to_pylist())))

@daft.udf(return_dtype=DataType.python())

def add_column(columnn, values):

    return values

@daft.udf(return_dtype=DataType.python())

def retrieve_udf(embedding_col, query_embeddings, retriever, k):

    embeddings = embedding_col.to_pylist()

    retrieval_outputs = retriever.retrieve(

        query_embeddings=query_embeddings, embeddings=embeddings, k=k

    )

    out = []

    for i in range(len(embeddings)):

        r = []

        for retrieval_output in retrieval_outputs:

            indices = retrieval_output.indices

            scores = retrieval_output.scores

            if i >= len(scores):

                score = -(10**10)

            else:

                score = scores[i]

            results = {"retrieve_index": None, "retrieve_score": score}

            if i in indices:

                results["retrieve_index"] = i

            r.append(results)

        out.append(r)

    return out

def retrieve(

    batch_size: int,

    df: daft.DataFrame,

    embedding_column_name: str,

    query_embeddings: Iterable[Any],

    retriever: Retriever,

    k: int = None,

    sort: bool = True,

    return_scores: bool = False,

    score_column_name: Optional[str] = None,

    resource_request: ResourceRequest = ResourceRequest(),

) -> List[daft.DataFrame]:

    if score_column_name is None:

        score_column_name = f"{embedding_column_name}_score"

    df = df.with_column(

        "retrieve_output",

        retrieve_udf(

            col(embedding_column_name),

            query_embeddings=query_embeddings,

            k=k,

            retriever=retriever,

        ),

        resource_request=resource_request,

    )

    dfs = []

    batch_size = get_batch_size(query_embeddings)

    for i in range(batch_size):

        _df = (

            df.with_column(

                "retrieve_index",

                col("retrieve_output").apply(

                    lambda x: x[i]["retrieve_index"], return_dtype=DataType.int64()

                ),

            )

            .with_column(

                score_column_name,

                col("retrieve_output").apply(

                    lambda x: x[i]["retrieve_score"], return_dtype=DataType.float64()

                ),

            )

            .exclude("retrieve_output")

            .where(col("retrieve_index") != -1)

            .exclude("retrieve_index")

        )

        if sort:

            _df = _df.sort(col(score_column_name), desc=True)

        if not return_scores:

            _df = _df.exclude(score_column_name)

        dfs.append(_df)

    return dfs

class Wrapper:

    def __init__(self, collection):

        self.collection = collection

    def __repr__(self) -> str:

        return self.daft_df.__repr__()

    def __getattr__(self, name):

        if hasattr(self.collection.daft_df, name):

            func = getattr(self.collection.daft_df, name)

            return lambda *args, **kwargs: self._wrap(func, args, kwargs)

        raise AttributeError(name)

    def _wrap(self, func, args, kwargs):

        if type(func) == MethodType:

            daft_df = func(*args, **kwargs)

        else:

            daft_df = func(self.collection.daft_df, *args, **kwargs)

        if daft_df is None:

            return self.collection

        return self.collection.from_daft_df(daft_df)

Variables#

add_column
indices
retrieve_udf

Functions#

retrieve#

def retrieve(
    batch_size: int,
    df: daft.dataframe.dataframe.DataFrame,
    embedding_column_name: str,
    query_embeddings: Iterable[Any],
    retriever: vexpresso.retrievers.np.Retriever,
    k: int = None,
    sort: bool = True,
    return_scores: bool = False,
    score_column_name: Optional[str] = None,
    resource_request: daft.resource_request.ResourceRequest = ResourceRequest(num_cpus=None, num_gpus=None, memory_bytes=None)
) -> List[daft.dataframe.dataframe.DataFrame]
View Source
def retrieve(

    batch_size: int,

    df: daft.DataFrame,

    embedding_column_name: str,

    query_embeddings: Iterable[Any],

    retriever: Retriever,

    k: int = None,

    sort: bool = True,

    return_scores: bool = False,

    score_column_name: Optional[str] = None,

    resource_request: ResourceRequest = ResourceRequest(),

) -> List[daft.DataFrame]:

    if score_column_name is None:

        score_column_name = f"{embedding_column_name}_score"

    df = df.with_column(

        "retrieve_output",

        retrieve_udf(

            col(embedding_column_name),

            query_embeddings=query_embeddings,

            k=k,

            retriever=retriever,

        ),

        resource_request=resource_request,

    )

    dfs = []

    batch_size = get_batch_size(query_embeddings)

    for i in range(batch_size):

        _df = (

            df.with_column(

                "retrieve_index",

                col("retrieve_output").apply(

                    lambda x: x[i]["retrieve_index"], return_dtype=DataType.int64()

                ),

            )

            .with_column(

                score_column_name,

                col("retrieve_output").apply(

                    lambda x: x[i]["retrieve_score"], return_dtype=DataType.float64()

                ),

            )

            .exclude("retrieve_output")

            .where(col("retrieve_index") != -1)

            .exclude("retrieve_index")

        )

        if sort:

            _df = _df.sort(col(score_column_name), desc=True)

        if not return_scores:

            _df = _df.exclude(score_column_name)

        dfs.append(_df)

    return dfs

Classes#

Wrapper#

class Wrapper(
    collection
)
View Source
class Wrapper:

    def __init__(self, collection):

        self.collection = collection

    def __repr__(self) -> str:

        return self.daft_df.__repr__()

    def __getattr__(self, name):

        if hasattr(self.collection.daft_df, name):

            func = getattr(self.collection.daft_df, name)

            return lambda *args, **kwargs: self._wrap(func, args, kwargs)

        raise AttributeError(name)

    def _wrap(self, func, args, kwargs):

        if type(func) == MethodType:

            daft_df = func(*args, **kwargs)

        else:

            daft_df = func(self.collection.daft_df, *args, **kwargs)

        if daft_df is None:

            return self.collection

        return self.collection.from_daft_df(daft_df)