Source code for neural_data_simulator.decoder.run_decoder

"""Script that configures and starts the example :class:`Decoder`."""

import argparse
import logging
from pathlib import Path
from typing import cast

from pydantic import Extra
from pydantic_yaml import VersionedYamlModel
from rich.pretty import pprint
import yaml

from neural_data_simulator.core import inputs
from neural_data_simulator.core import outputs
from neural_data_simulator.core import timing
from neural_data_simulator.core.settings import LogLevel
from neural_data_simulator.core.settings import TimerModel
from neural_data_simulator.decoder.decoders import Decoder
from neural_data_simulator.decoder.decoders import PersistedFileDecoderModel
from neural_data_simulator.decoder.settings import DecoderSettings
from neural_data_simulator.util.runtime import configure_logger
from neural_data_simulator.util.runtime import get_abs_path
from neural_data_simulator.util.runtime import get_configs_dir
from neural_data_simulator.util.runtime import initialize_logger
from neural_data_simulator.util.runtime import open_connection
from neural_data_simulator.util.settings_loader import check_config_override_str
from neural_data_simulator.util.settings_loader import load_settings

SCRIPT_NAME = "nds-decoder"
logger = logging.getLogger(__name__)


class _Settings(VersionedYamlModel):
    """Decoder app settings."""

    log_level: LogLevel
    decoder: DecoderSettings
    timer: TimerModel

    class Config:
        extra = Extra.forbid


def _parse_args():
    parser = argparse.ArgumentParser(
        description="Decode behavior from input neural data.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--settings-path",
        type=Path,
        default=Path(get_configs_dir()).joinpath("settings_decoder.yaml"),
        help="Path to the settings_decoder.yaml file.",
    )
    parser.add_argument(
        "--overrides",
        "-o",
        nargs="*",
        type=check_config_override_str,
        help=(
            "Specify settings overrides as key-value pairs, separated by spaces. "
            "For example: -o log_level=DEBUG decoder.spike_threshold=-210"
        ),
    )
    parser.add_argument(
        "--print-settings-only",
        "-p",
        action="store_true",
        help="Parse/print the settings and exit.",
    )
    args = parser.parse_args()
    return args


def _read_decode_send(
    data_input: inputs.Input, dec: Decoder, data_output: outputs.Output
) -> None:
    """Read input data, decode it, and send to the output stream.

    Args:
        data_input: Input data source.
        dec: Decoder.
        data_output: Output data sink.
    """
    samples = data_input.read()
    if not samples.empty:
        decoded_samples = dec.decode(samples)
        if not decoded_samples.empty:
            data_output.send(decoded_samples)


[docs]def run(): """Run the decoder loop.""" initialize_logger(SCRIPT_NAME) args = _parse_args() settings: _Settings = cast( _Settings, load_settings( args.settings_path, settings_parser=_Settings, override_dotlist=args.overrides, ), ) if args.print_settings_only: pprint(settings) return configure_logger(SCRIPT_NAME, settings.log_level) logger.debug(f"run_decoder settings:\n{yaml.dump(settings.dict())}") # Set up timer timer_settings = settings.timer timer = timing.get_timer(timer_settings.loop_time, timer_settings.max_cpu_buffer) # Create LSL input and output objects output_settings = settings.decoder.output data_output = outputs.LSLOutputDevice.from_lsl_settings( lsl_settings=output_settings.lsl, sampling_rate=output_settings.sampling_rate, n_channels=output_settings.n_channels, ) lsl_input_settings = settings.decoder.input.lsl data_input = inputs.LSLInput( stream_name=lsl_input_settings.stream_name, connection_timeout=lsl_input_settings.connection_timeout, ) logger.debug(f"Querying info from LSL stream: {lsl_input_settings.stream_name}") # Set up decoder decoder_model = PersistedFileDecoderModel(get_abs_path(settings.decoder.model_file)) dec = Decoder( model=decoder_model, input_sample_rate=data_input.get_info().sample_rate, output_sample_rate=1.0 / timer_settings.loop_time, n_channels=data_input.get_info().channel_count, threshold=settings.decoder.spike_threshold, ) logger.debug("Attempting to open LSL connections...") try: with open_connection(data_output), open_connection(data_input): timer.start() # Run the decoder periodically while True: _read_decode_send( data_input=data_input, dec=dec, data_output=data_output ) timer.wait() except KeyboardInterrupt: logger.info("CTRL+C received. Exiting...")
if __name__ == "__main__": run()