Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion sotodlib/preprocess/pcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, step_cfgs):
self.save_cfgs = step_cfgs.get("save")
self.select_cfgs = step_cfgs.get("select")
self.plot_cfgs = step_cfgs.get("plot")
self.stats_cfgs = step_cfgs.get("stats")
self.skip_on_sim = step_cfgs.get("skip_on_sim", False)
def process(self, aman, proc_aman, sim=False):
""" This function makes changes to the time ordered data AxisManager.
Expand Down Expand Up @@ -154,6 +155,23 @@ def plot(self, aman, proc_aman, filename):
return
raise NotImplementedError

def calc_stats(self, aman, proc_aman):
"""Calculate stats from a preprocessing function to be saved in a
preprocessing stats ObsDb.

Arguments
---------
aman : AxisManager
The time ordered data
proc_aman : AxisManager
Any information generated by previous elements in the preprocessing
pipeline.
"""

if self.stats_cfgs is None:
return
raise NotImplementedError

@classmethod
def gen_metric(cls, meta, proc_aman):
""" Generate a QA metric from the output of this process.
Expand Down Expand Up @@ -447,6 +465,7 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
process.calc_and_save()
process.save() ## called by process.calc_and_save()
process.select()
process.calc_stats()

Arguments
---------
Expand Down Expand Up @@ -535,7 +554,7 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
process.select(aman, proc_aman)
proc_aman.restrict('dets', aman.dets.vals)
self.logger.debug(f"{proc_aman.dets.count} detectors remaining")

if aman.dets.count == 0:
success = process.name
break
Expand All @@ -549,6 +568,13 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
full.move("frequency_cutoffs", None)
full.wrap("frequency_cutoffs", proc_aman["frequency_cutoffs"])

# calc stats
aman.wrap("stats", core.AxisManager())
aman.stats.wrap("ndets", proc_aman.dets.count)
aman.stats.wrap("nsamps", proc_aman.samps.count)
for step, process in enumerate(self):
process.calc_stats(aman, full)

return full, success


Expand Down
22 changes: 22 additions & 0 deletions sotodlib/preprocess/preprocess_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,28 @@ def get_preprocess_db(configs, group_by, logger=None):
return db


def get_preprocess_stats_db(statsdb_path, group_by):
"""Make or load a preprocessing stats ObsDb using preprocessing grouping
as indexes.

Arguments
----------
statsdb_path : str
The path to the stats db file.
group_by : list of str
The list of keys used to group the detectors.

Returns
-------
statsdb : ObsDb
ObsDb object
"""
group_by = [g.split('.')[-1] for g in group_by]
statsdb = core.metadata.ObsDb(map_file=statsdb_path,
wafer_info=group_by)
return statsdb


def swap_archive(config, fpath):
"""Update the configuration archive policy filename,
create an output archive directory if it doesn't exist,
Expand Down
113 changes: 113 additions & 0 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,22 @@ def plot(self, aman, proc_aman, filename):
plot_det_bias_flags(aman, proc_aman['det_bias_flags'], rfrac_range=self.calc_cfgs['rfrac_range'],
psat_range=self.calc_cfgs['psat_range'], filename=filename.replace('{name}', f'{ufm}_bias_cuts_venn'))

def calc_stats(self, aman, proc_aman):
if (
self.stats_cfgs is None
or (isinstance(self.stats_cfgs, bool) and not self.stats_cfgs)
or self.save_cfgs is None
or (isinstance(self.save_cfgs, bool) and not self.save_cfgs)
):
return

dbc_aman = proc_aman["det_bias_flags"]
m = has_any_cuts(dbc_aman.valid)
for k in dbc_aman._assignments.keys():
if "flags" in k:
field = f"{k}_cuts"
aman.stats.wrap(field, np.sum(has_any_cuts(dbc_aman[k])[m]))


class Trends(_FracFlaggedMixIn, _Preprocess):
"""Calculate the trends in the data to look for unlocked detectors. All
Expand Down Expand Up @@ -174,6 +190,17 @@ def plot(self, aman, proc_aman, filename):
ufm = det.split('_')[2]
plot_trending_flags(aman, proc_aman['trends'], filename=filename.replace('{name}', f'{ufm}_trending_flags'))

def calc_stats(self, aman, proc_aman):
if self.stats_cfgs is None or self.save_cfgs is None:
return

trend_aman = proc_aman["trends"]

base_name = self.stats_cfgs.get("base_name", "trends")
m = has_any_cuts(trend_aman.valid)
field = f"{base_name}_cuts"
aman.stats.wrap(field, np.sum(has_any_cuts(trend_aman.trend_flags)[m]))


