test_observables.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # test_observables.py/Open GoPro, Version 2.0 (C) Copyright 2021 GoPro, Inc. (http://gopro.com/OpenGoPro).
  2. # This copyright was auto-generated on Mon May 12 23:03:50 UTC 2025
  3. import asyncio
  4. import re
  5. from calendar import c
  6. import pytest
  7. from asyncstdlib import anext, dropwhile, enumerate, filter, islice, map, takewhile
  8. from open_gopro.domain.gopro_observable import (
  9. GoProObservable,
  10. GoproObserverDistinctInitial,
  11. )
  12. from open_gopro.domain.observable import Observable
  13. from open_gopro.models.constants.statuses import StatusId
  14. from tests.mocks import MockWirelessGoPro
  15. async def test_base_observable():
  16. # GIVEN
  17. complete = asyncio.Event()
  18. started = asyncio.Event()
  19. observable = Observable[int]().on_subscribe(lambda: started.set())
  20. # WHEN
  21. async def emit_values():
  22. await started.wait()
  23. await observable.emit(0)
  24. async def single_get_values():
  25. observer = observable.observe()
  26. assert await anext(observer) == 0
  27. assert observable.current == 0
  28. complete.set()
  29. async with asyncio.TaskGroup() as tg:
  30. tg.create_task(emit_values())
  31. tg.create_task(single_get_values())
  32. await asyncio.wait_for(complete.wait(), 2)
  33. async def test_observable_with_default_replay():
  34. # GIVEN
  35. observable = Observable[int]()
  36. # WHEN
  37. await observable.emit(0)
  38. await observable.emit(1)
  39. result = await anext(observable.observe())
  40. # THEN
  41. assert result == 1
  42. async def test_observable_with_max_replay():
  43. # GIVEN
  44. observable = Observable[int]()
  45. results: list[int] = []
  46. observer = observable.observe(replay=Observable.REPLAY_ALL)
  47. # WHEN
  48. await observable.emit(0)
  49. await observable.emit(1)
  50. await observable.emit(2)
  51. results.append(await anext(observer))
  52. results.append(await anext(observer))
  53. results.append(await anext(observer))
  54. # THEN
  55. assert results == [0, 1, 2]
  56. async def test_observable_with_no_replay_times_out():
  57. # GIVEN
  58. observable = Observable[int]()
  59. results: list[int] = []
  60. observer = observable.observe(replay=0)
  61. # WHEN
  62. await observable.emit(0)
  63. await observable.emit(1)
  64. await observable.emit(2)
  65. # THEN
  66. with pytest.raises(asyncio.TimeoutError):
  67. await asyncio.wait_for(anext(observer), timeout=0.5)
  68. async def test_observable_on_start_sync_action():
  69. # GIVEN
  70. on_start = asyncio.Event()
  71. started = asyncio.Event()
  72. observable = Observable().on_subscribe(lambda: started.set())
  73. # WHEN
  74. async def emit_values():
  75. await started.wait()
  76. await observable.emit(0)
  77. observable = observable.on_start(lambda _: on_start.set())
  78. observer = observable.observe()
  79. async with asyncio.TaskGroup() as tg:
  80. tg.create_task(emit_values())
  81. single = tg.create_task(anext(observer))
  82. # THEN
  83. assert single.result() == 0
  84. assert await asyncio.wait_for(on_start.wait(), 1)
  85. assert observable.current == 0
  86. async def test_observable_on_start_async_action():
  87. # GIVEN
  88. on_start = asyncio.Event()
  89. started = asyncio.Event()
  90. observable = Observable().on_subscribe(lambda: started.set())
  91. # WHEN
  92. async def emit_values():
  93. await started.wait()
  94. await observable.emit(0)
  95. async def set_event_on_start(value: int) -> None:
  96. on_start.set()
  97. observable = observable.on_start(set_event_on_start)
  98. observer = observable.observe()
  99. async with asyncio.TaskGroup() as tg:
  100. tg.create_task(emit_values())
  101. single = tg.create_task(anext(observer))
  102. # THEN
  103. assert single.result() == 0
  104. assert await asyncio.wait_for(on_start.wait(), 1)
  105. assert observable.current == 0
  106. async def test_observable_take_while():
  107. # GIVEN
  108. started = asyncio.Event()
  109. observable = Observable[int]().on_subscribe(lambda: started.set())
  110. observer = observable.observe()
  111. received: list[int] = []
  112. # WHEN
  113. async def emit_values():
  114. await started.wait()
  115. await observable.emit(0)
  116. await observable.emit(1)
  117. await observable.emit(2)
  118. async def collector() -> None:
  119. async for value in takewhile(lambda x: x != 2, observer):
  120. received.append(value)
  121. async with asyncio.TaskGroup() as tg:
  122. tg.create_task(collector())
  123. tg.create_task(emit_values())
  124. # THEN
  125. assert len(received) == 2
  126. assert received[0] == 0
  127. assert received[1] == 1
  128. async def test_observable_drop_while():
  129. # GIVEN
  130. started = asyncio.Event()
  131. observable = Observable[int]().on_subscribe(lambda: started.set())
  132. observer = observable.observe()
  133. received: list[int] = []
  134. # WHEN
  135. async def emit_values():
  136. await started.wait()
  137. await observable.emit(0)
  138. await observable.emit(1)
  139. await observable.emit(2)
  140. await observable.emit(3)
  141. async def collector() -> None:
  142. async for value in dropwhile(lambda x: x < 2, observer):
  143. received.append(value)
  144. if value == 3:
  145. break
  146. async with asyncio.TaskGroup() as tg:
  147. tg.create_task(collector())
  148. tg.create_task(emit_values())
  149. # THEN
  150. assert len(received) == 2
  151. assert received[0] == 2
  152. assert received[1] == 3
  153. async def test_observable_first_matching():
  154. # GIVEN
  155. started = asyncio.Event()
  156. observable = Observable[int]().on_subscribe(lambda: started.set())
  157. observer = observable.observe()
  158. # WHEN
  159. async def emit_values():
  160. await started.wait()
  161. await observable.emit(0)
  162. await observable.emit(1)
  163. async with asyncio.TaskGroup() as tg:
  164. tg.create_task(emit_values())
  165. matched = tg.create_task(observer.first(lambda x: x == 1))
  166. # THEN
  167. assert matched.result() == 1
  168. async def test_observable_slice():
  169. # GIVEN
  170. started = asyncio.Event()
  171. observable = Observable[int]().on_subscribe(lambda: started.set())
  172. observer = observable.observe()
  173. collected: list[int] = []
  174. # WHEN
  175. async def emit_values():
  176. await started.wait()
  177. await observable.emit(0)
  178. await observable.emit(1)
  179. await observable.emit(2)
  180. await observable.emit(3)
  181. await observable.emit(4)
  182. async def collect():
  183. async for value in islice(observer, 1, 3):
  184. collected.append(value)
  185. async with asyncio.TaskGroup() as tg:
  186. tg.create_task(emit_values())
  187. tg.create_task(collect())
  188. # THEN
  189. assert collected == [1, 2]
  190. async def test_observable_map():
  191. # GIVEN
  192. started = asyncio.Event()
  193. observable = Observable[int]().on_subscribe(lambda: started.set())
  194. observer = observable.observe()
  195. collected: list[str] = []
  196. # WHEN
  197. async def emit_values():
  198. await started.wait()
  199. await observable.emit(0)
  200. await observable.emit(1)
  201. async def collect():
  202. async for idx, value in enumerate(map(lambda x: str(x), observer)):
  203. collected.append(value)
  204. if idx == 1:
  205. break
  206. async with asyncio.TaskGroup() as tg:
  207. tg.create_task(emit_values())
  208. tg.create_task(collect())
  209. # THEN
  210. assert collected == ["0", "1"]
  211. async def test_observable_filter():
  212. # GIVEN
  213. started = asyncio.Event()
  214. observable = Observable[int]().on_subscribe(lambda: started.set())
  215. observer = observable.observe()
  216. collected: list[int] = []
  217. # WHEN
  218. async def emit_values():
  219. await started.wait()
  220. await observable.emit(0)
  221. await observable.emit(1)
  222. await observable.emit(2)
  223. await observable.emit(3)
  224. await observable.emit(4)
  225. await observable.emit(5)
  226. async def collect():
  227. async for idx, value in enumerate(filter(lambda x: x % 2 == 0, observer)):
  228. collected.append(value)
  229. if idx == 2:
  230. break
  231. async with asyncio.TaskGroup() as tg:
  232. tg.create_task(emit_values())
  233. tg.create_task(collect())
  234. # THEN
  235. assert collected == [0, 2, 4]
  236. async def test_observable_filter_then_take_2():
  237. # GIVEN
  238. started = asyncio.Event()
  239. observable = Observable[int]().on_subscribe(lambda: started.set())
  240. observer = observable.observe()
  241. collected: list[int] = []
  242. # WHEN
  243. async def emit_values():
  244. await started.wait()
  245. await observable.emit(0)
  246. await observable.emit(1)
  247. await observable.emit(2)
  248. await observable.emit(3)
  249. await observable.emit(4)
  250. await observable.emit(5)
  251. async def collect():
  252. async for value in islice(filter(lambda x: x % 2 == 0, observer), 2):
  253. collected.append(value)
  254. async with asyncio.TaskGroup() as tg:
  255. tg.create_task(emit_values())
  256. tg.create_task(collect())
  257. # THEN
  258. assert collected == [0, 2]
  259. async def test_observable_map_then_filter():
  260. # GIVEN
  261. started = asyncio.Event()
  262. observable = Observable[int]().on_subscribe(lambda: started.set())
  263. observer = observable.observe()
  264. collected: list[int] = []
  265. # WHEN
  266. async def emit_values():
  267. await started.wait()
  268. await observable.emit(0)
  269. await observable.emit(1)
  270. await observable.emit(2)
  271. await observable.emit(3)
  272. async def collect():
  273. async for value in islice(filter(lambda x: x % 2 == 0, map(lambda x: x + 2, observer)), 2):
  274. collected.append(value)
  275. async with asyncio.TaskGroup() as tg:
  276. tg.create_task(emit_values())
  277. tg.create_task(collect())
  278. # THEN
  279. assert collected == [2, 4]
  280. async def test_observable_filter_then_map():
  281. # GIVEN
  282. started = asyncio.Event()
  283. observable = Observable[int]().on_subscribe(lambda: started.set())
  284. observer = observable.observe()
  285. collected: list[int] = []
  286. # WHEN
  287. async def emit_values():
  288. await started.wait()
  289. await observable.emit(0)
  290. await observable.emit(1)
  291. await observable.emit(2)
  292. await observable.emit(3)
  293. async def collect():
  294. async for value in islice(map(lambda x: x * 100, filter(lambda x: x >= 2, observer)), 2):
  295. collected.append(value)
  296. async with asyncio.TaskGroup() as tg:
  297. tg.create_task(emit_values())
  298. tg.create_task(collect())
  299. # THEN
  300. assert collected == [200, 300]
  301. async def test_take_then_take_only_takes_second():
  302. # GIVEN
  303. started = asyncio.Event()
  304. observable = Observable[int]().on_subscribe(lambda: started.set())
  305. observer = observable.observe()
  306. collected: list[int] = []
  307. # WHEN
  308. async def emit_values():
  309. await started.wait()
  310. await observable.emit(0)
  311. await observable.emit(1)
  312. await observable.emit(2)
  313. await observable.emit(3)
  314. async def collect():
  315. async for value in islice(islice(observer, 4), 2):
  316. collected.append(value)
  317. async with asyncio.TaskGroup() as tg:
  318. tg.create_task(emit_values())
  319. tg.create_task(collect())
  320. # THEN
  321. assert collected == [0, 1]
  322. async def test_simultaneous_collect():
  323. # GIVEN
  324. started = asyncio.Event()
  325. observable = Observable().on_subscribe(lambda: started.set())
  326. observer1 = observable.observe(replay=Observable.REPLAY_ALL)
  327. observer2 = observable.observe(replay=Observable.REPLAY_ALL)
  328. collected1: list[int] = []
  329. collected2: list[int] = []
  330. # WHEN
  331. async def emit_values():
  332. await started.wait()
  333. await observable.emit(0)
  334. await observable.emit(1)
  335. await observable.emit(2)
  336. await observable.emit(3)
  337. await observable.emit(4)
  338. async def collect_observer1():
  339. async for value in islice(observer1, 5):
  340. collected1.append(value)
  341. async def collect_observer2():
  342. async for value in islice(observer2, 5):
  343. collected2.append(value)
  344. async with asyncio.TaskGroup() as tg:
  345. tg.create_task(emit_values())
  346. tg.create_task(collect_observer1())
  347. tg.create_task(collect_observer2())
  348. # THEN
  349. assert collected1 == [0, 1, 2, 3, 4]
  350. assert collected2 == [0, 1, 2, 3, 4]
  351. async def test_status_observable_basic(mock_wireless_gopro_basic: MockWirelessGoPro):
  352. # GIVEN
  353. mock_wireless_gopro_basic._loop = asyncio.get_running_loop()
  354. started = asyncio.Event()
  355. observable = (
  356. await GoProObservable(
  357. gopro=mock_wireless_gopro_basic,
  358. update=StatusId.ENCODING,
  359. register_command=mock_wireless_gopro_basic.mock_gopro_resp(True),
  360. )
  361. .on_subscribe(lambda: started.set())
  362. .start()
  363. )
  364. observer = observable.observe()
  365. values: list[bool] = []
  366. def emit_status(encoding: bool):
  367. if encoding:
  368. payload = bytearray([0x05, 0x93, 0x00, StatusId.ENCODING.value, 0x01, 0x01])
  369. else:
  370. payload = bytearray([0x05, 0x93, 0x00, StatusId.ENCODING.value, 0x01, 0x00])
  371. mock_wireless_gopro_basic._notification_handler(0xFF, payload)
  372. # WHEN
  373. async def emit_statuses():
  374. await started.wait()
  375. emit_status(False)
  376. emit_status(True)
  377. emit_status(False)
  378. async def collect():
  379. async for value in islice(observer, 4):
  380. values.append(value)
  381. async with asyncio.TaskGroup() as tg:
  382. tg.create_task(emit_statuses())
  383. tg.create_task(collect())
  384. # THEN
  385. assert values == [True, False, True, False]
  386. async def test_status_observable_different_initial_response(mock_wireless_gopro_basic: MockWirelessGoPro):
  387. # GIVEN
  388. mock_wireless_gopro_basic._loop = asyncio.get_running_loop()
  389. started = asyncio.Event()
  390. observable = (
  391. await GoproObserverDistinctInitial(
  392. gopro=mock_wireless_gopro_basic,
  393. update=StatusId.ENCODING,
  394. register_command=mock_wireless_gopro_basic.ble_command.get_open_gopro_api_version(),
  395. )
  396. .on_subscribe(lambda: started.set())
  397. .start()
  398. )
  399. observer = observable.observe()
  400. def emit_status(encoding: bool):
  401. if encoding:
  402. payload = bytearray([0x05, 0x93, 0x00, StatusId.ENCODING.value, 0x01, 0x01])
  403. else:
  404. payload = bytearray([0x05, 0x93, 0x00, StatusId.ENCODING.value, 0x01, 0x00])
  405. mock_wireless_gopro_basic._notification_handler(0xFF, payload)
  406. # WHEN
  407. values: list[str | bool] = []
  408. async def emit_values():
  409. await started.wait()
  410. emit_status(True)
  411. emit_status(False)
  412. emit_status(True)
  413. emit_status(False)
  414. async def receive_values():
  415. async with observable:
  416. values.append(observable.initial_response)
  417. async for value in islice(observer, 4):
  418. values.append(value)
  419. async with asyncio.TaskGroup() as tg:
  420. tg.create_task(emit_values())
  421. tg.create_task(receive_values())
  422. # THEN
  423. assert values == ["2.0", True, False, True, False]