r"""Script that starts the streamer.
The streamer default configuration is located in `NDS_HOME/settings_streamer.yaml`
(see :mod:`neural_data_simulator.scripts.post_install_config`). The script can use
different config file specified via the `\--settings-path` argument.
Upon start, the streamer expects to read data from a file and output to an LSL
outlet. By default, a sample behavior data file will be downloaded by the
:mod:`neural_data_simulator.scripts.post_install_config` script, so the streamer should
be able to run without any additional configuration. If the input file cannot be found,
the streamer will not be able to start.
"""
import argparse
import contextlib
import logging
from pathlib import Path
from typing import cast, Dict, Iterator, List, Optional, Tuple
from neo.rawio.blackrockrawio import BlackrockRawIO
import numpy as np
from pydantic import Extra
from pydantic_yaml import VersionedYamlModel
from rich.pretty import pprint
import yaml
from neural_data_simulator.core.outputs import LSLOutputDevice
from neural_data_simulator.core.outputs import StreamConfig
from neural_data_simulator.core.samples import Samples
from neural_data_simulator.core.settings import LogLevel
from neural_data_simulator.streamer import settings
from neural_data_simulator.streamer import streamers
from neural_data_simulator.util.runtime import configure_logger
from neural_data_simulator.util.runtime import get_abs_path
from neural_data_simulator.util.runtime import get_configs_dir
from neural_data_simulator.util.runtime import initialize_logger
from neural_data_simulator.util.runtime import unwrap
from neural_data_simulator.util.settings_loader import check_config_override_str
from neural_data_simulator.util.settings_loader import load_settings
SCRIPT_NAME = "nds-streamer"
logger = logging.getLogger(__name__)
class _Settings(VersionedYamlModel):
"""Pydantic base settings for running the streamer."""
log_level: LogLevel
streamer: settings.Streamer
class Config:
extra = Extra.forbid
[docs]class StreamGroup:
"""Utility class for managing a list of streams."""
[docs] def __init__(self, streams_configs: List[StreamConfig]):
"""Create a new instance."""
self.streams_configs = streams_configs
self.lsl_outputs: List[LSLOutputDevice] = []
[docs] def connect(self):
"""Connect all streams."""
for stream_config in self.streams_configs:
lsl_output = LSLOutputDevice(stream_config)
lsl_output.connect()
self.lsl_outputs.append(lsl_output)
[docs] def disconnect(self):
"""Disconnect all streams."""
for lsl_output in self.lsl_outputs:
lsl_output.disconnect()
[docs] @contextlib.contextmanager
def open_connection(self) -> Iterator[None]:
"""Open a managed connection.
The connection is released after it is consumed.
"""
try:
self.connect()
yield
finally:
self.disconnect()
[docs]def load_blackrock_file(
filepath: Path, output_settings: settings.LSLSimplifiedOutputModel
) -> Tuple[List[StreamConfig], List[Samples]]:
"""Parse streams from a Blackrock Neurotech file."""
neo_io = BlackrockRawIO(filename=filepath)
neo_io.parse_header()
all_stream_configs, all_samples = _get_analog_streams(neo_io, output_settings)
stream_config, samples = _get_spikes_stream(neo_io, output_settings)
if stream_config is not None and samples is not None:
all_stream_configs.append(stream_config)
all_samples.append(samples)
return all_stream_configs, all_samples
def _get_analog_streams(
neo_io: BlackrockRawIO, output_settings: settings.LSLSimplifiedOutputModel
) -> Tuple[List[StreamConfig], List[Samples]]:
samples = []
stream_configs = []
# Build stream info for each sampling group
for stream_ix in range(neo_io.signal_streams_count()):
stream_id = neo_io.header["signal_streams"][stream_ix]["id"]
channels = neo_io.header["signal_channels"]
channels = channels[channels["stream_id"] == stream_id]
fs = channels[0]["sampling_rate"]
analog_signals = neo_io.get_analogsignal_chunk(
stream_index=stream_ix
) # Samples x channels
proc_times = round(
neo_io.get_signal_t_start(
block_index=0, seg_index=0, stream_index=stream_ix
)
* fs
) + np.arange(analog_signals.shape[0], dtype=np.int64)
analog_timestamps = proc_times / fs
analog_timestamps = analog_timestamps - analog_timestamps[0]
samples.append(Samples(analog_timestamps, analog_signals))
stream_config = _get_regular_stream_config(
stream_id, fs, channels, output_settings
)
stream_configs.append(stream_config)
return stream_configs, samples
def _get_regular_stream_config(
stream_id: str,
sample_rate: float,
channels: Dict,
output_settings: settings.LSLSimplifiedOutputModel,
) -> StreamConfig:
stream_config = StreamConfig(
name=f"Blackrock-Group{stream_id}-Inst{output_settings.instrument.id}",
type="Blackrock SMP",
source_id=f"playback-SMP{stream_id}-Inst{output_settings.instrument.id}",
acquisition={
"manufacturer": output_settings.instrument.manufacturer,
"model": output_settings.instrument.model,
"instrument_id": output_settings.instrument.id,
},
sample_rate=sample_rate,
channel_format=output_settings.channel_format,
channel_labels=channels["name"],
)
return stream_config
def _get_spikes_stream(
neo_io: BlackrockRawIO, output_settings: settings.LSLSimplifiedOutputModel
) -> Tuple[Optional[StreamConfig], Optional[Samples]]:
n_spike_channels = neo_io.spike_channels_count()
if n_spike_channels == 0:
return None, None
ch_ids = np.zeros((0,), dtype=int)
un_ids = np.zeros((0,), dtype=int)
proc_times = np.zeros((0,), dtype=np.int64)
spike_waveforms = None
for glbl_u_idx in range(n_spike_channels):
ch_id, un_id = neo_io.internal_unit_ids[glbl_u_idx]
ch_ids = np.hstack((ch_ids, [ch_id] * neo_io.spike_count(0, 0, glbl_u_idx)))
un_ids = np.hstack((un_ids, [un_id] * neo_io.spike_count(0, 0, glbl_u_idx)))
proc_times = np.hstack(
(
proc_times,
neo_io.get_spike_timestamps(0, 0, glbl_u_idx).astype(np.int64),
)
)
new_waveforms = neo_io.get_spike_raw_waveforms(
block_index=0, seg_index=0, spike_channel_index=glbl_u_idx
)
if spike_waveforms is None:
spike_waveforms = np.zeros((0, new_waveforms.shape[-1]), dtype=np.int16)
spike_waveforms = np.vstack((spike_waveforms, new_waveforms[:, 0, :]))
if spike_waveforms is None:
return None, None
re_ix = np.argsort(proc_times)
ch_ids = ch_ids[re_ix]
un_ids = un_ids[re_ix]
proc_times = proc_times[re_ix]
spike_waveforms = spike_waveforms[re_ix]
spike_times = neo_io.rescale_spike_timestamp(proc_times, dtype="float64")
spike_times = spike_times - spike_times[0]
samples = Samples(
spike_times, np.hstack((ch_ids[:, None], un_ids[:, None], spike_waveforms))
)
stream_config = _get_irregular_stream_config(
spike_waveforms.shape[1], output_settings
)
return stream_config, samples
def _get_irregular_stream_config(
n_waveforms: int, output_settings: settings.LSLSimplifiedOutputModel
) -> StreamConfig:
wf_ch_labels = ["wf_" + str(-10 + _) for _ in range(n_waveforms)]
stream_config = StreamConfig(
name=f"Blackrock-SPK-Inst{output_settings.instrument.id}",
type="Blackrock SPK",
source_id=f"playback-SPK-Inst{output_settings.instrument.id}",
acquisition={
"manufacturer": output_settings.instrument.manufacturer,
"model": output_settings.instrument.model,
"instrument_id": output_settings.instrument.id,
},
sample_rate=0.0,
channel_format=output_settings.channel_format,
channel_labels=["ch_id", "unit_id"] + wf_ch_labels,
)
return stream_config
def _parse_args():
parser = argparse.ArgumentParser(
description="Stream file neurodata to LSL.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--settings-path",
type=Path,
default=Path(get_configs_dir()).joinpath("settings_streamer.yaml"),
help="Path to the settings_streamer.yaml file.",
)
parser.add_argument(
"--overrides",
"-o",
nargs="*",
type=check_config_override_str,
help=(
"Specify settings overrides as key-value pairs, separated by spaces. "
"For example: -o log_level=DEBUG streamer.lsl_chunk_frequency=50"
),
)
parser.add_argument(
"--print-settings-only",
"-p",
action="store_true",
help="Parse/print the settings and exit.",
)
args = parser.parse_args()
return args
[docs]def run():
"""Load the configuration and start the streamer."""
initialize_logger(SCRIPT_NAME)
args = _parse_args()
run_settings: _Settings = cast(
_Settings,
load_settings(
args.settings_path,
settings_parser=_Settings,
override_dotlist=args.overrides,
),
)
if args.print_settings_only:
pprint(run_settings)
return
configure_logger(SCRIPT_NAME, run_settings.log_level)
logger.debug(f"run_decoder settings:\n{yaml.dump(run_settings.dict())}")
if run_settings.streamer.input_type == settings.StreamerInputType.NPZ:
input_settings = unwrap(run_settings.streamer.npz).input
samples = [
Samples.load_from_npz(
get_abs_path(input_settings.file),
timestamps_array_name=input_settings.timestamps_array_name,
data_array_name=input_settings.data_array_name,
)
]
output_settings = unwrap(run_settings.streamer.npz).output
lsl_settings = output_settings.lsl
stream_config = StreamConfig.from_lsl_settings(
lsl_settings, output_settings.sampling_rate, output_settings.n_channels
)
stream_group = StreamGroup([stream_config])
elif run_settings.streamer.input_type == settings.StreamerInputType.Blackrock:
input_settings = unwrap(run_settings.streamer.blackrock).input
output_settings = unwrap(run_settings.streamer.blackrock).output
lsl_settings = output_settings.lsl
stream_configs, samples = load_blackrock_file(
Path(get_abs_path(input_settings.file)), lsl_settings
)
stream_group = StreamGroup(stream_configs)
else:
raise ValueError(f"Unsupported input type: {run_settings.streamer.input_type}")
with stream_group.open_connection():
streamer = streamers.LSLStreamer(
stream_group.lsl_outputs,
samples,
run_settings.streamer.lsl_chunk_frequency,
run_settings.streamer.stream_indefinitely,
)
try:
streamer.stream()
except KeyboardInterrupt:
logger.info("CTRL+C received. Exiting...")
if __name__ == "__main__":
run()