@@ -22,17 +22,23 @@ def _reduce_all_columns_wrapper(*args, columns=None, udf, **kwargs):
2222
2323
2424def _process_lc (
25- row : dict [str , object ], * , n_src : int , lc_col : str , length_col : str , rng : np .random .Generator
25+ row : dict [str , object ],
26+ * ,
27+ n_src : int ,
28+ subsample_src : bool ,
29+ lc_col : str ,
30+ length_col : str ,
31+ rng : np .random .Generator ,
2632) -> dict [str , np .ndarray ]:
2733 lc_length = row .pop (length_col )
28- idx = rng .choice (lc_length , size = n_src , replace = False )
34+ idx = rng .choice (lc_length , size = n_src , replace = False ) if subsample_src else np . arange ( lc_length )
2935
3036 result : dict [str , np .ndarray ] = {}
3137 for col , value in row .items ():
3238 if col .startswith (f"{ lc_col } ." ):
3339 result [col ] = value [idx ]
3440 else :
35- result [f"{ lc_col } .{ col } " ] = np .full (n_src , value )
41+ result [f"{ lc_col } .{ col } " ] = np .full (len ( idx ) , value )
3642 return result
3743
3844
@@ -41,6 +47,7 @@ def _process_partition(
4147 pixel : HealpixPixel ,
4248 * ,
4349 n_src : int ,
50+ subsample_src : bool ,
4451 lc_col : str ,
4552 id_col : str ,
4653 hash_range : tuple [int , int ] | None ,
@@ -77,6 +84,7 @@ def _process_partition(
7784 columns = columns ,
7885 udf = _process_lc ,
7986 n_src = n_src ,
87+ subsample_src = subsample_src ,
8088 lc_col = lc_col ,
8189 length_col = length_col ,
8290 rng = rng ,
@@ -92,6 +100,7 @@ def lsdb_nested_series_data_generator(
92100 id_col : str = "id" ,
93101 client : dask .distributed .Client | None ,
94102 n_src : int ,
103+ subsample_src : bool = True ,
95104 partitions_per_chunk : int | None ,
96105 hash_range : tuple [int , int ] | None = None ,
97106 loop : bool = False ,
@@ -101,8 +110,10 @@ def lsdb_nested_series_data_generator(
101110
102111 The data is pre-fetched on the background, 'n_workers' number
103112 of partitions per time (derived from `client` object).
104- It filters out light curves with less than `n_src` observations,
105- and selects `n_src` random observations per light curve.
113+ Filters out light curves with fewer than `n_src` observations.
114+ If `subsample_src` is ``True``, selects exactly `n_src` random observations
115+ per light curve. If ``False``, all observations from qualifying light curves
116+ are included.
106117
107118 Parameters
108119 ----------
@@ -118,7 +129,12 @@ def lsdb_nested_series_data_generator(
118129 value. If Dask client is given, the data would be fetched on the
119130 background.
120131 n_src : int
121- Number of random observations per light curve.
132+ Minimum number of observations required per light curve. Also the
133+ subsample target when `subsample_src` is ``True``.
134+ subsample_src : bool, optional
135+ If ``True`` (default), randomly subsample exactly `n_src` observations
136+ per light curve. If ``False``, include all observations from qualifying
137+ light curves.
122138 partitions_per_chunk : int
123139 Number of `catalog` partitions load in memory simultaneously.
124140 This changes the randomness.
@@ -151,6 +167,7 @@ def lsdb_nested_series_data_generator(
151167 _process_partition ,
152168 include_pixel = True ,
153169 n_src = n_src ,
170+ subsample_src = subsample_src ,
154171 lc_col = lc_col ,
155172 id_col = id_col ,
156173 hash_range = hash_range ,
@@ -205,7 +222,8 @@ class LSDBIterableDataset(IterableDataset):
205222 Number of batches to yield. If `splits` is used, it will be the size
206223 of the first subset.
207224 n_src : int
208- Number of random observations per light curve.
225+ Number of random observations per light curve. Light curves with fewer
226+ than `n_src` observations are filtered out.
209227 partitions_per_chunk : int or None
210228 Number of `catalog` partitions per time, if None it is derived
211229 from the number of dask workers associated with `Client` (one if
0 commit comments