| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349 |
- # observable.py/Open GoPro, Version 2.0 (C) Copyright 2021 GoPro, Inc. (http://gopro.com/OpenGoPro).
- # This copyright was auto-generated on Mon May 12 23:03:50 UTC 2025
- """Observable / observer async generators to process asynchronous data stream.
- An Observable is a source of asynchronous data that can be observed by multiple observers. Observers can
- subscribe to the observable to retrieve an async generator observe and asynchronously receive updates when new data is
- emitted. The observer can also perform actions on the data as it is emitted, such as filtering or transforming the data
- using the [asyncstdlib](https://pypi.org/project/asyncstdlib) library to manipulate the async generator.
- """
- # pylint: disable=redefined-builtin
- from __future__ import annotations
- import asyncio
- import logging
- from dataclasses import dataclass, field
- from inspect import iscoroutinefunction
- from typing import (
- Any,
- AsyncGenerator,
- Callable,
- Coroutine,
- Final,
- Generic,
- NoReturn,
- Self,
- TypeAlias,
- TypeVar,
- )
- from uuid import UUID, uuid1
- from asyncstdlib import anext, filter
- O = TypeVar("O")
- T = TypeVar("T")
- SyncAction: TypeAlias = Callable[[T], None]
- AsyncAction: TypeAlias = Callable[[T], Coroutine[Any, Any, None]]
- SyncFilter: TypeAlias = Callable[[T], bool]
- AsyncFilter: TypeAlias = Callable[[T], Coroutine[Any, Any, bool]]
- logger = logging.getLogger(__name__)
- T_I = TypeVar("T_I")
- class Observer(AsyncGenerator[T, None]):
- """Async generator wrapper with added control methods"""
- def __init__(self, observable: Observable[T], uuid: UUID, replay: int, debug_id: str | None = None) -> None:
- self._observable = observable
- self._uuid = uuid
- self._replay = replay
- self._debug_id = debug_id or str(uuid)
- self._is_active = False
- def __aiter__(self) -> Observer[T]:
- return self
- async def __anext__(self) -> T:
- if not self._is_active:
- self._is_active = True
- await self._observable._add_observer(self._uuid, replay=self._replay)
- try:
- logger.trace(f"Observer ({self._debug_id}) waiting for next value") # type: ignore
- value = await self._observable._get_next(self._uuid)
- logger.trace(f"Observer ({self._debug_id}) received value: {value}") # type: ignore
- return value
- except Exception as e:
- logger.error(f"Error in observer {self._debug_id}: {repr(e)}")
- await self._cleanup()
- raise e
- async def first(self, predicate: SyncFilter) -> T:
- """Get the first value that matches the predicate
- Args:
- predicate (SyncFilter): Predicate to match
- Returns:
- T: First value that matches the predicate
- """
- return await anext(filter(predicate, self))
- async def _cleanup(self) -> None:
- """Clean up resources when generator is done"""
- if self._is_active:
- self._is_active = False
- await self._observable._remove_observer(self._uuid)
- async def aclose(self) -> None:
- """Properly close the generator and clean up resources"""
- await self._cleanup()
- async def athrow(self, typ: Any, val: Any = None, tb: Any = None) -> NoReturn:
- """Throw an exception into the generator"""
- if not self._is_active:
- raise StopAsyncIteration
- # Cleanup first
- await self._cleanup()
- # Then raise the exception
- if val is None:
- val = typ()
- if tb is not None:
- raise val.with_traceback(tb)
- raise val
- async def asend(self, value: None) -> T:
- """Send a value into the generator (required by protocol)"""
- if not self._is_active:
- raise StopAsyncIteration
- # We don't really use the sent value, so just advance to next item
- return await anext(self)
- class Observable(Generic[T]):
- """The source of asynchronous data
- Attributes:
- REPLAY_ALL (Final[int]): Special integer value to indicate all values should be replayed
- OBS_IDX (int): counter of observable instantiations used for debugging
- Args:
- capacity (int): Maximum values to store for replay. Defaults to 100.
- debug_id (str | None): Identifier to log for debugging. Defaults to None (will use generated UUID).
- """
- REPLAY_ALL: Final[int] = -1
- OBS_IDX: int = 0
- @dataclass
- class _SharedData(Generic[T_I]):
- """Common data used for internal management that should be accessed in critical sections"""
- current: T_I | None = None
- cache: list[T_I] = field(default_factory=list)
- q_dict: dict[UUID, asyncio.Queue[T_I]] = field(default_factory=dict)
- def __post_init__(self) -> None: # noqa
- self._condition = asyncio.Condition()
- async def __aenter__(self) -> Observable._SharedData[T_I]: # noqa
- await self._condition.acquire()
- return self
- async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa
- self._condition.release()
- def __init__(self, capacity: int = 100, debug_id: str | None = None) -> None:
- self._lock = asyncio.Condition()
- self._count = 0
- self._capacity = capacity
- self._debug_id = debug_id or str(Observable.OBS_IDX)
- Observable.OBS_IDX += 1
- self._on_start_actions: list[SyncAction[T] | AsyncAction[T]] = []
- self._on_subscribe_actions: list[Callable[[], None] | Callable[[], Coroutine[Any, Any, None]]] = []
- self._per_value_actions: list[SyncAction[T] | AsyncAction[T]] = []
- self._shared_data = Observable._SharedData[T]()
- # TODO handle cleanup
- async def _add_observer(self, uuid: UUID, replay: int) -> None:
- """Add an observer to receive collected values
- Args:
- uuid (UUID): observer identifier
- replay (int): how many values to replay from cache
- """
- async with self._shared_data:
- if uuid not in self._shared_data.q_dict:
- self._shared_data.q_dict[uuid] = asyncio.Queue()
- if replay == Observable.REPLAY_ALL:
- replay = len(self._shared_data.cache)
- head = max(len(self._shared_data.cache) - replay, 0)
- for value in self._shared_data.cache[head:]:
- self._shared_data.q_dict[uuid].put_nowait(value)
- async def _remove_observer(self, uuid: UUID) -> None:
- """Remove an observer from receiving collected values
- Args:
- uuid (UUID): observer identifier
- """
- async with self._shared_data:
- if uuid in self._shared_data.q_dict:
- del self._shared_data.q_dict[uuid]
- async def emit(self, value: T) -> None:
- """Receive a value and queue it for per-observer retrieval
- Not intended to be used by the observer.
- Args:
- value (T): Value to queue
- """
- async with self._shared_data:
- self._shared_data.current = value
- self._shared_data.cache.append(value)
- if len(self._shared_data.cache) > self._capacity:
- self._shared_data.cache.pop(0)
- for uuid, q in self._shared_data.q_dict.items():
- logger.trace(f"Observable {self._debug_id} emitting {value} to observer {uuid}") # type: ignore
- await q.put(value)
- def _mux_action(
- self,
- action: SyncAction[T] | AsyncAction[T],
- value: T,
- tg: asyncio.TaskGroup | None = None,
- ) -> None:
- """Execute an action that is either synchronous or asynchronous
- If tg is passed, the async action will be added to the task group. Otherwise an anonymous async task will be
- created. In both cases, this function will return without awaiting the created task.
- Note! If action is synchronous, this will block until the action returns.
- Args:
- action (SyncAction[T] | AsyncAction[T]): action to execute
- value (T): value to pass to action
- tg (asyncio.TaskGroup | None, optional): Task group to add async action. Defaults to None (don't add to any
- task group).
- """
- if iscoroutinefunction(action):
- if tg:
- tg.create_task(action(value))
- else:
- asyncio.create_task(action(value))
- else:
- action(value)
- async def _mux_filter_blocking(self, predicate: SyncFilter | AsyncFilter, value: T) -> bool:
- """Execute a filter that is either synchronous or asynchronous
- Note! This will await / block until the action completes.
- Args:
- predicate (SyncFilter | AsyncFilter): Filter to execute
- value (T): value to analyze with predicates
- Returns:
- bool: _description_
- """
- if iscoroutinefunction(predicate):
- return await predicate(value)
- return predicate(value) # type: ignore
- @property
- def current(self) -> T | None:
- """Get the most recently collected value of the observable.
- Note that this does not indicate the value in real-time. It is the most recent value that was collected
- from a receiver.
- Returns:
- T | None: Most recently collected value, or None if no values were collected yet
- """
- return self._shared_data.current
- # TODO what is the difference betwenn this and on_start?
- def on_subscribe(
- self,
- action: Callable[[], None] | Callable[[], Coroutine[Any, Any, None]],
- ) -> Self:
- """Register to receive a callback to be called when the observable starts emitting
- Args:
- action (Callable[[], None] | Callable[[], Coroutine[Any, Any, None]]): Callback
- Returns:
- Self: modified observable
- """
- self._on_subscribe_actions.append(action)
- return self
- def on_start(self, action: SyncAction[T] | AsyncAction[T]) -> Self:
- """Register a callback action to be called when the observable starts emitting
- Args:
- action (SyncAction[T] | AsyncAction[T]): Callback action
- Returns:
- Self: modified observable
- """
- self._on_start_actions.append(action)
- return self
- ####################################################################################################################
- ##### Terminal Operators
- ####################################################################################################################
- def observe(self, replay: int = 1, debug_id: str | None = None) -> Observer[T]:
- """Get an async generator to yield values from the observable
- Args:
- replay (int): how many values to replay from cache. Defaults to 1.
- debug_id (str | None): Identifier for debug logging. Defaults to None (will use generated UUID).
- Returns:
- Observer[T]: async generator to yield values from the observable
- """
- # Create the async generator with a unique ID
- return Observer(self, uuid1(), replay, debug_id=debug_id)
- async def _get_next(self, uuid: UUID) -> T:
- """Get the next per-observer value
- Args:
- uuid (UUID): observer identifier
- Raises:
- RuntimeError: Observer ended without receiving any values
- Returns:
- T: Latest per-observer value
- """
- while True:
- # If this is the first time entering, notify all on-subscribe listeners
- if self._count == 0:
- for action in self._on_subscribe_actions:
- if iscoroutinefunction(action):
- await action()
- action()
- # Acquire the condition and read the per-collector value
- async with self._shared_data:
- if uuid not in self._shared_data.q_dict:
- logger.error("Attempted to get value from a non-registered observer.")
- raise RuntimeError("Observer has not been added!")
- q = self._shared_data.q_dict[uuid]
- # Note! This can't be called inside shared data context as it will cause a deadlock. We've already retrieved
- # the q here which is itself coroutine-safe so just await it.
- value = await q.get()
- self._count += 1
- # If this is the first value, notify on start listeners
- if self._count == 1:
- for action in self._on_start_actions: # type: ignore
- self._mux_action(action, value) # type: ignore
- # Notify per-value actions
- for action in self._per_value_actions: # type: ignore
- action(value) # type: ignore
- # We've made it! Return the continuing value
- return value
|