Changed code to support older Python versions
This commit is contained in:
parent
eb92d2d36f
commit
582458cdd0
5027 changed files with 794942 additions and 4 deletions
|
|
@ -0,0 +1,65 @@
|
|||
from .exceptions import SettingsError
|
||||
from .main import BaseSettings, CliApp, SettingsConfigDict
|
||||
from .sources import (
|
||||
CLI_SUPPRESS,
|
||||
AWSSecretsManagerSettingsSource,
|
||||
AzureKeyVaultSettingsSource,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliUnknownArgs,
|
||||
DotEnvSettingsSource,
|
||||
EnvSettingsSource,
|
||||
ForceDecode,
|
||||
GoogleSecretManagerSettingsSource,
|
||||
InitSettingsSource,
|
||||
JsonConfigSettingsSource,
|
||||
NestedSecretsSettingsSource,
|
||||
NoDecode,
|
||||
PydanticBaseSettingsSource,
|
||||
PyprojectTomlConfigSettingsSource,
|
||||
SecretsSettingsSource,
|
||||
TomlConfigSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .version import VERSION
|
||||
|
||||
__all__ = (
|
||||
'CLI_SUPPRESS',
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'BaseSettings',
|
||||
'CliApp',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'CliUnknownArgs',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'ForceDecode',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'NestedSecretsSettingsSource',
|
||||
'NoDecode',
|
||||
'PydanticBaseSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'SettingsConfigDict',
|
||||
'SettingsError',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
'__version__',
|
||||
'get_subcommand',
|
||||
)
|
||||
|
||||
__version__ = VERSION
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,4 @@
|
|||
class SettingsError(ValueError):
|
||||
"""Base exception for settings-related errors."""
|
||||
|
||||
pass
|
||||
717
venv/lib/python3.11/site-packages/pydantic_settings/main.py
Normal file
717
venv/lib/python3.11/site-packages/pydantic_settings/main.py
Normal file
|
|
@ -0,0 +1,717 @@
|
|||
from __future__ import annotations as _annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, ClassVar, Literal, TypeVar
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic._internal._config import config_keys
|
||||
from pydantic._internal._signature import _field_name_for_signature
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from .exceptions import SettingsError
|
||||
from .sources import (
|
||||
ENV_FILE_SENTINEL,
|
||||
CliSettingsSource,
|
||||
DefaultSettingsSource,
|
||||
DotEnvSettingsSource,
|
||||
DotenvType,
|
||||
EnvSettingsSource,
|
||||
InitSettingsSource,
|
||||
JsonConfigSettingsSource,
|
||||
PathType,
|
||||
PydanticBaseSettingsSource,
|
||||
PydanticModel,
|
||||
PyprojectTomlConfigSettingsSource,
|
||||
SecretsSettingsSource,
|
||||
TomlConfigSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .sources.utils import _get_alias_names
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SettingsConfigDict(ConfigDict, total=False):
|
||||
case_sensitive: bool
|
||||
nested_model_default_partial_update: bool | None
|
||||
env_prefix: str
|
||||
env_file: DotenvType | None
|
||||
env_file_encoding: str | None
|
||||
env_ignore_empty: bool
|
||||
env_nested_delimiter: str | None
|
||||
env_nested_max_split: int | None
|
||||
env_parse_none_str: str | None
|
||||
env_parse_enums: bool | None
|
||||
cli_prog_name: str | None
|
||||
cli_parse_args: bool | list[str] | tuple[str, ...] | None
|
||||
cli_parse_none_str: str | None
|
||||
cli_hide_none_type: bool
|
||||
cli_avoid_json: bool
|
||||
cli_enforce_required: bool
|
||||
cli_use_class_docs_for_groups: bool
|
||||
cli_exit_on_error: bool
|
||||
cli_prefix: str
|
||||
cli_flag_prefix_char: str
|
||||
cli_implicit_flags: bool | None
|
||||
cli_ignore_unknown_args: bool | None
|
||||
cli_kebab_case: bool | Literal['all', 'no_enums'] | None
|
||||
cli_shortcuts: Mapping[str, str | list[str]] | None
|
||||
secrets_dir: PathType | None
|
||||
json_file: PathType | None
|
||||
json_file_encoding: str | None
|
||||
yaml_file: PathType | None
|
||||
yaml_file_encoding: str | None
|
||||
yaml_config_section: str | None
|
||||
"""
|
||||
Specifies the top-level key in a YAML file from which to load the settings.
|
||||
If provided, the settings will be loaded from the nested section under this key.
|
||||
This is useful when the YAML file contains multiple configuration sections
|
||||
and you only want to load a specific subset into your settings model.
|
||||
"""
|
||||
|
||||
pyproject_toml_depth: int
|
||||
"""
|
||||
Number of levels **up** from the current working directory to attempt to find a pyproject.toml
|
||||
file.
|
||||
|
||||
This is only used when a pyproject.toml file is not found in the current working directory.
|
||||
"""
|
||||
|
||||
pyproject_toml_table_header: tuple[str, ...]
|
||||
"""
|
||||
Header of the TOML table within a pyproject.toml file to use when filling variables.
|
||||
This is supplied as a `tuple[str, ...]` instead of a `str` to accommodate for headers
|
||||
containing a `.`.
|
||||
|
||||
For example, `toml_table_header = ("tool", "my.tool", "foo")` can be used to fill variable
|
||||
values from a table with header `[tool."my.tool".foo]`.
|
||||
|
||||
To use the root table, exclude this config setting or provide an empty tuple.
|
||||
"""
|
||||
|
||||
toml_file: PathType | None
|
||||
enable_decoding: bool
|
||||
|
||||
|
||||
# Extend `config_keys` by pydantic settings config keys to
|
||||
# support setting config through class kwargs.
|
||||
# Pydantic uses `config_keys` in `pydantic._internal._config.ConfigWrapper.for_model`
|
||||
# to extract config keys from model kwargs, So, by adding pydantic settings keys to
|
||||
# `config_keys`, they will be considered as valid config keys and will be collected
|
||||
# by Pydantic.
|
||||
config_keys |= set(SettingsConfigDict.__annotations__.keys())
|
||||
|
||||
|
||||
class BaseSettings(BaseModel):
|
||||
"""
|
||||
Base class for settings, allowing values to be overridden by environment variables.
|
||||
|
||||
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
|
||||
Heroku and any 12 factor app design.
|
||||
|
||||
All the below attributes can be set via `model_config`.
|
||||
|
||||
Args:
|
||||
_case_sensitive: Whether environment and CLI variable names should be read with case-sensitivity.
|
||||
Defaults to `None`.
|
||||
_nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
|
||||
Defaults to `False`.
|
||||
_env_prefix: Prefix for all environment variables. Defaults to `None`.
|
||||
_env_file: The env file(s) to load settings values from. Defaults to `Path('')`, which
|
||||
means that the value from `model_config['env_file']` should be used. You can also pass
|
||||
`None` to indicate that environment variables should not be loaded from an env file.
|
||||
_env_file_encoding: The env file encoding, e.g. `'latin-1'`. Defaults to `None`.
|
||||
_env_ignore_empty: Ignore environment variables where the value is an empty string. Default to `False`.
|
||||
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
|
||||
_env_nested_max_split: The nested env values maximum nesting. Defaults to `None`, which means no limit.
|
||||
_env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.)
|
||||
into `None` type(None). Defaults to `None` type(None), which means no parsing should occur.
|
||||
_env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur.
|
||||
_cli_prog_name: The CLI program name to display in help text. Defaults to `None` if _cli_parse_args is `None`.
|
||||
Otherwise, defaults to sys.argv[0].
|
||||
_cli_parse_args: The list of CLI arguments to parse. Defaults to None.
|
||||
If set to `True`, defaults to sys.argv[1:].
|
||||
_cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to None.
|
||||
_cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into
|
||||
`None` type(None). Defaults to _env_parse_none_str value if set. Otherwise, defaults to "null" if
|
||||
_cli_avoid_json is `False`, and "None" if _cli_avoid_json is `True`.
|
||||
_cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`.
|
||||
_cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`.
|
||||
_cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`.
|
||||
_cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions.
|
||||
Defaults to `False`.
|
||||
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
|
||||
Defaults to `True`.
|
||||
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
|
||||
_cli_flag_prefix_char: The flag prefix character to use for CLI optional arguments. Defaults to '-'.
|
||||
_cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
|
||||
(e.g. --flag, --no-flag). Defaults to `False`.
|
||||
_cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`.
|
||||
_cli_kebab_case: CLI args use kebab case. Defaults to `False`.
|
||||
_cli_shortcuts: Mapping of target field name to alias names. Defaults to `None`.
|
||||
_secrets_dir: The secret files directory or a sequence of directories. Defaults to `None`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
__pydantic_self__,
|
||||
_case_sensitive: bool | None = None,
|
||||
_nested_model_default_partial_update: bool | None = None,
|
||||
_env_prefix: str | None = None,
|
||||
_env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
_env_file_encoding: str | None = None,
|
||||
_env_ignore_empty: bool | None = None,
|
||||
_env_nested_delimiter: str | None = None,
|
||||
_env_nested_max_split: int | None = None,
|
||||
_env_parse_none_str: str | None = None,
|
||||
_env_parse_enums: bool | None = None,
|
||||
_cli_prog_name: str | None = None,
|
||||
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
|
||||
_cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
_cli_parse_none_str: str | None = None,
|
||||
_cli_hide_none_type: bool | None = None,
|
||||
_cli_avoid_json: bool | None = None,
|
||||
_cli_enforce_required: bool | None = None,
|
||||
_cli_use_class_docs_for_groups: bool | None = None,
|
||||
_cli_exit_on_error: bool | None = None,
|
||||
_cli_prefix: str | None = None,
|
||||
_cli_flag_prefix_char: str | None = None,
|
||||
_cli_implicit_flags: bool | None = None,
|
||||
_cli_ignore_unknown_args: bool | None = None,
|
||||
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
|
||||
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
|
||||
_secrets_dir: PathType | None = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
**__pydantic_self__._settings_build_values(
|
||||
values,
|
||||
_case_sensitive=_case_sensitive,
|
||||
_nested_model_default_partial_update=_nested_model_default_partial_update,
|
||||
_env_prefix=_env_prefix,
|
||||
_env_file=_env_file,
|
||||
_env_file_encoding=_env_file_encoding,
|
||||
_env_ignore_empty=_env_ignore_empty,
|
||||
_env_nested_delimiter=_env_nested_delimiter,
|
||||
_env_nested_max_split=_env_nested_max_split,
|
||||
_env_parse_none_str=_env_parse_none_str,
|
||||
_env_parse_enums=_env_parse_enums,
|
||||
_cli_prog_name=_cli_prog_name,
|
||||
_cli_parse_args=_cli_parse_args,
|
||||
_cli_settings_source=_cli_settings_source,
|
||||
_cli_parse_none_str=_cli_parse_none_str,
|
||||
_cli_hide_none_type=_cli_hide_none_type,
|
||||
_cli_avoid_json=_cli_avoid_json,
|
||||
_cli_enforce_required=_cli_enforce_required,
|
||||
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
|
||||
_cli_exit_on_error=_cli_exit_on_error,
|
||||
_cli_prefix=_cli_prefix,
|
||||
_cli_flag_prefix_char=_cli_flag_prefix_char,
|
||||
_cli_implicit_flags=_cli_implicit_flags,
|
||||
_cli_ignore_unknown_args=_cli_ignore_unknown_args,
|
||||
_cli_kebab_case=_cli_kebab_case,
|
||||
_cli_shortcuts=_cli_shortcuts,
|
||||
_secrets_dir=_secrets_dir,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""
|
||||
Define the sources and their order for loading the settings values.
|
||||
|
||||
Args:
|
||||
settings_cls: The Settings class.
|
||||
init_settings: The `InitSettingsSource` instance.
|
||||
env_settings: The `EnvSettingsSource` instance.
|
||||
dotenv_settings: The `DotEnvSettingsSource` instance.
|
||||
file_secret_settings: The `SecretsSettingsSource` instance.
|
||||
|
||||
Returns:
|
||||
A tuple containing the sources and their order for loading the settings values.
|
||||
"""
|
||||
return init_settings, env_settings, dotenv_settings, file_secret_settings
|
||||
|
||||
def _settings_build_values(
|
||||
self,
|
||||
init_kwargs: dict[str, Any],
|
||||
_case_sensitive: bool | None = None,
|
||||
_nested_model_default_partial_update: bool | None = None,
|
||||
_env_prefix: str | None = None,
|
||||
_env_file: DotenvType | None = None,
|
||||
_env_file_encoding: str | None = None,
|
||||
_env_ignore_empty: bool | None = None,
|
||||
_env_nested_delimiter: str | None = None,
|
||||
_env_nested_max_split: int | None = None,
|
||||
_env_parse_none_str: str | None = None,
|
||||
_env_parse_enums: bool | None = None,
|
||||
_cli_prog_name: str | None = None,
|
||||
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
|
||||
_cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
_cli_parse_none_str: str | None = None,
|
||||
_cli_hide_none_type: bool | None = None,
|
||||
_cli_avoid_json: bool | None = None,
|
||||
_cli_enforce_required: bool | None = None,
|
||||
_cli_use_class_docs_for_groups: bool | None = None,
|
||||
_cli_exit_on_error: bool | None = None,
|
||||
_cli_prefix: str | None = None,
|
||||
_cli_flag_prefix_char: str | None = None,
|
||||
_cli_implicit_flags: bool | None = None,
|
||||
_cli_ignore_unknown_args: bool | None = None,
|
||||
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
|
||||
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
|
||||
_secrets_dir: PathType | None = None,
|
||||
) -> dict[str, Any]:
|
||||
# Determine settings config values
|
||||
case_sensitive = _case_sensitive if _case_sensitive is not None else self.model_config.get('case_sensitive')
|
||||
env_prefix = _env_prefix if _env_prefix is not None else self.model_config.get('env_prefix')
|
||||
nested_model_default_partial_update = (
|
||||
_nested_model_default_partial_update
|
||||
if _nested_model_default_partial_update is not None
|
||||
else self.model_config.get('nested_model_default_partial_update')
|
||||
)
|
||||
env_file = _env_file if _env_file != ENV_FILE_SENTINEL else self.model_config.get('env_file')
|
||||
env_file_encoding = (
|
||||
_env_file_encoding if _env_file_encoding is not None else self.model_config.get('env_file_encoding')
|
||||
)
|
||||
env_ignore_empty = (
|
||||
_env_ignore_empty if _env_ignore_empty is not None else self.model_config.get('env_ignore_empty')
|
||||
)
|
||||
env_nested_delimiter = (
|
||||
_env_nested_delimiter
|
||||
if _env_nested_delimiter is not None
|
||||
else self.model_config.get('env_nested_delimiter')
|
||||
)
|
||||
env_nested_max_split = (
|
||||
_env_nested_max_split
|
||||
if _env_nested_max_split is not None
|
||||
else self.model_config.get('env_nested_max_split')
|
||||
)
|
||||
env_parse_none_str = (
|
||||
_env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str')
|
||||
)
|
||||
env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums')
|
||||
|
||||
cli_prog_name = _cli_prog_name if _cli_prog_name is not None else self.model_config.get('cli_prog_name')
|
||||
cli_parse_args = _cli_parse_args if _cli_parse_args is not None else self.model_config.get('cli_parse_args')
|
||||
cli_settings_source = (
|
||||
_cli_settings_source if _cli_settings_source is not None else self.model_config.get('cli_settings_source')
|
||||
)
|
||||
cli_parse_none_str = (
|
||||
_cli_parse_none_str if _cli_parse_none_str is not None else self.model_config.get('cli_parse_none_str')
|
||||
)
|
||||
cli_parse_none_str = cli_parse_none_str if not env_parse_none_str else env_parse_none_str
|
||||
cli_hide_none_type = (
|
||||
_cli_hide_none_type if _cli_hide_none_type is not None else self.model_config.get('cli_hide_none_type')
|
||||
)
|
||||
cli_avoid_json = _cli_avoid_json if _cli_avoid_json is not None else self.model_config.get('cli_avoid_json')
|
||||
cli_enforce_required = (
|
||||
_cli_enforce_required
|
||||
if _cli_enforce_required is not None
|
||||
else self.model_config.get('cli_enforce_required')
|
||||
)
|
||||
cli_use_class_docs_for_groups = (
|
||||
_cli_use_class_docs_for_groups
|
||||
if _cli_use_class_docs_for_groups is not None
|
||||
else self.model_config.get('cli_use_class_docs_for_groups')
|
||||
)
|
||||
cli_exit_on_error = (
|
||||
_cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error')
|
||||
)
|
||||
cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix')
|
||||
cli_flag_prefix_char = (
|
||||
_cli_flag_prefix_char
|
||||
if _cli_flag_prefix_char is not None
|
||||
else self.model_config.get('cli_flag_prefix_char')
|
||||
)
|
||||
cli_implicit_flags = (
|
||||
_cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags')
|
||||
)
|
||||
cli_ignore_unknown_args = (
|
||||
_cli_ignore_unknown_args
|
||||
if _cli_ignore_unknown_args is not None
|
||||
else self.model_config.get('cli_ignore_unknown_args')
|
||||
)
|
||||
cli_kebab_case = _cli_kebab_case if _cli_kebab_case is not None else self.model_config.get('cli_kebab_case')
|
||||
cli_shortcuts = _cli_shortcuts if _cli_shortcuts is not None else self.model_config.get('cli_shortcuts')
|
||||
|
||||
secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')
|
||||
|
||||
# Configure built-in sources
|
||||
default_settings = DefaultSettingsSource(
|
||||
self.__class__, nested_model_default_partial_update=nested_model_default_partial_update
|
||||
)
|
||||
init_settings = InitSettingsSource(
|
||||
self.__class__,
|
||||
init_kwargs=init_kwargs,
|
||||
nested_model_default_partial_update=nested_model_default_partial_update,
|
||||
)
|
||||
env_settings = EnvSettingsSource(
|
||||
self.__class__,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_nested_max_split=env_nested_max_split,
|
||||
env_ignore_empty=env_ignore_empty,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
dotenv_settings = DotEnvSettingsSource(
|
||||
self.__class__,
|
||||
env_file=env_file,
|
||||
env_file_encoding=env_file_encoding,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_nested_max_split=env_nested_max_split,
|
||||
env_ignore_empty=env_ignore_empty,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
file_secret_settings = SecretsSettingsSource(
|
||||
self.__class__, secrets_dir=secrets_dir, case_sensitive=case_sensitive, env_prefix=env_prefix
|
||||
)
|
||||
# Provide a hook to set built-in sources priority and add / remove sources
|
||||
sources = self.settings_customise_sources(
|
||||
self.__class__,
|
||||
init_settings=init_settings,
|
||||
env_settings=env_settings,
|
||||
dotenv_settings=dotenv_settings,
|
||||
file_secret_settings=file_secret_settings,
|
||||
) + (default_settings,)
|
||||
custom_cli_sources = [source for source in sources if isinstance(source, CliSettingsSource)]
|
||||
if not any(custom_cli_sources):
|
||||
if isinstance(cli_settings_source, CliSettingsSource):
|
||||
sources = (cli_settings_source,) + sources
|
||||
elif cli_parse_args is not None:
|
||||
cli_settings = CliSettingsSource[Any](
|
||||
self.__class__,
|
||||
cli_prog_name=cli_prog_name,
|
||||
cli_parse_args=cli_parse_args,
|
||||
cli_parse_none_str=cli_parse_none_str,
|
||||
cli_hide_none_type=cli_hide_none_type,
|
||||
cli_avoid_json=cli_avoid_json,
|
||||
cli_enforce_required=cli_enforce_required,
|
||||
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
|
||||
cli_exit_on_error=cli_exit_on_error,
|
||||
cli_prefix=cli_prefix,
|
||||
cli_flag_prefix_char=cli_flag_prefix_char,
|
||||
cli_implicit_flags=cli_implicit_flags,
|
||||
cli_ignore_unknown_args=cli_ignore_unknown_args,
|
||||
cli_kebab_case=cli_kebab_case,
|
||||
cli_shortcuts=cli_shortcuts,
|
||||
case_sensitive=case_sensitive,
|
||||
)
|
||||
sources = (cli_settings,) + sources
|
||||
# We ensure that if command line arguments haven't been parsed yet, we do so.
|
||||
elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars:
|
||||
custom_cli_sources[0](args=cli_parse_args) # type: ignore
|
||||
|
||||
self._settings_warn_unused_config_keys(sources, self.model_config)
|
||||
|
||||
if sources:
|
||||
state: dict[str, Any] = {}
|
||||
defaults: dict[str, Any] = {}
|
||||
states: dict[str, dict[str, Any]] = {}
|
||||
for source in sources:
|
||||
if isinstance(source, PydanticBaseSettingsSource):
|
||||
source._set_current_state(state)
|
||||
source._set_settings_sources_data(states)
|
||||
|
||||
source_name = source.__name__ if hasattr(source, '__name__') else type(source).__name__
|
||||
source_state = source()
|
||||
|
||||
if isinstance(source, DefaultSettingsSource):
|
||||
defaults = source_state
|
||||
|
||||
states[source_name] = source_state
|
||||
state = deep_update(source_state, state)
|
||||
|
||||
# Strip any default values not explicity set before returning final state
|
||||
state = {key: val for key, val in state.items() if key not in defaults or defaults[key] != val}
|
||||
self._settings_restore_init_kwarg_names(self.__class__, init_kwargs, state)
|
||||
|
||||
return state
|
||||
else:
|
||||
# no one should mean to do this, but I think returning an empty dict is marginally preferable
|
||||
# to an informative error and much better than a confusing error
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _settings_restore_init_kwarg_names(
|
||||
settings_cls: type[BaseSettings], init_kwargs: dict[str, Any], state: dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Restore the init_kwarg key names to the final merged state dictionary.
|
||||
"""
|
||||
if init_kwargs and state:
|
||||
state_kwarg_names = set(state.keys())
|
||||
init_kwarg_names = set(init_kwargs.keys())
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
matchable_names = set(alias_names)
|
||||
include_name = settings_cls.model_config.get('populate_by_name', False)
|
||||
if include_name:
|
||||
matchable_names.add(field_name)
|
||||
init_kwarg_name = init_kwarg_names & matchable_names
|
||||
state_kwarg_name = state_kwarg_names & matchable_names
|
||||
if init_kwarg_name and state_kwarg_name:
|
||||
state[init_kwarg_name.pop()] = state.pop(state_kwarg_name.pop())
|
||||
|
||||
@staticmethod
|
||||
def _settings_warn_unused_config_keys(sources: tuple[object, ...], model_config: SettingsConfigDict) -> None:
|
||||
"""
|
||||
Warns if any values in model_config were set but the corresponding settings source has not been initialised.
|
||||
|
||||
The list alternative sources and their config keys can be found here:
|
||||
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source
|
||||
|
||||
Args:
|
||||
sources: The tuple of configured sources
|
||||
model_config: The model config to check for unused config keys
|
||||
"""
|
||||
|
||||
def warn_if_not_used(source_type: type[PydanticBaseSettingsSource], keys: tuple[str, ...]) -> None:
|
||||
if not any(isinstance(source, source_type) for source in sources):
|
||||
for key in keys:
|
||||
if model_config.get(key) is not None:
|
||||
warnings.warn(
|
||||
f'Config key `{key}` is set in model_config but will be ignored because no '
|
||||
f'{source_type.__name__} source is configured. To use this config key, add a '
|
||||
f'{source_type.__name__} source to the settings sources via the '
|
||||
'settings_customise_sources hook.',
|
||||
UserWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
warn_if_not_used(JsonConfigSettingsSource, ('json_file', 'json_file_encoding'))
|
||||
warn_if_not_used(PyprojectTomlConfigSettingsSource, ('pyproject_toml_depth', 'pyproject_toml_table_header'))
|
||||
warn_if_not_used(TomlConfigSettingsSource, ('toml_file',))
|
||||
warn_if_not_used(YamlConfigSettingsSource, ('yaml_file', 'yaml_file_encoding', 'yaml_config_section'))
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
extra='forbid',
|
||||
arbitrary_types_allowed=True,
|
||||
validate_default=True,
|
||||
case_sensitive=False,
|
||||
env_prefix='',
|
||||
nested_model_default_partial_update=False,
|
||||
env_file=None,
|
||||
env_file_encoding=None,
|
||||
env_ignore_empty=False,
|
||||
env_nested_delimiter=None,
|
||||
env_nested_max_split=None,
|
||||
env_parse_none_str=None,
|
||||
env_parse_enums=None,
|
||||
cli_prog_name=None,
|
||||
cli_parse_args=None,
|
||||
cli_parse_none_str=None,
|
||||
cli_hide_none_type=False,
|
||||
cli_avoid_json=False,
|
||||
cli_enforce_required=False,
|
||||
cli_use_class_docs_for_groups=False,
|
||||
cli_exit_on_error=True,
|
||||
cli_prefix='',
|
||||
cli_flag_prefix_char='-',
|
||||
cli_implicit_flags=False,
|
||||
cli_ignore_unknown_args=False,
|
||||
cli_kebab_case=False,
|
||||
cli_shortcuts=None,
|
||||
json_file=None,
|
||||
json_file_encoding=None,
|
||||
yaml_file=None,
|
||||
yaml_file_encoding=None,
|
||||
yaml_config_section=None,
|
||||
toml_file=None,
|
||||
secrets_dir=None,
|
||||
protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'),
|
||||
enable_decoding=True,
|
||||
)
|
||||
|
||||
|
||||
class CliApp:
|
||||
"""
|
||||
A utility class for running Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as
|
||||
CLI applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
|
||||
if issubclass(model_cls, BaseSettings):
|
||||
return model_cls
|
||||
|
||||
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
|
||||
__doc__ = model_cls.__doc__
|
||||
model_config = SettingsConfigDict(
|
||||
nested_model_default_partial_update=True,
|
||||
case_sensitive=True,
|
||||
cli_hide_none_type=True,
|
||||
cli_avoid_json=True,
|
||||
cli_enforce_required=True,
|
||||
cli_implicit_flags=True,
|
||||
cli_kebab_case=True,
|
||||
)
|
||||
|
||||
return CliAppBaseSettings
|
||||
|
||||
@staticmethod
|
||||
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
|
||||
command = getattr(type(model), cli_cmd_method_name, None)
|
||||
if command is None:
|
||||
if is_required:
|
||||
raise SettingsError(f'Error: {type(model).__name__} class is missing {cli_cmd_method_name} entrypoint')
|
||||
return model
|
||||
|
||||
# If the method is asynchronous, we handle its execution based on the current event loop status.
|
||||
if inspect.iscoroutinefunction(command):
|
||||
# For asynchronous methods, we have two execution scenarios:
|
||||
# 1. If no event loop is running in the current thread, run the coroutine directly with asyncio.run().
|
||||
# 2. If an event loop is already running in the current thread, run the coroutine in a separate thread to avoid conflicts.
|
||||
try:
|
||||
# Check if an event loop is currently running in this thread.
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# We're in a context with an active event loop (e.g., Jupyter Notebook).
|
||||
# Running asyncio.run() here would cause conflicts, so we use a separate thread.
|
||||
exception_container = []
|
||||
|
||||
def run_coro() -> None:
|
||||
try:
|
||||
# Execute the coroutine in a new event loop in this separate thread.
|
||||
asyncio.run(command(model))
|
||||
except Exception as e:
|
||||
exception_container.append(e)
|
||||
|
||||
thread = threading.Thread(target=run_coro)
|
||||
thread.start()
|
||||
thread.join()
|
||||
if exception_container:
|
||||
# Propagate exceptions from the separate thread.
|
||||
raise exception_container[0]
|
||||
else:
|
||||
# No event loop is running; safe to run the coroutine directly.
|
||||
asyncio.run(command(model))
|
||||
else:
|
||||
# For synchronous methods, call them directly.
|
||||
command(model)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
model_cls: type[T],
|
||||
cli_args: list[str] | Namespace | SimpleNamespace | dict[str, Any] | None = None,
|
||||
cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
cli_exit_on_error: bool | None = None,
|
||||
cli_cmd_method_name: str = 'cli_cmd',
|
||||
**model_init_data: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Runs a Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as a CLI application.
|
||||
Running a model as a CLI application requires the `cli_cmd` method to be defined in the model class.
|
||||
|
||||
Args:
|
||||
model_cls: The model class to run as a CLI application.
|
||||
cli_args: The list of CLI arguments to parse. If `cli_settings_source` is specified, this may
|
||||
also be a namespace or dictionary of pre-parsed CLI arguments. Defaults to `sys.argv[1:]`.
|
||||
cli_settings_source: Override the default CLI settings source with a user defined instance.
|
||||
Defaults to `None`.
|
||||
cli_exit_on_error: Determines whether this function exits on error. If model is subclass of
|
||||
`BaseSettings`, defaults to BaseSettings `cli_exit_on_error` value. Otherwise, defaults to
|
||||
`True`.
|
||||
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
|
||||
model_init_data: The model init data.
|
||||
|
||||
Returns:
|
||||
The ran instance of model.
|
||||
|
||||
Raises:
|
||||
SettingsError: If model_cls is not subclass of `BaseModel` or `pydantic.dataclasses.dataclass`.
|
||||
SettingsError: If model_cls does not have a `cli_cmd` entrypoint defined.
|
||||
"""
|
||||
|
||||
if not (is_pydantic_dataclass(model_cls) or is_model_class(model_cls)):
|
||||
raise SettingsError(
|
||||
f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass'
|
||||
)
|
||||
|
||||
cli_settings = None
|
||||
cli_parse_args = True if cli_args is None else cli_args
|
||||
if cli_settings_source is not None:
|
||||
if isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
|
||||
cli_settings = cli_settings_source(parsed_args=cli_parse_args)
|
||||
else:
|
||||
cli_settings = cli_settings_source(args=cli_parse_args)
|
||||
elif isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
|
||||
raise SettingsError('Error: `cli_args` must be list[str] or None when `cli_settings_source` is not used')
|
||||
|
||||
model_init_data['_cli_parse_args'] = cli_parse_args
|
||||
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
|
||||
model_init_data['_cli_settings_source'] = cli_settings
|
||||
if not issubclass(model_cls, BaseSettings):
|
||||
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
|
||||
model = base_settings_cls(**model_init_data)
|
||||
model_init_data = {}
|
||||
for field_name, field_info in base_settings_cls.model_fields.items():
|
||||
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)
|
||||
|
||||
return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False)
|
||||
|
||||
@staticmethod
|
||||
def run_subcommand(
|
||||
model: PydanticModel, cli_exit_on_error: bool | None = None, cli_cmd_method_name: str = 'cli_cmd'
|
||||
) -> PydanticModel:
|
||||
"""
|
||||
Runs the model subcommand. Running a model subcommand requires the `cli_cmd` method to be defined in
|
||||
the nested model subcommand class.
|
||||
|
||||
Args:
|
||||
model: The model to run the subcommand from.
|
||||
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
|
||||
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
|
||||
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
|
||||
|
||||
Returns:
|
||||
The ran subcommand model.
|
||||
|
||||
Raises:
|
||||
SystemExit: When no subcommand is found and cli_exit_on_error=`True` (the default).
|
||||
SettingsError: When no subcommand is found and cli_exit_on_error=`False`.
|
||||
"""
|
||||
|
||||
subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error)
|
||||
return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)
|
||||
|
||||
@staticmethod
|
||||
def serialize(model: PydanticModel) -> list[str]:
|
||||
"""
|
||||
Serializes the CLI arguments for a Pydantic data model.
|
||||
|
||||
Args:
|
||||
model: The data model to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized CLI arguments for the data model.
|
||||
"""
|
||||
|
||||
base_settings_cls = CliApp._get_base_settings_cls(type(model))
|
||||
return CliSettingsSource[Any](base_settings_cls)._serialized_args(model)
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
"""Package for handling configuration sources in pydantic-settings."""
|
||||
|
||||
from .base import (
|
||||
ConfigFileSourceMixin,
|
||||
DefaultSettingsSource,
|
||||
InitSettingsSource,
|
||||
PydanticBaseEnvSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .providers.aws import AWSSecretsManagerSettingsSource
|
||||
from .providers.azure import AzureKeyVaultSettingsSource
|
||||
from .providers.cli import (
|
||||
CLI_SUPPRESS,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliUnknownArgs,
|
||||
)
|
||||
from .providers.dotenv import DotEnvSettingsSource, read_env_file
|
||||
from .providers.env import EnvSettingsSource
|
||||
from .providers.gcp import GoogleSecretManagerSettingsSource
|
||||
from .providers.json import JsonConfigSettingsSource
|
||||
from .providers.nested_secrets import NestedSecretsSettingsSource
|
||||
from .providers.pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .providers.secrets import SecretsSettingsSource
|
||||
from .providers.toml import TomlConfigSettingsSource
|
||||
from .providers.yaml import YamlConfigSettingsSource
|
||||
from .types import DEFAULT_PATH, ENV_FILE_SENTINEL, DotenvType, ForceDecode, NoDecode, PathType, PydanticModel
|
||||
|
||||
__all__ = [
|
||||
'CLI_SUPPRESS',
|
||||
'ENV_FILE_SENTINEL',
|
||||
'DEFAULT_PATH',
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'CliUnknownArgs',
|
||||
'DefaultSettingsSource',
|
||||
'DotEnvSettingsSource',
|
||||
'DotenvType',
|
||||
'EnvSettingsSource',
|
||||
'ForceDecode',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'NestedSecretsSettingsSource',
|
||||
'NoDecode',
|
||||
'PathType',
|
||||
'PydanticBaseEnvSettingsSource',
|
||||
'PydanticBaseSettingsSource',
|
||||
'ConfigFileSourceMixin',
|
||||
'PydanticModel',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
'get_subcommand',
|
||||
'read_env_file',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,541 @@
|
|||
"""Base classes and core functionality for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast, get_args
|
||||
|
||||
from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from pydantic._internal._utils import is_model_class
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..exceptions import SettingsError
|
||||
from ..utils import _lenient_issubclass
|
||||
from .types import EnvNoneType, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand
|
||||
from .utils import (
|
||||
_annotation_is_complex,
|
||||
_get_alias_names,
|
||||
_get_model_fields,
|
||||
_strip_annotated,
|
||||
_union_is_complex,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
def get_subcommand(
|
||||
model: PydanticModel, is_required: bool = True, cli_exit_on_error: bool | None = None
|
||||
) -> PydanticModel | None:
|
||||
"""
|
||||
Get the subcommand from a model.
|
||||
|
||||
Args:
|
||||
model: The model to get the subcommand from.
|
||||
is_required: Determines whether a model must have subcommand set and raises error if not
|
||||
found. Defaults to `True`.
|
||||
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
|
||||
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
|
||||
|
||||
Returns:
|
||||
The subcommand model if found, otherwise `None`.
|
||||
|
||||
Raises:
|
||||
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
|
||||
(the default).
|
||||
SettingsError: When no subcommand is found and is_required=`True` and
|
||||
cli_exit_on_error=`False`.
|
||||
"""
|
||||
|
||||
model_cls = type(model)
|
||||
if cli_exit_on_error is None and is_model_class(model_cls):
|
||||
model_default = model_cls.model_config.get('cli_exit_on_error')
|
||||
if isinstance(model_default, bool):
|
||||
cli_exit_on_error = model_default
|
||||
if cli_exit_on_error is None:
|
||||
cli_exit_on_error = True
|
||||
|
||||
subcommands: list[str] = []
|
||||
for field_name, field_info in _get_model_fields(model_cls).items():
|
||||
if _CliSubCommand in field_info.metadata:
|
||||
if getattr(model, field_name) is not None:
|
||||
return getattr(model, field_name)
|
||||
subcommands.append(field_name)
|
||||
|
||||
if is_required:
|
||||
error_message = (
|
||||
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
|
||||
if subcommands
|
||||
else 'Error: CLI subcommand is required but no subcommands were found.'
|
||||
)
|
||||
raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class PydanticBaseSettingsSource(ABC):
|
||||
"""
|
||||
Abstract base class for settings sources, every settings source classes should inherit from it.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
self.settings_cls = settings_cls
|
||||
self.config = settings_cls.model_config
|
||||
self._current_state: dict[str, Any] = {}
|
||||
self._settings_sources_data: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _set_current_state(self, state: dict[str, Any]) -> None:
|
||||
"""
|
||||
Record the state of settings from the previous settings sources. This should
|
||||
be called right before __call__.
|
||||
"""
|
||||
self._current_state = state
|
||||
|
||||
def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Record the state of settings from all previous settings sources. This should
|
||||
be called right before __call__.
|
||||
"""
|
||||
self._settings_sources_data = states
|
||||
|
||||
@property
|
||||
def current_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
The current state of the settings, populated by the previous settings sources.
|
||||
"""
|
||||
return self._current_state
|
||||
|
||||
@property
|
||||
def settings_sources_data(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
The state of all previous settings sources.
|
||||
"""
|
||||
return self._settings_sources_data
|
||||
|
||||
@abstractmethod
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the key for model creation, and a flag to determine whether value is complex.
|
||||
|
||||
This is an abstract method that should be overridden in every settings source classes.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value, key and a flag to determine whether value is complex.
|
||||
"""
|
||||
pass
|
||||
|
||||
def field_is_complex(self, field: FieldInfo) -> bool:
|
||||
"""
|
||||
Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
|
||||
Returns:
|
||||
Whether the field is complex.
|
||||
"""
|
||||
return _annotation_is_complex(field.annotation, field.metadata)
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepares the value of a field.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
value_is_complex: A flag to determine whether value is complex.
|
||||
|
||||
Returns:
|
||||
The prepared value.
|
||||
"""
|
||||
if value is not None and (self.field_is_complex(field) or value_is_complex):
|
||||
return self.decode_complex_value(field_name, field, value)
|
||||
return value
|
||||
|
||||
def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
|
||||
"""
|
||||
Decode the value for a complex field
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
|
||||
Returns:
|
||||
The decoded value for further preparation
|
||||
"""
|
||||
if field and (
|
||||
NoDecode in field.metadata
|
||||
or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata)
|
||||
):
|
||||
return value
|
||||
|
||||
return json.loads(value)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class ConfigFileSourceMixin(ABC):
|
||||
def _read_files(self, files: PathType | None) -> dict[str, Any]:
|
||||
if files is None:
|
||||
return {}
|
||||
if isinstance(files, (str, os.PathLike)):
|
||||
files = [files]
|
||||
vars: dict[str, Any] = {}
|
||||
for file in files:
|
||||
file_path = Path(file).expanduser()
|
||||
if file_path.is_file():
|
||||
vars.update(self._read_file(file_path))
|
||||
return vars
|
||||
|
||||
@abstractmethod
|
||||
def _read_file(self, path: Path) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading default object values.
|
||||
|
||||
Args:
|
||||
settings_cls: The Settings class.
|
||||
nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
|
||||
Defaults to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None):
|
||||
super().__init__(settings_cls)
|
||||
self.defaults: dict[str, Any] = {}
|
||||
self.nested_model_default_partial_update = (
|
||||
nested_model_default_partial_update
|
||||
if nested_model_default_partial_update is not None
|
||||
else self.config.get('nested_model_default_partial_update', False)
|
||||
)
|
||||
if self.nested_model_default_partial_update:
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
preferred_alias = alias_names[0]
|
||||
if is_dataclass(type(field_info.default)):
|
||||
self.defaults[preferred_alias] = asdict(field_info.default)
|
||||
elif is_model_class(type(field_info.default)):
|
||||
self.defaults[preferred_alias] = field_info.default.model_dump()
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return self.defaults
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})'
|
||||
)
|
||||
|
||||
|
||||
class InitSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading values provided during settings class initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_kwargs: dict[str, Any],
|
||||
nested_model_default_partial_update: bool | None = None,
|
||||
):
|
||||
self.init_kwargs = {}
|
||||
init_kwarg_names = set(init_kwargs.keys())
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
# When populate_by_name is True, allow using the field name as an input key,
|
||||
# but normalize to the preferred alias to keep keys consistent across sources.
|
||||
matchable_names = set(alias_names)
|
||||
include_name = settings_cls.model_config.get('populate_by_name', False)
|
||||
if include_name:
|
||||
matchable_names.add(field_name)
|
||||
init_kwarg_name = init_kwarg_names & matchable_names
|
||||
if init_kwarg_name:
|
||||
preferred_alias = alias_names[0] if alias_names else field_name
|
||||
# Choose provided key deterministically: prefer the first alias in alias_names order;
|
||||
# fall back to field_name if allowed and provided.
|
||||
provided_key = next((alias for alias in alias_names if alias in init_kwarg_names), None)
|
||||
if provided_key is None and include_name and field_name in init_kwarg_names:
|
||||
provided_key = field_name
|
||||
# provided_key should not be None here because init_kwarg_name is non-empty
|
||||
assert provided_key is not None
|
||||
init_kwarg_names -= init_kwarg_name
|
||||
self.init_kwargs[preferred_alias] = init_kwargs[provided_key]
|
||||
# Include any remaining init kwargs (e.g., extras) unchanged
|
||||
# Note: If populate_by_name is True and the provided key is the field name, but
|
||||
# no alias exists, we keep it as-is so it can be processed as extra if allowed.
|
||||
self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names})
|
||||
|
||||
super().__init__(settings_cls)
|
||||
self.nested_model_default_partial_update = (
|
||||
nested_model_default_partial_update
|
||||
if nested_model_default_partial_update is not None
|
||||
else self.config.get('nested_model_default_partial_update', False)
|
||||
)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return (
|
||||
TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs)
|
||||
if self.nested_model_default_partial_update
|
||||
else self.init_kwargs
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(settings_cls)
|
||||
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
|
||||
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
|
||||
self.env_ignore_empty = (
|
||||
env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
|
||||
)
|
||||
self.env_parse_none_str = (
|
||||
env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
|
||||
)
|
||||
self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
|
||||
|
||||
def _apply_case_sensitive(self, value: str) -> str:
|
||||
return value.lower() if not self.case_sensitive else value
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
"""
|
||||
Extracts field info. This info is used to get the value of field from environment variables.
|
||||
|
||||
It returns a list of tuples, each tuple contains:
|
||||
* field_key: The key of field that has to be used in model creation.
|
||||
* env_name: The environment variable name of the field.
|
||||
* value_is_complex: A flag to determine whether the value from environment variable
|
||||
is complex and has to be parsed.
|
||||
|
||||
Args:
|
||||
field (FieldInfo): The field.
|
||||
field_name (str): The field name.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
|
||||
"""
|
||||
field_info: list[tuple[str, str, bool]] = []
|
||||
if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
|
||||
v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
|
||||
else:
|
||||
v_alias = field.validation_alias
|
||||
|
||||
if v_alias:
|
||||
if isinstance(v_alias, list): # AliasChoices, AliasPath
|
||||
for alias in v_alias:
|
||||
if isinstance(alias, str): # AliasPath
|
||||
field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
|
||||
elif isinstance(alias, list): # AliasChoices
|
||||
first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
|
||||
field_info.append(
|
||||
(first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
|
||||
)
|
||||
else: # string validation alias
|
||||
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
|
||||
|
||||
if not v_alias or self.config.get('populate_by_name', False):
|
||||
annotation = field.annotation
|
||||
if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
|
||||
annotation = _strip_annotated(annotation.__value__) # type: ignore[union-attr]
|
||||
if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata):
|
||||
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
|
||||
else:
|
||||
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
|
||||
|
||||
return field_info
|
||||
|
||||
def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace field names in values dict by looking in models fields insensitively.
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubSub(BaseModel):
|
||||
VaL3: str
|
||||
|
||||
class SubSub(BaseModel):
|
||||
Val2: str
|
||||
SUB_sub_SuB: SubSubSub
|
||||
|
||||
class Sub(BaseModel):
|
||||
VAL1: str
|
||||
SUB_sub: SubSub
|
||||
|
||||
class Settings(BaseSettings):
|
||||
nested: Sub
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__')
|
||||
```
|
||||
|
||||
Then:
|
||||
_replace_field_names_case_insensitively(
|
||||
field,
|
||||
{"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
|
||||
)
|
||||
Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for name, value in field_values.items():
|
||||
sub_model_field: FieldInfo | None = None
|
||||
|
||||
annotation = field.annotation
|
||||
|
||||
# If field is Optional, we need to find the actual type
|
||||
if is_union_origin(get_origin(field.annotation)):
|
||||
args = get_args(annotation)
|
||||
if len(args) == 2 and type(None) in args:
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
annotation = arg
|
||||
break
|
||||
|
||||
# This is here to make mypy happy
|
||||
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
|
||||
if not annotation or not hasattr(annotation, 'model_fields'):
|
||||
values[name] = value
|
||||
continue
|
||||
else:
|
||||
model_fields: dict[str, FieldInfo] = annotation.model_fields
|
||||
|
||||
# Find field in sub model by looking in fields case insensitively
|
||||
field_key: str | None = None
|
||||
for sub_model_field_name, sub_model_field in model_fields.items():
|
||||
aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field)
|
||||
_search = (alias for alias in aliases if alias.lower() == name.lower())
|
||||
if field_key := next(_search, None):
|
||||
break
|
||||
|
||||
if not field_key:
|
||||
values[name] = value
|
||||
continue
|
||||
|
||||
if (
|
||||
sub_model_field is not None
|
||||
and _lenient_issubclass(sub_model_field.annotation, BaseModel)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value)
|
||||
else:
|
||||
values[field_key] = value
|
||||
|
||||
return values
|
||||
|
||||
def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None).
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for key, value in field_value.items():
|
||||
if not isinstance(value, EnvNoneType):
|
||||
values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value)
|
||||
else:
|
||||
values[key] = None
|
||||
|
||||
return values
|
||||
|
||||
def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the preferred alias key for model creation, and a flag to determine whether value
|
||||
is complex.
|
||||
|
||||
Note:
|
||||
In V3, this method should either be made public, or, this method should be removed and the
|
||||
abstract method get_field_value should be updated to include a "use_preferred_alias" flag.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value, preferred key and a flag to determine whether value is complex.
|
||||
"""
|
||||
field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
|
||||
if not (value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name))):
|
||||
field_infos = self._extract_field_info(field, field_name)
|
||||
preferred_key, *_ = field_infos[0]
|
||||
return field_value, preferred_key, value_is_complex
|
||||
return field_value, field_key, value_is_complex
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
try:
|
||||
field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name)
|
||||
except Exception as e:
|
||||
raise SettingsError(
|
||||
f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
try:
|
||||
field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
except ValueError as e:
|
||||
raise SettingsError(
|
||||
f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
if field_value is not None:
|
||||
if self.env_parse_none_str is not None:
|
||||
if isinstance(field_value, dict):
|
||||
field_value = self._replace_env_none_type_values(field_value)
|
||||
elif isinstance(field_value, EnvNoneType):
|
||||
field_value = None
|
||||
if (
|
||||
not self.case_sensitive
|
||||
# and _lenient_issubclass(field.annotation, BaseModel)
|
||||
and isinstance(field_value, dict)
|
||||
):
|
||||
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
|
||||
else:
|
||||
data[field_key] = field_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ConfigFileSourceMixin',
|
||||
'DefaultSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'PydanticBaseEnvSettingsSource',
|
||||
'PydanticBaseSettingsSource',
|
||||
'SettingsError',
|
||||
]
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Package containing individual source implementations."""
|
||||
|
||||
from .aws import AWSSecretsManagerSettingsSource
|
||||
from .azure import AzureKeyVaultSettingsSource
|
||||
from .cli import (
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
)
|
||||
from .dotenv import DotEnvSettingsSource
|
||||
from .env import EnvSettingsSource
|
||||
from .gcp import GoogleSecretManagerSettingsSource
|
||||
from .json import JsonConfigSettingsSource
|
||||
from .pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .secrets import SecretsSettingsSource
|
||||
from .toml import TomlConfigSettingsSource
|
||||
from .yaml import YamlConfigSettingsSource
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,79 @@
|
|||
from __future__ import annotations as _annotations # important for BaseSettings import to work
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import parse_env_vars
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
boto3_client = None
|
||||
SecretsManagerClient = None
|
||||
|
||||
|
||||
def import_aws_secrets_manager() -> None:
|
||||
global boto3_client
|
||||
global SecretsManagerClient
|
||||
|
||||
try:
|
||||
from boto3 import client as boto3_client
|
||||
from mypy_boto3_secretsmanager.client import SecretsManagerClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AWSSecretsManagerSettingsSource(EnvSettingsSource):
|
||||
_secret_id: str
|
||||
_secretsmanager_client: SecretsManagerClient # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secret_id: str,
|
||||
region_name: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = '--',
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_aws_secrets_manager()
|
||||
self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
|
||||
self._secret_id = secret_id
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore
|
||||
|
||||
return parse_env_vars(
|
||||
json.loads(response['SecretString']),
|
||||
self.case_sensitive,
|
||||
self.env_ignore_empty,
|
||||
self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
]
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
"""Azure Key Vault settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic.alias_generators import to_snake
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
TokenCredential = None
|
||||
ResourceNotFoundError = None
|
||||
SecretClient = None
|
||||
|
||||
|
||||
def import_azure_key_vault() -> None:
|
||||
global TokenCredential
|
||||
global SecretClient
|
||||
global ResourceNotFoundError
|
||||
|
||||
try:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AzureKeyVaultMapping(Mapping[str, str | None]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretClient
|
||||
_secret_names: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret_client: SecretClient,
|
||||
case_sensitive: bool,
|
||||
snake_case_conversion: bool,
|
||||
) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._case_sensitive = case_sensitive
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
self._secret_map: dict[str, str] = self._load_remote()
|
||||
|
||||
def _load_remote(self) -> dict[str, str]:
|
||||
secret_names: Iterator[str] = (
|
||||
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
|
||||
)
|
||||
|
||||
if self._snake_case_conversion:
|
||||
return {to_snake(name): name for name in secret_names}
|
||||
|
||||
if self._case_sensitive:
|
||||
return {name: name for name in secret_names}
|
||||
|
||||
return {name.lower(): name for name in secret_names}
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
new_key = key
|
||||
|
||||
if self._snake_case_conversion:
|
||||
new_key = to_snake(key)
|
||||
elif not self._case_sensitive:
|
||||
new_key = key.lower()
|
||||
|
||||
if new_key not in self._loaded_secrets:
|
||||
if new_key in self._secret_map:
|
||||
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[new_key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_map)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_map.keys())
|
||||
|
||||
|
||||
class AzureKeyVaultSettingsSource(EnvSettingsSource):
|
||||
_url: str
|
||||
_credential: TokenCredential
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
url: str,
|
||||
credential: TokenCredential,
|
||||
dash_to_underscore: bool = False,
|
||||
case_sensitive: bool | None = None,
|
||||
snake_case_conversion: bool = False,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_azure_key_vault()
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
self._dash_to_underscore = dash_to_underscore
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=False if snake_case_conversion else case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter='__' if snake_case_conversion else '--',
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
|
||||
return AzureKeyVaultMapping(
|
||||
secret_client=secret_client,
|
||||
case_sensitive=self.case_sensitive,
|
||||
snake_case_conversion=self._snake_case_conversion,
|
||||
)
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
if self._snake_case_conversion:
|
||||
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
if self._dash_to_underscore:
|
||||
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
return super()._extract_field_info(field, field_name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,168 @@
|
|||
"""Dotenv file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dotenv import dotenv_values
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..types import ENV_FILE_SENTINEL, DotenvType
|
||||
from ..utils import (
|
||||
_annotation_is_complex,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class DotEnvSettingsSource(EnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from env files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
env_file_encoding: str | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
|
||||
self.env_file_encoding = (
|
||||
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
|
||||
)
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive,
|
||||
env_prefix,
|
||||
env_nested_delimiter,
|
||||
env_nested_max_split,
|
||||
env_ignore_empty,
|
||||
env_parse_none_str,
|
||||
env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return self._read_env_files()
|
||||
|
||||
@staticmethod
|
||||
def _static_read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str)
|
||||
|
||||
def _read_env_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
) -> Mapping[str, str | None]:
|
||||
return self._static_read_env_file(
|
||||
file_path,
|
||||
encoding=self.env_file_encoding,
|
||||
case_sensitive=self.case_sensitive,
|
||||
ignore_empty=self.env_ignore_empty,
|
||||
parse_none_str=self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def _read_env_files(self) -> Mapping[str, str | None]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars: dict[str, str | None] = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(self._read_env_file(env_path))
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = super().__call__()
|
||||
is_extra_allowed = self.config.get('extra') != 'forbid'
|
||||
|
||||
# As `extra` config is allowed in dotenv settings source, We have to
|
||||
# update data with extra env variables from dotenv file.
|
||||
for env_name, env_value in self.env_vars.items():
|
||||
if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields):
|
||||
continue
|
||||
env_used = False
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
for _, field_env_name, _ in self._extract_field_info(field, field_name):
|
||||
if env_name == field_env_name or (
|
||||
(
|
||||
_annotation_is_complex(field.annotation, field.metadata)
|
||||
or (
|
||||
is_union_origin(get_origin(field.annotation))
|
||||
and _union_is_complex(field.annotation, field.metadata)
|
||||
)
|
||||
)
|
||||
and env_name.startswith(field_env_name)
|
||||
):
|
||||
env_used = True
|
||||
break
|
||||
if env_used:
|
||||
break
|
||||
if not env_used:
|
||||
if is_extra_allowed and env_name.startswith(self.env_prefix):
|
||||
# env_prefix should be respected and removed from the env_name
|
||||
normalized_env_name = env_name[len(self.env_prefix) :]
|
||||
data[normalized_env_name] = env_value
|
||||
else:
|
||||
data[env_name] = env_value
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
warnings.warn(
|
||||
'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return DotEnvSettingsSource._static_read_env_file(
|
||||
file_path,
|
||||
encoding=encoding,
|
||||
case_sensitive=case_sensitive,
|
||||
ignore_empty=ignore_empty,
|
||||
parse_none_str=parse_none_str,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['DotEnvSettingsSource', 'read_env_file']
|
||||
|
|
@ -0,0 +1,294 @@
|
|||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import Json, TypeAdapter, ValidationError
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ...utils import _lenient_issubclass
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import EnvNoneType
|
||||
from ..utils import (
|
||||
_annotation_contains_types,
|
||||
_annotation_enum_name_to_val,
|
||||
_get_model_fields,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
|
||||
)
|
||||
self.env_nested_delimiter = (
|
||||
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
|
||||
)
|
||||
self.env_nested_max_split = (
|
||||
env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
|
||||
)
|
||||
self.maxsplit = (self.env_nested_max_split or 0) - 1
|
||||
self.env_prefix_len = len(self.env_prefix)
|
||||
|
||||
self.env_vars = self._load_env_vars()
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from environment variables and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if not found), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
env_val: str | None = None
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
env_val = self.env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
return env_val, field_key, value_is_complex
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepare value for the field.
|
||||
|
||||
* Extract value for nested field.
|
||||
* Deserialize value to python object for complex field.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains prepared value for the field.
|
||||
|
||||
Raises:
|
||||
ValuesError: When There is an error in deserializing value for complex field.
|
||||
"""
|
||||
is_complex, allow_parse_failure = self._field_is_complex(field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(field.annotation, value)
|
||||
value = value if enum_val is None else enum_val
|
||||
|
||||
if is_complex or value_is_complex:
|
||||
if isinstance(value, EnvNoneType):
|
||||
return value
|
||||
elif value is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
|
||||
if env_val_built:
|
||||
return env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
value = self.decode_complex_value(field_name, field, value)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise e
|
||||
|
||||
if isinstance(value, dict):
|
||||
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
|
||||
else:
|
||||
return value
|
||||
elif value is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
return self._coerce_env_val_strict(field, value)
|
||||
|
||||
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if self.field_is_complex(field):
|
||||
allow_parse_failure = False
|
||||
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
# Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
|
||||
# We have to change the method to a non-static method and use
|
||||
# `self.case_sensitive` instead in V3.
|
||||
def next_field(
|
||||
self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
|
||||
) -> FieldInfo | None:
|
||||
"""
|
||||
Find the field in a sub model by key(env name)
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubModel(BaseSettings):
|
||||
dvals: Dict
|
||||
|
||||
class SubModel(BaseSettings):
|
||||
vals: list[str]
|
||||
sub_sub_model: SubSubModel
|
||||
|
||||
class Cfg(BaseSettings):
|
||||
sub_model: SubModel
|
||||
```
|
||||
|
||||
Then:
|
||||
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
|
||||
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
key: The key (env name).
|
||||
case_sensitive: Whether to search for key case sensitively.
|
||||
|
||||
Returns:
|
||||
Field if it finds the next field otherwise `None`.
|
||||
"""
|
||||
if not field:
|
||||
return None
|
||||
|
||||
annotation = field.annotation if isinstance(field, FieldInfo) else field
|
||||
for type_ in get_args(annotation):
|
||||
type_has_key = self.next_field(type_, key, case_sensitive)
|
||||
if type_has_key:
|
||||
return type_has_key
|
||||
if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
|
||||
fields = _get_model_fields(annotation)
|
||||
# `case_sensitive is None` is here to be compatible with the old behavior.
|
||||
# Has to be removed in V3.
|
||||
for field_name, f in fields.items():
|
||||
for _, env_name, _ in self._extract_field_info(f, field_name):
|
||||
if case_sensitive is None or case_sensitive:
|
||||
if field_name == key or env_name == key:
|
||||
return f
|
||||
elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
env_vars: Environment variables.
|
||||
|
||||
Returns:
|
||||
A dictionary contains extracted values from nested env values.
|
||||
"""
|
||||
if not self.env_nested_delimiter:
|
||||
return {}
|
||||
|
||||
ann = field.annotation
|
||||
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
|
||||
|
||||
prefixes = [
|
||||
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
|
||||
]
|
||||
result: dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
try:
|
||||
prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
|
||||
except StopIteration:
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[len(prefix) :]
|
||||
*keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
|
||||
env_var = result
|
||||
target_field: FieldInfo | None = field
|
||||
for key in keys:
|
||||
target_field = self.next_field(target_field, key, self.case_sensitive)
|
||||
if isinstance(env_var, dict):
|
||||
env_var = env_var.setdefault(key, {})
|
||||
|
||||
# get proper field with last_key
|
||||
target_field = self.next_field(target_field, last_key, self.case_sensitive)
|
||||
|
||||
# check if env_val maps to a complex field and if so, parse the env_val
|
||||
if (target_field or is_dict) and env_val:
|
||||
if target_field:
|
||||
is_complex, allow_json_failure = self._field_is_complex(target_field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
|
||||
env_val = env_val if enum_val is None else enum_val
|
||||
else:
|
||||
# nested field type is dict
|
||||
is_complex, allow_json_failure = True, True
|
||||
if is_complex:
|
||||
try:
|
||||
env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore
|
||||
except ValueError as e:
|
||||
if not allow_json_failure:
|
||||
raise e
|
||||
if isinstance(env_var, dict):
|
||||
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
|
||||
env_var[last_key] = self._coerce_env_val_strict(target_field, env_val)
|
||||
return result
|
||||
|
||||
def _coerce_env_val_strict(self, field: FieldInfo | None, value: Any) -> Any:
|
||||
"""
|
||||
Coerce environment string values based on field annotation if model config is `strict=True`.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
value: The value to coerce.
|
||||
|
||||
Returns:
|
||||
The coerced value if successful, otherwise the original value.
|
||||
"""
|
||||
try:
|
||||
if self.config.get('strict') and isinstance(value, str) and field is not None:
|
||||
if value == self.env_parse_none_str:
|
||||
return value
|
||||
if not _annotation_contains_types(field.annotation, (Json,), is_instance=True):
|
||||
return TypeAdapter(field.annotation).validate_python(value)
|
||||
except ValidationError:
|
||||
# Allow validation error to be raised at time of instatiation
|
||||
pass
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
|
||||
f'env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['EnvSettingsSource']
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
Credentials = None
|
||||
SecretManagerServiceClient = None
|
||||
google_auth_default = None
|
||||
|
||||
|
||||
def import_gcp_secret_manager() -> None:
|
||||
global Credentials
|
||||
global SecretManagerServiceClient
|
||||
global google_auth_default
|
||||
|
||||
try:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class GoogleSecretManagerMapping(Mapping[str, str | None]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretManagerServiceClient
|
||||
|
||||
def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._project_id = project_id
|
||||
self._case_sensitive = case_sensitive
|
||||
|
||||
@property
|
||||
def _gcp_project_path(self) -> str:
|
||||
return self._secret_client.common_project_path(self._project_id)
|
||||
|
||||
@cached_property
|
||||
def _secret_names(self) -> list[str]:
|
||||
rv: list[str] = []
|
||||
|
||||
secrets = self._secret_client.list_secrets(parent=self._gcp_project_path)
|
||||
for secret in secrets:
|
||||
name = self._secret_client.parse_secret_path(secret.name).get('secret', '')
|
||||
if not self._case_sensitive:
|
||||
name = name.lower()
|
||||
rv.append(name)
|
||||
return rv
|
||||
|
||||
def _secret_version_path(self, key: str, version: str = 'latest') -> str:
|
||||
return self._secret_client.secret_version_path(self._project_id, key, version)
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
if not self._case_sensitive:
|
||||
key = key.lower()
|
||||
if key not in self._loaded_secrets:
|
||||
# If we know the key isn't available in secret manager, raise a key error
|
||||
if key not in self._secret_names:
|
||||
raise KeyError(key)
|
||||
|
||||
try:
|
||||
self._loaded_secrets[key] = self._secret_client.access_secret_version(
|
||||
name=self._secret_version_path(key)
|
||||
).payload.data.decode('UTF-8')
|
||||
except Exception:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_names)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_names)
|
||||
|
||||
|
||||
class GoogleSecretManagerSettingsSource(EnvSettingsSource):
|
||||
_credentials: Credentials
|
||||
_secret_client: SecretManagerServiceClient
|
||||
_project_id: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
credentials: Credentials | None = None,
|
||||
project_id: str | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
secret_client: SecretManagerServiceClient | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
) -> None:
|
||||
# Import Google Packages if they haven't already been imported
|
||||
if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None:
|
||||
import_gcp_secret_manager()
|
||||
|
||||
# If credentials or project_id are not passed, then
|
||||
# try to get them from the default function
|
||||
if not credentials or not project_id:
|
||||
_creds, _project_id = google_auth_default() # type: ignore[no-untyped-call]
|
||||
|
||||
# Set the credentials and/or project id if they weren't specified
|
||||
if credentials is None:
|
||||
credentials = _creds
|
||||
|
||||
if project_id is None:
|
||||
if isinstance(_project_id, str):
|
||||
project_id = _project_id
|
||||
else:
|
||||
raise AttributeError(
|
||||
'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default'
|
||||
)
|
||||
|
||||
self._credentials: Credentials = credentials
|
||||
self._project_id: str = project_id
|
||||
|
||||
if secret_client:
|
||||
self._secret_client = secret_client
|
||||
else:
|
||||
self._secret_client = SecretManagerServiceClient(credentials=self._credentials)
|
||||
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return GoogleSecretManagerMapping(
|
||||
self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping']
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""JSON file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a JSON file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
json_file: PathType | None = DEFAULT_PATH,
|
||||
json_file_encoding: str | None = None,
|
||||
):
|
||||
self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file')
|
||||
self.json_file_encoding = (
|
||||
json_file_encoding
|
||||
if json_file_encoding is not None
|
||||
else settings_cls.model_config.get('json_file_encoding')
|
||||
)
|
||||
self.json_data = self._read_files(self.json_file_path)
|
||||
super().__init__(settings_cls, self.json_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
with open(file_path, encoding=self.json_file_encoding) as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(json_file={self.json_file_path})'
|
||||
|
||||
|
||||
__all__ = ['JsonConfigSettingsSource']
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
import os
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from glob import iglob
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from ...exceptions import SettingsError
|
||||
from ...utils import path_type_label
|
||||
from ..base import PydanticBaseSettingsSource
|
||||
from ..utils import parse_env_vars
|
||||
from .env import EnvSettingsSource
|
||||
from .secrets import SecretsSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...main import BaseSettings
|
||||
from ...sources import PathType
|
||||
|
||||
|
||||
SECRETS_DIR_MAX_SIZE = 16 * 2**20 # 16 MiB seems to be a reasonable default
|
||||
|
||||
|
||||
class NestedSecretsSettingsSource(EnvSettingsSource):
|
||||
def __init__(
|
||||
self,
|
||||
file_secret_settings: PydanticBaseSettingsSource | SecretsSettingsSource,
|
||||
secrets_dir: Optional['PathType'] = None,
|
||||
secrets_dir_missing: Literal['ok', 'warn', 'error'] | None = None,
|
||||
secrets_dir_max_size: int | None = None,
|
||||
secrets_case_sensitive: bool | None = None,
|
||||
secrets_prefix: str | None = None,
|
||||
secrets_nested_delimiter: str | None = None,
|
||||
secrets_nested_subdir: bool | None = None,
|
||||
# args for compatibility with SecretsSettingsSource, don't use directly
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
) -> None:
|
||||
# We allow the first argument to be settings_cls like original
|
||||
# SecretsSettingsSource. However, it is recommended to pass
|
||||
# SecretsSettingsSource instance instead (as it is shown in usage examples),
|
||||
# otherwise `_secrets_dir` arg passed to Settings() constructor will be ignored.
|
||||
settings_cls: type[BaseSettings] = getattr(
|
||||
file_secret_settings,
|
||||
'settings_cls',
|
||||
file_secret_settings, # type: ignore[arg-type]
|
||||
)
|
||||
# config options
|
||||
conf = settings_cls.model_config
|
||||
self.secrets_dir: PathType | None = first_not_none(
|
||||
getattr(file_secret_settings, 'secrets_dir', None),
|
||||
secrets_dir,
|
||||
conf.get('secrets_dir'),
|
||||
)
|
||||
self.secrets_dir_missing: Literal['ok', 'warn', 'error'] = first_not_none(
|
||||
secrets_dir_missing,
|
||||
conf.get('secrets_dir_missing'),
|
||||
'warn',
|
||||
)
|
||||
if self.secrets_dir_missing not in ('ok', 'warn', 'error'):
|
||||
raise SettingsError(f'invalid secrets_dir_missing value: {self.secrets_dir_missing}')
|
||||
self.secrets_dir_max_size: int = first_not_none(
|
||||
secrets_dir_max_size,
|
||||
conf.get('secrets_dir_max_size'),
|
||||
SECRETS_DIR_MAX_SIZE,
|
||||
)
|
||||
self.case_sensitive: bool = first_not_none(
|
||||
secrets_case_sensitive,
|
||||
conf.get('secrets_case_sensitive'),
|
||||
case_sensitive,
|
||||
conf.get('case_sensitive'),
|
||||
False,
|
||||
)
|
||||
self.secrets_prefix: str = first_not_none(
|
||||
secrets_prefix,
|
||||
conf.get('secrets_prefix'),
|
||||
env_prefix,
|
||||
conf.get('env_prefix'),
|
||||
'',
|
||||
)
|
||||
|
||||
# nested options
|
||||
self.secrets_nested_delimiter: str | None = first_not_none(
|
||||
secrets_nested_delimiter,
|
||||
conf.get('secrets_nested_delimiter'),
|
||||
conf.get('env_nested_delimiter'),
|
||||
)
|
||||
self.secrets_nested_subdir: bool = first_not_none(
|
||||
secrets_nested_subdir,
|
||||
conf.get('secrets_nested_subdir'),
|
||||
False,
|
||||
)
|
||||
if self.secrets_nested_subdir:
|
||||
if secrets_nested_delimiter or conf.get('secrets_nested_delimiter'):
|
||||
raise SettingsError('Options secrets_nested_delimiter and secrets_nested_subdir are mutually exclusive')
|
||||
else:
|
||||
self.secrets_nested_delimiter = os.sep
|
||||
|
||||
# ensure valid secrets_path
|
||||
if self.secrets_dir is None:
|
||||
paths = []
|
||||
elif isinstance(self.secrets_dir, (Path, str)):
|
||||
paths = [self.secrets_dir]
|
||||
else:
|
||||
paths = list(self.secrets_dir)
|
||||
self.secrets_paths: list[Path] = [Path(p).expanduser().resolve() for p in paths]
|
||||
for path in self.secrets_paths:
|
||||
self.validate_secrets_path(path)
|
||||
|
||||
# construct parent
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=self.case_sensitive,
|
||||
env_prefix=self.secrets_prefix,
|
||||
env_nested_delimiter=self.secrets_nested_delimiter,
|
||||
env_ignore_empty=False, # match SecretsSettingsSource behaviour
|
||||
env_parse_enums=True, # we can pass everything here, it will still behave as "True"
|
||||
env_parse_none_str=None, # match SecretsSettingsSource behaviour
|
||||
)
|
||||
self.env_parse_none_str = None # update manually because of None
|
||||
|
||||
# update parent members
|
||||
if not len(self.secrets_paths):
|
||||
self.env_vars = {}
|
||||
else:
|
||||
secrets = reduce(
|
||||
lambda d1, d2: dict((*d1.items(), *d2.items())),
|
||||
(self.load_secrets(p) for p in self.secrets_paths),
|
||||
)
|
||||
self.env_vars = parse_env_vars(
|
||||
secrets,
|
||||
self.case_sensitive,
|
||||
self.env_ignore_empty,
|
||||
self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def validate_secrets_path(self, path: Path) -> None:
|
||||
if not path.exists():
|
||||
if self.secrets_dir_missing == 'ok':
|
||||
pass
|
||||
elif self.secrets_dir_missing == 'warn':
|
||||
warnings.warn(f'directory "{path}" does not exist', stacklevel=2)
|
||||
elif self.secrets_dir_missing == 'error':
|
||||
raise SettingsError(f'directory "{path}" does not exist')
|
||||
else:
|
||||
raise ValueError # unreachable, checked before
|
||||
else:
|
||||
if not path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
|
||||
secrets_dir_size = sum(f.stat().st_size for f in path.glob('**/*') if f.is_file())
|
||||
if secrets_dir_size > self.secrets_dir_max_size:
|
||||
raise SettingsError(f'secrets_dir size is above {self.secrets_dir_max_size} bytes')
|
||||
|
||||
@staticmethod
|
||||
def load_secrets(path: Path) -> dict[str, str]:
|
||||
return {
|
||||
str(p.relative_to(path)): p.read_text().strip()
|
||||
for p in map(Path, iglob(f'{path}/**/*', recursive=True))
|
||||
if p.is_file()
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'NestedSecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
|
||||
|
||||
|
||||
def first_not_none(*objs: Any) -> Any:
|
||||
return next(filter(lambda o: o is not None, objs), None)
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""Pyproject TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .toml import TomlConfigSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource):
|
||||
"""
|
||||
A source class that loads variables from a `pyproject.toml` file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: Path | None = None,
|
||||
) -> None:
|
||||
self.toml_file_path = self._pick_pyproject_toml_file(
|
||||
toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0)
|
||||
)
|
||||
self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
|
||||
'pyproject_toml_table_header', ('tool', 'pydantic-settings')
|
||||
)
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
for key in self.toml_table_header:
|
||||
self.toml_data = self.toml_data.get(key, {})
|
||||
super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data)
|
||||
|
||||
@staticmethod
|
||||
def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path:
|
||||
"""Pick a `pyproject.toml` file path to use.
|
||||
|
||||
Args:
|
||||
provided: Explicit path provided when instantiating this class.
|
||||
depth: Number of directories up the tree to check of a pyproject.toml.
|
||||
|
||||
"""
|
||||
if provided:
|
||||
return provided.resolve()
|
||||
rv = Path.cwd() / 'pyproject.toml'
|
||||
count = 0
|
||||
if not rv.is_file():
|
||||
child = rv.parent.parent / 'pyproject.toml'
|
||||
while count < depth:
|
||||
if child.is_file():
|
||||
return child
|
||||
if str(child.parent) == rv.root:
|
||||
break # end discovery after checking system root once
|
||||
child = child.parent.parent / 'pyproject.toml'
|
||||
count += 1
|
||||
return rv
|
||||
|
||||
|
||||
__all__ = ['PyprojectTomlConfigSettingsSource']
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
"""Secrets file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from pydantic_settings.utils import path_type_label
|
||||
|
||||
from ...exceptions import SettingsError
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from secret files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secrets_dir: PathType | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
|
||||
)
|
||||
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: dict[str, str | None] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir
|
||||
secrets_paths = [Path(p).expanduser() for p in secrets_dirs]
|
||||
self.secrets_paths = []
|
||||
|
||||
for path in secrets_paths:
|
||||
if not path.exists():
|
||||
warnings.warn(f'directory "{path}" does not exist')
|
||||
else:
|
||||
self.secrets_paths.append(path)
|
||||
|
||||
if not len(self.secrets_paths):
|
||||
return secrets
|
||||
|
||||
for path in self.secrets_paths:
|
||||
if not path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
|
||||
|
||||
return super().__call__()
|
||||
|
||||
@classmethod
|
||||
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path.
|
||||
file_name: File name.
|
||||
case_sensitive: Whether to search for file name case sensitively.
|
||||
|
||||
Returns:
|
||||
Whether file path or `None` if file does not exist in directory.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from secret file and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if the file does not exist), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
# paths reversed to match the last-wins behaviour of `env_file`
|
||||
for secrets_path in reversed(self.secrets_paths):
|
||||
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
return path.read_text().strip(), field_key, value_is_complex
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return None, field_key, value_is_complex
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
"""TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
tomllib = None
|
||||
import tomli
|
||||
else:
|
||||
tomllib = None
|
||||
tomli = None
|
||||
|
||||
|
||||
def import_toml() -> None:
|
||||
global tomli
|
||||
global tomllib
|
||||
if sys.version_info < (3, 11):
|
||||
if tomli is not None:
|
||||
return
|
||||
try:
|
||||
import tomli
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e
|
||||
else:
|
||||
if tomllib is not None:
|
||||
return
|
||||
import tomllib
|
||||
|
||||
|
||||
class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a TOML file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: PathType | None = DEFAULT_PATH,
|
||||
):
|
||||
self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file')
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
super().__init__(settings_cls, self.toml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_toml()
|
||||
with open(file_path, mode='rb') as toml_file:
|
||||
if sys.version_info < (3, 11):
|
||||
return tomli.load(toml_file)
|
||||
return tomllib.load(toml_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(toml_file={self.toml_file_path})'
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
"""YAML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import yaml
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
yaml = None
|
||||
|
||||
|
||||
def import_yaml() -> None:
|
||||
global yaml
|
||||
if yaml is not None:
|
||||
return
|
||||
try:
|
||||
import yaml
|
||||
except ImportError as e:
|
||||
raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
|
||||
|
||||
|
||||
class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a yaml file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
yaml_file: PathType | None = DEFAULT_PATH,
|
||||
yaml_file_encoding: str | None = None,
|
||||
yaml_config_section: str | None = None,
|
||||
):
|
||||
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
|
||||
self.yaml_file_encoding = (
|
||||
yaml_file_encoding
|
||||
if yaml_file_encoding is not None
|
||||
else settings_cls.model_config.get('yaml_file_encoding')
|
||||
)
|
||||
self.yaml_config_section = (
|
||||
yaml_config_section
|
||||
if yaml_config_section is not None
|
||||
else settings_cls.model_config.get('yaml_config_section')
|
||||
)
|
||||
self.yaml_data = self._read_files(self.yaml_file_path)
|
||||
|
||||
if self.yaml_config_section:
|
||||
try:
|
||||
self.yaml_data = self.yaml_data[self.yaml_config_section]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
|
||||
)
|
||||
super().__init__(settings_cls, self.yaml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_yaml()
|
||||
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
|
||||
return yaml.safe_load(yaml_file) or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
|
||||
|
||||
|
||||
__all__ = ['YamlConfigSettingsSource']
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
"""Type definitions for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic._internal._dataclasses import PydanticDataclass
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
PydanticModel = PydanticDataclass | BaseModel
|
||||
else:
|
||||
PydanticModel = Any
|
||||
|
||||
|
||||
class EnvNoneType(str):
|
||||
pass
|
||||
|
||||
|
||||
class NoDecode:
|
||||
"""Annotation to prevent decoding of a field value."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ForceDecode:
|
||||
"""Annotation to force decoding of a field value."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
DotenvType = Path | str | Sequence[Path | str]
|
||||
PathType = Path | str | Sequence[Path | str]
|
||||
DEFAULT_PATH: PathType = Path('')
|
||||
|
||||
# This is used as default value for `_env_file` in the `BaseSettings` class and
|
||||
# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`.
|
||||
# See the docstring of `BaseSettings` for more details.
|
||||
ENV_FILE_SENTINEL: DotenvType = Path('')
|
||||
|
||||
|
||||
class _CliSubCommand:
|
||||
pass
|
||||
|
||||
|
||||
class _CliPositionalArg:
|
||||
pass
|
||||
|
||||
|
||||
class _CliImplicitFlag:
|
||||
pass
|
||||
|
||||
|
||||
class _CliExplicitFlag:
|
||||
pass
|
||||
|
||||
|
||||
class _CliUnknownArgs:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
'DEFAULT_PATH',
|
||||
'ENV_FILE_SENTINEL',
|
||||
'DotenvType',
|
||||
'EnvNoneType',
|
||||
'ForceDecode',
|
||||
'NoDecode',
|
||||
'PathType',
|
||||
'PydanticModel',
|
||||
'_CliExplicitFlag',
|
||||
'_CliImplicitFlag',
|
||||
'_CliPositionalArg',
|
||||
'_CliSubCommand',
|
||||
'_CliUnknownArgs',
|
||||
]
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
"""Utility functions for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, cast, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Json, RootModel, Secret
|
||||
from pydantic._internal._utils import is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from ..exceptions import SettingsError
|
||||
from ..utils import _lenient_issubclass
|
||||
from .types import EnvNoneType
|
||||
|
||||
|
||||
def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
|
||||
return key if case_sensitive else key.lower()
|
||||
|
||||
|
||||
def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType:
|
||||
return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value)
|
||||
|
||||
|
||||
def parse_env_vars(
|
||||
env_vars: Mapping[str, str | None],
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
return {
|
||||
_get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str)
|
||||
for k, v in env_vars.items()
|
||||
if not (ignore_empty and v == '')
|
||||
}
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool:
|
||||
# If the model is a root model, the root annotation should be used to
|
||||
# evaluate the complexity.
|
||||
if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
|
||||
annotation = annotation.__value__
|
||||
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
|
||||
annotation = cast('type[RootModel[Any]]', annotation)
|
||||
root_annotation = annotation.model_fields['root'].annotation
|
||||
if root_annotation is not None: # pragma: no branch
|
||||
annotation = root_annotation
|
||||
|
||||
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
|
||||
return False
|
||||
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# Check if annotation is of the form Annotated[type, metadata].
|
||||
if typing_objects.is_annotated(origin):
|
||||
# Return result of recursive call on inner type.
|
||||
inner, *meta = get_args(annotation)
|
||||
return _annotation_is_complex(inner, meta)
|
||||
|
||||
if origin is Secret:
|
||||
return False
|
||||
|
||||
return (
|
||||
_annotation_is_complex_inner(annotation)
|
||||
or _annotation_is_complex_inner(origin)
|
||||
or hasattr(origin, '__pydantic_core_schema__')
|
||||
or hasattr(origin, '__get_pydantic_core_schema__')
|
||||
)
|
||||
|
||||
|
||||
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
|
||||
if _lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
|
||||
return _lenient_issubclass(
|
||||
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
|
||||
) or is_dataclass(annotation)
|
||||
|
||||
|
||||
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
|
||||
"""Check if a union type contains any complex types."""
|
||||
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
|
||||
|
||||
|
||||
def _annotation_contains_types(
|
||||
annotation: type[Any] | None,
|
||||
types: tuple[Any, ...],
|
||||
is_include_origin: bool = True,
|
||||
is_strip_annotated: bool = False,
|
||||
is_instance: bool = False,
|
||||
) -> bool:
|
||||
"""Check if a type annotation contains any of the specified types."""
|
||||
if is_strip_annotated:
|
||||
annotation = _strip_annotated(annotation)
|
||||
if is_include_origin is True:
|
||||
origin = get_origin(annotation)
|
||||
if origin in types:
|
||||
return True
|
||||
if is_instance and any(isinstance(origin, type_) for type_ in types):
|
||||
return True
|
||||
for type_ in get_args(annotation):
|
||||
if _annotation_contains_types(
|
||||
type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated, is_instance=is_instance
|
||||
):
|
||||
return True
|
||||
if is_instance and any(isinstance(annotation, type_) for type_ in types):
|
||||
return True
|
||||
return annotation in types
|
||||
|
||||
|
||||
def _strip_annotated(annotation: Any) -> Any:
|
||||
if typing_objects.is_annotated(get_origin(annotation)):
|
||||
return annotation.__origin__
|
||||
else:
|
||||
return annotation
|
||||
|
||||
|
||||
def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> str | None:
|
||||
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
|
||||
if _lenient_issubclass(type_, Enum):
|
||||
if value in tuple(val.value for val in type_):
|
||||
return type_(value).name
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
|
||||
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
|
||||
if _lenient_issubclass(type_, Enum):
|
||||
if name in tuple(val.name for val in type_):
|
||||
return type_[name]
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]:
|
||||
"""Get fields from a pydantic model or dataclass."""
|
||||
|
||||
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
|
||||
return model_cls.__pydantic_fields__
|
||||
if is_model_class(model_cls):
|
||||
return model_cls.model_fields
|
||||
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
|
||||
|
||||
|
||||
def _get_alias_names(
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
alias_path_args: dict[str, int | None] | None = None,
|
||||
case_sensitive: bool = True,
|
||||
) -> tuple[tuple[str, ...], bool]:
|
||||
"""Get alias names for a field, handling alias paths and case sensitivity."""
|
||||
from pydantic import AliasChoices, AliasPath
|
||||
|
||||
alias_names: list[str] = []
|
||||
is_alias_path_only: bool = True
|
||||
if not any((field_info.alias, field_info.validation_alias)):
|
||||
alias_names += [field_name]
|
||||
is_alias_path_only = False
|
||||
else:
|
||||
new_alias_paths: list[AliasPath] = []
|
||||
for alias in (field_info.alias, field_info.validation_alias):
|
||||
if alias is None:
|
||||
continue
|
||||
elif isinstance(alias, str):
|
||||
alias_names.append(alias)
|
||||
is_alias_path_only = False
|
||||
elif isinstance(alias, AliasChoices):
|
||||
for name in alias.choices:
|
||||
if isinstance(name, str):
|
||||
alias_names.append(name)
|
||||
is_alias_path_only = False
|
||||
else:
|
||||
new_alias_paths.append(name)
|
||||
else:
|
||||
new_alias_paths.append(alias)
|
||||
for alias_path in new_alias_paths:
|
||||
name = cast(str, alias_path.path[0])
|
||||
name = name.lower() if not case_sensitive else name
|
||||
if alias_path_args is not None:
|
||||
alias_path_args[name] = (
|
||||
alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None
|
||||
)
|
||||
if not alias_names and is_alias_path_only:
|
||||
alias_names.append(name)
|
||||
if not case_sensitive:
|
||||
alias_names = [alias_name.lower() for alias_name in alias_names]
|
||||
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
|
||||
|
||||
|
||||
def _is_function(obj: Any) -> bool:
|
||||
"""Check if an object is a function."""
|
||||
from types import BuiltinFunctionType, FunctionType
|
||||
|
||||
return isinstance(obj, (FunctionType, BuiltinFunctionType))
|
||||
|
||||
|
||||
__all__ = [
|
||||
'_annotation_contains_types',
|
||||
'_annotation_enum_name_to_val',
|
||||
'_annotation_enum_val_to_name',
|
||||
'_annotation_is_complex',
|
||||
'_annotation_is_complex_inner',
|
||||
'_get_alias_names',
|
||||
'_get_env_var_key',
|
||||
'_get_model_fields',
|
||||
'_is_function',
|
||||
'_parse_env_none_str',
|
||||
'_strip_annotated',
|
||||
'_union_is_complex',
|
||||
'parse_env_vars',
|
||||
]
|
||||
42
venv/lib/python3.11/site-packages/pydantic_settings/utils.py
Normal file
42
venv/lib/python3.11/site-packages/pydantic_settings/utils.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, _GenericAlias, get_origin # type: ignore [attr-defined]
|
||||
|
||||
_PATH_TYPE_LABELS = {
|
||||
Path.is_dir: 'directory',
|
||||
Path.is_file: 'file',
|
||||
Path.is_mount: 'mount point',
|
||||
Path.is_symlink: 'symlink',
|
||||
Path.is_block_device: 'block device',
|
||||
Path.is_char_device: 'char device',
|
||||
Path.is_fifo: 'FIFO',
|
||||
Path.is_socket: 'socket',
|
||||
}
|
||||
|
||||
|
||||
def path_type_label(p: Path) -> str:
|
||||
"""
|
||||
Find out what sort of thing a path is.
|
||||
"""
|
||||
assert p.exists(), 'path does not exist'
|
||||
for method, name in _PATH_TYPE_LABELS.items():
|
||||
if method(p):
|
||||
return name
|
||||
|
||||
return 'unknown' # pragma: no cover
|
||||
|
||||
|
||||
# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
|
||||
# once we drop support for Python 3.10.
|
||||
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if get_origin(cls) is not None:
|
||||
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
|
||||
# (e.g. list[int])
|
||||
return False
|
||||
raise
|
||||
|
||||
|
||||
_WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType)
|
||||
|
|
@ -0,0 +1 @@
|
|||
VERSION = '2.12.0'
|
||||
Loading…
Add table
Add a link
Reference in a new issue