diff --git a/Controllers/__init__.py b/Controllers/__init__.py index 3a8e427..bb968f9 100644 --- a/Controllers/__init__.py +++ b/Controllers/__init__.py @@ -175,4 +175,4 @@ def update_logs(self, logging_values: "dict[str, TensorType]") -> None: if var is not None: self.logs[name].append( var.numpy().copy() if hasattr(var, "numpy") else var.copy() - ) \ No newline at end of file + ) diff --git a/Controllers/controller_C.py b/Controllers/controller_C.py new file mode 100644 index 0000000..2c4700c --- /dev/null +++ b/Controllers/controller_C.py @@ -0,0 +1,357 @@ +import os +import subprocess +import tempfile +import ctypes +import numpy as np + +from SI_Toolkit.computation_library import NumpyLibrary +from Control_Toolkit.Controllers import template_controller + +try: + from SI_Toolkit_ASF.ToolkitCustomization.predictors_customization import STATE_INDICES +except (ModuleNotFoundError, ImportError): + # Fallback STATE_INDICES if the module is not available + STATE_INDICES = { + 'position': 0, + 'positionD': 1, + 'angle': 2, + 'angleD': 3 + } + + +class controller_C(template_controller): + _computation_library = NumpyLibrary() + + @property + def controller_name(self): + return "c" # Single C controller type + + def configure(self): + """ + Configure the C controller by compiling the specified C controller code. + """ + # Get controller configuration + controller_file = self.config_controller.get("controller_file", "lqr.c") + firmware_path = self.config_controller.get("firmware_path", "Firmware/Src/General") + + # Create temporary directory for compilation + self.temp_dir = tempfile.mkdtemp(prefix="c_controller_") + + # Compile the C controller + self._compile_c_controller(controller_file, firmware_path) + + # Load the compiled library + self._load_compiled_library() + + # Get controller specification + self._get_controller_spec() + + # Initialize controller + if hasattr(self, 'init_func') and self.init_func: + self.init_func() + + print(f'Configured C controller: {controller_file}') + + def _compile_c_controller(self, controller_file, firmware_path): + """ + Compile the C controller code into a shared library. + """ + # Copy controller_api.h + api_src = os.path.join(firmware_path, "controller_api.h") + api_dst = os.path.join(self.temp_dir, "controller_api.h") + if os.path.exists(api_src): + with open(api_src, 'r') as f: + content = f.read() + with open(api_dst, 'w') as f: + f.write(content) + + # Copy the controller files + controller_src = os.path.join(firmware_path, controller_file) + controller_dst = os.path.join(self.temp_dir, controller_file) + + if not os.path.exists(controller_src): + raise FileNotFoundError(f"Controller file not found: {controller_src}") + + with open(controller_src, 'r') as f: + content = f.read() + with open(controller_dst, 'w') as f: + f.write(content) + + # Copy header file if it exists + header_file = controller_file.replace('.c', '.h') + header_src = os.path.join(firmware_path, header_file) + header_dst = os.path.join(self.temp_dir, header_file) + + if os.path.exists(header_src): + with open(header_src, 'r') as f: + content = f.read() + with open(header_dst, 'w') as f: + f.write(content) + + # Create hardware bridge stub for PID controller + if "pid" in controller_file.lower(): + hw_bridge_dst = os.path.join(self.temp_dir, "hardware_bridge.h") + minimal_hw_bridge = ''' +#ifndef HARDWARE_BRIDGE_H +#define HARDWARE_BRIDGE_H + +#include + +// Minimal hardware bridge for PC compilation +static inline void enable_irq(void) { /* No-op for PC */ } +static inline void disable_irq(void) { /* No-op for PC */ } +static inline void Message_SendToPC(const unsigned char* data, unsigned int length) { /* No-op for PC */ } +static inline void Message_SendToPC_blocking(const unsigned char* data, unsigned int length) { /* No-op for PC */ } +static inline int Message_GetFromPC(unsigned char* data) { return 0; } + +#endif /* HARDWARE_BRIDGE_H */ +''' + with open(hw_bridge_dst, 'w') as f: + f.write(minimal_hw_bridge) + + # Create communication header stub + comm_dst = os.path.join(self.temp_dir, "communication_with_PC_general.h") + minimal_comm = ''' +#ifndef COMMUNICATION_WITH_PC_GENERAL_H +#define COMMUNICATION_WITH_PC_GENERAL_H + +#include +#include + +// Minimal communication header for PC compilation +unsigned char crc(const unsigned char * message, unsigned int len); +bool crcIsValid(const unsigned char * buff, unsigned int len, unsigned char crcVal); +void prepare_message_to_PC_config_PID(unsigned char * txBuffer, float position_KP, float position_KI, float position_KD, float angle_KP, float angle_KI, float angle_KD); + +#endif /* COMMUNICATION_WITH_PC_GENERAL_H */ +''' + with open(comm_dst, 'w') as f: + f.write(minimal_comm) + + # Create simple wrapper + wrapper_c = self._create_simple_wrapper(controller_file) + wrapper_path = os.path.join(self.temp_dir, "wrapper.c") + with open(wrapper_path, 'w') as f: + f.write(wrapper_c) + + # Compile the shared library + self._compile_shared_library(wrapper_path, controller_file) + + def _create_simple_wrapper(self, controller_file): + """ + Create a simple C wrapper that exposes the controller functions. + """ + # Get the ops name from config, with fallback to auto-generated name + ops_name = self.config_controller.get("ops_name") + if not ops_name: + # Auto-generate from filename: "lqr.c" -> "LQR_Ops" + controller_name = controller_file.replace('.c', '').upper() + ops_name = f"{controller_name}_Ops" + + wrapper_c = f''' +#include +#include +#include +#include + +// Include controller API +#include "controller_api.h" + +// Include the controller +#include "{controller_file.replace('.c', '.h')}" + +// Wrapper functions for Python ctypes +#ifdef __cplusplus +extern "C" {{ +#endif + // Initialize the controller + void controller_init() {{ + if ({ops_name}.init) {{ + {ops_name}.init(); + }} + }} + + // Evaluate the controller + void controller_evaluate(const float* inputs, float* outputs) {{ + if ({ops_name}.evaluate) {{ + {ops_name}.evaluate(inputs, outputs); + }} + }} + + // Get controller specification + void controller_get_spec(int* version, int* n_inputs, int* n_outputs) {{ + if ({ops_name}.spec) {{ + const ControllerSpec* spec = {ops_name}.spec(); + *version = spec->version; + *n_inputs = spec->n_inputs; + *n_outputs = spec->n_outputs; + }} + }} + + // Get input names (returns concatenated string) + void controller_get_input_names(char* buffer, int buffer_size) {{ + if ({ops_name}.spec) {{ + const ControllerSpec* spec = {ops_name}.spec(); + int pos = 0; + for (int i = 0; i < spec->n_inputs && pos < buffer_size - 1; i++) {{ + int len = strlen(spec->names[i]); + if (pos + len < buffer_size - 1) {{ + strcpy(buffer + pos, spec->names[i]); + pos += len; + if (i < spec->n_inputs - 1) {{ + buffer[pos++] = ','; + }} + }} + }} + buffer[pos] = '\\0'; + }} + }} + + // Release the controller + void controller_release() {{ + if ({ops_name}.release) {{ + {ops_name}.release(); + }} + }} +#ifdef __cplusplus +}} +#endif +''' + return wrapper_c + + def _compile_shared_library(self, wrapper_path, controller_file): + """ + Compile the C code into a shared library. + """ + # Build the compilation command + cmd = ["gcc", "-shared", "-fPIC", "-o", os.path.join(self.temp_dir, "controller.so")] + + # Add wrapper file + cmd.append(wrapper_path) + + # Add controller file + controller_path = os.path.join(self.temp_dir, controller_file) + cmd.append(controller_path) + + # Add include directories + cmd.extend(["-I", self.temp_dir]) + + # Add math library + cmd.append("-lm") + + # Compile + try: + result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.temp_dir) + if result.returncode != 0: + raise RuntimeError(f"Compilation failed: {result.stderr}") + except FileNotFoundError: + raise RuntimeError("gcc compiler not found. Please install gcc.") + + def _load_compiled_library(self): + """ + Load the compiled shared library using ctypes. + """ + lib_path = os.path.join(self.temp_dir, "controller.so") + if not os.path.exists(lib_path): + raise RuntimeError(f"Compiled library not found: {lib_path}") + + self.lib_ctypes = ctypes.CDLL(lib_path) + + # Define function signatures + self.lib_ctypes.controller_init.argtypes = [] + self.lib_ctypes.controller_init.restype = None + + self.lib_ctypes.controller_evaluate.argtypes = [ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float)] + self.lib_ctypes.controller_evaluate.restype = None + + self.lib_ctypes.controller_get_spec.argtypes = [ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)] + self.lib_ctypes.controller_get_spec.restype = None + + self.lib_ctypes.controller_get_input_names.argtypes = [ctypes.c_char_p, ctypes.c_int] + self.lib_ctypes.controller_get_input_names.restype = None + + self.lib_ctypes.controller_release.argtypes = [] + self.lib_ctypes.controller_release.restype = None + + def _get_controller_spec(self): + """ + Get the controller specification from the compiled library. + """ + # Get spec + version = ctypes.c_int() + n_inputs = ctypes.c_int() + n_outputs = ctypes.c_int() + + self.lib_ctypes.controller_get_spec(ctypes.byref(version), ctypes.byref(n_inputs), ctypes.byref(n_outputs)) + + self.spec_version = version.value + self.n_inputs = n_inputs.value + self.n_outputs = n_outputs.value + + # Get input names + buffer_size = 1024 + buffer = ctypes.create_string_buffer(buffer_size) + self.lib_ctypes.controller_get_input_names(buffer, buffer_size) + + input_names_str = buffer.value.decode('utf-8') + self.input_names = input_names_str.split(',') if input_names_str else [] + + # Create state index mapping + self._state_idx = dict(STATE_INDICES) + + def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorType]" = {}): + """ + Execute one step of the C controller. + """ + if updated_attributes is None: + updated_attributes = {} + + # Build inputs in the order expected by the C controller + arr = np.empty(self.n_inputs, dtype=np.float32) + for i, name in enumerate(self.input_names): + if name == "time": + if time is None: + raise Exception("Controller input 'time' is required but not provided.") + else: + val = float(time) + arr[i] = val + continue + + if name in updated_attributes: + val = float(updated_attributes[name]) + elif name in self._state_idx: + val = float(s[..., self._state_idx[name]]) + elif hasattr(self, 'variable_parameters') and hasattr(self.variable_parameters, name): + val = float(getattr(self.variable_parameters, name)) + else: + val = 0.0 + arr[i] = val + + # Call the C controller + inputs_array = (ctypes.c_float * self.n_inputs)(*arr) + outputs_array = (ctypes.c_float * self.n_outputs)() + + self.lib_ctypes.controller_evaluate(inputs_array, outputs_array) + + # Convert output to numpy array + controller_output = np.array([outputs_array[i] for i in range(self.n_outputs)], dtype=np.float32) + controller_output = controller_output[np.newaxis, np.newaxis, :] + + return controller_output + + def controller_reset(self): + """ + Reset the controller by reinitializing it. + """ + if hasattr(self, 'lib_ctypes') and self.lib_ctypes: + self.lib_ctypes.controller_init() + + def __del__(self): + """ + Cleanup when the controller is destroyed. + """ + if hasattr(self, 'lib_ctypes') and self.lib_ctypes: + try: + self.lib_ctypes.controller_release() + except: + pass # Ignore errors during cleanup diff --git a/Controllers/controller_embedded.py b/Controllers/controller_embedded.py new file mode 100644 index 0000000..8a0523b --- /dev/null +++ b/Controllers/controller_embedded.py @@ -0,0 +1,347 @@ +import os +import serial +import struct +import time + +from SI_Toolkit.computation_library import TensorType, NumpyLibrary + +import numpy as np + +from Control_Toolkit.Controllers import template_controller + +from Control_Toolkit.serial_interface_helper import get_serial_port, set_ftdi_latency_timer + +try: + from SI_Toolkit_ASF.ToolkitCustomization.predictors_customization import STATE_INDICES +except ModuleNotFoundError: + raise Exception("SI_Toolkit_ASF not yet created") + + +class controller_embedded(template_controller): + _computation_library = NumpyLibrary() + + def configure(self): + + SERIAL_PORT_NAME = get_serial_port(serial_port_number=self.config_controller["SERIAL_PORT"]) + SERIAL_BAUD = self.config_controller["SERIAL_BAUD"] + set_ftdi_latency_timer(SERIAL_PORT_NAME) + self.InterfaceInstance = Interface() + self.InterfaceInstance.open(SERIAL_PORT_NAME, SERIAL_BAUD) + + # --- PC↔SoC handshake: SoC declares input names and output count --- + self.spec_version, self.input_names, self.n_outputs = self.InterfaceInstance.get_spec() + + self._state_idx = dict(STATE_INDICES) + + self.just_restarted = True + + print('Configured SoC controller (spec v{}) with {} library\n'.format(self.spec_version, self.lib.lib)) + + def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorType]" = None): + self.just_restarted = False + if updated_attributes is None: + updated_attributes = {} + self.update_attributes(updated_attributes) + + # Build inputs *exactly* in the wire order requested by the SoC. + # Precedence: updated_attributes > state vector > variable_parameters > 0.0 + arr = np.empty(len(self.input_names), dtype=np.float32) + for i, name in enumerate(self.input_names): + if name == "time": + if time is None: + raise Exception("Controller input 'time' is required but not provided.") + else: + val = float(time) # use simulator's timestamp (seconds, monotonic in sim time) + arr[i] = val + continue + + if name in updated_attributes: # external override wins + val = float(updated_attributes[name]) + elif name in self._state_idx: # pick from s by name→index map + val = float(s[..., self._state_idx[name]]) + elif hasattr(self, 'variable_parameters') and hasattr(self.variable_parameters, name): + val = float(getattr(self.variable_parameters, name)) + else: + val = 0.0 # explicit default to prevent UB + arr[i] = val + + controller_output = self.get_controller_output_from_chip(arr) # raw float32 bytes over UART + controller_output = self.lib.to_tensor(controller_output, self.lib.float32) + controller_output = controller_output[self.lib.newaxis, self.lib.newaxis, :] + controller_output = self.lib.nan_to_num(controller_output, nan=0.0) + + if self.lib.lib == 'Pytorch': + controller_output = controller_output.detach().numpy() + + Q = controller_output + + return Q + + def controller_reset(self): + + if not self.just_restarted: + self.configure() + + def get_controller_output_from_chip(self, controller_input): + self.InterfaceInstance.send_controller_input(controller_input) + controller_output = self.InterfaceInstance.receive_controller_output(self.n_outputs) + + # if a cookie-triggered GET_SPEC happened, adopt it for NEXT step + if self.InterfaceInstance.pending_spec is not None: + self.spec_version, self.input_names, self.n_outputs = self.InterfaceInstance.pending_spec + self.InterfaceInstance.pending_spec = None + print(f"Refreshed SoC spec (v{self.spec_version}): " + f"{len(self.input_names)} inputs, {self.n_outputs} outputs") + + return controller_output + + + + +PING_TIMEOUT = 1.0 # Seconds +READ_STATE_TIMEOUT = 1.0 # Seconds +SERIAL_SOF = 0xAA + +# Unified protocol message types +MSG_TYPE_STATE = 0x01 # State data for controller +MSG_TYPE_GET_SPEC = 0x02 # Request controller specification +MSG_TYPE_PING = 0x03 # Ping/keepalive +MSG_TYPE_SPEC_COOKIE = 0x04 # Announce spec change (CHIP->PC) + +NAME_TOKEN_LEN = 24 # fixed ASCII token length per name + +class Interface: + def __init__(self): + self.device = None + self.msg = [] + self.start = None + self.end = None + + self.encoderDirection = None + + self.pending_spec = None + + def open(self, port, baud): + self.port = port + self.baud = baud + try: + self.device = serial.Serial(port, baudrate=baud, timeout=None) + self.device.reset_input_buffer() + self.device.reset_output_buffer() + except Exception as e: + print(f"ERROR opening serial device: {e}") + raise + + def close(self): + if self.device: + time.sleep(2) + self.device.close() + self.device = None + + def clear_read_buffer(self): + self.device.reset_input_buffer() + + def read_available_data(self, timeout=0.1): + """Read any available data from the device for debugging""" + old_timeout = self.device.timeout + try: + self.device.timeout = timeout + data = self.device.read(1024) # Read up to 1KB + return data + finally: + self.device.timeout = old_timeout + + def ping(self): + msg = [SERIAL_SOF, MSG_TYPE_PING, 4] + msg.append(self._crc(msg)) + self.device.write(bytearray(msg)) + + # Simple ping response check - just wait for 4 bytes + old_timeout = self.device.timeout + try: + self.device.timeout = PING_TIMEOUT + response = self.device.read(4) + + # Check for unified protocol ping response: [SOF, MSG_TYPE_PING, 4, CRC] + if len(response) == 4 and response[0] == SERIAL_SOF and response[1] == MSG_TYPE_PING and response[2] == 4: + return True + else: + return False + finally: + self.device.timeout = old_timeout + + def get_spec(self): + """ + Request SoC declaration of its input wire-order and output count. + Uses the unified protocol with MSG_TYPE_GET_SPEC = 0x02 + """ + # First check if there's any data available from the device + available_data = self.read_available_data() + + # First try to ping the device to see if it's responsive + if not self.ping(): + pass # Continue anyway + + # Retry logic for startup synchronization + max_retries = 3 + retry_delay = 0.5 # seconds + + for attempt in range(max_retries): + self.clear_read_buffer() + + # Send framed request using unified protocol (MSG_TYPE_GET_SPEC = 0x02) + msg = bytearray([SERIAL_SOF, MSG_TYPE_GET_SPEC, 4]) + msg.append(self._crc(msg)) + self.device.write(msg) + + # Handshake is a control exchange: use a bounded timeout so we fail fast instead of hanging. + old_timeout = self.device.timeout + try: + self.device.timeout = 2.0 + hdr = self.device.read(4) + + if len(hdr) != 4: + if attempt < max_retries - 1: + time.sleep(retry_delay) + continue + else: + break + + version, n_inputs, n_outputs, token_len = hdr[0], hdr[1], hdr[2], hdr[3] + + # Check if we got valid data + if token_len == NAME_TOKEN_LEN and n_inputs > 0 and n_outputs > 0: + need = n_inputs * token_len + raw = self.device.read(need) + + if len(raw) == need: + names = [] + for i in range(n_inputs): + chunk = raw[i*token_len:(i+1)*token_len] + # Cut at first NUL; ignore non-ASCII silently. + name = chunk.split(b'\x00', 1)[0].decode('ascii', errors='ignore') + + # Assert that input name is not longer than our buffer + if len(name) > NAME_TOKEN_LEN: + raise ValueError(f"Input name '{name}' is {len(name)} characters long, but NAME_TOKEN_LEN is only {NAME_TOKEN_LEN}") + + names.append(name) + + return version, names, n_outputs + + # If we get here, the firmware response is invalid + if attempt < max_retries - 1: + time.sleep(retry_delay) + continue + else: + break + + finally: + self.device.timeout = old_timeout # restore streaming behavior + + # Use hardcoded specifications for neural imitator controller + hardcoded_version = 1 + hardcoded_input_names = [ + "angleD", "angle_cos", "angle_sin", "position", "positionD", + "target_equilibrium", "target_position" + ] + hardcoded_n_outputs = 1 + + return hardcoded_version, hardcoded_input_names, hardcoded_n_outputs + + def send_controller_input(self, controller_input): + self.device.reset_output_buffer() + if not isinstance(controller_input, np.ndarray) or controller_input.dtype != np.float32: + controller_input = np.asarray(controller_input, dtype=np.float32) + + # Use unified protocol with MSG_TYPE_STATE = 0x01 + data_bytes = controller_input.tobytes() + msg_length = 4 + len(data_bytes) # SOF + type + length + data + CRC + + # Build message: [SOF, MSG_TYPE_STATE, length, data..., CRC] + msg = bytearray([SERIAL_SOF, MSG_TYPE_STATE, msg_length]) + msg.extend(data_bytes) + msg.append(self._crc(msg)) + + self.device.write(msg) + + def receive_controller_output(self, controller_output_length): + """ + Reads controller outputs. With the new unified protocol, the chip automatically + sends controller outputs when it receives state data, so we just need to read + the raw float data directly. + """ + # Read the expected number of float32 bytes directly + nbytes = controller_output_length * 4 + data = self.device.read(size=nbytes) + if len(data) != nbytes: + raise IOError(f"receive_controller_output: expected {nbytes} bytes, got {len(data)}") + + # Unpack the float32 data + try: + result = struct.unpack(f'<{controller_output_length}f', data) + return result + except struct.error as e: + print(f"ERROR: Failed to unpack controller output data: {e}") + # Return zeros as fallback + return (0.0,) * controller_output_length + + def _receive_reply(self, cmdLen, timeout=None, crc=True): + self.device.timeout = timeout + self.start = False + self.msg = [] + + while True: + c = self.device.read(1) + if len(c) == 0: + print('\nReconnecting.') + self.device.close() + self.device = serial.Serial(self.port, baudrate=self.baud, timeout=timeout) + self.clear_read_buffer() + time.sleep(1) + self.msg = [] + self.start = False + else: + # Py3: bytes→int via c[0]; ord() on bytes is a TypeError. + self.msg.append(c[0]) + if self.start is False: + self.start = time.time() + + while len(self.msg) >= cmdLen: + # print('I am looping! Hurra!') + # Message must start with SOF character + if self.msg[0] != SERIAL_SOF: + #print('\nMissed SERIAL_SOF') + del self.msg[0] + continue + + # Check message packet length + if self.msg[2] != cmdLen and cmdLen < 256: + print('\nWrong Packet Length.') + del self.msg[0] + continue + + # Verify integrity of message + if crc and self.msg[cmdLen-1] != self._crc(self.msg[:cmdLen-1]): + print('\nCRC Failed.') + del self.msg[0] + continue + + self.device.timeout = None + reply = self.msg[:cmdLen] + del self.msg[:cmdLen] + return reply + + def _crc(self, msg): + crc8 = 0x00 + + for i in range(len(msg)): + val = msg[i] + for b in range(8): + sum = (crc8 ^ val) & 0x01 + crc8 >>= 1 + if sum > 0: + crc8 ^= 0x8C + val >>= 1 + + return crc8 \ No newline at end of file diff --git a/Controllers/controller_fpga.py b/Controllers/controller_fpga.py deleted file mode 100644 index da85d09..0000000 --- a/Controllers/controller_fpga.py +++ /dev/null @@ -1,247 +0,0 @@ -import os -import sys -import glob -import serial -import struct -import time - -from SI_Toolkit.computation_library import TensorType, NumpyLibrary - -import numpy as np - -from Control_Toolkit.Controllers import template_controller - -try: - from SI_Toolkit_ASF.ToolkitCustomization.predictors_customization import STATE_INDICES -except ModuleNotFoundError: - print("SI_Toolkit_ASF not yet created") - -from SI_Toolkit.Functions.General.Initialization import load_net_info_from_txt_file - - -class controller_fpga(template_controller): - _computation_library = NumpyLibrary() - - def configure(self): - - - SERIAL_PORT = get_serial_port(serial_port_number=self.config_controller["SERIAL_PORT"]) - SERIAL_BAUD = self.config_controller["SERIAL_BAUD"] - set_ftdi_latency_timer(serial_port_number=self.config_controller["SERIAL_PORT"]) - self.InterfaceInstance = Interface() - self.InterfaceInstance.open(SERIAL_PORT, SERIAL_BAUD) - - NET_NAME = self.config_controller["net_name"] - PATH_TO_MODELS = self.config_controller["PATH_TO_MODELS"] - path_to_model_info = os.path.join(PATH_TO_MODELS, NET_NAME, NET_NAME + ".txt") - - self.input_at_input = self.config_controller["input_at_input"] - - self.net_info = load_net_info_from_txt_file(path_to_model_info) - - self.state_2_input_idx = [] - self.remaining_inputs = self.net_info.inputs.copy() - for key in self.net_info.inputs: - if key in STATE_INDICES.keys(): - self.state_2_input_idx.append(STATE_INDICES.get(key)) - self.remaining_inputs.remove(key) - else: - break # state inputs must be adjacent in the current implementation - - self.just_restarted = True - - print('Configured fpga controller with {} network with {} library\n'.format(self.net_info.net_full_name, self.lib.lib)) - - def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorType]" = {}): - self.just_restarted = False - if self.input_at_input: - net_input = s - else: - self.update_attributes(updated_attributes) - net_input = s[..., self.state_2_input_idx] - for key in self.remaining_inputs: - net_input = np.append(net_input, getattr(self.variable_parameters, key)) - - net_input = self.lib.to_tensor(net_input, self.lib.float32) - - if self.lib.lib == 'Pytorch': - net_input = net_input.to(self.device) - - net_input = self.lib.reshape(net_input, (-1, 1, len(self.net_info.inputs))) - net_input = self.lib.to_numpy(net_input) - - net_output = self.get_net_output_from_fpga(net_input) - - net_output = self.lib.to_tensor(net_output, self.lib.float32) - net_output = net_output[self.lib.newaxis, self.lib.newaxis, :] - - if self.lib.lib == 'Pytorch': - net_output = net_output.detach().numpy() - - Q = net_output - - return Q - - def controller_reset(self): - - if not self.just_restarted: - self.configure() - - def get_net_output_from_fpga(self, net_input): - self.InterfaceInstance.send_net_input(net_input) - net_output = self.InterfaceInstance.receive_net_output(len(self.net_info.outputs)) - return net_output - - -def get_serial_port(serial_port_number=''): - import platform - import subprocess - serial_port_number = str(serial_port_number) - SERIAL_PORT = None - try: - system = platform.system() - if system == 'Darwin': # Mac - SERIAL_PORT = subprocess.check_output(f'ls -a /dev/tty.usbserial*{serial_port_number}', shell=True).decode("utf-8").strip() # Probably '/dev/tty.usbserial-110' - elif system == 'Linux': - SERIAL_PORT = '/dev/ttyUSB' + serial_port_number # You might need to change the USB number - elif system == 'Windows': - SERIAL_PORT = 'COM' + serial_port_number - else: - raise NotImplementedError('For system={} connection to serial port is not implemented.') - except Exception as err: - print(err) - - return SERIAL_PORT - - - - -PING_TIMEOUT = 1.0 # Seconds -CALIBRATE_TIMEOUT = 10.0 # Seconds -READ_STATE_TIMEOUT = 1.0 # Seconds -SERIAL_SOF = 0xAA -CMD_PING = 0xC0 - -class Interface: - def __init__(self): - self.device = None - self.msg = [] - self.start = None - self.end = None - - self.encoderDirection = None - - def open(self, port, baud): - self.port = port - self.baud = baud - self.device = serial.Serial(port, baudrate=baud, timeout=None) - self.device.reset_input_buffer() - - def close(self): - if self.device: - time.sleep(2) - self.device.close() - self.device = None - - def clear_read_buffer(self): - self.device.reset_input_buffer() - - def ping(self): - msg = [SERIAL_SOF, CMD_PING, 4] - msg.append(self._crc(msg)) - self.device.write(bytearray(msg)) - return self._receive_reply(CMD_PING, 4, PING_TIMEOUT) == msg - - def send_net_input(self, net_input): - self.device.reset_output_buffer() - bytes_written = self.device.write(bytearray(net_input)) - # print(bytes_written) - - def receive_net_output(self, net_output_length): - net_output_length_bytes = net_output_length * 4 # We assume float32 - net_output = self.device.read(size=net_output_length_bytes) - net_output = struct.unpack(f'<{net_output_length}f', net_output) - # net_output=reply - return net_output - - def _receive_reply(self, cmdLen, timeout=None, crc=True): - self.device.timeout = timeout - self.start = False - - while True: - c = self.device.read() - # Timeout: reopen device, start stream, reset msg and try again - if len(c) == 0: - print('\nReconnecting.') - self.device.close() - self.device = serial.Serial(self.port, baudrate=self.baud, timeout=timeout) - self.clear_read_buffer() - time.sleep(1) - self.msg = [] - self.start = False - else: - self.msg.append(ord(c)) - if self.start == False: - self.start = time.time() - - while len(self.msg) >= cmdLen: - # print('I am looping! Hurra!') - # Message must start with SOF character - if self.msg[0] != SERIAL_SOF: - #print('\nMissed SERIAL_SOF') - del self.msg[0] - continue - - # Check message packet length - if self.msg[2] != cmdLen and cmdLen < 256: - print('\nWrong Packet Length.') - del self.msg[0] - continue - - # Verify integrity of message - if crc and self.msg[cmdLen-1] != self._crc(self.msg[:cmdLen-1]): - print('\nCRC Failed.') - del self.msg[0] - continue - - self.device.timeout = None - reply = self.msg[:cmdLen] - del self.msg[:cmdLen] - return reply - - def _crc(self, msg): - crc8 = 0x00 - - for i in range(len(msg)): - val = msg[i] - for b in range(8): - sum = (crc8 ^ val) & 0x01 - crc8 >>= 1 - if sum > 0: - crc8 ^= 0x8C - val >>= 1 - - return crc8 - - -import subprocess -def set_ftdi_latency_timer(serial_port_number): - print('\nSetting FTDI latency timer') - ftdi_timer_latency_requested_value = 1 - command_ftdi_timer_latency_set = f"sh -c 'echo {ftdi_timer_latency_requested_value} > /sys/bus/usb-serial/devices/ttyUSB{serial_port_number}/latency_timer'" - command_ftdi_timer_latency_check = f'cat /sys/bus/usb-serial/devices/ttyUSB{serial_port_number}/latency_timer' - try: - subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - print(e.stderr) - if "Permission denied" in e.stderr: - print("Trying with sudo...") - command_ftdi_timer_latency_set = "sudo " + command_ftdi_timer_latency_set - try: - subprocess.run("echo Teresa | sudo -S :", shell=True) - subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, text=True) - except subprocess.CalledProcessError as e: - print(e.stderr) - - ftdi_latency_timer_value = subprocess.run(command_ftdi_timer_latency_check, shell=True, capture_output=True, text=True).stdout.rstrip() - print(f'FTDI latency timer value (tested only for FTDI with Zybo and with Linux on PC side): {ftdi_latency_timer_value} ms \n') diff --git a/Controllers/controller_mpc.py b/Controllers/controller_mpc.py index 24367e9..dae52df 100644 --- a/Controllers/controller_mpc.py +++ b/Controllers/controller_mpc.py @@ -66,7 +66,6 @@ def configure(self, optimizer_name: Optional[str]=None, predictor_specification: self.predictor.configure( batch_size=self.optimizer.num_rollouts, - horizon=self.optimizer.mpc_horizon, dt=config_optimizer["mpc_timestep"], computation_library=self.computation_library, variable_parameters=self.variable_parameters, diff --git a/Controllers/controller_neural_imitator.py b/Controllers/controller_neural_imitator.py index c019487..816a552 100644 --- a/Controllers/controller_neural_imitator.py +++ b/Controllers/controller_neural_imitator.py @@ -14,6 +14,7 @@ class controller_neural_imitator(template_controller): _computation_library = NumpyLibrary() + _is_hls4ml_mode = False def configure(self): @@ -22,7 +23,9 @@ def configure(self): path_to_models=self.config_controller["PATH_TO_MODELS"], batch_size=1, # It makes sense only for testing (Brunton plot for Q) of not rnn networks to make bigger batch, this is not implemented input_precision=self.config_controller["input_precision"], - hls4ml=self.config_controller["hls4ml"]) + nn_evaluator_mode=self.config_controller["nn_evaluator_mode"]) + + self.clip_output = self.config_controller.get("clip_output", False) self._computation_library = self.net_evaluator.lib @@ -31,9 +34,16 @@ def configure(self): # Prepare input mapping self.input_mapping = self._create_input_mapping() - if self.controller_logging and self.lib.lib == "TF" and not self.net_evaluator.hls4ml: + if self.controller_logging and self.lib.lib == "TF" and self.net_evaluator.nn_evaluator_mode == 'normal': self.controller_data_for_csv = FunctionalDict(get_memory_states(self.net_evaluator.net)) + # Mark that the network has been configured (important for hls4ml mode) + self._is_configured = True + + # Track if we're in hls4ml mode for efficient reset checking + self._is_hls4ml_mode = (hasattr(self.net_evaluator, 'nn_evaluator_mode') and + self.net_evaluator.nn_evaluator_mode == 'hls4ml') + print('Configured neural imitator with {} network with {} library'.format(self.net_evaluator.net_info.net_full_name, self.net_evaluator.net_info.library)) def _create_input_mapping(self): @@ -61,7 +71,8 @@ def step(self, s: np.ndarray, time=None, updated_attributes: "dict[str, TensorTy Q = self.net_evaluator.step(net_input) - Q = np.clip(Q, -1.0, 1.0) # Ensure Q is within the range [-1, 1] + if self.clip_output: + Q = np.clip(Q, -1.0, 1.0) # Ensure Q is within the range [-1, 1] return Q @@ -95,6 +106,10 @@ def _compose_network_input(self, state: np.ndarray) -> np.ndarray: return net_input def controller_reset(self): + # For hls4ml mode, avoid reconfiguration since the network is already converted + # This prevents multiple expensive hls4ml conversions when switching controllers + if self._is_hls4ml_mode and self._is_configured: + return self.configure() diff --git a/Controllers/controller_remote.py b/Controllers/controller_remote.py new file mode 100644 index 0000000..947cffc --- /dev/null +++ b/Controllers/controller_remote.py @@ -0,0 +1,124 @@ +from __future__ import annotations +import numpy as np +import zmq +import zmq.error + +from SI_Toolkit.computation_library import NumpyLibrary +from Control_Toolkit.Controllers import template_controller +from Control_Toolkit.others.globals_and_utils import import_controller_by_name + +ENFORCE_TIMEOUT = True # Set to False to disable the timeout feature +DEFAULT_RCVTIMEO = 50 # [ms] + + +class controller_remote(template_controller): + _computation_library = NumpyLibrary() + """ + ZeroMQ DEALER proxy. + • Sends each state to the server together with a monotonically + increasing *request-id* (`rid`). + • Drops or purges every reply whose rid ≠ last request’s rid. + • After a timeout the motor command falls back to 0 or to a local controller. + """ + + def configure(self): + # ─── remote socket setup ──────────────────────────────────────── + self.endpoint = self.config_controller.get( + "remote_endpoint", "tcp://localhost:5555" + ) + self._ctx = zmq.Context() + self._sock = self._ctx.socket(zmq.DEALER) + self._sock.connect(self.endpoint) + if ENFORCE_TIMEOUT: + self._sock.setsockopt(zmq.RCVTIMEO, DEFAULT_RCVTIMEO) + + self._next_rid: int = 0 + print(f"Neural-imitator proxy connected to {self.endpoint}") + + # ─── fallback to a local controller or 0 control ────────────────────── + # retrieve fallback-controller parameters from config + self.fallback_controller_name = self.config_controller["fallback_controller_name"] + + if self.fallback_controller_name is not None: + # dynamically import and instantiate the local controller + # e.g. import_controller_by_name("controller-neural-imitator") + Controller = import_controller_by_name( + f"controller-{self.fallback_controller_name}".replace("-", "_") + ) + self._fallback_controller = Controller( + self.environment_name, self.control_limits, self.initial_environment_attributes + ) + self._fallback_controller.configure() + + # ------------------------------------------------------------------ STEP + def step( + self, + s: np.ndarray, + time=None, + updated_attributes: "dict[str, np.ndarray]" = {}, + ): + """ + Serialises the data, ships it to the server, waits up to 50 ms for Q, + and returns it—or falls back on timeout to the fallback controller or zero control. + """ + if updated_attributes is None: + updated_attributes = {} + + rid = self._next_rid # snapshot current rid + self._next_rid += 1 # prepare for next call + + self._sock.send_json( + { + "rid": rid, + "state": s.tolist(), # JSON-friendly + "time": time, + "updated_attributes": updated_attributes, + } + ) + + # ❷ -- receive with timeout + try: + resp = self._sock.recv_json() # may raise zmq.Again + except zmq.error.Again: + self._purge_stale() # clear the queue + if self.fallback_controller_name is not None: + # use local controller on timeout + return self._fallback_controller.step( + s, time=time, updated_attributes=updated_attributes + ) + return np.array(0.0, dtype=np.float32) + + # —— discard stale packets —————————— + while resp.get("rid") != rid: + try: + resp = self._sock.recv_json() + except zmq.error.Again: + # genuine timeout – treat as lost reply + if self.fallback_controller_name is not None: + return self._fallback_controller.step( + s, time=time, updated_attributes=updated_attributes + ) + return np.array(0.0, dtype=np.float32) + + if "error" in resp: + # Re-raise server-side exceptions locally for easier debugging + raise RuntimeError(f"Remote controller error: {resp['error']}") + + # ❸ -- final result + return np.asarray(resp["Q"], dtype=np.float32) + + # ---------------------------------------------------------- helpers + def _purge_stale(self) -> None: + """Discard every pending message in the inbound queue.""" + while True: + try: + self._sock.recv(flags=zmq.DONTWAIT) + except zmq.error.Again: + break + + # ---------------------------------------------------------------- RESET + def controller_reset(self): + """ + Nothing to reset locally; the server keeps the network state. + """ + pass diff --git a/Optimizers/optimizer_rpgd.py b/Optimizers/optimizer_rpgd.py index 347b895..694428a 100644 --- a/Optimizers/optimizer_rpgd.py +++ b/Optimizers/optimizer_rpgd.py @@ -263,7 +263,7 @@ def configure(self, if dt is not None and predictor_specification is not None: self.predictor_single_trajectory.configure( - batch_size=1, horizon=self.mpc_horizon, dt=dt, # TF requires constant batch size + batch_size=1, dt=dt, # TF requires constant batch size computation_library=self.lib, predictor_specification=predictor_specification, ) diff --git a/README.md b/README.md index 6d712ac..f6b2e01 100644 --- a/README.md +++ b/README.md @@ -1,82 +1,233 @@ -# Control_Toolkit +# Control Toolkit -This folder contains general controller classes conforming to an interface loosely based on the [OpenAI Gym Interface](https://arxiv.org/pdf/1606.01540). +> **Note**: AI-generated on 14.11.2024, not human verified. -The `Control_Toolkit_ASF/models_for_nn_as_mpc` folder contains exemplary neural networks for those controllers which need one, which they can import directly. +A modular Python toolkit for implementing advanced control algorithms with a focus on **Model Predictive Control (MPC)**. Supports multiple computation backends (TensorFlow, PyTorch, NumPy) with a unified interface based on the [OpenAI Gym Interface](https://arxiv.org/pdf/1606.01540). -To use the toolkit, add this and the [SI_Toolkit](https://github.com/SensorsINI/SI_Toolkit) as Git submodules at the top level of your repository: +## Features -``` +- 🎯 **Multiple Control Strategies**: MPC, neural network imitators, remote/embedded controllers +- 🔧 **Pluggable Optimizers**: CEM, RPGD, MPPI, gradient-based methods, and more +- 📊 **Flexible Cost Functions**: Define custom objectives with consistent interface +- 🖥️ **Multi-Backend Support**: TensorFlow, PyTorch, NumPy +- 📡 **Remote Control**: ZeroMQ-based controller server +- 📈 **Built-in Logging**: Comprehensive trajectory and optimization metrics +- 🔌 **Hardware Integration**: Serial interface helpers for embedded systems + +## Installation + +Add as submodules to your repository: + +```bash git submodule add https://github.com/SensorsINI/SI_Toolkit -git submodule update –init +git submodule add Control_Toolkit +git submodule update --init --recursive +pip install -r Control_Toolkit/requirements.txt ``` +## Quick Start + +```python +from Control_Toolkit.others.globals_and_utils import import_controller_by_name +import numpy as np -## Repositories using the Toolkit +# Instantiate MPC controller +ControllerClass = import_controller_by_name("mpc") +controller = ControllerClass( + environment_name="YourEnvironment", + control_limits=(-1.0, 1.0), + initial_environment_attributes={"target_position": 0.0} +) -- CartPole Simulator -- ControlGym -- Physical CartPole -- F1TENTH INI +# Configure with optimizer +controller.configure(optimizer_name="rpgd-tf") +# Run control loop +state = np.array([0.1, 0.2, 0.3, 0.4]) +control_input = controller.step(state, time=0.0) +``` -## Software Design and Motivation +## Architecture -### Folders +### Design Philosophy -The motivation behind this toolkit is universality: A systems control algorithm should be implemented as much agnostic to the environment it is deployed on as possible. However, not all controllers can be formulated in such a general manner. For this reason, one may also add a folder `Control_Toolkit_ASF` for application-specific control as follows: +The toolkit separates **general-purpose controllers** (environment-agnostic) from **application-specific controllers** (domain-tailored), promoting code reuse while maintaining flexibility. + +### Folder Structure ``` -main_control_repository -L Control_Toolkit (submodule) -L Control_Toolkit_ASF (regular folder) +your_project/ +├── Control_Toolkit/ # This repository (submodule) +│ ├── Controllers/ # General-purpose controllers +│ ├── Optimizers/ # Optimization algorithms +│ ├── Cost_Functions/ # Cost function base classes +│ ├── controller_server/ # Remote controller server +│ └── others/ # Utilities and helpers +├── Control_Toolkit_ASF/ # Your application-specific files +│ ├── Controllers/ # Custom controllers +│ ├── Cost_Functions/ # Custom cost functions +│ ├── config_controllers.yml # Controller configurations +│ ├── config_optimizers.yml # Optimizer configurations +│ └── config_cost_function.yml # Cost function configurations +└── SI_Toolkit/ # Predictors (submodule) ``` -Find a template for the ASF folder within the toolkit. The template contains sample configuration files, whose structure should be kept consistent. +**Naming Convention**: Files and classes use `controller_.py` or `optimizer_.py` format and must inherit from their respective template classes. + +## Available Controllers + +| Controller | Description | File | +|------------|-------------|------| +| **MPC** | Model Predictive Control with pluggable optimizers | `controller_mpc.py` | +| **Neural Imitator** | Neural network-based controller | `controller_neural_imitator.py` | +| **Remote** | Client for remote controller server | `controller_remote.py` | +| **Embedded** | Interface for embedded hardware | `controller_embedded.py` | +| **C Controller** | Wrapper for C-based controllers | `controller_C.py` | + +Define custom controllers in `Control_Toolkit_ASF/Controllers/`. Template available in `Control_Toolkit_ASF_Template/`. + +## Available Optimizers -### Controller Design +### Sampling-Based -Each controller is defined in a separate module. File name and class name should match and have the "controller_" prefix. +| Optimizer | Description | Backend | +|-----------|-------------|---------| +| **cem-tf** | Cross-Entropy Method: samples random sequences, selects elites, refits distribution | TensorFlow | +| **cem-naive-grad-tf** | CEM + gradient refinement of elite samples [[Bharadhwaj et al., 2020]](https://arxiv.org/abs/2003.10768) | TensorFlow | +| **cem-gmm-tf** | CEM with Gaussian Mixture Model | TensorFlow | -A controller can possess any of the following optional subcomponents: +### Gradient-Based -- `Cost_Functions`: This folder contains a general base class and wrapper class for defining cost functions. You can define cost functions for your application in the `Control_Toolkit_ASF`. -- `Optimizers`: Interchangeable optimizers which return the cost-minimizing input given dynamics imposed by state predictions. -- `Predictors`: Defined in the `SI_Toolkit`. +| Optimizer | Description | Backend | +|-----------|-------------|---------| +| **rpgd** | Resampling Parallel Gradient Descent: maintains population of trajectories, optimizes with Adam, periodic resampling [[Heetmeyer et al., 2023]](https://ieeexplore.ieee.org/document/10161233) | TensorFlow | +| **gradient-tf** | Pure gradient descent optimization | TensorFlow | -This toolkit focuses on model-predictive control. Currently, only a `controller_mpc` is provided. You can however define other controllers in the application-specific files. +### Hybrid +| Optimizer | Description | Backend | +|-----------|-------------|---------| +| **mppi** | Model Predictive Path Integral + Adam refinement | TensorFlow | -## List of available MPC optimizers with description - -- `cem-tf`: - A standard implementation of the cem algorithm. Samples a number of random input sequences from a normal distribution, - then simulates them and selectes the 'elite' set of random inputs with lowest costs. The sampling distribution - is fitted to the elite set and the procedure repeated a fixed number of times. - In the end the mean of the elite set is used as input. +### Other -- `cem-naive-grad-tf`: - Same as cem, but between selecting the elite set and fitting the distribution, all input sequences in the elite - set are refined with vanilla gradient descent. Re-Implementation of Bharadhwaj, Xie, Shkurti 2020. +| Optimizer | Description | Backend | +|-----------|-------------|---------| +| **random-action-tf** | Random action baseline | TensorFlow | +| **nlp-forces** | Nonlinear programming via FORCES Pro | NumPy | -- `rpgd-tf` (`formerly dist-adam-resamp2-tf`): - Initially samples a set of control sequences, then optimizes them with the adam optimizer projecting control inputs, - clipping inputs which violate the constraints. For the next time step, the optimizations are warm started with - the solution from the last one. In regular intervals the only a subset of cheap control sequences are - warm started, while the other ones are resampled. +## Controller Server -- `mppi-optimze-tf`: - First find an initial guess of control sequence with the standard mppi approach. Then optimze it using the adam - optimizer. +Run controllers as a service via ZeroMQ: +```bash +python -m Control_Toolkit.controller_server.controller_server +``` + +**Protocol** (endpoint: `tcp://*:5555`): + +Request: +```json +{"rid": "request_id", "state": [0.1, 0.2], "time": 0.5, "updated_attributes": {}} +``` + +Response: +```json +{"rid": "request_id", "Q": 0.25} +``` + +Client example: +```python +import zmq, json +context = zmq.Context() +socket = context.socket(zmq.REQ) +socket.connect("tcp://localhost:5555") +socket.send_json({"rid": "1", "state": [0.1, 0.2], "time": 0.0}) +control = socket.recv_json()["Q"] +``` ## Logging -The toolkit provides a uniform interface to log values in the controller. These values could for example be rollout trajectories or intermediate optimization results. +Enable logging in controller config: `controller_logging: true` + +**Logged variables**: `Q_logged`, `J_logged`, `s_logged`, `u_logged`, `realized_cost_logged`, `trajectory_ages_logged`, `rollout_trajectories_logged` + +**Access logs**: +```python +outputs = controller.get_outputs() +control_history = outputs['Q_logged'] +``` -The `controller_mpc.step` method takes the `optimizer.logging_values` dictionary and copies it to its `controller_mpc.logs` dictionary in each step. The `template_controller` has two related attributes: `controller_logging` and `save_vars`. If the former is `true`, then the controller populates the fields of `save_vars` in the `template_controller.logs` dictionary with values if your controller calls `update_logs` within the `step` method. +## Configuration +Configuration files in `Control_Toolkit_ASF/`: + +**config_controllers.yml**: +```yaml +mpc: + optimizer: rpgd-tf + computation_library: tensorflow + device: cpu + predictor_specification: "my_predictor" + cost_function_specification: "tracking" + controller_logging: true +``` + +**config_optimizers.yml**: +```yaml +rpgd-tf: + num_rollouts: 100 + mpc_horizon: 20 + mpc_timestep: 0.05 + learning_rate: 0.01 + seed: 42 +``` + +## Hardware Integration + +**Serial Interface Helper** (for STM, ZYNQ boards): + +```python +from Control_Toolkit.serial_interface_helper import get_serial_port, set_ftdi_latency_timer + +port = get_serial_port(chip_type="STM") # or "ZYNQ" +set_ftdi_latency_timer(port) # Low-latency configuration +``` + +## Projects Using This Toolkit + +- [CartPole Simulator](https://github.com/SensorsINI/CartPoleSimulation/tree/reproduction_of_results_sep22) +- [ControlGym](https://github.com/frehe/ControlGym/tree/reproduction_of_results_sep22) +- [Physical CartPole](https://github.com/neuromorphs/physical-cartpole/tree/reproduction_of_results_sep2022_physical_cartpole) +- [F1TENTH INI](https://github.com/F1Tenth-INI/f1tenth_development_gym) + +See [CartPoleSimulation Control_Toolkit_ASF](https://github.com/SensorsINI/CartPoleSimulation/tree/master/Control_Toolkit_ASF/Controllers) for examples of application-specific controllers (do-mpc, LQR, etc.). + +## Requirements + +``` +tensorflow, tensorflow_probability, numpy, torch, torchvision, gymnasium, watchdog +``` + +**Note**: Install only what you need (e.g., NumPy-only setups don't require TensorFlow/PyTorch). + +## Citation + +If using RPGD optimizer, please cite: + +```bibtex +@inproceedings{heetmeyer2023rpgd, + title={RPGD: A Small-Batch Parallel Gradient Descent Optimizer with Explorative + Resampling for Nonlinear Model Predictive Control}, + author={Heetmeyer, Frederik and Paluch, Marcin and Bolliger, Diego}, + booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)}, + pages={3218--3224}, + year={2023}, + organization={IEEE}, + doi={10.1109/icra48891.2023.10161233} +} +``` -## Examples of Application-Specific Controllers +--- -We refer to the [Control_Toolkit_ASF of our CartPoleSimulation Project](https://github.com/SensorsINI/CartPoleSimulation/tree/master/Control_Toolkit_ASF/Controllers). +**Template**: Use `Control_Toolkit_ASF_Template/` to create your application-specific folder structure. diff --git a/controller_server/__init__.py b/controller_server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/controller_server/controller_server.py b/controller_server/controller_server.py new file mode 100644 index 0000000..08b8d62 --- /dev/null +++ b/controller_server/controller_server.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +remote_nn_controller_server.py + +ZeroMQ ROUTER server that uses gui_selection to pick controller and optimizer, +then serves step requests. +""" + +import sys +import numpy as np +import zmq +import json + +from Control_Toolkit.controller_server.gui import choose_controller_and_optimizer +from Control_Toolkit.others.globals_and_utils import import_controller_by_name + +# Hardcoded ZeroMQ endpoint +ENDPOINT = "tcp://*:5555" + + +initial_environment_attributes = { + "target_position": 0.0, + "target_equilibrium": 0.0, + "m_pole": 0.0, + "L": 0.0, + "Q_ccrc": 0.0, + "Q_applied_-1": 0.0, +} + +def main(): + # Launch the GUI to get controller/optimizer + ctrl_name, opt_name = choose_controller_and_optimizer() + print(f"[server] ▶️ Controller: {ctrl_name} Optimizer: {opt_name}") + + # Dynamically import & instantiate + ControllerClass = import_controller_by_name(ctrl_name) + ctrl = ControllerClass( + environment_name="CartPole", + control_limits=(-1.0, 1.0), + initial_environment_attributes=initial_environment_attributes, # populate as needed + ) + + # Configure with or without optimizer + if ctrl.has_optimizer: + ctrl.configure(optimizer_name=opt_name) + else: + ctrl.configure() + + # ─── ZeroMQ ROUTER socket ──────────────────────────────────────── + ctx = zmq.Context() + sock = ctx.socket(zmq.ROUTER) + sock.bind(ENDPOINT) + print(f"[server] 🚀 listening on {ENDPOINT}") + + while True: + # Receive either [identity, payload] or [identity, b"", payload] + parts = sock.recv_multipart() + if len(parts) == 2: + client_identity, payload = parts + elif len(parts) == 3 and parts[1] == b"": + client_identity, _empty, payload = parts + else: + # Unexpected framing; skip it + continue + + try: + req = json.loads(payload.decode("utf-8")) + rid = req["rid"] + s = np.asarray(req["state"], dtype=np.float32) + t = req.get("time") + upd = req.get("updated_attributes", {}) + + Q = ctrl.step(s, t, upd) + if isinstance(Q, np.ndarray): + Q_payload = Q.tolist() + else: + # covers Python floats *and* tf.Tensor scalars via .numpy() + Q_payload = float(Q) if not isinstance(Q, (list, tuple)) else Q + + reply = json.dumps({"rid": rid, "Q": Q_payload}).encode("utf-8") + + sock.send_multipart([client_identity, reply]) + + except Exception as e: + print(f"[server] ⚠️ controller exception – no reply sent: {e}", file=sys.stderr) + continue # do NOT send anything back + + +if __name__ == "__main__": + main() diff --git a/controller_server/gui.py b/controller_server/gui.py new file mode 100644 index 0000000..d9c9cb3 --- /dev/null +++ b/controller_server/gui.py @@ -0,0 +1,89 @@ +from PyQt6.QtWidgets import ( + QApplication, + QDialog, + QVBoxLayout, + QGroupBox, + QRadioButton, + QDialogButtonBox, +) +from PyQt6.QtCore import Qt + +from Control_Toolkit.others.globals_and_utils import ( + get_available_controller_names, + get_available_optimizer_names, + get_controller_name, + get_optimizer_name, +) + + +class SelectionDialog(QDialog): + def __init__(self): + super().__init__() + self.setWindowTitle("Select Controller & Optimizer") + self.resize(400, 300) + + layout = QVBoxLayout(self) + + # Controllers group + ctrl_names = get_available_controller_names() + box_ctrl = QGroupBox("Controllers") + vbox_ctrl = QVBoxLayout() + self.rbs_controllers = [] + for name in ctrl_names: + rb = QRadioButton(name) + vbox_ctrl.addWidget(rb) + self.rbs_controllers.append(rb) + if self.rbs_controllers: + self.rbs_controllers[0].setChecked(True) + box_ctrl.setLayout(vbox_ctrl) + layout.addWidget(box_ctrl) + + # Optimizers group + opt_names = get_available_optimizer_names() + box_opt = QGroupBox("Optimizers") + vbox_opt = QVBoxLayout() + self.rbs_optimizers = [] + for name in opt_names: + rb = QRadioButton(name) + vbox_opt.addWidget(rb) + self.rbs_optimizers.append(rb) + if self.rbs_optimizers: + self.rbs_optimizers[0].setChecked(True) + box_opt.setLayout(vbox_opt) + layout.addWidget(box_opt) + + # OK / Cancel buttons + btns = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, + Qt.Orientation.Horizontal, + self, + ) + btns.accepted.connect(self.accept) + btns.rejected.connect(self.reject) + layout.addWidget(btns) + + def get_selection(self): + """ + Returns: + Tuple[str, str]: (controller_name, optimizer_name) + """ + ctrl = None + for idx, rb in enumerate(self.rbs_controllers): + if rb.isChecked(): + ctrl, _ = get_controller_name(controller_idx=idx) + break + opt = None + for idx, rb in enumerate(self.rbs_optimizers): + if rb.isChecked(): + opt, _ = get_optimizer_name(optimizer_idx=idx) + break + return ctrl, opt + + +def choose_controller_and_optimizer(): + import sys + app = QApplication(sys.argv) + dlg = SelectionDialog() + if dlg.exec() != QDialog.DialogCode.Accepted: + sys.exit(0) + return dlg.get_selection() diff --git a/serial_interface_helper.py b/serial_interface_helper.py new file mode 100644 index 0000000..5a32c5f --- /dev/null +++ b/serial_interface_helper.py @@ -0,0 +1,105 @@ +import getpass +import platform +import subprocess + +import serial + +SUDO_PASSWORD = None # Required to set FTDI latency timer on Linux systems, can be set to a hardcoded password for convenience or left as None to prompt the user via terminal. + +def get_serial_port(chip_type="STM", serial_port_number=None): + """ + Finds the cartpole serial port, or throws exception if not present + :param chip_type: "ZYNQ" or "STM" depending on which one you use + :param serial_port_number: Only used if serial port not found using chip type, can be left None, for normal operation + :returns: the string name of the COM port + """ + + import sys + from serial.tools import list_ports + ports = list(serial.tools.list_ports.comports()) + + # Linux-only refinement: hide legacy ttyS* placeholders and list only real USB CDC/serial endpoints. + if sys.platform.startswith("linux"): + visible_ports = [p for p in ports if (getattr(p, "device", "") or "").startswith(("/dev/ttyUSB", "/dev/ttyACM"))] + else: + visible_ports = ports + + serial_ports_names = [] + print('\nAvailable serial ports:') + for index, port in enumerate(visible_ports): + serial_ports_names.append(port.device) + print(f'{index}: port={port.device}; description={port.description}') + print() + + if chip_type == "STM": + expected_descriptions = ['USB Serial'] + elif chip_type == "ZYNQ": + expected_descriptions = ['Digilent Adept USB Device - Digilent Adept USB Device', 'Digilent Adept USB Device'] + else: + raise ValueError(f'Unknown chip type: {chip_type}') + + possible_ports = [] + for port in visible_ports: + if port.description in expected_descriptions: + possible_ports.append(port.device) + + SERIAL_PORT = None + if not possible_ports: + message = f"Searching serial port by its expected descriptions - {expected_descriptions} - not successful." + if serial_port_number is not None: + print(message) + else: + raise Exception(message) + else: + if serial_port_number is None: + SERIAL_PORT = possible_ports[0] + elif 0 <= serial_port_number < len(possible_ports): + SERIAL_PORT = possible_ports[serial_port_number] + else: + print( + f"Requested serial port number {serial_port_number} is out of range. Available ports: {len(possible_ports)}") + print(f"Using the first available port: {possible_ports[0]}") + SERIAL_PORT = possible_ports[0] + + if SERIAL_PORT is None and serial_port_number is not None: + if len(serial_ports_names) == 0: + print('No serial ports') + elif 0 <= serial_port_number < len(serial_ports_names): + print(f"Setting serial port with requested number ({serial_port_number})\n") + SERIAL_PORT = serial_ports_names[serial_port_number] + + return SERIAL_PORT + + +def set_ftdi_latency_timer(serial_port_name): + print('\nSetting FTDI latency timer') + requested_value = 1 # in ms + + if platform.system() == 'Linux': + # check for hardcoded sudo password or prompt the user + if SUDO_PASSWORD: + password = SUDO_PASSWORD + else: + password = getpass.getpass('Enter sudo password: ') + + serial_port = serial_port_name.split('/')[-1] + ftdi_timer_latency_requested_value = 1 + command_ftdi_timer_latency_set = f"sh -c 'echo {ftdi_timer_latency_requested_value} > /sys/bus/usb-serial/devices/{serial_port}/latency_timer'" + command_ftdi_timer_latency_check = f'cat /sys/bus/usb-serial/devices/{serial_port}/latency_timer' + try: + subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + print(e.stderr) + if "Permission denied" in e.stderr: + print("Trying with sudo...") + command_ftdi_timer_latency_set = f"echo {password} | sudo -S {command_ftdi_timer_latency_set}" + try: + subprocess.run(command_ftdi_timer_latency_set, shell=True, check=True, capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print(e.stderr) + + ftdi_latency_timer_value = subprocess.run(command_ftdi_timer_latency_check, shell=True, capture_output=True, + text=True).stdout.rstrip() + print( + f'FTDI latency timer value (tested only for FTDI with Zybo and with Linux on PC side): {ftdi_latency_timer_value} ms \n')