class GlitchDetection(_FracFlaggedMixIn, _Preprocess):
"""Run glitch detection algorithm to find glitches. All calculation configs
Expand Down Expand Up @@ -256,6 +283,26 @@ def plot(self, aman, proc_aman, filename):
plot_ds_factor=self.plot_cfgs.get("plot_ds_factor", 50), filename=filename.replace('{name}', f'{ufm}_glitch_signal_diff'))
plot_flag_stats(aman, proc_aman[self.glitch_name], flag_type='glitches', filename=filename.replace('{name}', f'{ufm}_glitch_stats'))

def calc_stats(self, aman, proc_aman):
if (
self.stats_cfgs is None
or (isinstance(self.stats_cfgs, bool) and not self.stats_cfgs)
or self.save_cfgs is None
or (isinstance(self.save_cfgs, bool) and not self.save_cfgs)
):
return

glitch_aman = proc_aman[self.glitch_name]

m = has_all_cut(glitch_aman.valid)
field = f"{self.glitch_name}_cuts"
aman.stats.wrap(field, np.sum(has_all_cut(glitch_aman.glitch_flags)[m]))

m = has_any_cuts(glitch_aman.valid)
field = f"{self.glitch_name}_counts"
aman.stats.wrap(field, np.sum(count_cuts(glitch_aman.glitch_flags)[m]))


class FixJumps(_Preprocess):
"""
Repairs the jump heights given a set of jump flags and heights.
Expand Down Expand Up @@ -376,6 +423,31 @@ def plot(self, aman, proc_aman, filename):
plot_ds_factor=self.plot_cfgs.get("plot_ds_factor", 50), filename=filename.replace('{name}', f'{ufm}_jump_signal_diff'))
plot_flag_stats(aman, proc_aman[name], flag_type='jumps', filename=filename.replace('{name}', f'{ufm}_jumps_stats'))

def calc_stats(self, aman, proc_aman):
if (
self.stats_cfgs is None
or (isinstance(self.stats_cfgs, bool) and not self.stats_cfgs)
or self.save_cfgs is None
or (isinstance(self.save_cfgs, bool) and not self.save_cfgs)
):
return

jumps_aman = proc_aman[self.save_cfgs.get('jumps_name', 'jumps')]

if isinstance(self.stats_cfgs, bool):
base_name = "jumps"
else:
base_name = self.stats_cfgs.get("base_name", "jumps")

m = has_all_cut(jumps_aman.valid)
field = f"{base_name}_cuts"
aman.stats.wrap(field, np.sum(has_all_cut(jumps_aman.jump_flag)[m]))

m = has_any_cuts(jumps_aman.valid)
field = f"{base_name}_counts"
aman.stats.wrap(field, np.sum(count_cuts(jumps_aman.jump_flag)[m]))


class PSDCalc(_Preprocess):
""" Calculate the PSD of the data and add it to the AxisManager under the
"psd" field.
Expand Down Expand Up @@ -849,6 +921,47 @@ def select(self, meta, proc_aman=None, in_place=True):
else:
return keep

def calc_stats(self, aman, proc_aman):
if (
self.stats_cfgs is None
or (isinstance(self.stats_cfgs, bool) and not self.stats_cfgs)
or self.save_cfgs is None
or (isinstance(self.save_cfgs, bool) and not self.save_cfgs)
):
return

if (
isinstance(self.save_cfgs, bool)
or not self.save_cfgs.get("wrap_name", None)
):
noise_aman = proc_aman["noise"]
else:
noise_aman = proc_aman[self.save_cfgs['wrap_name']]

m = has_all_cut(noise_aman.valid)

if isinstance(self.stats_cfgs, bool):
base_name = "noise"
else:
base_name = self.stats_cfgs.get("base_name", "noise")

if "white_noise" in noise_aman:
field = f"mean_{base_name}_white_noise"
aman.stats.wrap(field, np.nanmean(noise_aman.white_noise[m]))
field = f"std_{base_name}_white_noise"
aman.stats.wrap(field, np.nanstd(noise_aman.white_noise[m]))

if self.fit:
for i, coeff in enumerate(noise_aman.noise_model_coeffs.vals):
# skip if white noise was calculated separately
if coeff == "white_noise" and "white_noise" in noise_aman:
continue
field = f"mean_{base_name}_{coeff}"
aman.stats.wrap(field, np.nanmean(noise_aman.fit[:,i][m]))
field = f"std_{base_name}_{coeff}"
aman.stats.wrap(field, np.nanstd(noise_aman.fit[:, i][m]))



class Calibrate(_Preprocess):
"""Calibrate the timestreams based on some provided information.
Expand Down
38 changes: 36 additions & 2 deletions sotodlib/site_pipeline/multilayer_preprocess_tod.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,12 @@ def multilayer_preprocess_tod(obs_id: str,
compress=compress,
)

