"""Utilities for handling data in the desired NDS format."""
from __future__ import annotations
from dataclasses import dataclass
import errno
import os
from numpy import ndarray
import numpy as np
[docs]@dataclass
class Samples:
"""Unified collection of timestamps and data points."""
timestamps: ndarray
"""Timestamps for each data sample. Each row corresponds to a data sample."""
data: ndarray
"""Array of data samples. Each row corresponds to a data sample, while each column
corresponds to a dimension of the data sample."""
def __post_init__(self):
"""Execute validation checks on the data."""
_validate_inputs(self.timestamps, self.data)
def __len__(self):
"""Return the number of data points."""
return next(iter(np.shape(self.timestamps)), 0)
@property
def empty(self) -> bool:
"""Check if the samples are empty.
Returns:
True if there are no data points.
"""
return self.__len__() == 0
def __eq__(self, o: object) -> bool:
"""Compare data and timestamps of two samples objects."""
if not isinstance(o, Samples):
return False
return np.array_equal(self.data, o.data) and np.array_equal(
self.timestamps, o.timestamps
)
[docs] @classmethod
def empty_samples(cls) -> Samples:
"""Create an empty samples instance.
Returns:
Samples instance with empty timestamps and data arrays.
"""
return Samples(np.array([]), np.array([]))
[docs] @classmethod
def load_from_npz(
cls,
filepath: str,
timestamps_array_name: str = "timestamps",
data_array_name: str = "data",
) -> Samples:
"""Load the timestamps and data from the file into a new samples instance.
Args:
filepath: `.npz` file path with the timestamps and data
timestamps_array_name: Name of the timestamp array defined
when creating the file (see `np.savez` documentation for details). The
loaded array should be in the shape of (Nx1), N = number of samples.
Defaults to "timestamps"
data_array_name: Name of the data array defined when creating the file
(see `np.savez` documentation for details). The loaded array should be
in the shape of (NxM), N = number of samples and M is the number of
channels. Defaults to "data".
"""
file_data = cls._load_file_data(filepath)
return Samples(
timestamps=file_data[timestamps_array_name], data=file_data[data_array_name]
)
@classmethod
def _check_file_exists(cls, filepath: str):
if not os.path.exists(filepath):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), filepath)
@classmethod
def _load_file_data(cls, filepath: str) -> ndarray:
cls._check_file_exists(filepath)
return np.load(filepath)
def _validate_inputs(timestamps: ndarray, data: ndarray):
_validate_timestamps_shape(timestamps)
_validate_data_shape(data)
_validate_shapes_match(timestamps, data)
def _validate_timestamps_shape(timestamps: ndarray):
if (len(timestamps.shape) == 2 and timestamps.shape[1] > 1) or (
len(timestamps.shape) > 2
):
raise ValueError(
"Timestamps should be of shape (N, 1) or (N,) where N is the number of"
" samples"
)
def _validate_data_shape(data: ndarray):
if (len(data.shape) == 2 and data.shape[1] == 0) or (len(data.shape) > 2):
raise ValueError(
"data should be of shape (N, C) where N is the number of data points and C"
" is the number of channels"
)
def _validate_shapes_match(timestamps: ndarray, data: ndarray):
"""Execute validation checks on the data."""
data_points = next(iter(np.shape(data)), 0)
time_points = next(iter(np.shape(timestamps)), 0)
if data_points != time_points:
raise ValueError(
f"Number of data points ({data_points}) does not match the"
f" number of timestamps ({time_points})"
)