Source code for pyobs.comm.comm

from __future__ import annotations

import asyncio
import functools
import inspect
import logging
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, overload

import pyobs.interfaces
from pyobs.events import Event, LogEvent, ModuleClosedEvent
from pyobs.interfaces import Interface
from pyobs.utils.enums import ModuleState

from .commlogging import CommLoggingHandler
from .proxy import Proxy, ProxyType, _ProxyContext

if TYPE_CHECKING:
    from pyobs.modules import Module

StateCallback = Callable[[Any], None]
PresenceCallback = Callable[["ModuleState", str], None]

log = logging.getLogger(__name__)


class Comm:
    """Base class for all Comm modules in pyobs."""

    __module__ = "pyobs.comm"

    def __init__(self) -> None:
        """Creates a comm module."""

        self._proxies: dict[str, Proxy] = {}
        self._state_subscriptions: dict[str, list[tuple[type[Interface], StateCallback]]] = {}
        self._module: Module | None = None
        self._log_queue: asyncio.Queue[LogEvent] = asyncio.Queue()
        self._logging_task: asyncio.Task[Any] | None = None
        self._event_handlers: dict[type[Event], list[Callable[[Event, str], Coroutine[Any, Any, bool]]]] = {}
        self._closing = asyncio.Event()

    @property
    def has_module(self) -> bool:
        return self._module is not None

    @property
    def module(self) -> Module:
        """The module that this Comm object is attached to."""
        if self._module is None:
            raise ValueError("No module.")
        return self._module

    @module.setter
    def module(self, module: Module) -> None:
        """The module that this Comm object is attached to."""
        # if we have a _set_module method, call it
        self._set_module(module)

        # store module
        self._module = module

    def _set_module(self, module: Module) -> None: ...

