Source code for neural_data_simulator.scripts.run_ephys_generator

r"""Script that starts the ephys generator.

(see :mod:`neural_data_simulator.scripts.post_install_config`). The script can use
different config file specified via the `\--settings-path` argument.

The config file has an `ephys_generator` section where the settings for the input,
output, noise, etc. can be adjusted. By default, the ephys generator expects to read
data from an LSL stream and output to an LSL outlet. In absence of the
input stream, the ephys generator will not be able to start.
"""

import argparse
import logging
from pathlib import Path
from typing import cast, Optional

import numpy as np
from rich.pretty import pprint
import yaml

from neural_data_simulator.core import inputs
from neural_data_simulator.core import outputs
from neural_data_simulator.core.ephys_generator import ContinuousData
from neural_data_simulator.core.ephys_generator import LSLSpikeRateInputAdapter
from neural_data_simulator.core.ephys_generator import NoiseData
from neural_data_simulator.core.ephys_generator import ProcessOutput
from neural_data_simulator.core.ephys_generator import SpikeRateInput
from neural_data_simulator.core.ephys_generator import SpikeRateTestingInput
from neural_data_simulator.core.ephys_generator import Spikes
from neural_data_simulator.core.ephys_generator import Waveforms
from neural_data_simulator.core.health_checker import HealthChecker
from neural_data_simulator.core.outputs import StreamConfig
from neural_data_simulator.core.settings import EphysGeneratorEndpointType
from neural_data_simulator.core.settings import EphysGeneratorSettings
from neural_data_simulator.core.settings import Settings
from neural_data_simulator.util.runtime import configure_logger
from neural_data_simulator.util.runtime import get_configs_dir
from neural_data_simulator.util.runtime import initialize_logger
from neural_data_simulator.util.runtime import unwrap
from neural_data_simulator.util.settings_loader import check_config_override_str
from neural_data_simulator.util.settings_loader import load_settings

SCRIPT_NAME = "nds-ephys-generator"
logger = logging.getLogger(__name__)


def _setup_test_input(n_channels: int, n_units_per_channel: int) -> SpikeRateInput:
    """Set up a spike rate input for testing.

    The testing input will result in generating spikes from constant
    spike rates.

    Args:
        n_channels: Number of input channels.
        n_units_per_channel: Number of neurons captured by each electrode.

    Returns:
        A type of input that can be consumed by the ephys generator process.
    """
    n_units = n_units_per_channel * n_channels
    spike_rate_input = SpikeRateTestingInput(n_channels, n_units)
    return spike_rate_input


def _setup_LSL_input(
    stream_name: str, connection_timeout: float
) -> LSLSpikeRateInputAdapter:
    """Set up an LSL stream as spike rate input.

    Args:
        stream_name: LSL stream name.
        connection_timeout: Maximum time for attempting a connection
          to the LSL input stream.

    Returns:
        LSL stream input that can be used to read data from.
    """
    lsl_inlet = inputs.LSLInput(stream_name, connection_timeout)
    spike_rate_input = LSLSpikeRateInputAdapter(lsl_inlet)
    return spike_rate_input


def _setup_LSL_output(config: StreamConfig) -> outputs.LSLOutputDevice:
    """Set up output that will make the data available via an LSL stream.

    Args:
        config: the output stream configuration.

    Returns:
        An LSL output stream that can be used by the ephys generator
          to publish data.
    """
    lsl_output = outputs.LSLOutputDevice(config)
    lsl_output.connect()
    return lsl_output


def _get_process_output_params(
    ephys_generator_settings: EphysGeneratorSettings,
) -> ProcessOutput.Params:
    return ProcessOutput.Params(
        ephys_generator_settings.n_units_per_channel,
        ephys_generator_settings.lsl_chunk_frequency,
        ephys_generator_settings.raw_data_frequency,
        ephys_generator_settings.resolution,
    )


def _get_continuous_data_params(
    ephys_generator_settings: EphysGeneratorSettings,
) -> ContinuousData.Params:
    return ContinuousData.Params(
        ephys_generator_settings.raw_data_frequency,
        ephys_generator_settings.n_units_per_channel,
        ephys_generator_settings.waveforms.n_samples,
        ephys_generator_settings.output.lfp.data_frequency,
        ephys_generator_settings.output.lfp.filter_cutoff,
        ephys_generator_settings.output.lfp.filter_order,
    )


def _get_spikes_params(
    ephys_generator_settings: EphysGeneratorSettings,
) -> Spikes.Params:
    return Spikes.Params(
        ephys_generator_settings.raw_data_frequency,
        ephys_generator_settings.n_units_per_channel,
        ephys_generator_settings.refractory_time,
    )


