Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions aeon/transformations/collection/dictionary_based/_paa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
__maintainer__ = []

import numpy as np
from numba import get_num_threads, njit, prange, set_num_threads

from aeon.transformations.collection import BaseCollectionTransformer
from aeon.utils.validation import check_n_jobs


class PAA(BaseCollectionTransformer):
Expand Down Expand Up @@ -39,12 +41,14 @@ class PAA(BaseCollectionTransformer):

_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"fit_is_empty": True,
"algorithm_type": "dictionary",
}

def __init__(self, n_segments=8):
def __init__(self, n_segments=8, n_jobs=1):
self.n_segments = n_segments
self.n_jobs = n_jobs

super().__init__()

Expand All @@ -71,7 +75,6 @@ def _transform(self, X, y=None):
# of segments is 3, the indices will be [0:3], [3:6] and [6:10]
# so 3 segments, two of length 3 and one of length 4
split_segments = np.array_split(all_indices, self.n_segments)

# If the series length is divisible by the number of segments
# then the transformation can be done in one line
# If not, a for loop is needed only on the segments while
Expand All @@ -82,13 +85,13 @@ def _transform(self, X, y=None):
return X_paa

else:
n_samples, n_channels, _ = X.shape
X_paa = np.zeros(shape=(n_samples, n_channels, self.n_segments))

for _s, segment in enumerate(split_segments):
if X[:, :, segment].shape[-1] > 0: # avoids mean of empty slice error
X_paa[:, :, _s] = X[:, :, segment].mean(axis=-1)

prev_threads = get_num_threads()
_n_jobs = check_n_jobs(self.n_jobs)
set_num_threads(_n_jobs)
X_paa = _parallel_paa_transform(
X, n_segments=self.n_segments, split_segments=split_segments
)
set_num_threads(prev_threads)
return X_paa

def inverse_paa(self, X, original_length):
Expand All @@ -110,17 +113,17 @@ def inverse_paa(self, X, original_length):
return np.repeat(X, repeats=int(original_length / self.n_segments), axis=-1)

else:
n_samples, n_channels, _ = X.shape
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length))

all_indices = np.arange(original_length)
split_segments = np.array_split(all_indices, self.n_segments)

for _s, segment in enumerate(split_segments):
X_inverse_paa[:, :, segment] = np.repeat(
X[:, :, [_s]], repeats=len(segment), axis=-1
)

split_segments = np.array_split(np.arange(original_length), self.n_segments)
prev_threads = get_num_threads()
_n_jobs = check_n_jobs(self.n_jobs)
set_num_threads(_n_jobs)
X_inverse_paa = _parallel_inverse_paa_transform(
X,
original_length=original_length,
n_segments=self.n_segments,
split_segments=split_segments,
)
set_num_threads(prev_threads)
return X_inverse_paa

