"""Run the center-out reach task GUI."""
import argparse
import logging
from pathlib import Path
import re
from typing import cast, Tuple
import numpy as np
from pydantic import Extra
from pydantic_yaml import VersionedYamlModel
from rich.pretty import pprint
import yaml
from neural_data_simulator.core import inputs
from neural_data_simulator.core import outputs
from neural_data_simulator.core.outputs import StreamConfig
from neural_data_simulator.core.settings import LogLevel
from neural_data_simulator.tasks.center_out_reach.input_events import InputHandler
from neural_data_simulator.tasks.center_out_reach.metrics import MetricsCollector
from neural_data_simulator.tasks.center_out_reach.scalers import PixelsToMetersConverter
from neural_data_simulator.tasks.center_out_reach.scalers import StandardVelocityScaler
from neural_data_simulator.tasks.center_out_reach.settings import CenterOutReach
from neural_data_simulator.tasks.center_out_reach.task_runner import TaskRunner
from neural_data_simulator.tasks.center_out_reach.task_state import StateParams
from neural_data_simulator.tasks.center_out_reach.task_state import TaskState
from neural_data_simulator.tasks.center_out_reach.task_window import TaskWindow
from neural_data_simulator.util.runtime import configure_logger
from neural_data_simulator.util.runtime import get_configs_dir
from neural_data_simulator.util.runtime import initialize_logger
from neural_data_simulator.util.runtime import open_connection
from neural_data_simulator.util.runtime import unwrap
from neural_data_simulator.util.settings_loader import check_config_override_str
from neural_data_simulator.util.settings_loader import load_settings
SCRIPT_NAME = "nds-center-out-reach"
logger = logging.getLogger(__name__)
class _Settings(VersionedYamlModel):
"""Center-out reach app settings.
Defines the schema of a `settings_center_out_reach.yaml` file.
"""
log_level: LogLevel
center_out_reach: CenterOutReach
class Config:
extra = Extra.forbid
def _get_task_window_params(
task_settings: CenterOutReach.Task,
window_settings: CenterOutReach.Window,
unit_converter: PixelsToMetersConverter,
):
return TaskWindow.Params(
int(unit_converter.meters_to_pixels(task_settings.target_radius)),
int(unit_converter.meters_to_pixels(task_settings.cursor_radius)),
int(unit_converter.meters_to_pixels(task_settings.radius_to_target)),
task_settings.number_of_targets,
window_settings.colors.background,
window_settings.colors.decoded_cursor,
window_settings.colors.decoded_cursor_on_target,
window_settings.colors.actual_cursor,
window_settings.colors.target,
window_settings.colors.target_waiting_for_cue,
# if we use meter units for GUI elements they will be
# correctly scaled to all devices
font_size=int(unit_converter.meters_to_pixels(0.006)),
button_size=(
(
int(unit_converter.meters_to_pixels(0.04)),
int(unit_converter.meters_to_pixels(0.01)),
)
),
button_spacing=int(unit_converter.meters_to_pixels(0.005)),
button_offset_top=int(unit_converter.meters_to_pixels(0.03)),
)
def _get_state_machine_params(task_settings):
return StateParams(
task_settings.delay_to_begin,
task_settings.delay_waiting_for_cue,
task_settings.target_holding_time,
task_settings.max_trial_time,
)
def _get_window_rect(unit_converter, window_settings, task_settings) -> Tuple[int, int]:
if window_settings.width and window_settings.height:
return (
window_settings.width,
window_settings.height,
)
radius_to_target_pixels = unit_converter.meters_to_pixels(
task_settings.radius_to_target
)
target_diameter_pixels = radius_to_target_pixels * 2
return target_diameter_pixels * 1.2, target_diameter_pixels * 1.2
def _parse_rich_text(text: str, default_text_color: str):
annotation_regex = re.compile(r"(\[[^\]]+\]\([^)]+\))")
color_regex = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
text_parts = annotation_regex.split(text)
hint = []
for part in text_parts:
if rp := list(color_regex.findall(part)):
rich_text = rp[0]
hint.append({"color": rich_text[1], "text": rich_text[0]})
else:
hint.append({"color": default_text_color, "text": part})
return hint
def _get_menu_text(
default_text_color,
actual_cursor_color,
decoded_cursor_color,
target_color,
input_device,
):
text = (
f"\nWelcome to the Center Out Reaching Task!\n\n"
f"Press the <Start> button to begin.\n\n"
f"Two cursors will be presented: the [input cursor]({actual_cursor_color})"
f" that\n directly follows your {input_device} movements (and can be toggled\n"
f" on and off by pressing `c` on the keyboard), and the\n"
f" [decoded cursor]({decoded_cursor_color}) that is the decoded from the"
f" simulated\n spikes from the [input cursor]({actual_cursor_color})"
f" movement.\n\nYour goal is to reach and stay within the"
f" [target]({target_color}) using the\n"
f" [decoded cursor]({decoded_cursor_color})"
f" until the next [target]({target_color}) is presented.\n"
f"There is no time or target limit, you can press the\n"
f" `Escape` key to finish the task at any time.\n"
)
return _parse_rich_text(text, default_text_color)
def _get_training_text(default_text_color, target_color, input_device):
text = (
f"\nWelcome to the Center Out Reaching Task!\n\n"
f"Press the <Start> button to begin.\n\n"
f"In open loop mode, the cursor follows your\n"
f"{input_device} movements.\n\nYour goal is to reach and stay within the"
f" [target]({target_color}) using the\n"
f" cursor"
f" until the next [target]({target_color}) is presented.\n\n"
f"There is no time or target limit, you can press the\n"
f" `Escape` key to finish the task at any time.\n"
)
return _parse_rich_text(text, default_text_color)
def _metrics_enabled(settings: _Settings) -> bool:
return (
settings.center_out_reach.input.enabled
and settings.center_out_reach.with_metrics
)
def _parse_args():
parser = argparse.ArgumentParser(
description="Run the center-out reach task GUI..",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--settings-path",
type=Path,
default=Path(get_configs_dir()).joinpath("settings_center_out_reach.yaml"),
help="Path to the settings_center_out_reach.yaml file.",
)
parser.add_argument(
"--overrides",
"-o",
nargs="*",
type=check_config_override_str,
help=(
"Specify settings overrides as key-value pairs, separated by spaces. "
"For example: -o log_level=DEBUG center_out_reach.task.target_radius=0.03"
),
)
parser.add_argument(
"--print-settings-only",
"-p",
action="store_true",
help="Parse/print the settings and exit.",
)
parser.add_argument(
"--control-file",
type=Path,
help="Path to the control file that will receive control messages.",
)
args = parser.parse_args()
return args
[docs]def run():
"""Run the center-out reach task GUI."""
initialize_logger(SCRIPT_NAME)
args = _parse_args()
settings: _Settings = cast(
_Settings,
load_settings(
args.settings_path,
settings_parser=_Settings,
override_dotlist=args.overrides,
),
)
if args.print_settings_only:
pprint(settings)
return
configure_logger(SCRIPT_NAME, settings.log_level)
logger.debug(f"run_center_out_reach settings:\n{yaml.dump(settings.dict())}")
if settings.center_out_reach.input.enabled:
lsl_input_settings = unwrap(settings.center_out_reach.input.lsl)
data_input = inputs.LSLInput(
lsl_input_settings.stream_name, lsl_input_settings.connection_timeout
)
else:
data_input = None
lsl_output_settings = settings.center_out_reach.output.lsl
sampling_rate = settings.center_out_reach.sampling_rate
data_output = outputs.LSLOutputDevice(
stream_config=StreamConfig.from_lsl_settings(
lsl_output_settings,
sampling_rate,
n_channels=2,
)
)
# Set up the output for the task state
task_window_output = None
if settings.center_out_reach.task_window_output is not None:
task_window_output = outputs.LSLOutputDevice.from_lsl_settings(
settings.center_out_reach.task_window_output.lsl,
sampling_rate,
n_channels=4,
)
window_settings = settings.center_out_reach.window
task_settings = settings.center_out_reach.task
unit_converter = PixelsToMetersConverter(window_settings.ppi)
window_rect = _get_window_rect(unit_converter, window_settings, task_settings)
window_params = _get_task_window_params(
task_settings,
window_settings,
unit_converter,
)
scaler_settings = settings.center_out_reach.standard_scaler
velocity_scaler = StandardVelocityScaler(
np.array(scaler_settings.scale), np.array(scaler_settings.mean), unit_converter
)
state_params = _get_state_machine_params(settings.center_out_reach.task)
actual_cursor_color = window_settings.colors.actual_cursor
decoded_cursor_color = window_settings.colors.decoded_cursor
if _metrics_enabled(settings):
metrics_collector = MetricsCollector(
window_rect,
task_settings.target_radius,
unit_converter,
actual_cursor_color,
decoded_cursor_color,
)
else:
metrics_collector = None
with_decoded_cursor = settings.center_out_reach.input.enabled
user_input = InputHandler()
if with_decoded_cursor:
menu_text = _get_menu_text(
"black",
actual_cursor_color,
decoded_cursor_color,
window_settings.colors.target,
user_input.input_device_name,
)
else:
menu_text = _get_training_text(
"black", window_settings.colors.target, user_input.input_device_name
)
interrupted = False
task_window = None
try:
with open_connection(data_output), open_connection(data_input), open_connection(
task_window_output
):
task_window = TaskWindow(window_rect, window_params, menu_text)
task_state = TaskState(task_window, state_params)
task_runner = TaskRunner(
sampling_rate,
data_input,
data_output,
velocity_scaler,
with_decoded_cursor,
metrics_collector,
task_window_output=task_window_output,
)
logger.info("Running task")
task_runner.run(task_state, user_input)
except KeyboardInterrupt:
logger.info("CTRL+C received. Exiting...")
interrupted = True
# This is used as a signal to a parent process that the main task has finished
if args.control_file is not None:
with args.control_file.open("w") as control_file:
control_file.write("main_task_finished\n")
if (
not interrupted
and task_window is not None
and not task_window.show_menu_screen
and _metrics_enabled(settings)
):
unwrap(metrics_collector).plot_metrics(task_window.target_positions)
if task_window is not None:
task_window.leave()
if __name__ == "__main__":
run()