Skip to content

Commit 922c972

Browse files
committed
Make the MovingWindow and PeriodicFeatureExtractor generic
Because the resampler is now generic, we can also make the `MovingWindow` and `PeriodicFeatureExtractor` generic so they can return a specialized quantity instead of a unit-less quantity, again improving performance and safety. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 310c158 commit 922c972

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

src/frequenz/sdk/timeseries/_moving_window.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,21 @@
99
import math
1010
from collections.abc import Sequence
1111
from datetime import datetime, timedelta
12-
from typing import SupportsIndex, overload
12+
from typing import Generic, SupportsIndex, overload
1313

1414
import numpy as np
1515
from frequenz.channels import Broadcast, Receiver, Sender
1616
from numpy.typing import ArrayLike
1717

1818
from ..actor._background_service import BackgroundService
19-
from ._base_types import UNIX_EPOCH, Sample
20-
from ._quantities import Quantity
19+
from ._base_types import UNIX_EPOCH, Sample, SupportsFloatT
2120
from ._resampling import Resampler, ResamplerConfig
2221
from ._ringbuffer import OrderedRingBuffer
2322

2423
_logger = logging.getLogger(__name__)
2524

2625

27-
class MovingWindow(BackgroundService):
26+
class MovingWindow(BackgroundService, Generic[SupportsFloatT]):
2827
"""
2928
A data window that moves with the latest datapoints of a data stream.
3029
@@ -130,9 +129,9 @@ async def run() -> None:
130129
def __init__( # pylint: disable=too-many-arguments
131130
self,
132131
size: timedelta,
133-
resampled_data_recv: Receiver[Sample[Quantity]],
132+
resampled_data_recv: Receiver[Sample[SupportsFloatT]],
134133
input_sampling_period: timedelta,
135-
resampler_config: ResamplerConfig | None = None,
134+
resampler_config: ResamplerConfig[SupportsFloatT] | None = None,
136135
align_to: datetime = UNIX_EPOCH,
137136
*,
138137
name: str | None = None,
@@ -166,8 +165,8 @@ def __init__( # pylint: disable=too-many-arguments
166165

167166
self._sampling_period = input_sampling_period
168167

169-
self._resampler: Resampler | None = None
170-
self._resampler_sender: Sender[Sample[Quantity]] | None = None
168+
self._resampler: Resampler[SupportsFloatT] | None = None
169+
self._resampler_sender: Sender[Sample[SupportsFloatT]] | None = None
171170

172171
if resampler_config:
173172
assert (
@@ -182,7 +181,9 @@ def __init__( # pylint: disable=too-many-arguments
182181
size.total_seconds() / self._sampling_period.total_seconds()
183182
)
184183

185-
self._resampled_data_recv = resampled_data_recv
184+
self._resampled_data_recv: Receiver[Sample[SupportsFloatT]] = (
185+
resampled_data_recv
186+
)
186187
self._buffer = OrderedRingBuffer(
187188
np.empty(shape=num_samples, dtype=float),
188189
sampling_period=self._sampling_period,
@@ -341,11 +342,11 @@ def _configure_resampler(self) -> None:
341342
"""Configure the components needed to run the resampler."""
342343
assert self._resampler is not None
343344

344-
async def sink_buffer(sample: Sample[Quantity]) -> None:
345+
async def sink_buffer(sample: Sample[SupportsFloatT]) -> None:
345346
if sample.value is not None:
346347
self._buffer.update(sample)
347348

348-
resampler_channel = Broadcast[Sample[Quantity]]("average")
349+
resampler_channel = Broadcast[Sample[SupportsFloatT]]("average")
349350
self._resampler_sender = resampler_channel.new_sender()
350351
self._resampler.add_timeseries(
351352
"avg", resampler_channel.new_receiver(), sink_buffer

src/frequenz/sdk/timeseries/_periodic_feature_extractor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
import logging
1717
from dataclasses import dataclass
1818
from datetime import datetime, timedelta
19+
from typing import Generic
1920

2021
import numpy as np
2122
from numpy.typing import NDArray
2223

2324
from .._internal._math import is_close_to_zero
2425
from ._moving_window import MovingWindow
26+
from ._quantities import SupportsFloatT
2527
from ._ringbuffer import OrderedRingBuffer
2628

2729
_logger = logging.getLogger(__name__)
@@ -48,7 +50,7 @@ class RelativePositions:
4850
"""The relative position of the next incoming sample."""
4951

5052

51-
class PeriodicFeatureExtractor:
53+
class PeriodicFeatureExtractor(Generic[SupportsFloatT]):
5254
"""
5355
A feature extractor for historical timeseries data.
5456
@@ -106,7 +108,7 @@ class PeriodicFeatureExtractor:
106108

107109
def __init__(
108110
self,
109-
moving_window: MovingWindow,
111+
moving_window: MovingWindow[SupportsFloatT],
110112
period: timedelta,
111113
) -> None:
112114
"""
@@ -119,7 +121,7 @@ def __init__(
119121
Raises:
120122
ValueError: If the MovingWindow size is not a integer multiple of the period.
121123
"""
122-
self._moving_window = moving_window
124+
self._moving_window: MovingWindow[SupportsFloatT] = moving_window
123125

124126
self._sampling_period = self._moving_window.sampling_period
125127
"""The sampling_period as float to use it for indexing of samples."""

0 commit comments

Comments
 (0)