Source code for neural_data_simulator.core.outputs

"""A collection of outputs that can be used by NDS."""
import abc
from dataclasses import dataclass
import logging
from typing import Any, Callable, IO, List, Optional, Union

from numpy import ndarray
import numpy as np
import pylsl

from neural_data_simulator.core.samples import Samples
from neural_data_simulator.core.settings import LSLOutputModel


[docs]class Output(abc.ABC): """Represents an abstract output that can be used to send samples.""" @property @abc.abstractmethod def channel_count(self) -> int: """Return the number of channels.""" pass
[docs] def wait_for_consumers(self, timeout: int) -> bool: """Wait for consumers to connect until the timeout expires. Args: timeout: Timeout in seconds. Returns: True if consumers are connected, False otherwise. """ return True
[docs] def has_consumers(self) -> bool: """Return whether there are consumers connected to the output.""" return True
[docs] @abc.abstractmethod def connect(self) -> None: """Connect to output.""" pass
[docs] def disconnect(self) -> None: """Disconnect from output. The default implementation does nothing.""" pass
@abc.abstractmethod def _send(self, samples: Samples) -> None: """Send samples to output. Args: samples: Samples to output. """ pass def _validate_data_shape(self, data: ndarray) -> None: """Validate the shape of the samples.""" if not len(data): return if (len(data.shape) == 1 and self.channel_count != data.shape[0]) or ( len(data.shape) == 2 and self.channel_count != data.shape[1] ): raise ValueError( f"Output expects data with {self.channel_count} channels," + f" received data with {data.shape[-1]} channels" )
[docs] def send(self, samples: Samples) -> Samples: """Push samples to output and return the data unchanged. Args: samples: Samples to output. Returns: The input samples unchanged. """ self._validate_data_shape(samples.data) self._send(samples) return samples
[docs]class ConsoleOutput(Output): """Represents an output device that prints to the terminal."""
[docs] def __init__(self, channel_count: int): """Initialize the ConsoleOutput class.""" self.logger = logging.getLogger(__name__) self._channel_count = channel_count
@property def channel_count(self) -> int: """The number of channels. Returns: Number of channels of the output. """ return self._channel_count def _send(self, samples: Samples) -> None: """Send data to a file without index or header. Args: samples: :class:`neural_data_simulator.core.samples.Samples` dataclass with timestamps and data. """ timestamps_and_data = np.column_stack((samples.timestamps, samples.data)) print(np.array2string(timestamps_and_data))
[docs] def connect(self) -> None: """Connect to the device within a context. The default implementation does nothing. """ pass
[docs]class FileOutput(Output): """Represents an output device that writes to a file."""
[docs] def __init__(self, channel_count: int, file_name: str = "output.csv"): """Initialize FileOutput class. Args: channel_count: Number of channels for this output. file_name: File path to write the samples via the `send` method. Defaults to "output.csv". """ self.logger = logging.getLogger(__name__) self._channel_count = channel_count self.file: Optional[IO[Any]] = None self.file_name = file_name
@property def channel_count(self) -> int: """The number of channels. Returns: Number of channels of the output. """ return self._channel_count def _send(self, samples: Samples) -> None: """Write the samples into the file. Args: samples: :class:`neural_data_simulator.core.samples.Samples` dataclass. """ if self.file is not None: timestamps_and_data = np.column_stack((samples.timestamps, samples.data)) np.savetxt(self.file, timestamps_and_data, delimiter=",", fmt="%f")
[docs] def connect(self) -> None: """Open the output file.""" self.logger.info(f"Opening output file {self.file_name}") self.file = open(self.file_name, "w")
[docs] def disconnect(self) -> None: """Close the output file.""" if self.file is not None: self.file.close()
[docs]@dataclass class StreamConfig: """Parameters of an LSL stream.""" name: str """LSL stream name.""" type: str """LSL stream type.""" source_id: str """LSL source id.""" acquisition: dict """Information regarding the acquisition device.""" sample_rate: Union[float, Callable[[], float]] """Sampling rate in Hz.""" channel_format: str """Stream data type, for example `float32` or `int32`.""" channel_labels: List[str] """Channel labels. The number of labels must match the number of channels."""
[docs] @classmethod def from_lsl_settings( cls, lsl_settings: LSLOutputModel, sampling_rate: Union[float, Callable], n_channels: int, ): """Create a StreamConfig from an LSLOutputModel. Args: lsl_settings: :class:`neural_data_simulator.core.settings.LSLOutputModel` instance. sampling_rate: Sampling rate in Hz. n_channels: Number of channels. """ acquisition = { "manufacturer": lsl_settings.instrument.manufacturer, "model": lsl_settings.instrument.model, "instrument_id": lsl_settings.instrument.id, } if lsl_settings.channel_labels is not None: channel_labels = lsl_settings.channel_labels if len(channel_labels) != n_channels: raise ValueError( f"Number of channel labels ({len(channel_labels)}) does not match " + f"number of channels ({n_channels})" ) else: channel_labels = [str(i) for i in range(n_channels)] return StreamConfig( lsl_settings.stream_name, lsl_settings.stream_type, lsl_settings.source_id, acquisition, sampling_rate, lsl_settings.channel_format, channel_labels, )
[docs]class LSLOutputDevice(Output): """An output device that can be used to stream data via LSL."""
[docs] def __init__(self, stream_config: StreamConfig): """Initialize the LSL Output Device from a StreamConfig. Args: stream_config: :class:`neural_data_simulator.outputs.StreamConfig` instance. """ self.logger = logging.getLogger(__name__) self._stream_config = stream_config self._outlet: Optional[pylsl.StreamOutlet] = None self._stream_info: Optional[pylsl.StreamInfo] = None self._stream_configured = False
[docs] @classmethod def from_lsl_settings( cls, lsl_settings: LSLOutputModel, sampling_rate: Union[float, Callable], n_channels: int, ): """Initialize from :class:`neural_data_simulator.core.settings.LSLOutputModel`. Args: lsl_settings: :class:`neural_data_simulator.core.settings.LSLOutputModel` instance. sampling_rate: Sampling rate in Hz. n_channels: Number of channels. """ stream_config = StreamConfig.from_lsl_settings( lsl_settings=lsl_settings, sampling_rate=sampling_rate, n_channels=n_channels, ) return LSLOutputDevice(stream_config)
@property def _dtype(self): """Return the numpy data type of the stream.""" channel_format = self._stream_config.channel_format if channel_format == "float32": return np.float32 elif channel_format == "double64": return np.longdouble elif channel_format == "int8": return np.int8 elif channel_format == "int16": return np.int16 elif channel_format == "int32": return np.int32 elif channel_format == "int64": return np.int64 else: raise ValueError(f"Unsupported channel format: {channel_format}") @property def channel_count(self) -> int: """The number of channels. Returns: Number of channels of the output. """ return len(self._stream_config.channel_labels) @property def sample_rate(self) -> Union[float, Callable[[], float]]: """Sample rate of the stream. Returns: The sample rate in Hz. """ return self._stream_config.sample_rate @property def name(self) -> str: """The name of the stream. Returns: The configured name of the output stream. """ return self._stream_config.name def _check_connection(self): if self._outlet is None: raise ConnectionError( "LSL StreamOutlet is not connected, ensure you run connect before send." ) def _send(self, samples: Samples): """Push the data to the LSL outlet. Args: samples: :class:`neural_data_simulator.core.samples.Samples` dataclass with timestamps and data. Raises: ValueError: LSL StreamOutlet is not connected. `connect` should be called before `send`. """ for timestamp, data_point in zip(samples.timestamps, samples.data): self.send_as_sample(data_point, timestamp)
[docs] def send_as_chunk(self, data: ndarray, timestamp: Optional[float] = None): """Send a list of data points to the LSL outlet. Args: data: An array of data points. timestamp: An optional timestamp corresponding to the data points. Raises: ValueError: LSL StreamOutlet is not connected. `connect` should be called before `send`. ValueError: There was nothing to send because the data array is empty. """ self._check_data(data) self._check_connection() assert self._outlet is not None # cast data to expected channel format data_out = data.astype(self._dtype) if timestamp: self._outlet.push_chunk(data_out, timestamp) else: self._outlet.push_chunk(data_out)
[docs] def send_as_sample(self, data: ndarray, timestamp: Optional[float] = None): """Send a single sample with the corresponding timestamp. A sample consisting of a data point per channel will be pushed to the LSL outlet together with an optional timestamp. Args: data: A single data point as an array of 1 value per channel. timestamp: An optional timestamp corresponding to the data point. Raises: ValueError: LSL StreamOutlet is not connected. `connect` should be called before `send`. ValueError: There was nothing to send because the data array is empty. """ self._check_data(data) self._check_connection() assert self._outlet is not None # cast data to expected channel format data_out = data.astype(self._dtype) if timestamp: self._outlet.push_sample(data_out, timestamp) else: self._outlet.push_sample(data_out)
def _check_data(self, data: ndarray): if len(data) == 0: self.logger.debug("No data to output") raise ValueError("No data data to output") self._validate_data_shape(data) @staticmethod def _get_open_stream_names() -> List[str]: stream_infos = pylsl.resolve_streams() return [lsl_stream_info.name() for lsl_stream_info in stream_infos]
[docs] def connect(self): """Connect to the LSL stream.""" self.logger.info("Initializing LSL output stream...") self._stream_info = LSLOutputDevice._get_info_from_config(self._stream_config) outlet_buffer_time_in_s = 1 self._outlet = pylsl.StreamOutlet( self._stream_info, max_buffered=outlet_buffer_time_in_s ) self.logger.info(f"Created LSL output stream: '{self._stream_info.name()}'") self._stream_configured = True
[docs] def disconnect(self): """Forget the connection to the LSL stream.""" self.logger.info("Destroying LSL output stream...") del self._outlet self._outlet = None self._stream_info = None self._stream_configured = False
[docs] def has_consumers(self) -> bool: """Check if there are consumers connected to the stream. Return: True if there are consumers, False if there aren't any. """ if self._outlet: return self._outlet.have_consumers() return False
[docs] def wait_for_consumers(self, timeout: int) -> bool: """Wait for consumers to connect until the timeout expires. Args: timeout: Timeout in seconds. Returns: True if consumers are connected, False otherwise. """ if self._outlet: return self._outlet.wait_for_consumers(timeout=timeout) return False
@staticmethod def _build_xml_from_dict(node: pylsl.XMLElement, dict_info: dict): """Build XML recursively.""" for k, v in dict_info.items(): if isinstance(v, dict): new_node = node.append_child(k) LSLOutputDevice._build_xml_from_dict(new_node, v) else: node.append_child_value(k, str(v)) @staticmethod def _get_info_from_config(config: StreamConfig) -> pylsl.StreamInfo: if callable(config.sample_rate): sample_rate = config.sample_rate() else: sample_rate = config.sample_rate out_info = pylsl.stream_info( name=config.name, type=config.type, channel_count=len(config.channel_labels), nominal_srate=sample_rate, channel_format=config.channel_format, # type: ignore[arg-type] source_id=config.source_id, ) channels_xml = out_info.desc().append_child("channels") for _, ch_label in enumerate(config.channel_labels): chan = channels_xml.append_child("channel") chan.append_child_value("label", ch_label) acq = out_info.desc().append_child("acquisition") LSLOutputDevice._build_xml_from_dict(acq, config.acquisition) return out_info