util.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # util.py/Open GoPro, Version 2.0 (C) Copyright 2021 GoPro, Inc. (http://gopro.com/OpenGoPro).
  2. # This copyright was auto-generated on Wed, Sep 1, 2021 5:05:50 PM
  3. """Miscellaneous utilities for the GoPro package."""
  4. from __future__ import annotations
  5. import argparse
  6. import asyncio
  7. import enum
  8. import logging
  9. import subprocess
  10. import sys
  11. from dataclasses import is_dataclass
  12. from datetime import datetime
  13. from pathlib import Path
  14. from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
  15. import pytz
  16. from construct import Container
  17. from pydantic import BaseModel
  18. from typing_extensions import TypeIs
  19. from tzlocal import get_localzone
  20. if TYPE_CHECKING:
  21. from _typeshed import DataclassInstance
  22. util_logger = logging.getLogger(__name__)
  23. class Singleton:
  24. """To be subclassed to create a singleton class."""
  25. _instances: dict[type[Singleton], Singleton] = {}
  26. def __new__(cls, *_: Any) -> Any: # noqa https://github.com/PyCQA/pydocstyle/issues/515
  27. if cls not in cls._instances:
  28. cls._instances[cls] = object.__new__(cls)
  29. return cls._instances[cls]
  30. def map_keys(obj: Any, key: str, func: Callable) -> None:
  31. """Map all matching keys (deeply searched) using the input function
  32. Args:
  33. obj (Any): object to modify in place
  34. key (str): key to search for to modify
  35. func (Callable): mapping function
  36. """
  37. if isinstance(obj, dict):
  38. for k in obj.keys():
  39. if k == key:
  40. obj[k] = func(obj[k])
  41. else:
  42. map_keys(obj[k], key, func)
  43. elif isinstance(obj, list):
  44. for i in obj:
  45. map_keys(i, key, func)
  46. else:
  47. # neither a dict nor a list, do nothing
  48. pass
  49. def scrub(obj: Any, bad_keys: list | None = None, bad_values: list | None = None) -> None:
  50. """Recursively scrub a collection (dict / list) of bad keys and / or bad values
  51. Args:
  52. obj (Any): collection to scrub
  53. bad_keys (list | None): Keys to remove. Defaults to None.
  54. bad_values (list | None): Values to remove. Defaults to None.
  55. Raises:
  56. ValueError: Missing bad keys / values
  57. """
  58. bad_keys = bad_keys or []
  59. bad_values = bad_values or []
  60. if not (bad_values or bad_keys):
  61. raise ValueError("Must pass either / or bad_keys or bad_values")
  62. def recurse(obj: Any) -> None:
  63. if isinstance(obj, dict):
  64. for key, value in {**obj}.items():
  65. if key in bad_keys or value in bad_values:
  66. del obj[key]
  67. else:
  68. recurse(obj[key])
  69. elif isinstance(obj, list):
  70. for i, value in enumerate(list(obj)):
  71. if value in bad_values:
  72. del obj[i]
  73. else:
  74. recurse(obj[i])
  75. else:
  76. # neither a dict nor a list, do nothing
  77. pass
  78. recurse(obj)
  79. def pretty_print(obj: Any, stringify_all: bool = True, should_quote: bool = True) -> str:
  80. """Recursively iterate through object and turn elements into strings
  81. Args:
  82. obj (Any): object to recurse through
  83. stringify_all (bool): At the end of each recursion, should the element be turned into a string?
  84. For example, should an int be turned into a str? Defaults to True.
  85. should_quote (bool): Should each element be surrounded in quotes?. Defaults to True.
  86. Returns:
  87. str: pretty-printed string
  88. """
  89. output = ""
  90. nest_level = 0
  91. def sanitize(e: Any) -> str:
  92. """Get the value part and replace any underscored with spaces
  93. Args:
  94. e (Any): argument to sanitize
  95. Returns:
  96. str: sanitized string
  97. """
  98. value_part = str(e).lower().split(".")[1]
  99. value_part = value_part.replace("_", " ").title()
  100. return value_part
  101. def stringify(elem: Any) -> Any:
  102. """Get the string value of an element if it is not a number (int, float, etc.)
  103. Args:
  104. elem (Any): element to potentially stringify
  105. Returns:
  106. Any: string representation or original object
  107. """
  108. def quote(elem: Any) -> Any:
  109. return f'"{elem}"' if should_quote else elem
  110. ret: str
  111. if isinstance(elem, (bytes, bytearray)):
  112. ret = quote(elem.hex(":"))
  113. if isinstance(elem, enum.Enum) and isinstance(elem, int):
  114. ret = quote(str(elem) if not stringify_all else sanitize(elem))
  115. if isinstance(elem, (bool, int, float)):
  116. ret = quote(elem) if stringify_all else elem # type: ignore
  117. ret = str(elem)
  118. return quote(ret)
  119. def recurse(elem: Any) -> None:
  120. """Recursion function
  121. Args:
  122. elem (Any): current element to work on
  123. """
  124. nonlocal output
  125. nonlocal nest_level
  126. indent_size = 4
  127. # Convert to dict if possible
  128. if isinstance(elem, BaseModel):
  129. elem = dict(elem)
  130. scrub(elem, bad_values=[None])
  131. if isinstance(elem, dict):
  132. # nested dictionary
  133. nest_level += 1
  134. output += "{"
  135. for k, v in elem.items():
  136. output += f"\n{' ' * (indent_size * nest_level)}"
  137. # Add key
  138. recurse(k)
  139. output += " : "
  140. # Add value
  141. if isinstance(v, (dict, list, BaseModel)):
  142. recurse(v)
  143. else:
  144. output += stringify(v)
  145. output += ","
  146. nest_level -= 1
  147. output += f"\n{' '* (indent_size * nest_level)}}}"
  148. elif isinstance(elem, list):
  149. # nested list
  150. nest_level += 1
  151. output += f"[\n{' '* (indent_size * nest_level)}"
  152. if len(elem):
  153. for item in elem[:-1]:
  154. recurse(item)
  155. output += ", "
  156. recurse(elem[-1])
  157. nest_level -= 1
  158. output += f"\n{' '* (indent_size * nest_level)}]"
  159. else:
  160. output += stringify(elem)
  161. recurse(obj)
  162. return output
  163. def cmd(command: str) -> str:
  164. """Send a command to the shell and return the result.
  165. Args:
  166. command (str): command to send
  167. Returns:
  168. str: response returned from shell
  169. """
  170. # We don't want password showing in the log
  171. if "sudo" in command:
  172. logged_command = command[: command.find('"') + 1] + "********" + command[command.find(" | sudo") - 1 :]
  173. else:
  174. logged_command = command
  175. util_logger.debug(f"Send cmd --> {logged_command}")
  176. response = (
  177. subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) # type: ignore
  178. .stdout.read()
  179. .decode(errors="ignore")
  180. )
  181. util_logger.debug(f"Receive response --> {response}")
  182. return response
  183. T = TypeVar("T")
  184. class SnapshotQueue(asyncio.Queue, Generic[T]):
  185. """A subclass of the default queue module to safely take a snapshot of the queue
  186. This is so we can access the elements (in a thread safe manner) without dequeuing them.
  187. """
  188. def __init__(self, maxsize: int = 0) -> None:
  189. self._lock = asyncio.Lock()
  190. super().__init__(maxsize)
  191. async def get(self) -> T:
  192. """Wrapper for passing generic type through to subclass
  193. Returns:
  194. T: type of this Snapshot queue
  195. """
  196. return await super().get()
  197. async def peek_front(self) -> T | None:
  198. """Get the first element without dequeueing it
  199. Returns:
  200. T | None: First element of None if the queue is empty
  201. """
  202. async with self._lock:
  203. return None if self.empty() else self._queue[0] # type: ignore
  204. def add_cli_args_and_parse(
  205. parser: argparse.ArgumentParser,
  206. bluetooth: bool = True,
  207. wifi: bool = True,
  208. ) -> argparse.Namespace:
  209. """Append common argparse arguments to an argument parser
  210. WARNING!! This will also parse the arguments (i.e. call parser.parse_args) so ensure to add any additional
  211. arguments to the parser before passing it to this function.
  212. Args:
  213. parser (argparse.ArgumentParser): input parser to modify
  214. bluetooth (bool): Add bluetooth args?. Defaults to True.
  215. wifi (bool): Add WiFi args?. Defaults to True.
  216. Returns:
  217. argparse.Namespace: modified argument parser
  218. """
  219. # Common args
  220. parser.add_argument(
  221. "--log",
  222. type=Path,
  223. help="Location to store detailed log. Defaults to gopro_demo.log",
  224. default="gopro_demo.log",
  225. )
  226. if bluetooth:
  227. parser.add_argument(
  228. "--identifier",
  229. type=str,
  230. help="Last 4 digits of GoPro serial number, which is the last 4 digits of the default camera SSID. \
  231. If not used, first discovered GoPro will be connected to",
  232. default=None,
  233. )
  234. if wifi:
  235. parser.add_argument(
  236. "--wifi_interface",
  237. type=str,
  238. help="System Wifi Interface. If not set, first discovered interface will be used.",
  239. default=None,
  240. )
  241. parser.add_argument(
  242. "--password",
  243. action="store_true",
  244. help="Set to read sudo password from stdin. If not set, you will be prompted for password if needed",
  245. )
  246. parser.epilog = "Note that a minimal log is written to stdout. An extremely detailed log is written to the path set by the --log argument."
  247. args = parser.parse_args()
  248. if wifi:
  249. args.password = sys.stdin.readline() if args.password else None
  250. return args
  251. async def ainput(string: str, printer: Callable | None = None) -> str:
  252. """Async version of input
  253. Raises:
  254. ValueError: Can not access default sys.stdout.write
  255. Args:
  256. string (str): prompt string
  257. printer (Callable | None): Printer used to display prompt. Defaults to None in which case sys.stdout.write
  258. will attempt to be used.
  259. Returns:
  260. str: Input read from console
  261. """
  262. if not printer:
  263. try:
  264. printer = sys.stdout.write
  265. except AttributeError as e:
  266. raise ValueError("No printer was passed and default standard out writer does not exist.") from e
  267. await asyncio.get_event_loop().run_in_executor(None, lambda s=string: printer(s + " ")) # type: ignore
  268. return await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
  269. def get_current_dst_aware_time() -> tuple[datetime, int, bool]:
  270. """Get the current time, utc offset in minutes, and daylight savings time
  271. Returns:
  272. tuple[datetime, int, bool]: [time, utc_offset in minutes, is_dst?]
  273. """
  274. tz = pytz.timezone(get_localzone().key) # type: ignore
  275. now = tz.localize(datetime.now(), is_dst=None)
  276. try:
  277. is_dst = now.tzinfo._dst.seconds != 0 # type: ignore
  278. offset = (now.utcoffset().total_seconds() - now.tzinfo._dst.seconds) / 60 # type: ignore
  279. except AttributeError:
  280. is_dst = False
  281. offset = now.utcoffset().total_seconds() / 60 # type: ignore
  282. if is_dst:
  283. offset += 60
  284. return (now, int(offset), is_dst)
  285. def deeply_update_dict(d: dict, u: dict) -> dict:
  286. """Recursively update a dict
  287. Args:
  288. d (dict): original dict
  289. u (dict): dict to apply updates from
  290. Returns:
  291. dict: updated original dict
  292. """
  293. for k, v in u.items():
  294. if isinstance(v, dict):
  295. d[k] = deeply_update_dict(d.get(k, {}), v)
  296. else:
  297. d[k] = v
  298. return d
  299. def to_dict(container: Container) -> dict:
  300. """Convert a parsed construct container to a dict, removing any internal Construct fields
  301. This is needed because annoyingly all construct's contain an "_io" field.
  302. See https://github.com/construct/construct/issues/1055
  303. Args:
  304. container (Container): container to convert
  305. Returns:
  306. dict: converted dict with any construct internal properties removed
  307. """
  308. d = dict(container)
  309. d.pop("_io", None)
  310. return d
  311. def is_dataclass_instance(obj: Any) -> TypeIs[DataclassInstance | type[DataclassInstance]] | bool:
  312. """Check if a given object is a dataclass instance
  313. Args:
  314. obj (Any): object to analyze
  315. Returns:
  316. TypeIs[DataclassInstance | type[DataclassInstance]] | bool: TypeIs from analysis
  317. """
  318. return is_dataclass(obj) and not isinstance(obj, type)