smartknob_io.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. if __name__ == '__main__':
  2. import sys
  3. sys.exit('This is a library file to be imported into your own python scripts. It doesn\'t do anything if run directly')
  4. from cobs import cobs
  5. from collections import (
  6. defaultdict,
  7. )
  8. from contextlib import contextmanager
  9. from enum import Enum
  10. import logging
  11. import os
  12. from queue import (
  13. Empty,
  14. Full,
  15. Queue,
  16. )
  17. from random import randint
  18. import serial
  19. import serial.tools.list_ports
  20. import sys
  21. from threading import (
  22. Thread,
  23. Lock,
  24. )
  25. import time
  26. import zlib
  27. software_root = os.path.dirname(os.path.abspath(__file__))
  28. sys.path.append(os.path.join(software_root, 'proto_gen'))
  29. from proto_gen import smartknob_pb2
  30. SMARTKNOB_BAUD = 921600
  31. PROTOBUF_PROTOCOL_VERSION = 1
  32. class Smartknob(object):
  33. RETRY_TIMEOUT = 0.25
  34. def __init__(self, serial_instance):
  35. self._serial = serial_instance
  36. self._logger = logging.getLogger('smartknob')
  37. self._out_q = Queue()
  38. self._ack_q = Queue()
  39. self._next_nonce = randint(0, 255)
  40. self._run = True
  41. self._lock = Lock()
  42. self._message_handlers = defaultdict(list)
  43. def _read_loop(self):
  44. self._logger.debug('Read loop started')
  45. buffer = b''
  46. while True:
  47. buffer += self._serial.read_until(b'\0')
  48. if not self._run:
  49. return
  50. if not len(buffer):
  51. continue
  52. if not buffer.endswith(b'\0'):
  53. continue
  54. self._process_frame(buffer[:-1])
  55. buffer = b''
  56. def _process_frame(self, frame):
  57. try:
  58. decoded = cobs.decode(frame)
  59. except cobs.DecodeError:
  60. self._logger.debug(f'Failed decode ({len(frame)} bytes)')
  61. self._logger.debug(frame)
  62. return
  63. if len(decoded) < 4:
  64. return
  65. payload = decoded[:-4]
  66. expected_crc = zlib.crc32(payload) & 0xffffffff
  67. provided_crc = (decoded[-1] << 24) \
  68. | (decoded[-2] << 16) \
  69. | (decoded[-3] << 8) \
  70. | decoded[-4]
  71. if expected_crc != provided_crc:
  72. self._logger.debug(f'Bad CRC. expected={hex(expected_crc)}, actual={hex(provided_crc)}')
  73. return
  74. message = smartknob_pb2.FromSmartKnob()
  75. message.ParseFromString(payload)
  76. self._logger.debug(message)
  77. if message.protocol_version != PROTOBUF_PROTOCOL_VERSION:
  78. self._logger.warn(f'Invalid protocol version. Expected {PROTOBUF_PROTOCOL_VERSION}, received {message.protocol_version}')
  79. payload_type = message.WhichOneof('payload')
  80. # If this is an ack, notify the write thread
  81. if payload_type == 'ack':
  82. nonce = message.ack.nonce
  83. self._ack_q.put(nonce)
  84. with self._lock:
  85. for handler in self._message_handlers[payload_type] + self._message_handlers[None]:
  86. try:
  87. handler(getattr(message, payload_type))
  88. except:
  89. self._logger.warning(f'Unhandled exception in message handler ({payload_type})', exc_info=True)
  90. def _write_loop(self):
  91. self._logger.debug('Write loop started')
  92. while True:
  93. data = self._out_q.get()
  94. # Check for shutdown
  95. if not self._run:
  96. self._logger.debug('Write loop exiting @ _out_q')
  97. return
  98. (nonce, encoded_message) = data
  99. next_retry = 0
  100. while True:
  101. if time.time() >= next_retry:
  102. if next_retry > 0:
  103. self._logger.debug('Retry write...')
  104. self._serial.write(encoded_message)
  105. self._serial.write(b'\0')
  106. next_retry = time.time() + Smartknob.RETRY_TIMEOUT
  107. try:
  108. latest_ack_nonce = self._ack_q.get(timeout=next_retry - time.time())
  109. except Empty:
  110. latest_ack_nonce = None
  111. # Check for shutdown
  112. if not self._run:
  113. self._logger.debug('Write loop exiting @ _ack_q')
  114. return
  115. if latest_ack_nonce == nonce:
  116. break
  117. else:
  118. self._logger.debug(f'Got unexpected nonce: {latest_ack_nonce}')
  119. def _enqueue_message(self, message):
  120. nonce = self._next_nonce
  121. self._next_nonce += 1
  122. message.protocol_version = PROTOBUF_PROTOCOL_VERSION
  123. message.nonce = nonce
  124. payload = bytearray(message.SerializeToString())
  125. crc = zlib.crc32(payload) & 0xffffffff
  126. payload.append(crc & 0xff)
  127. payload.append((crc >> 8) & 0xff)
  128. payload.append((crc >> 16) & 0xff)
  129. payload.append((crc >> 24) & 0xff)
  130. encoded_message = cobs.encode(payload)
  131. self._out_q.put((nonce, encoded_message))
  132. approx_q_length = self._out_q.qsize()
  133. self._logger.debug(f'Out q length: {approx_q_length}')
  134. if approx_q_length > 10:
  135. self._logger.warning(f'Output queue length is high! ({approx_q_length}) Is the smartknob still connected and functional?')
  136. def set_config(self, config):
  137. message = smartknob_pb2.ToSmartknob()
  138. message.smartknob_config.CopyFrom(config)
  139. self._enqueue_message(message)
  140. def start(self):
  141. self.read_thread = Thread(target=self._read_loop)
  142. self.write_thread = Thread(target=self._write_loop)
  143. self.read_thread.start()
  144. self.write_thread.start()
  145. def shutdown(self):
  146. self._logger.info('Shutting down...')
  147. self._run = False
  148. self.read_thread.join()
  149. self._logger.debug('Read thread terminated')
  150. self._out_q.put(None)
  151. self._ack_q.put(None)
  152. self.write_thread.join()
  153. self._logger.debug('Write thread terminated')
  154. def add_handler(self, message_type, handler):
  155. with self._lock:
  156. self._message_handlers[message_type].append(handler)
  157. return lambda: self._remove_handler(message_type, handler)
  158. def _remove_handler(self, message_type, handler):
  159. with self._lock:
  160. self._message_handlers[message_type].remove(handler)
  161. def request_state(self):
  162. message = smartknob_pb2.ToSmartknob()
  163. message.request_state.SetInParent()
  164. self._enqueue_message(message)
  165. def hard_reset(self):
  166. self._serial.setRTS(True)
  167. self._serial.setDTR(False)
  168. time.sleep(0.2)
  169. self._serial.setDTR(True)
  170. time.sleep(0.2)
  171. @contextmanager
  172. def smartknob_context(serial_port, default_logging=True, wait_for_comms=True):
  173. with serial.Serial(serial_port, SMARTKNOB_BAUD, timeout=1.0) as ser:
  174. s = Smartknob(ser)
  175. s.start()
  176. if default_logging:
  177. s.add_handler('log', lambda msg: s._logger.info(f'From smartknob: {msg.msg}'))
  178. if wait_for_comms:
  179. s._logger.info('Connecting to smartknob...')
  180. q = Queue(1)
  181. def startup_handler(message):
  182. try:
  183. q.put_nowait(None)
  184. except Full:
  185. pass
  186. unregister = s.add_handler('smartknob_state', startup_handler)
  187. s.request_state()
  188. q.get()
  189. unregister()
  190. s._logger.info('Connected!')
  191. try:
  192. yield s
  193. finally:
  194. s.shutdown()
  195. def ask_for_serial_port():
  196. """
  197. Helper function to ask which port to use via stdin
  198. """
  199. print('Available ports:')
  200. ports = sorted(
  201. filter(
  202. lambda p: p.description != 'n/a',
  203. serial.tools.list_ports.comports(),
  204. ),
  205. key=lambda p: p.device,
  206. )
  207. for i, port in enumerate(ports):
  208. print('[{: 2}] {} - {}'.format(i, port.device, port.description))
  209. print()
  210. value = input('Use which port? ')
  211. port_index = int(value)
  212. assert 0 <= port_index < len(ports)
  213. return ports[port_index].device