[docs] async def open(self) -> None: """Open module.""" # add handler to global logger root_logger = logging.getLogger() if not any(isinstance(h, CommLoggingHandler) for h in root_logger.handlers): from pyobs.utils.logging.context import ModuleNameFilter handler = CommLoggingHandler(self) handler.setLevel(logging.INFO) handler.addFilter(ModuleNameFilter()) root_logger.addHandler(handler) # start logging thread self._logging_task = asyncio.create_task(self._logging()) # some events await self.register_event(ModuleClosedEvent, self._client_disconnected)
[docs] async def close(self) -> None: """Close module.""" self._closing.set() # close thread if self._logging_task: self._logging_task.cancel() self._logging_task = None
def _get_full_client_name(self, name: str) -> str: """Returns full name for given client. Some Comm modules may use short names for their clients. This methods returns the full name for a given short name. Args: name: Short name to get full name for. Returns: Full name for given client. """ # this base class doesn't have short names return name async def _get_client(self, client: str) -> Module | Proxy | None: """Get a proxy to the given client. Args: client: Name of client. Returns: Proxy class for given client. """ # return module, if "main" is requested if client == "main": return self.module if client is None: return None # if client doesn't exist or we disabled caching, fetch a new proxy if client not in self._proxies: # get interfaces try: interfaces = await self.get_interfaces(client) except IndexError: return None # collect capabilities (fixed at proxy construction time) capabilities: dict[type[Interface], Any] = {} for interface in interfaces: if interface.capabilities is not None: cap = await self._get_capabilities(client, interface) if cap is not None: capabilities[interface] = cap # create new proxy proxy = Proxy(self, client, interfaces, capabilities) # subscribe to state for interface in interfaces: if interface.state is not None: await self.subscribe_state(client, interface, functools.partial(proxy.update_state, interface)) self._proxies[client] = proxy # return proxy return self._proxies[client] async def _resolve_proxy( self, name_or_object: str | object, obj_type: type[ProxyType] | None = None ) -> Any | ProxyType: """Returns object directly if it is of given type. Otherwise get proxy of client with given name and check type. If name_or_object is an object: - If it is of type (or derived), return object. - Otherwise raise exception. If name_name_or_object is string: - Create proxy from name and raise exception, if it doesn't exist. - Check type and raise exception if wrong. - Return object. Args: name_or_object: Name of object or object itself. obj_type: Expected class of object. Returns: Object or proxy to object. Raises: ValueError: If proxy does not exist or wrong type. """ if obj_type is not None and isinstance(name_or_object, obj_type): # return directly return name_or_object elif isinstance(name_or_object, str): # get proxy try: proxy = await self._get_client(name_or_object) except KeyError: raise ValueError(f"Could not get proxy for {name_or_object}.") # check it if proxy is None: raise ValueError(f'Could not create proxy for given name "{name_or_object}".') elif obj_type is None or isinstance(proxy, obj_type): return proxy else: message = f'Proxy obtained from given name "{name_or_object}" is not of requested type "{obj_type}".' hint = self._diagnose_missing_interface(name_or_object, obj_type) if hint is not None: message += f" {hint}" raise ValueError(message) else: # completely wrong... raise ValueError(f'Given parameter is neither a name nor an object of requested type "{obj_type}".') def _diagnose_missing_interface(self, client: str, obj_type: type[Any]) -> str | None: """Backend hook, called when a proxy exists but doesn't implement obj_type. Backends that can tell "genuinely missing" apart from "present at an incompatible version" using data already in hand return a hint string to append to the ValueError; other backends just return None. """ return None async def _safe_resolve_proxy( self, name_or_object: str | object, obj_type: type[ProxyType] | None = None ) -> Any | ProxyType | None: """Calls proxy() in a safe way and returns None instead of raising an exception.""" try: return await self._resolve_proxy(name_or_object, obj_type) except ValueError: return None @overload def proxy(self, name_or_object: str | object, obj_type: type[ProxyType]) -> _ProxyContext[ProxyType]: ... @overload def proxy(self, name_or_object: str | object, obj_type: None = None) -> _ProxyContext[Any]: ...
[docs] def proxy(self, name_or_object: str | object, obj_type: type[ProxyType] | None = None) -> _ProxyContext[Any]: """Returns a context manager; use as `async with self.proxy(...) as x:`.""" return _ProxyContext(self._resolve_proxy(name_or_object, obj_type))
@overload def safe_proxy( self, name_or_object: str | object, obj_type: type[ProxyType] ) -> _ProxyContext[ProxyType | None]: ... @overload def safe_proxy(self, name_or_object: str | object, obj_type: None = None) -> _ProxyContext[Any]: ...
[docs] def safe_proxy(self, name_or_object: str | object, obj_type: type[ProxyType] | None = None) -> _ProxyContext[Any]: """Same as proxy(), but yields None inside the block instead of raising.""" return _ProxyContext(self._safe_resolve_proxy(name_or_object, obj_type))
[docs] async def has_proxy(self, name_or_object: str | object, obj_type: type[Any] | None = None) -> bool: """True if a proxy of the given type can currently be resolved. Doesn't keep a reference to it, so doesn't need async with the way proxy()/safe_proxy() do.""" return await self._safe_resolve_proxy(name_or_object, obj_type) is not None
async def _client_disconnected(self, event: Event, sender: str) -> bool: """Called when a client disconnects. Args: event: Disconnect event. sender: Name of client that disconnected. """ # if a client disconnects, clear its proxy state then evict it if sender in self._proxies: self._proxies[sender].clear_state() del self._proxies[sender] # tear down any state subscriptions held for that client for interface, callback in self._state_subscriptions.pop(sender, []): await self.unsubscribe_state(sender, interface, callback) return True @property def name(self) -> str | None: """Name of this client.""" raise NotImplementedError @property def clients(self) -> list[str]: """Returns list of currently connected clients. Returns: (list) list of currently connected clients. """ raise NotImplementedError
[docs] async def clients_with_interface(self, interface: type[Interface]) -> list[str]: """Returns list of currently connected clients that implement the given interface. Args: interface: Interface to search for. Returns: (list) list of currently connected clients that implement the given interface. """ return [c for c in self.clients if await self._supports_interface(c, interface)]
[docs] async def get_interfaces(self, client: str) -> list[type[Interface]]: """Returns list of interfaces for given client. Args: client: Name of client. Returns: list of supported interfaces. Raises: IndexError: If client cannot be found. """ raise NotImplementedError
async def _supports_interface(self, client: str, interface: type[Interface]) -> bool: """Checks, whether the given client supports the given interface. Args: client: Client to check. interface: Interface to check. Returns: Whether or not interface is supported. """ raise NotImplementedError @staticmethod def _interface_names_to_classes(interfaces: list[str]) -> list[type[Interface]]: """Converts a list of interface names to interface classes. Args: interfaces: list of interface names. Returns: list of interface classes. """ # get interface classes inspection = inspect.getmembers(pyobs.interfaces, predicate=inspect.isclass) # loop interfaces interface_classes = [] for interface_name in interfaces: # loop all classes found = False for cls_name, cls in inspection: # class needs to face same name and implement Interface if interface_name == cls_name and issubclass(cls, Interface): # found it! found = True # then add it to the list of all interfaces interface_classes.append(cls) # there can only be one... break # not found? if not found: log.error('Could not find interface "%s" for client.', interface_name) return interface_classes
[docs] async def execute(self, client: str, method: str, annotation: dict[str, Any], *args: Any) -> Any: """Execute a given method on a remote client. Args: client (str): ID of client. method (str): Method to call. annotation: Method annotation. *args: list of parameters for given method. Returns: Passes through return from method call. """ raise NotImplementedError
async def _logging(self) -> None: """Background thread for handling the logging.""" # run until closing while True: # get item (maybe wait for it) and send it try: entry = self._log_queue.get_nowait() await self.send_event(entry) except asyncio.QueueEmpty: # if queue is empty, sleep a little await asyncio.sleep(1) except asyncio.CancelledError: return except Exception: log.exception("Something went wrong") pass
[docs] def log_message(self, entry: LogEvent) -> None: """Send a log message to other clients. Args: entry (LogEvent): Log event to send. """ self._log_queue.put_nowait(entry)
[docs] async def send_event(self, event: Event) -> None: """Send an event to other clients. Args: event (Event): Event to send """ pass
def _get_derived_events(self, event: type[Event]) -> list[type[Event]]: """Return list of given event itself and all events derived from it. Args: event: Event class to check. Returns: list of event classes. """ import pyobs.events event_classes: list[type[Event]] = [] for cls in inspect.getmembers(pyobs.events, inspect.isclass): if issubclass(cls[1], event): event_classes.append(cls[1]) return event_classes
[docs] async def register_event( self, event_class: type[Event], handler: Callable[[Event, str], Coroutine[Any, Any, bool]] | None = None ) -> None: """Register an event type. If a handler is given, we also receive those events, otherwise we just send them. Args: event_class: Class of event to register. handler: Event handler method. """ # we also want to register all events derived from the given one event_classes = self._get_derived_events(event_class) # do we have a handler? if handler: # loop classes for ev in event_classes: # initialize list if ev not in self._event_handlers: self._event_handlers[ev] = [] # avoid duplicates if handler not in self._event_handlers[ev]: # add handler self._event_handlers[ev].append(handler) # if event is not a local one, we also need to do some XMPP stuff if not event_class.local: await self._register_events(event_classes, handler)
async def _register_events( self, events: list[type[Event]], handler: Callable[[Event, str], Coroutine[Any, Any, bool]] | None = None ) -> None: pass
[docs] async def set_state(self, interface: type[Interface], state: Any) -> None: """Publish state for this module. Args: interface: Interface type for the state. state: State object to publish. """ await self._set_state(interface, state)
async def _set_state(self, interface: type[Interface], state: Any) -> None: pass
[docs] async def set_capabilities(self, interface: type[Interface], capabilities: Any) -> None: """Publish capabilities for this module. Called by Module.open() for each interface that defines a capabilities dataclass. Not intended to be called directly by module authors after that point — capabilities are fixed for the module lifetime. Args: interface: Interface type the capabilities belong to. capabilities: Capabilities dataclass instance. """ await self._set_capabilities(interface, capabilities)
async def _set_capabilities(self, interface: type[Interface], capabilities: Any) -> None: pass
[docs] async def get_capabilities(self, module: str, interface: type[Interface]) -> Any | None: """Fetch and deserialize capabilities for a remote module's interface. Args: module: Module name (e.g. "camera"). interface: Interface class whose Capabilities dataclass to fetch. Returns: Deserialized Capabilities dataclass instance, or None if not published. """ return await self._get_capabilities(module, interface)
async def _get_capabilities(self, module: str, interface: type[Interface]) -> Any | None: return None
[docs] async def set_presence(self, state: ModuleState, error_string: str = "") -> None: """Publish presence for this module (module lifecycle state). Called automatically by Module.set_state — not intended to be called directly by module authors. Args: state: Current module lifecycle state. error_string: Error message, used when state is ERROR. """ await self._set_presence(state, error_string)
async def _set_presence(self, state: ModuleState, error_string: str = "") -> None: pass
[docs] def get_client_state(self, module: str) -> tuple[ModuleState, str] | None: """Return the last known presence state of a connected module. Returns (ModuleState, error_string) or None if the module is not connected. This replaces the old get_state()/get_error_string() RPC pattern. """ return self._get_client_state(module)
def _get_client_state(self, module: str) -> tuple[ModuleState, str] | None: return None
[docs] def get_own_state(self, interface: type[Interface]) -> Any: """Return the last state published by this module for the given interface, or None.""" return self._get_own_state(interface)
def _get_own_state(self, interface: type[Interface]) -> Any: return None
[docs] async def subscribe_state( self, module: str, interface: type[Interface], callback: StateCallback, ) -> None: """Subscribe to state updates for a given module and interface. Delivers the current value immediately on subscribe. Args: module: Name of remote module. interface: Interface type to subscribe to. callback: Called with state object on each update. """ self._state_subscriptions.setdefault(module, []).append((interface, callback)) await self._subscribe_state(module, interface, callback)
async def _subscribe_state( self, module: str, interface: type[Interface], callback: StateCallback, ) -> None: pass
[docs] async def unsubscribe_state( self, module: str, interface: type[Interface], callback: StateCallback, ) -> None: """Unsubscribe from state updates. Args: module: Name of remote module. interface: Interface type to unsubscribe from. callback: Callback that was registered. """ await self._unsubscribe_state(module, interface, callback)
async def _unsubscribe_state( self, module: str, interface: type[Interface], callback: StateCallback, ) -> None: pass
[docs] async def subscribe_presence(self, module: str, callback: PresenceCallback) -> None: """Subscribe to presence updates for a given module. Delivers the current value immediately, then on every change. Args: module: Name of remote module. callback: Called with (ModuleState, error_string) on each update. """ await self._subscribe_presence(module, callback)
async def _subscribe_presence(self, module: str, callback: PresenceCallback) -> None: pass def _send_event_to_module(self, event: Event, from_client: str) -> None: """Send an event to all connected modules. Args: event: Event to send. from_client: Client that sent the event. """ # send it if event.__class__ in self._event_handlers: for handler in self._event_handlers[event.__class__]: # handle it ret = handler(event, from_client) if asyncio.iscoroutine(ret): asyncio.create_task(ret) __all__ = ["Comm"]