"""Classes and logic for the Configuration."""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from foundry_dev_tools.config.config_types import Host
from foundry_dev_tools.config.token_provider import (
TOKEN_PROVIDER_MAPPING,
AppServiceTokenProvider,
TokenProvider,
)
from foundry_dev_tools.errors.config import (
FoundryConfigError,
MissingCredentialsConfigError,
MissingFoundryHostError,
TokenProviderConfigError,
)
from foundry_dev_tools.utils.config import (
cfg_files,
check_init,
get_environment_variable_config,
merge_dicts,
path_from_path_or_str,
user_cache,
)
if TYPE_CHECKING:
from collections.abc import Iterable
from os import PathLike
# compatibility for python version < 3.11
if sys.version_info < (3, 11):
import tomli as tomllib
else:
import tomllib
[docs]
class Config:
"""Class for Configuration options."""
[docs]
def __init__(
self,
requests_ca_bundle: PathLike[str] | None = None,
transforms_sql_sample_row_limit: int = 5000,
transforms_sql_dataset_size_threshold: int = 500,
transforms_sql_sample_select_random: bool = False,
transforms_force_full_dataset_download: bool = False,
cache_dir: PathLike[str] | None = None,
transforms_freeze_cache: bool = False,
transforms_output_folder: PathLike[str] | None = None,
rich_traceback: bool = False,
debug: bool = False,
) -> None:
"""Initialize the configuration.
Args:
requests_ca_bundle: a path to a CA bundle if :py:mod:`requests` needs custom certificates to work
e.g. in a corporate network
transforms_sql_dataset_size_threshold: only download the complete dataset if it doesn't exceed this
threshold otherwise only return a smaller number of rows, set by `transforms_sql_sample_row_limit`
transforms_sql_sample_row_limit: Number of sql rows to return when the dataset
is above the `transforms_sql_dataset_size_threshold`
transforms_sql_sample_select_random: set to true if the sample rows should be random
(query will take more time)
transforms_force_full_dataset_download: if true, ignores the `transforms_sql_dataset_size_threshold`
and downloads the full dataset
cache_dir: path to the cache dir for downloaded datasets,
default is :py:attr:`~foundry_dev_tools.utils.config.user_cache`
transforms_freeze_cache: if this setting is enabled, transforms will work offline
if the datasets are already in cache
transforms_output_folder: When @transform in combination with TransformOutput.filesystem() is used,
files are written to this folder.
rich_traceback: enables a prettier traceback provided by the module `rich` See: https://rich.readthedocs.io/en/stable/traceback.html
debug: enables debug logging
"""
self.requests_ca_bundle = os.fspath(requests_ca_bundle) if requests_ca_bundle else None
self.cache_dir = path_from_path_or_str(cache_dir) if cache_dir is not None else user_cache()
self.transforms_output_folder = (
path_from_path_or_str(transforms_output_folder) if transforms_output_folder is not None else None
)
self.transforms_sql_sample_row_limit = int(transforms_sql_sample_row_limit)
self.transforms_sql_dataset_size_threshold = int(transforms_sql_dataset_size_threshold)
self.transforms_sql_sample_select_random = bool(transforms_sql_sample_select_random)
self.transforms_force_full_dataset_download = bool(transforms_force_full_dataset_download)
self.transforms_freeze_cache = bool(transforms_freeze_cache)
self.rich_traceback = bool(rich_traceback)
self.debug = bool(debug)
def __repr__(self) -> str:
return "<" + self.__class__.__name__ + "(" + self.__dict__.__str__() + ")>"
def _load_config_file(config_file: Path) -> dict | None:
try:
with config_file.open("rb") as config_file_fd:
return tomllib.load(config_file_fd)
except OSError:
return None
def _load_config_files(config_files: Iterable[Path]) -> dict:
"""Merges the given config files.
The last file wins.
"""
config = {}
for cfg_file in config_files:
if cfg_file.exists():
c = _load_config_file(cfg_file) or {}
config = merge_dicts(config, c)
# When no files were found, check for old config
if len(config) == 0:
p = Path("~/.foundry-dev-tools/config").expanduser()
if p.exists() and not p.is_dir():
msg = "Please use the `fdt config migrate` command to migrate the v1 config to the v2 config format."
raise FoundryConfigError(msg)
return config
def _pure_config_dict(env: bool = True) -> dict:
config = _load_config_files(cfg_files())
if env:
config = merge_dicts(config, get_environment_variable_config())
return config
def _find_token_provider(credentials: dict) -> str | None:
# go backwards, to use the last defined token provider in the config
for k in reversed(credentials):
if k in TOKEN_PROVIDER_MAPPING:
return k
return None
MISSING_TP_ERROR = TokenProviderConfigError(
"To authenticate with Foundry you need a TokenProvider. The token provider can be configured either via the"
" configuration file or the token_provider FoundryContext parameter."
)
[docs]
def get_config_dict(profile: str | None = None, env: bool = True) -> dict | None:
"""Loads config from the config files and environment variables.
Profiles make configs like this possible:
.. code-block:: toml
[config]
example_option = 1
[integration.config]
example_option = 2
Where 'integration.config' will be loaded instead of 'config'
if the profile is set to integration.
Args:
profile: The profile to use, if None the default profile is used
env: Whether to load the environment variables
"""
config = _pure_config_dict(env=env)
if not config:
return None
if profile is None:
profile = config.get("profile")
if profile in ("config", "credentials"):
msg = f"Profile name can't be {profile}"
raise AttributeError(msg)
if profile and (profile_config := config.get(profile)) and isinstance(profile_config, dict):
profile_credentials = profile_config.get("credentials", {})
profile_config = profile_config.get("config", {})
else:
# use an empty profile config if no profile is specified
# this aims to reduce duplicate code
# but is going to do useless operations
profile_credentials = {}
profile_config = {}
default_credentials = config.get("credentials", {})
domain = profile_credentials.pop("domain", default_credentials.pop("domain", None))
# a domain must always be provided
if not domain:
raise MissingFoundryHostError
scheme = profile_credentials.pop("scheme", default_credentials.pop("scheme", None))
# decide which will be the token provider used
token_provider_name = _find_token_provider(profile_credentials) or _find_token_provider(default_credentials)
if not token_provider_name:
if "APP_SERVICE_TS" in os.environ:
token_provider_name = "app_service" # noqa: S105
else:
raise MISSING_TP_ERROR
# merge the profile config with the non-prefixed config
return_config = {}
merged_config = merge_dicts(config.get("config", {}), profile_config.get("config", {}))
if merged_config:
return_config["config"] = merged_config
return_config["credentials"] = {
"domain": domain,
token_provider_name: merge_dicts(
default_credentials.get(token_provider_name), profile_credentials.get(token_provider_name)
),
}
if scheme is not None:
return_config["credentials"]["scheme"] = scheme
return return_config
[docs]
def parse_credentials_config(config_dict: dict | None) -> TokenProvider:
"""Parses the credentials config dictionary and returns a TokenProvider object."""
# check if there is a credentials config present
if config_dict is not None and (credentials_config := config_dict.get("credentials")):
# a domain must always be provided
if "domain" not in credentials_config:
raise MissingFoundryHostError
# create a host object with the domain and the optional scheme setting
host = Host(credentials_config.pop("domain"), credentials_config.pop("scheme", None))
# get the token provider config setting, if it does not exist use an empty dict
try:
tp_name, tp_config = credentials_config.popitem()
# make it possible to do jwt = "eyJ" instead of jwt = {jwt="eyJ"}
if tp_config is None or len(tp_config) == 0:
tp_config = {}
elif not isinstance(tp_config, dict):
tp_config = {tp_name: tp_config}
except KeyError:
tp_name, tp_config = None, None
if tp_name:
if mapped_class := TOKEN_PROVIDER_MAPPING.get(tp_name):
# check the config kwargs and pass the valid kwargs to the mapped class
return mapped_class(**check_init(mapped_class, "credentials", {"host": host, **tp_config}))
# if the token_provider name was set but not present in the mapping
msg = f"The token provider implementation {tp_name} does not exist."
raise TokenProviderConfigError(msg)
# use flask/dash/streamlit provider when used in the app service
if "APP_SERVICE_TS" in os.environ:
return AppServiceTokenProvider(host=host)
raise MISSING_TP_ERROR
raise MissingCredentialsConfigError
[docs]
def parse_general_config(config_dict: dict | None = None) -> Config:
"""Parses the config dictionary and returns a Config object."""
if config_dict is not None and (general_config := config_dict.get("config")):
return Config(**check_init(Config, "config", general_config))
return Config()