def _get_waveforms_params(
    ephys_generator_settings: EphysGeneratorSettings,
) -> Waveforms.Params:
    waveforms_settings = ephys_generator_settings.waveforms
    return Waveforms.Params(
        waveforms_settings.prototypes,
        waveforms_settings.unit_prototype_mapping,
        waveforms_settings.n_samples,
    )


def _set_random_seed(random_seed: Optional[int]):
    if random_seed:
        logger.info(f"Using random seed '{random_seed}'")
        np.random.seed(random_seed)


def _parse_args():
    parser = argparse.ArgumentParser(
        description="Simulate electrophysiology data from firing rates.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--settings-path",
        type=Path,
        default=Path(get_configs_dir()).joinpath("settings.yaml"),
        help="Path to the settings.yaml file.",
    )
    parser.add_argument(
        "--overrides",
        "-o",
        nargs="*",
        type=check_config_override_str,
        help=(
            "Specify settings overrides as key-value pairs, separated by spaces. "
            "For example: -o log_level=DEBUG ephys_generator.input.type=testing"
        ),
    )
    parser.add_argument(
        "--print-settings-only",
        "-p",
        action="store_true",
        help="Parse/print the settings and exit.",
    )
    args = parser.parse_args()
    return args


[docs]def run(): """Load the configuration and start the ephys generator.""" initialize_logger(SCRIPT_NAME) args = _parse_args() settings: Settings = cast( Settings, load_settings( args.settings_path, settings_parser=Settings, override_dotlist=args.overrides, ), ) if args.print_settings_only: pprint(settings) return _set_random_seed(settings.ephys_generator.random_seed) configure_logger(SCRIPT_NAME, settings.log_level) logger.debug(f"run_ephys_generator settings:\n{yaml.dump(settings.dict())}") if settings.ephys_generator.input.type == EphysGeneratorEndpointType.LSL: lsl_input_settings = unwrap(settings.ephys_generator.input.lsl) spike_rate_input = _setup_LSL_input( lsl_input_settings.stream_name, lsl_input_settings.connection_timeout ) spike_rate_input.connect() elif settings.ephys_generator.input.type == EphysGeneratorEndpointType.TESTING: n_channels = unwrap(settings.ephys_generator.input.testing).n_channels spike_rate_input = _setup_test_input( n_channels, settings.ephys_generator.n_units_per_channel ) else: raise ValueError(f"Unexpected input type {settings.ephys_generator.input.type}") continuous_data_output = _setup_LSL_output( StreamConfig.from_lsl_settings( settings.ephys_generator.output.raw.lsl, settings.ephys_generator.raw_data_frequency, spike_rate_input.channel_count, ) ) lfp_output = _setup_LSL_output( StreamConfig.from_lsl_settings( settings.ephys_generator.output.lfp.lsl, settings.ephys_generator.output.lfp.data_frequency, spike_rate_input.channel_count, ) ) spike_events_config = StreamConfig.from_lsl_settings( settings.ephys_generator.output.spike_events.lsl, 0, # irregular sampling rate settings.ephys_generator.waveforms.n_samples, ) spike_events_config.channel_labels = [ "channels", "units", ] + spike_events_config.channel_labels spike_events_output = _setup_LSL_output(spike_events_config) process_output_params = _get_process_output_params(settings.ephys_generator) noise_settings = settings.ephys_generator.noise noise_data = NoiseData( spike_rate_input.channel_count, noise_settings.beta, noise_settings.standard_deviation, noise_settings.fmin, noise_settings.samples, settings.ephys_generator.random_seed, ) continuous_data = ContinuousData( noise_data, spike_rate_input.channel_count, _get_continuous_data_params(settings.ephys_generator), ) n_units = ( spike_rate_input.channel_count * settings.ephys_generator.n_units_per_channel ) waveforms = Waveforms(_get_waveforms_params(settings.ephys_generator), n_units) spikes = Spikes( spike_rate_input.channel_count, waveforms, _get_spikes_params(settings.ephys_generator), ) outputs = ProcessOutput.LSLOutputs( continuous_data_output, lfp_output, spike_events_output ) optimal_num_samples_per_iteration = int( settings.ephys_generator.raw_data_frequency / settings.ephys_generator.lsl_chunk_frequency ) health_checker = HealthChecker( int(settings.ephys_generator.lsl_chunk_frequency), optimal_num_samples_per_iteration, ) po = ProcessOutput( continuous_data, spikes, spike_rate_input, outputs, process_output_params, health_checker, ) try: po.start() except KeyboardInterrupt: logger.info("CTRL+C received. Exiting...") finally: po.stop() continuous_data_output.disconnect() lfp_output.disconnect() spike_events_output.disconnect() del spike_rate_input
if __name__ == "__main__": run()