Source code for neural_data_simulator.tasks.run_closed_loop

"""Run all components of the BCI closed loop."""
import argparse
import logging
import os
from pathlib import Path
import subprocess
import sys
import tempfile
import time

from neural_data_simulator.core.settings import LogLevel
from neural_data_simulator.util.runtime import configure_logger
from neural_data_simulator.util.runtime import initialize_logger
from neural_data_simulator.util.settings_loader import check_config_override_str

SCRIPT_NAME = "nds-run-closed-loop"
logger = logging.getLogger(__name__)

# This is the pipe message indicating that the main task has finished
MAIN_TASK_FINISHED_MSG = "main_task_finished"


def _run_process(args) -> subprocess.Popen:
    logger.info(f"Starting process: {args}")
    return subprocess.Popen(args, shell=sys.platform == "win32")


def _terminate_process(label: str, popen_process: subprocess.Popen, timeout: int = 5):
    logger.info(f"Terminating {label}")
    popen_process.terminate()
    try:
        popen_process.wait(timeout)
    except (subprocess.TimeoutExpired, KeyboardInterrupt):
        logger.warning(f"{label} did not terminate in time. Killing it")
        popen_process.kill()


def _parse_args():
    parser = argparse.ArgumentParser(description="Run closed loop.")
    parser.add_argument(
        "--nds-settings-path",
        type=Path,
        help="Path to the yaml file containing the NDS config.",
    )
    parser.add_argument(
        "--decoder-settings-path",
        type=Path,
        help="Path to the yaml file containing the decoder config.",
    )
    parser.add_argument(
        "--task-settings-path",
        type=Path,
        help="Path to the yaml file containing the task config.",
    )
    parser.add_argument(
        "--nds-overrides",
        "-no",
        nargs="*",
        type=check_config_override_str,
        help=(
            "Specify NDS config overrides as key-value pairs, separated by spaces. "
            "For example: -no log_level=DEBUG streamer.lsl_chunk_frequency=50"
        ),
    )
    parser.add_argument(
        "--decoder-overrides",
        "-do",
        nargs="*",
        type=check_config_override_str,
        help=(
            "Specify decoder config overrides as key-value pairs, separated by spaces. "
            "For example: -do log_level=DEBUG decoder.spike_threshold=-210"
        ),
    )
    parser.add_argument(
        "--task-overrides",
        "-to",
        nargs="*",
        type=check_config_override_str,
        help=(
            "Specify task config overrides as key-value pairs, separated by spaces. "
            "For example: -to log_level=DEBUG center_out_reach.task.target_radius=0.03"
        ),
    )

    parser.add_argument(
        "--remote-task",
        action="store_true",
        help="Run without center_out_reach task, expecting a remote task to be used.",
    )

    args = parser.parse_args()

    return args


def _build_param_from_arg(arg_value, param_name):
    if arg_value is None:
        return []
    if isinstance(arg_value, list):
        # Pass in each arg_value element as its string version
        return [param_name] + [str(val) for val in arg_value]
    return [param_name, str(arg_value)]


def _wait_for_center_out_reach_main_task(
    control_file_path: str, center_out_reach: subprocess.Popen
):
    """Waits for the center_out_reach main task to finish.

    This function is blocking and will only return when the center_out_reach
    has sent a message indicating that his main task has finished or when
    the center_out_reach process has stopped unexpectedly, indicating that the
    main task was also terminated.

    Args:
        control_file_path (str): Path to the control file.
        center_out_reach (subprocess.Popen): The center_out_reach process.
    """
    logger.info("Waiting for center_out_reach main task")
    with open(control_file_path, "r") as file:
        while True:
            line = str(file.readline())
            if MAIN_TASK_FINISHED_MSG in line:
                logger.info("center_out_reach main task finished")
                break
            if center_out_reach.poll() is not None:
                logger.info("center_out_reach stopped unexpectedly.")
                break
            time.sleep(0.5)


[docs]def run(): """Start all components.""" try: import pygame # noqa: F401 import screeninfo # noqa: F401 except ImportError: print( "Not running closed loop because neural-data-simulator " "extras couldn't be imported" ) print("Please reinstall neural-data-simulator with extras by running:") print('pip install "neural-data-simulator[extras]"') return initialize_logger(SCRIPT_NAME) configure_logger(SCRIPT_NAME, LogLevel._INFO) args = _parse_args() SETTINGS_PATH_PARAM = "--settings-path" OVERRIDES_PARAM = "--overrides" nds_params = _build_param_from_arg(args.nds_settings_path, SETTINGS_PATH_PARAM) nds_params += _build_param_from_arg(args.nds_overrides, OVERRIDES_PARAM) decoder_params = _build_param_from_arg( args.decoder_settings_path, SETTINGS_PATH_PARAM ) decoder_params += _build_param_from_arg(args.decoder_overrides, OVERRIDES_PARAM) task_params = _build_param_from_arg(args.task_settings_path, SETTINGS_PATH_PARAM) task_params += _build_param_from_arg(args.task_overrides, OVERRIDES_PARAM) logger.info("Starting modules") encoder = _run_process(["encoder"] + nds_params) ephys = _run_process(["ephys_generator"] + nds_params) decoder = _run_process(["decoder"] + decoder_params) if args.remote_task: logger.info("Running with remote task. Press CTRL+C to exit.") try: while True: time.sleep(1) except KeyboardInterrupt: logger.info("CTRL+C received. Exiting...") else: with tempfile.TemporaryDirectory() as temp_dir: control_file_path = os.path.join(temp_dir, "center_out_reach_control_file") Path(control_file_path).touch(exist_ok=False) center_out_reach = _run_process( ["center_out_reach", "--control-file", control_file_path] + task_params ) logger.info("Modules started") try: _wait_for_center_out_reach_main_task( control_file_path, center_out_reach ) except KeyboardInterrupt: logger.info("CTRL+C received. Exiting...") _terminate_process("center_out_reach", center_out_reach) if center_out_reach.poll() is None: logger.info("Waiting for center_out_reach") center_out_reach.wait() _terminate_process("encoder", encoder) _terminate_process("ephys_generator", ephys) _terminate_process("decoder", decoder) logger.info("Done")
if __name__ == "__main__": run()