Source code for pyobs.utils.serialization

from __future__ import annotations

import datetime
from abc import ABCMeta
from typing import Any, TypeVar, Self

from astropy.coordinates import EarthLocation
from pydantic import BaseModel as PydanticBaseModel, model_serializer, model_validator, ConfigDict, PrivateAttr
from pydantic_core.core_schema import ValidatorFunctionWrapHandler, ValidationInfo
from astroplan import Observer

from pyobs.comm import Comm
from pyobs.object import PrivateAttrMixin
from pyobs.vfs import VirtualFileSystem

"""Class of an Object."""
ObjectClass = TypeVar("ObjectClass")


[docs] class BaseModel(PydanticBaseModel, PrivateAttrMixin): """Pydantic base model for pyobs classes that need to be serialized.""" _timezone: datetime.tzinfo | None = PrivateAttr(default=None) _location: EarthLocation | None = PrivateAttr(default=None) _vfs: VirtualFileSystem | None = PrivateAttr(default=None) _observer: Observer | None = PrivateAttr(default=None) _comm: Comm | None = PrivateAttr(default=None) model_config = ConfigDict(arbitrary_types_allowed=True) @model_validator(mode="after") def _inject_context_into_children(self, info: ValidationInfo) -> Self: if info.context is not None: self._comm = info.context.get("comm") self._observer = info.context.get("observer") self._vfs = info.context.get("vfs") self._timezone = info.context.get("timezone") self._location = info.context.get("location") return self
[docs] class PolymorphicBaseModel(BaseModel, metaclass=ABCMeta): """Pydantic base model for pyobs sub classes that need to be serialized.""" @model_serializer(mode="wrap") def inject_class_on_serialization(self, handler: ValidatorFunctionWrapHandler) -> dict[str, Any]: # Collect fields from the concrete runtime type to avoid Pydantic v2 # resolving field schemas against the abstract base type when nested in a parent model result = {field_name: getattr(self, field_name) for field_name in type(self).model_fields} result["class"] = f"{self.__module__}.{self.__class__.__name__}" return result
[docs] @model_validator(mode="wrap") @classmethod def retrieve_class_on_deserialization( cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo ) -> Any: """Get the correct class for this model and run model_validate on that class with the current context.""" if isinstance(value, dict): from pyobs.object import get_class_from_string modified_value = value.copy() sub_cls_name = modified_value.pop("class", None) if sub_cls_name is not None: klass = get_class_from_string(sub_cls_name) return klass.model_validate(modified_value, context=info.context) return handler(value)
__all__ = ["BaseModel", "PolymorphicBaseModel"]