Source code for sphinxcontrib.autodoc_pydantic.inspection

"""This module contains the inspection functionality for pydantic models. It
is used to retrieve relevant information about fields, validators, config and
schema of pydantical models.

"""

from __future__ import annotations

import inspect
import itertools
import pydoc
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar

try:
    from typing import TypeGuard
except ImportError:
    from typing_extensions import TypeGuard

from pydantic import BaseModel, ConfigDict, PydanticInvalidForJsonSchema, create_model
from pydantic_settings import BaseSettings

ASTERISK_FIELD_NAME = 'all fields'

if TYPE_CHECKING:
    from pydantic.fields import FieldInfo
    from sphinx.addnodes import desc_signature


[docs] class ValidatorAdapter(NamedTuple): """Provide standardized interface to pydantic's validator objects with additional metadata (e.g. root validator) for internal usage in autodoc_pydantic. """ func: Callable @property def name(self) -> str: """Return the validators function name.""" return self.func.__name__ @property def class_name(self) -> str | None: """Return the validators class name. It might be None if validator is not bound to a class. """ qualname = self.func.__qualname__.split('.') if len(qualname) > 1: return qualname[-2] return None @property def module(self) -> str: """Return the validators module name.""" return self.func.__module__ @property def object_path(self) -> str: """Return the fully qualified object path of the validators function.""" return f'{self.func.__module__}.{self.func.__qualname__}' def __hash__(self) -> int: return id(f'{self}')
[docs] class ValidatorFieldMap(NamedTuple): """Contains single mapping of a pydantic validator and field.""" field_name: str """Name of the field.""" validator_name: str """Name of the validator.""" field_ref: str """Reference to field.""" validator_ref: str """Reference to validataor."""
[docs] class BaseInspectionComposite: """Serves as base class for inspector composites which are coupled to `ModelInspector` instances. Each composite provides a separate namespace to handle different areas of pydantic models (e.g. fields and validators). """ def __init__(self, parent: ModelInspector) -> None: self._parent: ModelInspector = parent self.model = self._parent.model
[docs] class FieldInspector(BaseInspectionComposite): """Provide namespace for inspection methods for fields of pydantic models.""" def __init__(self, parent: ModelInspector) -> None: super().__init__(parent) # json schema can reliably be created only at model level self.attribute = self.model.model_fields @property def names(self) -> list[str]: """Return field names while keeping ordering.""" return list(self.attribute.keys())
[docs] def get(self, name: str) -> FieldInfo: """Get the instance of `FieldInfo` for given field `name`.""" return self.attribute[name]
[docs] def get_alias_or_name(self, field_name: str) -> str: """Get the alias of a pydantic field if given. Otherwise, return the field name. """ if field_name == ASTERISK_FIELD_NAME: return field_name alias = self.get(field_name).alias if alias is not None: return alias return field_name
[docs] def get_property_from_field_info(self, field_name: str, property_name: str) -> Any: # noqa: ANN401 """Get specific property value from pydantic's field info.""" field = self.get(field_name) return getattr(field, property_name, None)
@staticmethod def _get_meta_items(meta_class: Any) -> dict[str, str]: # noqa: ANN401 """Helper method to extract constraint names and values from different pydantic Metadata objects such as `pydantic.types.Strict`. """ try: return meta_class.__dataclass_fields__ except AttributeError: return meta_class.__dict__
[docs] def get_constraints(self, field_name: str) -> dict[str, Any]: """Get constraints for given `field_name`.""" metadata = self.model.model_fields[field_name].metadata available = [meta for meta in metadata if meta is not None] return { key: getattr(meta, key) for meta in available for key, value in self._get_meta_items(meta).items() if getattr(meta, key) is not None }
[docs] def is_required(self, field_name: str) -> bool: """Check if a given pydantic field is required/mandatory. Returns True, if a value for this field needs to provided upon model creation. """ return self.get(field_name).is_required()
[docs] def has_default_factory(self, field_name: str) -> bool: """Check if field has a `default_factory` being set. This information is used to determine if a pydantic field is optional or not. """ return self.get(field_name).default_factory is not None
[docs] def is_json_serializable(self, field_name: str) -> bool: """Check if given pydantic field is JSON serializable by calling pydantic's `model.schema()` method. Custom objects might not be serializable and hence would break JSON schema generation. """ field = self.get(field_name) return self._is_json_serializable(field)
@classmethod def _is_json_serializable(cls, field: FieldInfo) -> bool: """Ensure JSON serializability for given pydantic `FieldInfo`.""" # hide user warnings in sphinx output with warnings.catch_warnings(): warnings.simplefilter('ignore') return cls._test_field_serializabiltiy(field) @staticmethod def _test_field_serializabiltiy(field: FieldInfo) -> bool: """Test JSON serializability for given pydantic `FieldInfo`.""" model_config = ConfigDict(arbitrary_types_allowed=True) try: field_args = (field.annotation, field.default) model = create_model('_', __config__=model_config, test_field=field_args) model.model_json_schema() except Exception: # noqa: BLE001 return False else: return True @property def non_json_serializable(self) -> list[str]: """Get all fields that can't be safely JSON serialized.""" return [name for name in self.names if not self.is_json_serializable(name)] def __bool__(self) -> bool: """Equals to False if no fields are present.""" return bool(self.attribute)
[docs] class ValidatorInspector(BaseInspectionComposite): """Provide namespace for inspection methods for validators of pydantic models. """ def __init__(self, parent: ModelInspector) -> None: super().__init__(parent) @property def values(self) -> set[ValidatorAdapter]: """Returns set of all available validators.""" all_validators = self._parent.field_validator_mappings.values() flattened = itertools.chain.from_iterable(all_validators) return set(flattened)
[docs] def get_reused_validators_names(self) -> list[str]: """Identify all reused validators. This is done implicitly by relying on the fact the reused validators are registered as unbound functions instead of bound methods. """ validators = self.model.__pydantic_decorators__.field_validators return [ x.cls_var_name for x in validators.values() if inspect.isfunction(x.func) ]
@property def names(self) -> set[str]: """Return names of all validators of pydantic model.""" return {validator.name for validator in self.values} def __bool__(self) -> bool: """Equals to False if no validators are present.""" return bool(self.values)
[docs] class ConfigInspector(BaseInspectionComposite): """Provide namespace for inspection methods for config class of pydantic models. """ def __init__(self, parent: ModelInspector) -> None: super().__init__(parent) self.items = self._get_values_per_type() def _get_values_per_type(self) -> dict[str, Any]: """Get the configuration values from any pydantic model. Behavior of configuration values varies between `BaseModel` and BaseSettings`. For `BaseModel`, if no configs are provided, then model_config` is empty. However, for `BaseSettings`, `model_config` contains a predefined set of values. This needs to be handled properly otherwise the `BaseSettings` always show up a lot of irrelevant default values. Hence, the default values are removed. """ cfg = self.model.model_config if issubclass(self.model, BaseSettings): default = tuple(BaseSettings.model_config.items()) available = tuple(cfg.items()) result = [given for given in available if given not in default] values = dict(result) else: values = dict(cfg) return values @property def is_configured(self) -> bool: """Check if pydantic model config was explicitly configured.""" return bool(self.items)
[docs] class ReferenceInspector(BaseInspectionComposite): """Provide namespace for inspection methods for creating references mainly between pydantic fields and validators. Importantly, `mappings` provides the set of all `ValidatorFieldMap` instances which contain all references between fields and validators. """ def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 super().__init__(*args, **kwargs) self.mappings = self._create_mappings() @property def model_path(self) -> str: """Retrieve the full path of the model.""" return f'{self.model.__module__}.{self.model.__name__}'
[docs] def create_model_reference(self, name: str) -> str: """Create reference for given attribute `name` returning full path including the model path. """ return f'{self.model_path}.{name}'
def _create_mappings(self) -> set[ValidatorFieldMap]: """Generate reference mappings between validators and corresponding fields. """ mappings = set() for field, validators in self._parent.field_validator_mappings.items(): is_aserisk = field == '*' field_name = ASTERISK_FIELD_NAME if is_aserisk else field for validator in validators: mapping = ValidatorFieldMap( field_name=field_name, field_ref=f'{self.model_path}.{field_name}', validator_name=validator.name, validator_ref=validator.object_path, ) mappings.add(mapping) return mappings
[docs] def filter_by_validator_name(self, name: str) -> list[ValidatorFieldMap]: """Return mappings for given validator `name`.""" return [mapping for mapping in self.mappings if mapping.validator_name == name]
[docs] def filter_by_field_name(self, name: str) -> list[ValidatorFieldMap]: """Return mappings for given field `name`.""" return [ mapping for mapping in self.mappings if mapping.field_name in (name, ASTERISK_FIELD_NAME) ]
[docs] class SchemaInspector(BaseInspectionComposite): """Provide namespace for inspection methods for general properties of pydantic models. """ @property def sanitized(self) -> dict: """Get model's `schema` while handling non serializable fields. Such fields will be replaced by TypeVars. """ try: with warnings.catch_warnings(): warnings.simplefilter('ignore') schema = self.model.model_json_schema() except (TypeError, ValueError, PydanticInvalidForJsonSchema): new_model = self.create_sanitized_model() schema = new_model.model_json_schema() keys_order = ['title', 'description', 'type', 'properties'] reordered_schema = {k: schema[k] for k in keys_order if k in schema} reordered_schema.update(schema) return reordered_schema
[docs] def create_sanitized_model(self) -> type[BaseModel]: """Generates a new pydantic model from the original one while substituting invalid fields with typevars. """ invalid_fields = self._parent.fields.non_json_serializable new: dict[str, Any] = {name: (TypeVar(name), None) for name in invalid_fields} return create_model(self.model.__name__, __base__=self.model, **new)
[docs] class StaticInspector: """Namespace under `ModelInspector` for static methods."""
[docs] @staticmethod def is_pydantic_model(obj: Any) -> TypeGuard[type[BaseModel]]: # noqa: ANN401 """Determine if object is a valid pydantic model.""" try: return issubclass(obj, BaseModel) except TypeError: return False
[docs] @classmethod def is_pydantic_field(cls, parent: Any, field_name: str) -> bool: # noqa: ANN401 """Determine if given `field` is a pydantic field.""" if not cls.is_pydantic_model(parent): return False return field_name in parent.model_fields
[docs] @classmethod def is_validator_by_name(cls, name: str, obj: Any) -> bool: # noqa: ANN401 """Determine if a validator is present under provided `name` for given `model`. """ if cls.is_pydantic_model(obj): inspector = ModelInspector(obj) return name in inspector.validators.names return False
[docs] class ModelInspector: """Provides inspection functionality for pydantic models.""" static = StaticInspector def __init__(self, model: type[BaseModel]) -> None: self.model = model self.field_validator_mappings = self.get_field_validator_mapping() self.config = ConfigInspector(self) self.schema = SchemaInspector(self) self.fields = FieldInspector(self) self.validators = ValidatorInspector(self) self.references = ReferenceInspector(self)
[docs] def get_field_validator_mapping(self) -> dict[str, list[ValidatorAdapter]]: """Collect all available validators keyed by their corresponding fields including post/pre root validators. Validators are wrapped into `ValidatorAdapters` to provide uniform interface within autodoc_pydantic. """ mapping: dict[str, list[Any]] = defaultdict(list) decorators = self.model.__pydantic_decorators__ # field validators for field_validator in decorators.field_validators.values(): for field in field_validator.info.fields: mapping[field].append(ValidatorAdapter(func=field_validator.func)) # model validators for model_validator in decorators.model_validators.values(): mapping['*'].append(ValidatorAdapter(func=model_validator.func)) return mapping
[docs] @classmethod def from_child_signode(cls, signode: desc_signature) -> ModelInspector: """Create instance from a child `signode` as used within sphinx directives. """ model_path_parts = signode['fullname'].split('.')[:-1] model_path = '.'.join(model_path_parts) model = pydoc.locate(f"{signode['module']}.{model_path}") if not cls.static.is_pydantic_model(model): err = ( f"Signode with full name {signode['fullname']} and extracted " f"model path does reference pydantic model. " ) raise ValueError(err) return cls(model)