return out_dict_init, out_dict_proc, errors
if "stats" in aman:
stats = {k: aman.stats[k] for k in aman.stats._assignments.keys()}
else:
stats = None

return out_dict_init, out_dict_proc, errors, stats


def _check_init_jobdb(
Expand Down Expand Up @@ -378,6 +383,11 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
pp_util.get_preprocess_db(configs_init, group_by, logger)
pp_util.get_preprocess_db(configs_proc, group_by, logger)

# get stats db
statsdb_path = configs_proc.get("statsdb", None)
if statsdb_path is not None:
statsdb = pp_util.get_preprocess_stats_db(statsdb_path, group_by)

futures = []
futures_dict = {}
obs_errors = {}
Expand All @@ -401,6 +411,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],

total = len(futures)

add_obs_cols = True
pb_name = f"pb_{str(int(time.time()))}.txt"
with open(pb_name, 'w') as f:
for future in tqdm(as_completed_callable(futures), total=total,
Expand All @@ -409,7 +420,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
obs_id, group = futures_dict[future]
out_meta = (obs_id, group)
try:
out_dict_init, out_dict_proc, errors = future.result()
out_dict_init, out_dict_proc, errors, stats = future.result()
obs_errors[obs_id].append({'group': group, 'error': errors[0]})
logger.info(f"{obs_id}: {group} extracted successfully")
except Exception as e:
Expand Down Expand Up @@ -451,6 +462,29 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
_t.value = errors[0]
else:
j.jstate = JState.done

# update statsdb
if (
errors[0] is None
and statsdb_path is not None
and stats is not None
):
if add_obs_cols == True:
stats_keys = []
for k, v in stats.items():
if isinstance(v, int):
t = "int"
elif isinstance(v, float):
t = "float"
elif isinstance(v, str):
t = "string"

stats_keys.append(f"{k} {t}")

statsdb.add_obs_columns(stats_keys, ignore_duplicates=True)
add_obs_cols = False
statsdb.update_obs((obs_id, *group), stats)

if raise_error:
n_obs_fail = 0
n_groups_fail = 0
Expand Down
39 changes: 37 additions & 2 deletions sotodlib/site_pipeline/preprocess_tod.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def preprocess_tod(configs: Union[str, dict],
A tuple containing the error from PreprocessError, an error message,
and the traceback. Each will be None if preproc_or_load_group finished
successfully.
stats : dict
A dictionary storing calculated preprocessing stats
"""
logger = pp_util.init_logger("preprocess", verbosity=verbosity)

Expand All @@ -142,7 +144,12 @@ def preprocess_tod(configs: Union[str, dict],
compress=compress,
)

return out_dict, errors
if "stats" in aman:
stats = {k: aman.stats[k] for k in aman.stats._assignments.keys()}
else:
stats = None

return out_dict, errors, stats


def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
Expand Down Expand Up @@ -286,6 +293,11 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
# ensure db exists up front to prevent race conditions
pp_util.get_preprocess_db(configs, group_by, logger)

# get stats db
statsdb_path = configs.get("statsdb", None)
if statsdb_path is not None:
statsdb = pp_util.get_preprocess_stats_db(statsdb_path, group_by)

futures = []
futures_dict = {}
obs_errors = {}
Expand All @@ -308,6 +320,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],

total = len(futures)

add_obs_cols = True
pb_name = f"pb_{str(int(time.time()))}.txt"
with open(pb_name, 'w') as f:
for future in tqdm(as_completed_callable(futures), total=total,
Expand All @@ -316,7 +329,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
obs_id, group = futures_dict[future]
out_meta = (obs_id, group)
try:
out_dict, errors = future.result()
out_dict, errors, stats = future.result()
obs_errors[obs_id].append({'group': group, 'error': errors[0]})
logger.info(f"{obs_id}: {group} extracted successfully")
except Exception as e:
Expand Down Expand Up @@ -349,6 +362,28 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"],
else:
j.jstate = JState.done

# update statsdb
if (
errors[0] is None
and statsdb_path is not None
and stats is not None
):
if add_obs_cols == True:
stats_keys = []
for k, v in stats.items():
if isinstance(v, int):
t = "int"
elif isinstance(v, float):
t = "float"
elif isinstance(v, str):
t = "string"

stats_keys.append(f"{k} {t}")

statsdb.add_obs_columns(stats_keys, ignore_duplicates=True)
add_obs_cols = False
statsdb.update_obs((obs_id, *group), stats)

if raise_error:
n_obs_fail = 0
n_groups_fail = 0
Expand Down