observable.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # observable.py/Open GoPro, Version 2.0 (C) Copyright 2021 GoPro, Inc. (http://gopro.com/OpenGoPro).
  2. # This copyright was auto-generated on Mon May 12 23:03:50 UTC 2025
  3. """Observable / observer async generators to process asynchronous data stream.
  4. An Observable is a source of asynchronous data that can be observed by multiple observers. Observers can
  5. subscribe to the observable to retrieve an async generator observe and asynchronously receive updates when new data is
  6. emitted. The observer can also perform actions on the data as it is emitted, such as filtering or transforming the data
  7. using the [asyncstdlib](https://pypi.org/project/asyncstdlib) library to manipulate the async generator.
  8. """
  9. # pylint: disable=redefined-builtin
  10. from __future__ import annotations
  11. import asyncio
  12. import logging
  13. from dataclasses import dataclass, field
  14. from inspect import iscoroutinefunction
  15. from typing import (
  16. Any,
  17. AsyncGenerator,
  18. Callable,
  19. Coroutine,
  20. Final,
  21. Generic,
  22. NoReturn,
  23. Self,
  24. TypeAlias,
  25. TypeVar,
  26. )
  27. from uuid import UUID, uuid1
  28. from asyncstdlib import anext, filter
  29. O = TypeVar("O")
  30. T = TypeVar("T")
  31. SyncAction: TypeAlias = Callable[[T], None]
  32. AsyncAction: TypeAlias = Callable[[T], Coroutine[Any, Any, None]]
  33. SyncFilter: TypeAlias = Callable[[T], bool]
  34. AsyncFilter: TypeAlias = Callable[[T], Coroutine[Any, Any, bool]]
  35. logger = logging.getLogger(__name__)
  36. T_I = TypeVar("T_I")
  37. class Observer(AsyncGenerator[T, None]):
  38. """Async generator wrapper with added control methods"""
  39. def __init__(self, observable: Observable[T], uuid: UUID, replay: int, debug_id: str | None = None) -> None:
  40. self._observable = observable
  41. self._uuid = uuid
  42. self._replay = replay
  43. self._debug_id = debug_id or str(uuid)
  44. self._is_active = False
  45. def __aiter__(self) -> Observer[T]:
  46. return self
  47. async def __anext__(self) -> T:
  48. if not self._is_active:
  49. self._is_active = True
  50. await self._observable._add_observer(self._uuid, replay=self._replay)
  51. try:
  52. logger.trace(f"Observer ({self._debug_id}) waiting for next value") # type: ignore
  53. value = await self._observable._get_next(self._uuid)
  54. logger.trace(f"Observer ({self._debug_id}) received value: {value}") # type: ignore
  55. return value
  56. except Exception as e:
  57. logger.error(f"Error in observer {self._debug_id}: {repr(e)}")
  58. await self._cleanup()
  59. raise e
  60. async def first(self, predicate: SyncFilter) -> T:
  61. """Get the first value that matches the predicate
  62. Args:
  63. predicate (SyncFilter): Predicate to match
  64. Returns:
  65. T: First value that matches the predicate
  66. """
  67. return await anext(filter(predicate, self))
  68. async def _cleanup(self) -> None:
  69. """Clean up resources when generator is done"""
  70. if self._is_active:
  71. self._is_active = False
  72. await self._observable._remove_observer(self._uuid)
  73. async def aclose(self) -> None:
  74. """Properly close the generator and clean up resources"""
  75. await self._cleanup()
  76. async def athrow(self, typ: Any, val: Any = None, tb: Any = None) -> NoReturn:
  77. """Throw an exception into the generator"""
  78. if not self._is_active:
  79. raise StopAsyncIteration
  80. # Cleanup first
  81. await self._cleanup()
  82. # Then raise the exception
  83. if val is None:
  84. val = typ()
  85. if tb is not None:
  86. raise val.with_traceback(tb)
  87. raise val
  88. async def asend(self, value: None) -> T:
  89. """Send a value into the generator (required by protocol)"""
  90. if not self._is_active:
  91. raise StopAsyncIteration
  92. # We don't really use the sent value, so just advance to next item
  93. return await anext(self)
  94. class Observable(Generic[T]):
  95. """The source of asynchronous data
  96. Attributes:
  97. REPLAY_ALL (Final[int]): Special integer value to indicate all values should be replayed
  98. OBS_IDX (int): counter of observable instantiations used for debugging
  99. Args:
  100. capacity (int): Maximum values to store for replay. Defaults to 100.
  101. debug_id (str | None): Identifier to log for debugging. Defaults to None (will use generated UUID).
  102. """
  103. REPLAY_ALL: Final[int] = -1
  104. OBS_IDX: int = 0
  105. @dataclass
  106. class _SharedData(Generic[T_I]):
  107. """Common data used for internal management that should be accessed in critical sections"""
  108. current: T_I | None = None
  109. cache: list[T_I] = field(default_factory=list)
  110. q_dict: dict[UUID, asyncio.Queue[T_I]] = field(default_factory=dict)
  111. def __post_init__(self) -> None: # noqa
  112. self._condition = asyncio.Condition()
  113. async def __aenter__(self) -> Observable._SharedData[T_I]: # noqa
  114. await self._condition.acquire()
  115. return self
  116. async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa
  117. self._condition.release()
  118. def __init__(self, capacity: int = 100, debug_id: str | None = None) -> None:
  119. self._lock = asyncio.Condition()
  120. self._count = 0
  121. self._capacity = capacity
  122. self._debug_id = debug_id or str(Observable.OBS_IDX)
  123. Observable.OBS_IDX += 1
  124. self._on_start_actions: list[SyncAction[T] | AsyncAction[T]] = []
  125. self._on_subscribe_actions: list[Callable[[], None] | Callable[[], Coroutine[Any, Any, None]]] = []
  126. self._per_value_actions: list[SyncAction[T] | AsyncAction[T]] = []
  127. self._shared_data = Observable._SharedData[T]()
  128. # TODO handle cleanup
  129. async def _add_observer(self, uuid: UUID, replay: int) -> None:
  130. """Add an observer to receive collected values
  131. Args:
  132. uuid (UUID): observer identifier
  133. replay (int): how many values to replay from cache
  134. """
  135. async with self._shared_data:
  136. if uuid not in self._shared_data.q_dict:
  137. self._shared_data.q_dict[uuid] = asyncio.Queue()
  138. if replay == Observable.REPLAY_ALL:
  139. replay = len(self._shared_data.cache)
  140. head = max(len(self._shared_data.cache) - replay, 0)
  141. for value in self._shared_data.cache[head:]:
  142. self._shared_data.q_dict[uuid].put_nowait(value)
  143. async def _remove_observer(self, uuid: UUID) -> None:
  144. """Remove an observer from receiving collected values
  145. Args:
  146. uuid (UUID): observer identifier
  147. """
  148. async with self._shared_data:
  149. if uuid in self._shared_data.q_dict:
  150. del self._shared_data.q_dict[uuid]
  151. async def emit(self, value: T) -> None:
  152. """Receive a value and queue it for per-observer retrieval
  153. Not intended to be used by the observer.
  154. Args:
  155. value (T): Value to queue
  156. """
  157. async with self._shared_data:
  158. self._shared_data.current = value
  159. self._shared_data.cache.append(value)
  160. if len(self._shared_data.cache) > self._capacity:
  161. self._shared_data.cache.pop(0)
  162. for uuid, q in self._shared_data.q_dict.items():
  163. logger.trace(f"Observable {self._debug_id} emitting {value} to observer {uuid}") # type: ignore
  164. await q.put(value)
  165. def _mux_action(
  166. self,
  167. action: SyncAction[T] | AsyncAction[T],
  168. value: T,
  169. tg: asyncio.TaskGroup | None = None,
  170. ) -> None:
  171. """Execute an action that is either synchronous or asynchronous
  172. If tg is passed, the async action will be added to the task group. Otherwise an anonymous async task will be
  173. created. In both cases, this function will return without awaiting the created task.
  174. Note! If action is synchronous, this will block until the action returns.
  175. Args:
  176. action (SyncAction[T] | AsyncAction[T]): action to execute
  177. value (T): value to pass to action
  178. tg (asyncio.TaskGroup | None, optional): Task group to add async action. Defaults to None (don't add to any
  179. task group).
  180. """
  181. if iscoroutinefunction(action):
  182. if tg:
  183. tg.create_task(action(value))
  184. else:
  185. asyncio.create_task(action(value))
  186. else:
  187. action(value)
  188. async def _mux_filter_blocking(self, predicate: SyncFilter | AsyncFilter, value: T) -> bool:
  189. """Execute a filter that is either synchronous or asynchronous
  190. Note! This will await / block until the action completes.
  191. Args:
  192. predicate (SyncFilter | AsyncFilter): Filter to execute
  193. value (T): value to analyze with predicates
  194. Returns:
  195. bool: _description_
  196. """
  197. if iscoroutinefunction(predicate):
  198. return await predicate(value)
  199. return predicate(value) # type: ignore
  200. @property
  201. def current(self) -> T | None:
  202. """Get the most recently collected value of the observable.
  203. Note that this does not indicate the value in real-time. It is the most recent value that was collected
  204. from a receiver.
  205. Returns:
  206. T | None: Most recently collected value, or None if no values were collected yet
  207. """
  208. return self._shared_data.current
  209. # TODO what is the difference betwenn this and on_start?
  210. def on_subscribe(
  211. self,
  212. action: Callable[[], None] | Callable[[], Coroutine[Any, Any, None]],
  213. ) -> Self:
  214. """Register to receive a callback to be called when the observable starts emitting
  215. Args:
  216. action (Callable[[], None] | Callable[[], Coroutine[Any, Any, None]]): Callback
  217. Returns:
  218. Self: modified observable
  219. """
  220. self._on_subscribe_actions.append(action)
  221. return self
  222. def on_start(self, action: SyncAction[T] | AsyncAction[T]) -> Self:
  223. """Register a callback action to be called when the observable starts emitting
  224. Args:
  225. action (SyncAction[T] | AsyncAction[T]): Callback action
  226. Returns:
  227. Self: modified observable
  228. """
  229. self._on_start_actions.append(action)
  230. return self
  231. ####################################################################################################################
  232. ##### Terminal Operators
  233. ####################################################################################################################
  234. def observe(self, replay: int = 1, debug_id: str | None = None) -> Observer[T]:
  235. """Get an async generator to yield values from the observable
  236. Args:
  237. replay (int): how many values to replay from cache. Defaults to 1.
  238. debug_id (str | None): Identifier for debug logging. Defaults to None (will use generated UUID).
  239. Returns:
  240. Observer[T]: async generator to yield values from the observable
  241. """
  242. # Create the async generator with a unique ID
  243. return Observer(self, uuid1(), replay, debug_id=debug_id)
  244. async def _get_next(self, uuid: UUID) -> T:
  245. """Get the next per-observer value
  246. Args:
  247. uuid (UUID): observer identifier
  248. Raises:
  249. RuntimeError: Observer ended without receiving any values
  250. Returns:
  251. T: Latest per-observer value
  252. """
  253. while True:
  254. # If this is the first time entering, notify all on-subscribe listeners
  255. if self._count == 0:
  256. for action in self._on_subscribe_actions:
  257. if iscoroutinefunction(action):
  258. await action()
  259. action()
  260. # Acquire the condition and read the per-collector value
  261. async with self._shared_data:
  262. if uuid not in self._shared_data.q_dict:
  263. logger.error("Attempted to get value from a non-registered observer.")
  264. raise RuntimeError("Observer has not been added!")
  265. q = self._shared_data.q_dict[uuid]
  266. # Note! This can't be called inside shared data context as it will cause a deadlock. We've already retrieved
  267. # the q here which is itself coroutine-safe so just await it.
  268. value = await q.get()
  269. self._count += 1
  270. # If this is the first value, notify on start listeners
  271. if self._count == 1:
  272. for action in self._on_start_actions: # type: ignore
  273. self._mux_action(action, value) # type: ignore
  274. # Notify per-value actions
  275. for action in self._per_value_actions: # type: ignore
  276. action(value) # type: ignore
  277. # We've made it! Return the continuing value
  278. return value