@classmethod
Expand All @@ -143,3 +146,45 @@ def _get_test_params(cls, parameter_set="default"):
"""
params = {"n_segments": 10}
return params


@njit(parallel=True, cache=True, fastmath=True)
def _parallel_paa_transform(X, n_segments, split_segments):
"""Parallelized PAA for uneven segment splits using Numba."""
n_samples, n_channels, _ = X.shape
X_paa = np.zeros((n_samples, n_channels, n_segments), dtype=X.dtype)

for _s in prange(n_segments): # Parallel over segments
segment = split_segments[_s]
seg_len = segment.shape[0]

if seg_len == 0:
continue # skip empty segment

for i in range(n_samples):
for j in range(n_channels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prange?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that spawn too many threads, given that the outer loop is with a prange?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have used three nested prange loops in _parallel_inverse_paa_transform ?

But, in fact numba seems to ignore nested loops. I did not know this before:

Loop serialization
Loop serialization occurs when any number of prange driven loops are present inside another prange driven loop. In this case the outermost of all the prange loops executes in parallel and any inner prange loops (nested or otherwise) are treated as standard range based loops. Essentially, nested parallelism does not occur.

https://numba.pydata.org/numba-doc/dev/user/parallel.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, missed it. However, I'd also like to point out, the previously implemented numba functions in SAX use multiple nested pranges as well which is essentially dead code. I can remove it in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, who wrote these. @hadifawaz1999 ?

Copy link
Member

@hadifawaz1999 hadifawaz1999 Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i wrote these, i dont the issue though from the discussion, SAX was working fine why is it dead code ? @patrickzib

acc = 0.0
for k in range(seg_len):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the .mean() here as in the original version??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numba doesn't have an implementation for mean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, I understand. Both seem to work fine (I might be missing something)

        for i in range(n_samples):
            for j in range(n_channels):
                acc = X[i, j, segment].mean()
        for i in range(n_samples):
            for j in range(n_channels):
                acc = np.mean(X[i, j, segment])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant we used axis=-1 in the original implementation, however, numba doesn't support optional arguments. Hence, implemented it this way.
:))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I understand..

Yet, why not use one of the two alternatives presented above that do not require axis=-1?

acc += X[i, j, segment[k]]
X_paa[i, j, _s] = acc / seg_len

return X_paa


@njit(parallel=True, cache=True, fastmath=True)
def _parallel_inverse_paa_transform(X, original_length, n_segments, split_segments):
"""Parallelize the inverse PAA transformation for cases where the series length is not
divisible by the number of segments.
"""
n_samples, n_channels, _ = X.shape
X_inverse_paa = np.zeros(shape=(n_samples, n_channels, original_length))

for _s in prange(n_segments):
segment = split_segments[_s]
for idx in prange(len(segment)):
t = segment[idx]
for i in prange(n_samples):
for j in prange(n_channels):
X_inverse_paa[i, j, t] = X[i, j, _s]

return X_inverse_paa
34 changes: 33 additions & 1 deletion aeon/transformations/collection/dictionary_based/_sax.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ def _get_sax_symbols(self, X_paa):
sax_symbols : np.ndarray of shape = (n_cases, n_channels, n_segments)
The output of the SAX transformation using np.digitize
"""
sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints)
prev_threads = get_num_threads()
_n_jobs = check_n_jobs(self.n_jobs)
set_num_threads(_n_jobs)
sax_symbols = _parallel_get_sax_symbols(X_paa, breakpoints=self.breakpoints)
set_num_threads(prev_threads)
return sax_symbols

def inverse_sax(self, X, original_length, y=None):
Expand Down Expand Up @@ -292,3 +296,31 @@ def _invert_sax_symbols(sax_symbols, n_timepoints, breakpoints_mid):
]

return sax_inverse


@njit(fastmath=True, cache=True, parallel=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not 100% sure what the gain would be of doing all this, if its not significant am not for doing this, it would mean instead of simply doing sax_symbols = np.digitize(x=X_paa, bins=self.breakpoints) we have 30+ new lines of code wiith nested loops and prange.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have done a benchmark in the past, I think it performs little better after n_jobs > 2 and significantly better post 4 threads as compared to np.digitize

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having dealt with MONSTER datasets lately, not being able to parallelize stuff for such huge size datasets is a big downside.
It might not be significant for UCR/UEA, but can help a lot when dataset size grow.

I'm not a fan of the flatten and reshape, though, but I guess you did it to avoid the nested loop parallelism problem ?

You could add that breakspoints need to be sorted for the function to work for clarity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to check _sfa_fast.py

I use np.digitize within a parallel for-loop. This approach strikes a balance between using only explicit loops and relying solely on vectorized digitize operations.

        for a in prange(dfts.shape[0]):
            for i in range(word_length):  # range(dfts.shape[2]):
                words[a, : dfts.shape[1]] = (
                    words[a, : dfts.shape[1]] << letter_bits
                ) | np.digitize(dfts[a, :, i], breakpoints[i], right=True)

def _parallel_get_sax_symbols(x, breakpoints, right=False):
"""Parallel version of `np.digitize`."""
x_flat = x.flatten()
result = np.empty(x_flat.shape[0], dtype=np.intp)

for i in prange(x_flat.shape[0]):
val = x_flat[i]
bin_idx = 0

if right:
for j in range(len(breakpoints)):
if val <= breakpoints[j]:
bin_idx = j
break
bin_idx = j + 1
else:
for j in range(len(breakpoints)):
if val < breakpoints[j]:
bin_idx = j
break
bin_idx = j + 1

result[i] = bin_idx

return result.reshape(x.shape)
Loading