Source code for foundry_dev_tools.cached_foundry_client

"""Cached Foundry Client, high level client to work with Foundry."""

from __future__ import annotations

import logging
import os
import pickle
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING

from foundry_dev_tools.config.context import FoundryContext
from foundry_dev_tools.errors.dataset import (
    BranchNotFoundError,
    DatasetHasNoSchemaError,
    DatasetNotFoundError,
)
from foundry_dev_tools.foundry_api_client import FoundryRestClient
from foundry_dev_tools.utils.caches.spark_caches import DiskPersistenceBackedSparkCache
from foundry_dev_tools.utils.converter.foundry_spark import (
    infer_dataset_format_from_foundry_schema,
)
from foundry_dev_tools.utils.misc import is_dataset_a_view

LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
    import pandas as pd
    import pyspark.sql

    from foundry_dev_tools.utils import api_types


[docs] class CachedFoundryClient: """A Foundry Client that offers a high level API to Foundry. Methods to save, load datasets are implemented. """
[docs] def __init__(self, config: dict | None = None, ctx: FoundryContext | FoundryRestClient | None = None): """Initialize `CachedFoundryClient`. Possible to pass overwrite config and uses `Configuration` to read it from the ~/.foundry-dev-tools/config file. Args: config: config dict to overwrite values from config file ctx: foundrycontext to use, if supplied the `config` parameter will be ignored """ if ctx: if isinstance(ctx, FoundryContext): self.api = FoundryRestClient(ctx=ctx) self.cache = DiskPersistenceBackedSparkCache(ctx=ctx) else: self.api = ctx self.cache = DiskPersistenceBackedSparkCache(ctx=ctx.ctx) else: config = config or {} self.api = FoundryRestClient(config) self.cache = DiskPersistenceBackedSparkCache(ctx=self.api.ctx)
[docs] def load_dataset(self, dataset_path_or_rid: str, branch: str = "master") -> pyspark.sql.DataFrame: """Loads complete dataset from Foundry and stores in cache. Cache is invalidated once new transaction is present in Foundry. Last 2 transactions are kept in cache and older transactions are cleaned up. Args: dataset_path_or_rid (str): Path to dataset or the rid of the dataset branch (str): The branch of the dataset Returns: :external+spark:py:class:`~pyspark.sql.DataFrame` """ _, dataset_identity = self.fetch_dataset(dataset_path_or_rid, branch) return self.cache[dataset_identity]
[docs] def fetch_dataset(self, dataset_path_or_rid: str, branch: str = "master") -> tuple[str, api_types.DatasetIdentity]: """Downloads complete dataset from Foundry and stores in cache. Returns local path to dataset Args: dataset_path_or_rid: Path to dataset or the rid of the dataset branch: The branch of the dataset Returns: `Tuple[str, dict]`: local path to the dataset, dataset_identity """ dataset_identity = self._get_dataset_identity(dataset_path_or_rid, branch) return os.fspath(self._fetch_dataset(dataset_identity, branch=branch)), dataset_identity
def _fetch_dataset(self, dataset_identity: api_types.DatasetIdentity, branch: str = "master") -> Path: last_transaction = dataset_identity["last_transaction"] if dataset_identity in list(self.cache.keys()): return self._return_local_path_of_cached_dataset(dataset_identity, branch) try: foundry_schema = self.api.get_dataset_schema( dataset_identity["dataset_rid"], dataset_identity["last_transaction"]["rid"], branch=branch, ) except DatasetHasNoSchemaError: # Binary datasets or no schema foundry_schema = None if is_dataset_a_view(last_transaction["transaction"]): self.cache[dataset_identity] = self.api.query_foundry_sql( f'SELECT * FROM `{dataset_identity["dataset_rid"]}`', # noqa: S608 branch=branch, return_type="spark", ) return self._return_local_path_of_cached_dataset(dataset_identity, branch) return self._download_dataset_and_return_local_path(dataset_identity, branch, foundry_schema) def _get_dataset_identity( self, dataset_path_or_rid: api_types.Rid, branch: api_types.DatasetBranch, ) -> api_types.DatasetIdentity: if self.api.ctx.config.transforms_freeze_cache is False: return self._get_dataset_identity_online(dataset_path_or_rid, branch) return self._get_dataset_identity_offline(dataset_path_or_rid) def _get_dataset_identity_online( self, dataset_path_or_rid: api_types.Rid, branch: api_types.DatasetBranch, ) -> api_types.DatasetIdentity: return self.api.get_dataset_identity(dataset_path_or_rid, branch) def _get_dataset_identity_offline(self, dataset_path_or_rid: api_types.Rid) -> api_types.DatasetIdentity: # Note: this is not branch aware, so it will return a dataset from the cache, # even though the branch might be different to that requested. return self.cache.get_dataset_identity_not_branch_aware(dataset_path_or_rid) def _return_local_path_of_cached_dataset( self, dataset_identity: api_types.DatasetIdentity, branch: api_types.DatasetBranch, ) -> Path: LOGGER.debug("Returning data for %s on branch %s from cache", dataset_identity, branch) return self.cache.get_path_to_local_dataset(dataset_identity) def _download_dataset_and_return_local_path( self, dataset_identity: api_types.DatasetIdentity, branch: api_types.DatasetBranch, foundry_schema: api_types.FoundrySchema | None, ) -> Path: LOGGER.debug("Caching data for %s on branch %s", dataset_identity, branch) self._download_dataset_to_cache_dir(dataset_identity, branch, foundry_schema) return self.cache.get_path_to_local_dataset(dataset_identity) def _download_dataset_to_cache_dir( self, dataset_identity: api_types.DatasetIdentity, branch: api_types.DatasetBranch, foundry_schema: api_types.FoundrySchema | None, ): list_of_files = self.api.list_dataset_files( dataset_identity["dataset_rid"], exclude_hidden_files=True, view=branch, ) suffix = "." + infer_dataset_format_from_foundry_schema(foundry_schema, list_of_files) path = self.cache.get_cache_dir().joinpath( dataset_identity["dataset_rid"], dataset_identity["last_transaction"]["rid"] + suffix, ) self.api.download_dataset_files( dataset_rid=dataset_identity["dataset_rid"], output_directory=path, files=list_of_files, view=branch, ) self.cache.set_item_metadata(path, dataset_identity, foundry_schema)
[docs] def save_dataset( self, df: pd.DataFrame | pyspark.sql.DataFrame, dataset_path_or_rid: str, branch: str = "master", exists_ok: bool = False, mode: str = "SNAPSHOT", ) -> tuple[str, str]: """Saves a dataframe to Foundry. If the dataset in Foundry does not exist it is created. If the branch does not exist, it is created. If the dataset exists, an exception is thrown. If exists_ok=True is passed, the dataset is overwritten. Creates SNAPSHOT transactions by default. Args: df (:external+pandas:py:class:`pandas.DataFrame` | :external+spark:py:class:`pyspark.sql.DataFrame`): A pyspark or pandas DataFrame to upload dataset_path_or_rid (str): Path or Rid of the dataset in which the object should be stored. branch (str): Branch of the dataset in which the object should be stored exists_ok (bool): By default, this method creates a new dataset. Pass exists_ok=True to overwrite according to strategy from parameter 'mode' mode (str): Foundry Transaction type: SNAPSHOT (only new files are present after transaction), UPDATE (replace files with same filename, keep present files), APPEND (add files that are not present yet) Returns: :py:class:`Tuple`: tuple with (dataset_rid, transaction_rid) Raises: ValueError: when dataframe is None ValueError: when branch is None """ if df is None: msg = "Please provide a spark or pandas dataframe object with parameter 'df'" raise ValueError(msg) if branch is None: msg = "Please provide a dataset branch with parameter 'branch'" raise ValueError(msg) with tempfile.TemporaryDirectory() as path: from foundry_dev_tools._optional.pandas import pd if not pd.__fake__ and isinstance(df, pd.DataFrame): df.to_parquet( os.sep.join([path + "/dataset.parquet"]), # noqa: PTH118 engine="pyarrow", compression="snappy", flavor="spark", ) else: df.write.format("parquet").option("compression", "snappy").save(path=path, mode="overwrite") filenames = list(filter(lambda file: not file.endswith(".crc"), os.listdir(path))) filepaths = [Path(path).joinpath(file) for file in filenames] # to be backwards compatible to most readers, that expect files # to be under spark/ folder = round(time.time() * 1000) if mode == "APPEND" else "spark" dataset_paths_in_foundry = [f"{folder}/" + file for file in filenames] path_file_dict = dict(zip(dataset_paths_in_foundry, filepaths, strict=False)) dataset_rid, transaction_id = self._save_objects( path_file_dict, dataset_path_or_rid, branch, exists_ok, mode, ) foundry_schema = self.api.infer_dataset_schema(dataset_rid, branch) self.api.upload_dataset_schema(dataset_rid, transaction_id, foundry_schema, branch) return dataset_rid, transaction_id
def _save_objects( self, path_file_dict: dict, dataset_path_or_rid: str, branch: str, exists_ok: bool = False, mode: str = "SNAPSHOT", ) -> tuple[str, str]: if path_file_dict is None or len(path_file_dict) == 0: msg = "Please provide at least one file like object in dict 'path_file_dict" raise ValueError(msg) if branch is None: msg = "Please provide a dataset branch with parameter 'branch'" raise ValueError(msg) try: identity = self.api.get_dataset_identity(dataset_path_or_rid) dataset_rid = identity["dataset_rid"] dataset_path = identity["dataset_path"] # Check if dataset and branch exists and not in trash self.api.get_dataset(dataset_rid) self.api.get_branch(dataset_rid, branch) if self.api.is_dataset_in_trash(dataset_path): msg = f"Dataset '{dataset_path}' is in trash." raise ValueError(msg) except DatasetNotFoundError: dataset_rid = self.api.create_dataset(dataset_path_or_rid)["rid"] self.api.create_branch(dataset_rid, branch) exists_ok = True except BranchNotFoundError: self.api.create_branch(dataset_rid, branch) exists_ok = True if not exists_ok: msg = ( f"Dataset '{dataset_path_or_rid}' already exists. If you are sure to overwrite" " / modify the existing dataset, call this method with parameter exists_ok=True" ) raise ValueError( msg, ) transaction_id = self.api.open_transaction(dataset_rid, mode, branch) try: self.api.upload_dataset_files(dataset_rid, transaction_id, path_file_dict) except Exception as e: self.api.abort_transaction(dataset_rid, transaction_id) msg = ( "There was an issue while uploading the dataset files." f" Transaction {transaction_id} on dataset {dataset_rid} has been aborted." ) raise ValueError( msg, ) from e else: self.api.commit_transaction(dataset_rid, transaction_id) return dataset_rid, transaction_id
[docs] def save_model( self, model_obj: object, dataset_path_or_rid: str, branch: str = "master", exists_ok: bool = False, mode: str = "SNAPSHOT", ) -> tuple[str, str]: """Saves a python object to a foundry dataset. The python object is pickled and uploaded to path model.pickle. The uploaded model can be loaded for performing predictions inside foundry pipelines. Args: model_obj (object): Any python object that can be pickled dataset_path_or_rid (str): Path or Rid of the dataset in which the object should be stored. branch (bool): Branch of the dataset in which the object should be stored exists_ok (bool): By default, this method creates a new dataset. Pass exists_ok=True to overwrite according to strategy from parameter 'mode' mode (str): Foundry Transaction type: SNAPSHOT (only new files are present after transaction), UPDATE (replace files with same filename, keep present files), APPEND (add files that are not present yet) Raises: ValueError: When model_obj or branch is None Returns: :py:class:`Tuple`: Tuple with (dataset_rid, transaction_rid) """ if model_obj is None: msg = "Please provide a model object with parameter 'model_obj'" raise ValueError(msg) if branch is None: msg = "Please provide a dataset branch with parameter 'branch'" raise ValueError(msg) with tempfile.TemporaryDirectory() as path: model_path = Path(path).joinpath("model.pickle") with model_path.open(mode="wb") as file: pickle.dump(model_obj, file) return self._save_objects( {"model.pickle": os.fspath(model_path)}, dataset_path_or_rid, branch, exists_ok, mode, )