smartknob_io.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. if __name__ == '__main__':
  2. import sys
  3. sys.exit('This is a library file to be imported into your own python scripts. It doesn\'t do anything if run directly')
  4. from cobs import cobs
  5. from collections import (
  6. defaultdict,
  7. )
  8. from contextlib import contextmanager
  9. from enum import Enum
  10. import logging
  11. import os
  12. from queue import (
  13. Empty,
  14. Full,
  15. Queue,
  16. )
  17. from random import randint
  18. import serial
  19. import serial.tools.list_ports
  20. import sys
  21. from threading import (
  22. Thread,
  23. Lock,
  24. )
  25. import time
  26. import zlib
  27. software_root = os.path.dirname(os.path.abspath(__file__))
  28. sys.path.append(os.path.join(software_root, 'proto_gen'))
  29. from proto_gen import smartknob_pb2
  30. SMARTKNOB_BAUD = 921600
  31. class Smartknob(object):
  32. RETRY_TIMEOUT = 0.25
  33. def __init__(self, serial_instance):
  34. self._serial = serial_instance
  35. self._logger = logging.getLogger('smartknob')
  36. self._out_q = Queue()
  37. self._ack_q = Queue()
  38. self._next_nonce = randint(0, 255)
  39. self._run = True
  40. self._lock = Lock()
  41. self._message_handlers = defaultdict(list)
  42. def _read_loop(self):
  43. self._logger.debug('Read loop started')
  44. buffer = b''
  45. while True:
  46. buffer += self._serial.read_until(b'\0')
  47. if not self._run:
  48. return
  49. if not len(buffer):
  50. continue
  51. if not buffer.endswith(b'\0'):
  52. continue
  53. self._process_frame(buffer[:-1])
  54. buffer = b''
  55. def _process_frame(self, frame):
  56. try:
  57. decoded = cobs.decode(frame)
  58. except cobs.DecodeError:
  59. self._logger.debug(f'Failed decode ({len(frame)} bytes)')
  60. self._logger.debug(frame)
  61. return
  62. if len(decoded) < 4:
  63. return
  64. payload = decoded[:-4]
  65. expected_crc = zlib.crc32(payload) & 0xffffffff
  66. provided_crc = (decoded[-1] << 24) \
  67. | (decoded[-2] << 16) \
  68. | (decoded[-3] << 8) \
  69. | decoded[-4]
  70. if expected_crc != provided_crc:
  71. self._logger.debug(f'Bad CRC. expected={hex(expected_crc)}, actual={hex(provided_crc)}')
  72. return
  73. message = smartknob_pb2.FromSmartKnob()
  74. message.ParseFromString(payload)
  75. self._logger.debug(message)
  76. payload_type = message.WhichOneof('payload')
  77. # If this is an ack, notify the write thread
  78. if payload_type == 'ack':
  79. nonce = message.ack.nonce
  80. self._ack_q.put(nonce)
  81. with self._lock:
  82. for handler in self._message_handlers[payload_type] + self._message_handlers[None]:
  83. try:
  84. handler(getattr(message, payload_type))
  85. except:
  86. self._logger.warning(f'Unhandled exception in message handler ({payload_type})', exc_info=True)
  87. def _write_loop(self):
  88. self._logger.debug('Write loop started')
  89. while True:
  90. data = self._out_q.get()
  91. # Check for shutdown
  92. if not self._run:
  93. self._logger.debug('Write loop exiting @ _out_q')
  94. return
  95. (nonce, encoded_message) = data
  96. next_retry = 0
  97. while True:
  98. if time.time() >= next_retry:
  99. if next_retry > 0:
  100. self._logger.debug('Retry write...')
  101. self._serial.write(encoded_message)
  102. self._serial.write(b'\0')
  103. next_retry = time.time() + Smartknob.RETRY_TIMEOUT
  104. try:
  105. latest_ack_nonce = self._ack_q.get(timeout=next_retry - time.time())
  106. except Empty:
  107. latest_ack_nonce = None
  108. # Check for shutdown
  109. if not self._run:
  110. self._logger.debug('Write loop exiting @ _ack_q')
  111. return
  112. if latest_ack_nonce == nonce:
  113. break
  114. else:
  115. self._logger.debug(f'Got unexpected nonce: {latest_ack_nonce}')
  116. def _enqueue_message(self, message):
  117. nonce = self._next_nonce
  118. self._next_nonce += 1
  119. message.nonce = nonce
  120. payload = bytearray(message.SerializeToString())
  121. crc = zlib.crc32(payload) & 0xffffffff
  122. payload.append(crc & 0xff)
  123. payload.append((crc >> 8) & 0xff)
  124. payload.append((crc >> 16) & 0xff)
  125. payload.append((crc >> 24) & 0xff)
  126. encoded_message = cobs.encode(payload)
  127. self._out_q.put((nonce, encoded_message))
  128. approx_q_length = self._out_q.qsize()
  129. self._logger.debug(f'Out q length: {approx_q_length}')
  130. if approx_q_length > 10:
  131. self._logger.warning(f'Output queue length is high! ({approx_q_length}) Is the smartknob still connected and functional?')
  132. def set_config(self, config):
  133. message = smartknob_pb2.ToSmartknob()
  134. message.smartknob_config.CopyFrom(config)
  135. self._enqueue_message(message)
  136. def start(self):
  137. self.read_thread = Thread(target=self._read_loop)
  138. self.write_thread = Thread(target=self._write_loop)
  139. self.read_thread.start()
  140. self.write_thread.start()
  141. def shutdown(self):
  142. self._logger.info('Shutting down...')
  143. self._run = False
  144. self.read_thread.join()
  145. self._logger.debug('Read thread terminated')
  146. self._out_q.put(None)
  147. self._ack_q.put(None)
  148. self.write_thread.join()
  149. self._logger.debug('Write thread terminated')
  150. def add_handler(self, message_type, handler):
  151. with self._lock:
  152. self._message_handlers[message_type].append(handler)
  153. return lambda: self._remove_handler(message_type, handler)
  154. def _remove_handler(self, message_type, handler):
  155. with self._lock:
  156. self._message_handlers[message_type].remove(handler)
  157. def request_state(self):
  158. message = smartknob_pb2.ToSmartknob()
  159. message.request_state.SetInParent()
  160. self._enqueue_message(message)
  161. def hard_reset(self):
  162. self._serial.setRTS(True)
  163. self._serial.setDTR(False)
  164. time.sleep(0.2)
  165. self._serial.setDTR(True)
  166. time.sleep(0.2)
  167. @contextmanager
  168. def smartknob_context(serial_port, default_logging=True, wait_for_comms=True):
  169. with serial.Serial(serial_port, SMARTKNOB_BAUD, timeout=1.0) as ser:
  170. s = Smartknob(ser)
  171. s.start()
  172. if default_logging:
  173. s.add_handler('log', lambda msg: s._logger.info(f'From smartknob: {msg.msg}'))
  174. if wait_for_comms:
  175. s._logger.info('Connecting to smartknob...')
  176. q = Queue(1)
  177. def startup_handler(message):
  178. try:
  179. q.put_nowait(None)
  180. except Full:
  181. pass
  182. unregister = s.add_handler('smartknob_state', startup_handler)
  183. s.request_state()
  184. q.get()
  185. unregister()
  186. s._logger.info('Connected!')
  187. try:
  188. yield s
  189. finally:
  190. s.shutdown()
  191. def ask_for_serial_port():
  192. """
  193. Helper function to ask which port to use via stdin
  194. """
  195. print('Available ports:')
  196. ports = sorted(
  197. filter(
  198. lambda p: p.description != 'n/a',
  199. serial.tools.list_ports.comports(),
  200. ),
  201. key=lambda p: p.device,
  202. )
  203. for i, port in enumerate(ports):
  204. print('[{: 2}] {} - {}'.format(i, port.device, port.description))
  205. print()
  206. value = input('Use which port? ')
  207. port_index = int(value)
  208. assert 0 <= port_index < len(ports)
  209. return ports[port_index].device