from __future__ import annotations
import asyncio
import functools
import inspect
import logging
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING, Any, overload
import pyobs.interfaces
from pyobs.events import Event, LogEvent, ModuleClosedEvent
from pyobs.interfaces import Interface
from pyobs.utils.enums import ModuleState
from .commlogging import CommLoggingHandler
from .proxy import Proxy, ProxyType, _ProxyContext
if TYPE_CHECKING:
from pyobs.modules import Module
StateCallback = Callable[[Any], None]
PresenceCallback = Callable[["ModuleState", str], None]
log = logging.getLogger(__name__)
class Comm:
"""Base class for all Comm modules in pyobs."""
__module__ = "pyobs.comm"
def __init__(self) -> None:
"""Creates a comm module."""
self._proxies: dict[str, Proxy] = {}
self._state_subscriptions: dict[str, list[tuple[type[Interface], StateCallback]]] = {}
self._module: Module | None = None
self._log_queue: asyncio.Queue[LogEvent] = asyncio.Queue()
self._logging_task: asyncio.Task[Any] | None = None
self._event_handlers: dict[type[Event], list[Callable[[Event, str], Coroutine[Any, Any, bool]]]] = {}
self._closing = asyncio.Event()
@property
def has_module(self) -> bool:
return self._module is not None
@property
def module(self) -> Module:
"""The module that this Comm object is attached to."""
if self._module is None:
raise ValueError("No module.")
return self._module
@module.setter
def module(self, module: Module) -> None:
"""The module that this Comm object is attached to."""
# if we have a _set_module method, call it
self._set_module(module)
# store module
self._module = module
def _set_module(self, module: Module) -> None: ...
[docs]
async def open(self) -> None:
"""Open module."""
# add handler to global logger
root_logger = logging.getLogger()
if not any(isinstance(h, CommLoggingHandler) for h in root_logger.handlers):
from pyobs.utils.logging.context import ModuleNameFilter
handler = CommLoggingHandler(self)
handler.setLevel(logging.INFO)
handler.addFilter(ModuleNameFilter())
root_logger.addHandler(handler)
# start logging thread
self._logging_task = asyncio.create_task(self._logging())
# some events
await self.register_event(ModuleClosedEvent, self._client_disconnected)
[docs]
async def close(self) -> None:
"""Close module."""
self._closing.set()
# close thread
if self._logging_task:
self._logging_task.cancel()
self._logging_task = None
def _get_full_client_name(self, name: str) -> str:
"""Returns full name for given client.
Some Comm modules may use short names for their clients. This methods returns the full name
for a given short name.
Args:
name: Short name to get full name for.
Returns:
Full name for given client.
"""
# this base class doesn't have short names
return name
async def _get_client(self, client: str) -> Module | Proxy | None:
"""Get a proxy to the given client.
Args:
client: Name of client.
Returns:
Proxy class for given client.
"""
# return module, if "main" is requested
if client == "main":
return self.module
if client is None:
return None
# if client doesn't exist or we disabled caching, fetch a new proxy
if client not in self._proxies:
# get interfaces
try:
interfaces = await self.get_interfaces(client)
except IndexError:
return None
# collect capabilities (fixed at proxy construction time)
capabilities: dict[type[Interface], Any] = {}
for interface in interfaces:
if interface.capabilities is not None:
cap = await self._get_capabilities(client, interface)
if cap is not None:
capabilities[interface] = cap
# create new proxy
proxy = Proxy(self, client, interfaces, capabilities)
# subscribe to state
for interface in interfaces:
if interface.state is not None:
await self.subscribe_state(client, interface, functools.partial(proxy.update_state, interface))
self._proxies[client] = proxy
# return proxy
return self._proxies[client]
async def _resolve_proxy(
self, name_or_object: str | object, obj_type: type[ProxyType] | None = None
) -> Any | ProxyType:
"""Returns object directly if it is of given type. Otherwise get proxy of client with given name and check type.
If name_or_object is an object:
- If it is of type (or derived), return object.
- Otherwise raise exception.
If name_name_or_object is string:
- Create proxy from name and raise exception, if it doesn't exist.
- Check type and raise exception if wrong.
- Return object.
Args:
name_or_object: Name of object or object itself.
obj_type: Expected class of object.
Returns:
Object or proxy to object.
Raises:
ValueError: If proxy does not exist or wrong type.
"""
if obj_type is not None and isinstance(name_or_object, obj_type):
# return directly
return name_or_object
elif isinstance(name_or_object, str):
# get proxy
try:
proxy = await self._get_client(name_or_object)
except KeyError:
raise ValueError(f"Could not get proxy for {name_or_object}.")
# check it
if proxy is None:
raise ValueError(f'Could not create proxy for given name "{name_or_object}".')
elif obj_type is None or isinstance(proxy, obj_type):
return proxy
else:
message = f'Proxy obtained from given name "{name_or_object}" is not of requested type "{obj_type}".'
hint = self._diagnose_missing_interface(name_or_object, obj_type)
if hint is not None:
message += f" {hint}"
raise ValueError(message)
else:
# completely wrong...
raise ValueError(f'Given parameter is neither a name nor an object of requested type "{obj_type}".')
def _diagnose_missing_interface(self, client: str, obj_type: type[Any]) -> str | None:
"""Backend hook, called when a proxy exists but doesn't implement obj_type.
Backends that can tell "genuinely missing" apart from "present at an
incompatible version" using data already in hand return a hint string
to append to the ValueError; other backends just return None.
"""
return None
async def _safe_resolve_proxy(
self, name_or_object: str | object, obj_type: type[ProxyType] | None = None
) -> Any | ProxyType | None:
"""Calls proxy() in a safe way and returns None instead of raising an exception."""
try:
return await self._resolve_proxy(name_or_object, obj_type)
except ValueError:
return None
@overload
def proxy(self, name_or_object: str | object, obj_type: type[ProxyType]) -> _ProxyContext[ProxyType]: ...
@overload
def proxy(self, name_or_object: str | object, obj_type: None = None) -> _ProxyContext[Any]: ...
[docs]
def proxy(self, name_or_object: str | object, obj_type: type[ProxyType] | None = None) -> _ProxyContext[Any]:
"""Returns a context manager; use as `async with self.proxy(...) as x:`."""
return _ProxyContext(self._resolve_proxy(name_or_object, obj_type))
@overload
def safe_proxy(
self, name_or_object: str | object, obj_type: type[ProxyType]
) -> _ProxyContext[ProxyType | None]: ...
@overload
def safe_proxy(self, name_or_object: str | object, obj_type: None = None) -> _ProxyContext[Any]: ...
[docs]
def safe_proxy(self, name_or_object: str | object, obj_type: type[ProxyType] | None = None) -> _ProxyContext[Any]:
"""Same as proxy(), but yields None inside the block instead of raising."""
return _ProxyContext(self._safe_resolve_proxy(name_or_object, obj_type))
[docs]
async def has_proxy(self, name_or_object: str | object, obj_type: type[Any] | None = None) -> bool:
"""True if a proxy of the given type can currently be resolved. Doesn't keep a reference
to it, so doesn't need async with the way proxy()/safe_proxy() do."""
return await self._safe_resolve_proxy(name_or_object, obj_type) is not None
async def _client_disconnected(self, event: Event, sender: str) -> bool:
"""Called when a client disconnects.
Args:
event: Disconnect event.
sender: Name of client that disconnected.
"""
# if a client disconnects, clear its proxy state then evict it
if sender in self._proxies:
self._proxies[sender].clear_state()
del self._proxies[sender]
# tear down any state subscriptions held for that client
for interface, callback in self._state_subscriptions.pop(sender, []):
await self.unsubscribe_state(sender, interface, callback)
return True
@property
def name(self) -> str | None:
"""Name of this client."""
raise NotImplementedError
@property
def clients(self) -> list[str]:
"""Returns list of currently connected clients.
Returns:
(list) list of currently connected clients.
"""
raise NotImplementedError
[docs]
async def clients_with_interface(self, interface: type[Interface]) -> list[str]:
"""Returns list of currently connected clients that implement the given interface.
Args:
interface: Interface to search for.
Returns:
(list) list of currently connected clients that implement the given interface.
"""
return [c for c in self.clients if await self._supports_interface(c, interface)]
[docs]
async def get_interfaces(self, client: str) -> list[type[Interface]]:
"""Returns list of interfaces for given client.
Args:
client: Name of client.
Returns:
list of supported interfaces.
Raises:
IndexError: If client cannot be found.
"""
raise NotImplementedError
async def _supports_interface(self, client: str, interface: type[Interface]) -> bool:
"""Checks, whether the given client supports the given interface.
Args:
client: Client to check.
interface: Interface to check.
Returns:
Whether or not interface is supported.
"""
raise NotImplementedError
@staticmethod
def _interface_names_to_classes(interfaces: list[str]) -> list[type[Interface]]:
"""Converts a list of interface names to interface classes.
Args:
interfaces: list of interface names.
Returns:
list of interface classes.
"""
# get interface classes
inspection = inspect.getmembers(pyobs.interfaces, predicate=inspect.isclass)
# loop interfaces
interface_classes = []
for interface_name in interfaces:
# loop all classes
found = False
for cls_name, cls in inspection:
# class needs to face same name and implement Interface
if interface_name == cls_name and issubclass(cls, Interface):
# found it!
found = True
# then add it to the list of all interfaces
interface_classes.append(cls)
# there can only be one...
break
# not found?
if not found:
log.error('Could not find interface "%s" for client.', interface_name)
return interface_classes
[docs]
async def execute(self, client: str, method: str, annotation: dict[str, Any], *args: Any) -> Any:
"""Execute a given method on a remote client.
Args:
client (str): ID of client.
method (str): Method to call.
annotation: Method annotation.
*args: list of parameters for given method.
Returns:
Passes through return from method call.
"""
raise NotImplementedError
async def _logging(self) -> None:
"""Background thread for handling the logging."""
# run until closing
while True:
# get item (maybe wait for it) and send it
try:
entry = self._log_queue.get_nowait()
await self.send_event(entry)
except asyncio.QueueEmpty:
# if queue is empty, sleep a little
await asyncio.sleep(1)
except asyncio.CancelledError:
return
except Exception:
log.exception("Something went wrong")
pass
[docs]
def log_message(self, entry: LogEvent) -> None:
"""Send a log message to other clients.
Args:
entry (LogEvent): Log event to send.
"""
self._log_queue.put_nowait(entry)
[docs]
async def send_event(self, event: Event) -> None:
"""Send an event to other clients.
Args:
event (Event): Event to send
"""
pass
def _get_derived_events(self, event: type[Event]) -> list[type[Event]]:
"""Return list of given event itself and all events derived from it.
Args:
event: Event class to check.
Returns:
list of event classes.
"""
import pyobs.events
event_classes: list[type[Event]] = []
for cls in inspect.getmembers(pyobs.events, inspect.isclass):
if issubclass(cls[1], event):
event_classes.append(cls[1])
return event_classes
[docs]
async def register_event(
self, event_class: type[Event], handler: Callable[[Event, str], Coroutine[Any, Any, bool]] | None = None
) -> None:
"""Register an event type. If a handler is given, we also receive those events, otherwise we just
send them.
Args:
event_class: Class of event to register.
handler: Event handler method.
"""
# we also want to register all events derived from the given one
event_classes = self._get_derived_events(event_class)
# do we have a handler?
if handler:
# loop classes
for ev in event_classes:
# initialize list
if ev not in self._event_handlers:
self._event_handlers[ev] = []
# avoid duplicates
if handler not in self._event_handlers[ev]:
# add handler
self._event_handlers[ev].append(handler)
# if event is not a local one, we also need to do some XMPP stuff
if not event_class.local:
await self._register_events(event_classes, handler)
async def _register_events(
self, events: list[type[Event]], handler: Callable[[Event, str], Coroutine[Any, Any, bool]] | None = None
) -> None:
pass
[docs]
async def set_state(self, interface: type[Interface], state: Any) -> None:
"""Publish state for this module.
Args:
interface: Interface type for the state.
state: State object to publish.
"""
await self._set_state(interface, state)
async def _set_state(self, interface: type[Interface], state: Any) -> None:
pass
[docs]
async def set_capabilities(self, interface: type[Interface], capabilities: Any) -> None:
"""Publish capabilities for this module.
Called by Module.open() for each interface that defines a capabilities
dataclass. Not intended to be called directly by module authors after
that point — capabilities are fixed for the module lifetime.
Args:
interface: Interface type the capabilities belong to.
capabilities: Capabilities dataclass instance.
"""
await self._set_capabilities(interface, capabilities)
async def _set_capabilities(self, interface: type[Interface], capabilities: Any) -> None:
pass
[docs]
async def get_capabilities(self, module: str, interface: type[Interface]) -> Any | None:
"""Fetch and deserialize capabilities for a remote module's interface.
Args:
module: Module name (e.g. "camera").
interface: Interface class whose Capabilities dataclass to fetch.
Returns:
Deserialized Capabilities dataclass instance, or None if not published.
"""
return await self._get_capabilities(module, interface)
async def _get_capabilities(self, module: str, interface: type[Interface]) -> Any | None:
return None
[docs]
async def set_presence(self, state: ModuleState, error_string: str = "") -> None:
"""Publish presence for this module (module lifecycle state).
Called automatically by Module.set_state — not intended to be called
directly by module authors.
Args:
state: Current module lifecycle state.
error_string: Error message, used when state is ERROR.
"""
await self._set_presence(state, error_string)
async def _set_presence(self, state: ModuleState, error_string: str = "") -> None:
pass
[docs]
def get_client_state(self, module: str) -> tuple[ModuleState, str] | None:
"""Return the last known presence state of a connected module.
Returns (ModuleState, error_string) or None if the module is not connected.
This replaces the old get_state()/get_error_string() RPC pattern.
"""
return self._get_client_state(module)
def _get_client_state(self, module: str) -> tuple[ModuleState, str] | None:
return None
[docs]
def get_own_state(self, interface: type[Interface]) -> Any:
"""Return the last state published by this module for the given interface, or None."""
return self._get_own_state(interface)
def _get_own_state(self, interface: type[Interface]) -> Any:
return None
[docs]
async def subscribe_state(
self,
module: str,
interface: type[Interface],
callback: StateCallback,
) -> None:
"""Subscribe to state updates for a given module and interface.
Delivers the current value immediately on subscribe.
Args:
module: Name of remote module.
interface: Interface type to subscribe to.
callback: Called with state object on each update.
"""
self._state_subscriptions.setdefault(module, []).append((interface, callback))
await self._subscribe_state(module, interface, callback)
async def _subscribe_state(
self,
module: str,
interface: type[Interface],
callback: StateCallback,
) -> None:
pass
[docs]
async def unsubscribe_state(
self,
module: str,
interface: type[Interface],
callback: StateCallback,
) -> None:
"""Unsubscribe from state updates.
Args:
module: Name of remote module.
interface: Interface type to unsubscribe from.
callback: Callback that was registered.
"""
await self._unsubscribe_state(module, interface, callback)
async def _unsubscribe_state(
self,
module: str,
interface: type[Interface],
callback: StateCallback,
) -> None:
pass
[docs]
async def subscribe_presence(self, module: str, callback: PresenceCallback) -> None:
"""Subscribe to presence updates for a given module.
Delivers the current value immediately, then on every change.
Args:
module: Name of remote module.
callback: Called with (ModuleState, error_string) on each update.
"""
await self._subscribe_presence(module, callback)
async def _subscribe_presence(self, module: str, callback: PresenceCallback) -> None:
pass
def _send_event_to_module(self, event: Event, from_client: str) -> None:
"""Send an event to all connected modules.
Args:
event: Event to send.
from_client: Client that sent the event.
"""
# send it
if event.__class__ in self._event_handlers:
for handler in self._event_handlers[event.__class__]:
# handle it
ret = handler(event, from_client)
if asyncio.iscoroutine(ret):
asyncio.create_task(ret)
__all__ = ["Comm"]