-
Notifications
You must be signed in to change notification settings - Fork 216
[ENH] Parallelize SAX and PAA transformers #2980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f1b04ce
cea39a5
f36f223
6b8d269
6b15711
1bf5b6b
a6cc829
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,10 @@ | |
__maintainer__ = [] | ||
|
||
import numpy as np | ||
from numba import get_num_threads, njit, prange, set_num_threads | ||
|
||
from aeon.transformations.collection import BaseCollectionTransformer | ||
from aeon.utils.validation import check_n_jobs | ||
|
||
|
||
class PAA(BaseCollectionTransformer): | ||
|
@@ -39,12 +41,14 @@ class PAA(BaseCollectionTransformer): | |
|
||
_tags = { | ||
"capability:multivariate": True, | ||
"capability:multithreading": True, | ||
"fit_is_empty": True, | ||
"algorithm_type": "dictionary", | ||
} | ||
|
||
def __init__(self, n_segments=8): | ||
def __init__(self, n_segments=8, n_jobs=1): | ||
self.n_segments = n_segments | ||
self.n_jobs = n_jobs | ||
|
||
super().__init__() | ||
|
||
|
@@ -71,7 +75,6 @@ def _transform(self, X, y=None): | |
# of segments is 3, the indices will be [0:3], [3:6] and [6:10] | ||
# so 3 segments, two of length 3 and one of length 4 | ||
split_segments = np.array_split(all_indices, self.n_segments) | ||
|
||
# If the series length is divisible by the number of segments | ||
# then the transformation can be done in one line | ||
# If not, a for loop is needed only on the segments while | ||
|
@@ -82,13 +85,13 @@ def _transform(self, X, y=None): | |
return X_paa | ||
|
||
else: | ||
n_samples, n_channels, _ = X.shape | ||
X_paa = np.zeros(shape=(n_samples, n_channels, self.n_segments)) | ||
|
||
for _s, segment in enumerate(split_segments): | ||
if X[:, :, segment].shape[-1] > 0: # avoids mean of empty slice error | ||
X_paa[:, :, _s] = X[:, :, segment].mean(axis=-1) | ||
|
||
prev_threads = get_num_threads() | ||
_n_jobs = check_n_jobs(self.n_jobs) | ||
set_num_threads(_n_jobs) | ||
X_paa = _parallel_paa_transform( | ||
X, n_segments=self.n_segments, split_segments=split_segments | ||
) | ||
set_num_threads(prev_threads) | ||
return X_paa | ||
|
||
def inverse_paa(self, X, original_length): | ||
|
@@ -110,17 +113,17 @@ def inverse_paa(self, X, original_length): | |
return np.repeat(X, repeats=int(original_length / self.n_segments), axis=-1) | ||
|
||
else: | ||
n_samples, n_channels, _ = X.shape | ||
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length)) | ||
|
||
all_indices = np.arange(original_length) | ||
split_segments = np.array_split(all_indices, self.n_segments) | ||
|
||
for _s, segment in enumerate(split_segments): | ||
X_inverse_paa[:, :, segment] = np.repeat( | ||
X[:, :, [_s]], repeats=len(segment), axis=-1 | ||
) | ||
|
||
split_segments = np.array_split(np.arange(original_length), self.n_segments) | ||
prev_threads = get_num_threads() | ||
_n_jobs = check_n_jobs(self.n_jobs) | ||
set_num_threads(_n_jobs) | ||
X_inverse_paa = _parallel_inverse_paa_transform( | ||
X, | ||
original_length=original_length, | ||
n_segments=self.n_segments, | ||
split_segments=split_segments, | ||
) | ||
set_num_threads(prev_threads) | ||
return X_inverse_paa | ||
|
||
@classmethod | ||
|
@@ -143,3 +146,45 @@ def _get_test_params(cls, parameter_set="default"): | |
""" | ||
params = {"n_segments": 10} | ||
return params | ||
|
||
|
||
@njit(parallel=True, cache=True, fastmath=True) | ||
def _parallel_paa_transform(X, n_segments, split_segments): | ||
"""Parallelized PAA for uneven segment splits using Numba.""" | ||
n_samples, n_channels, _ = X.shape | ||
X_paa = np.zeros((n_samples, n_channels, n_segments), dtype=X.dtype) | ||
|
||
for _s in prange(n_segments): # Parallel over segments | ||
segment = split_segments[_s] | ||
seg_len = segment.shape[0] | ||
|
||
if seg_len == 0: | ||
continue # skip empty segment | ||
|
||
for i in range(n_samples): | ||
for j in range(n_channels): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't that spawn too many threads, given that the outer loop is with a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have used three nested prange loops in But, in fact numba seems to ignore nested loops. I did not know this before:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, missed it. However, I'd also like to point out, the previously implemented numba functions in SAX use multiple nested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure, who wrote these. @hadifawaz1999 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes i wrote these, i dont the issue though from the discussion, SAX was working fine why is it dead code ? @patrickzib |
||
acc = 0.0 | ||
for k in range(seg_len): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numba doesn't have an implementation for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure, I understand. Both seem to work fine (I might be missing something) for i in range(n_samples):
for j in range(n_channels):
acc = X[i, j, segment].mean() for i in range(n_samples):
for j in range(n_channels):
acc = np.mean(X[i, j, segment]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant we used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I understand.. Yet, why not use one of the two alternatives presented above that do not require |
||
acc += X[i, j, segment[k]] | ||
X_paa[i, j, _s] = acc / seg_len | ||
|
||
return X_paa | ||
|
||
|
||
@njit(parallel=True, cache=True, fastmath=True) | ||
def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments): | ||
"""Parallelize the inverse PAA transformation for cases where the series length is not | ||
divisible by the number of segments. | ||
""" | ||
n_samples, n_channels, _ = X.shape | ||
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length)) | ||
|
||
for _s in prange(n_segments): | ||
segment = split_segments[_s] | ||
for idx in prange(len(segment)): | ||
t = segment[idx] | ||
for i in prange(n_samples): | ||
for j in prange(n_channels): | ||
X_inverse_paa[i, j, t] = X[i, j, _s] | ||
|
||
return X_inverse_paa |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,7 +167,11 @@ def _get_sax_symbols(self, X_paa): | |
sax_symbols : np.ndarray of shape = (n_cases, n_channels, n_segments) | ||
The output of the SAX transformation using np.digitize | ||
""" | ||
sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints) | ||
prev_threads = get_num_threads() | ||
_n_jobs = check_n_jobs(self.n_jobs) | ||
set_num_threads(_n_jobs) | ||
sax_symbols = _parallel_get_sax_symbols(X_paa, breakpoints=self.breakpoints) | ||
set_num_threads(prev_threads) | ||
return sax_symbols | ||
|
||
def inverse_sax(self, X, original_length, y=None): | ||
|
@@ -292,3 +296,31 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid): | |
] | ||
|
||
return sax_inverse | ||
|
||
|
||
@njit(fastmath=True, cache=True, parallel=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not 100% sure what the gain would be of doing all this, if its not significant am not for doing this, it would mean instead of simply doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have done a benchmark in the past, I think it performs little better after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having dealt with MONSTER datasets lately, not being able to parallelize stuff for such huge size datasets is a big downside. I'm not a fan of the flatten and reshape, though, but I guess you did it to avoid the nested loop parallelism problem ? You could add that breakspoints need to be sorted for the function to work for clarity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to check I use np.digitize within a parallel for-loop. This approach strikes a balance between using only explicit loops and relying solely on vectorized digitize operations. for a in prange(dfts.shape[0]):
for i in range(word_length): # range(dfts.shape[2]):
words[a, : dfts.shape[1]] = (
words[a, : dfts.shape[1]] << letter_bits
) | np.digitize(dfts[a, :, i], breakpoints[i], right=True) |
||
def _parallel_get_sax_symbols(x, breakpoints, right=False): | ||
"""Parallel version of `np.digitize`.""" | ||
x_flat = x.flatten() | ||
result = np.empty(x_flat.shape[0], dtype=np.intp) | ||
|
||
for i in prange(x_flat.shape[0]): | ||
val = x_flat[i] | ||
bin_idx = 0 | ||
|
||
if right: | ||
for j in range(len(breakpoints)): | ||
if val <= breakpoints[j]: | ||
bin_idx = j | ||
break | ||
bin_idx = j + 1 | ||
else: | ||
for j in range(len(breakpoints)): | ||
if val < breakpoints[j]: | ||
bin_idx = j | ||
break | ||
bin_idx = j + 1 | ||
|
||
result[i] = bin_idx | ||
|
||
return result.reshape(x.shape) |
Uh oh!
There was an error while loading. Please reload this page.