Skip to content

Commit 2a86215

Browse files
authored
Merge pull request #1 from genematx/setup
Validation Workflows
2 parents 96c7c90 + a5495a3 commit 2a86215

File tree

4 files changed

+250
-32
lines changed

4 files changed

+250
-32
lines changed

data_validation.py

Lines changed: 216 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,232 @@
1-
import os
2-
import time as ttime
1+
import time
32

43
from prefect import flow, task, get_run_logger
54
from prefect.blocks.system import Secret
5+
66
from tiled.client import from_profile
7+
from tiled.client.array import ArrayClient
8+
from tiled.client.dataframe import DataFrameClient
9+
from tiled.client.utils import handle_error
10+
from tiled.mimetypes import DEFAULT_ADAPTERS_BY_MIMETYPE as ADAPTERS_BY_MIMETYPE
11+
from tiled.utils import safe_json_dump
712

813
BEAMLINE_OR_ENDSTATION = "arpes"
914

1015

16+
class ValidationException(Exception):
17+
18+
def __init__(self, message, uid=None):
19+
super().__init__(message)
20+
self.uid = uid
21+
22+
class ReadingValidationException(ValidationException):
23+
pass
24+
25+
class RunValidationException(ValidationException):
26+
pass
27+
28+
class MetadataValidationException(ValidationException):
29+
pass
30+
31+
32+
def validate(root_client, fix_errors=True, try_reading=True, raise_on_error=False, ignore_errors=[]):
33+
"""Validate the given BlueskyRun client for completeness and data integrity.
34+
35+
Parameters
36+
----------
37+
38+
root_client : tiled.client.run.RunClient
39+
The Run client to validate.
40+
fix_errors : bool, optional
41+
Whether to attempt to fix structural errors found during validation.
42+
Default is True.
43+
try_reading : bool, optional
44+
Whether to attempt reading the data for external data keys.
45+
Default is True.
46+
raise_on_error : bool, optional
47+
Whether to raise an exception on the first validation error encountered.
48+
Default is False.
49+
ignore_errors : list of str, optional
50+
List of error messages to ignore during reading validation.
51+
Default is an empty list.
52+
53+
Returns
54+
-------
55+
bool
56+
True if validation passed without errors, False otherwise.
57+
"""
58+
59+
logger = get_run_logger()
60+
61+
# Check if there's a Stop document in the run
62+
if "stop" not in root_client.metadata:
63+
logger.error("The Run is not complete: missing the Stop document")
64+
if raise_on_error:
65+
raise RunValidationException("Missing Stop document in the run")
66+
67+
# Check all streams and data keys
68+
errored_keys, notes = [], []
69+
streams_node = root_client['streams'] if 'streams' in root_client.keys() else root_client
70+
for sname, stream in streams_node.items():
71+
for data_key in stream.base:
72+
if data_key == "internal":
73+
continue
74+
75+
data_client = stream[data_key]
76+
if data_client.data_sources()[0].management != "external":
77+
continue
78+
79+
# Validate data structure
80+
title = f"Validation of data key '{sname}/{data_key}'"
81+
try:
82+
_notes = validate_structure(data_client, fix_errors=fix_errors)
83+
notes.extend([title + ": " + note for note in _notes])
84+
except Exception as e:
85+
msg = f"{type(e).__name__}: " + str(e).replace("\n", " ").replace("\r", "").strip()
86+
msg = title + f" failed with error: {msg}"
87+
logger.warning(msg)
88+
notes.append(msg)
89+
90+
# Validate reading of the data
91+
if try_reading:
92+
try:
93+
validate_reading(data_client, ignore_errors=ignore_errors)
94+
except Exception as e:
95+
errored_keys.append((sname, data_key, str(e)))
96+
logger.error(f"Reading validation of '{sname}/{data_key}' failed with error: {e}")
97+
if raise_on_error:
98+
raise e
99+
100+
time.sleep(0.1)
101+
102+
if try_reading and (not errored_keys):
103+
logger.info(f"Reading validation completed successfully.")
104+
105+
# Update the root metadata with validation notes
106+
if notes:
107+
existing_notes = root_client.metadata.get("notes", [])
108+
root_client.update_metadata(
109+
{"notes": existing_notes + notes},
110+
drop_revision=True,
111+
)
112+
113+
return not errored_keys
114+
115+
116+
def validate_reading(data_client, ignore_errors=[]):
117+
logger = get_run_logger()
118+
119+
data_key = data_client.item['id']
120+
sname = data_client.item['attributes']['ancestors'][-1] # stream name
121+
122+
if isinstance(data_client, ArrayClient):
123+
try:
124+
data_client[*(0,)*len(data_client.shape)] # try to read the first element
125+
data_client[*(-1,)*len(data_client.shape)] # try to read the last element
126+
except Exception as e:
127+
if any([msg in e.args[0] for msg in ignore_errors]):
128+
logger.info(f"Ignoring array reading error: {sname}/{data_key}: {e.args[0]}")
129+
else:
130+
raise ReadingValidationException(f"Array reading failed with error: {e.args[0]}")
131+
132+
elif isinstance(data_client, DataFrameClient):
133+
try:
134+
data_client.read() # try to read the entire table
135+
except Exception as e:
136+
if any([msg in e.args[0] for msg in ignore_errors]):
137+
logger.info(f"Ignoring table reading error: {sname}/{data_key}: {e.args[0]}")
138+
else:
139+
raise ReadingValidationException(f"Table reading failed with error: {e.args[0]}")
140+
141+
else:
142+
logger.warning(f"Validation of '{data_key=}' is not supported with client of type {type(data_client)}.")
143+
144+
145+
def validate_structure(data_client, fix_errors=False) -> list[str]:
146+
logger = get_run_logger()
147+
148+
data_source = data_client.data_sources()[0]
149+
uris = [asset.data_uri for asset in data_source.assets]
150+
structure = data_client.structure()
151+
notes = []
152+
153+
# Initialize adapter from uris and determine the structure as read by the adapter
154+
adapter_class = ADAPTERS_BY_MIMETYPE[data_source.mimetype]
155+
true_structure = adapter_class.from_uris(*uris, **data_source.parameters).structure()
156+
true_data_type = true_structure.data_type
157+
true_shape = true_structure.shape
158+
true_chunks = true_structure.chunks
159+
160+
# Validate structure components
161+
if structure.shape != true_shape:
162+
if not fix_errors:
163+
raise ValueError(f"Shape mismatch: {structure.shape} != {true_shape}")
164+
else:
165+
msg = f"Fixed shape mismatch: {structure.shape} -> {true_shape}"
166+
logger.warning(msg)
167+
structure.shape = true_shape
168+
notes.append(msg)
169+
170+
if structure.chunks != true_chunks:
171+
if not fix_errors:
172+
raise ValueError(f"Chunk shape mismatch: {structure.chunks} != {true_chunks}")
173+
else:
174+
_true_chunk_shape = tuple(c[0] for c in true_chunks)
175+
_chunk_shape = tuple(c[0] for c in structure.chunks)
176+
msg = f"Fixed chunk shape mismatch: {_chunk_shape} -> {_true_chunk_shape}"
177+
logger.warning(msg)
178+
structure.chunks = true_chunks
179+
notes.append(msg)
180+
181+
if structure.data_type != true_data_type:
182+
if not fix_errors:
183+
raise ValueError(f"dtype mismatch: {structure.data_type} != {true_data_type}")
184+
else:
185+
msg = f"Fixed dtype mismatch: {structure.data_type.to_numpy_dtype()} -> {true_data_type.to_numpy_dtype()}" # noqa
186+
logger.warning(msg)
187+
structure.data_type = true_data_type
188+
notes.append(msg)
189+
190+
if structure.dims and (len(structure.dims) != len(true_shape)):
191+
if not fix_errors:
192+
raise ValueError(f"Number of dimension names mismatch for a {len(true_shape)}-dimensional array: {structure.dims}") # noqa
193+
else:
194+
old_dims = structure.dims
195+
if len(old_dims) < len(true_shape):
196+
structure.dims = ("time",) + old_dims + tuple(f"dim{i}" for i in range(len(old_dims)+1, len(true_shape)))
197+
else:
198+
structure.dims = old_dims[: len(true_shape)]
199+
msg = f"Fixed dimension names: {old_dims} -> {structure.dims}"
200+
logger.warning(msg)
201+
notes.append(msg)
202+
203+
# Update the data source structure if any fixes were applied
204+
if notes:
205+
data_source.structure = structure
206+
handle_error(
207+
data_client.context.http_client.put(
208+
data_client.uri.replace("/api/v1/metadata/", "/api/v1/data_source/", 1),
209+
content=safe_json_dump({"data_source": data_source}),
210+
)
211+
).json()
212+
213+
return notes
214+
215+
11216
@task(retries=2, retry_delay_seconds=10)
12-
def read_all_streams(uid, beamline_acronym=BEAMLINE_OR_ENDSTATION):
217+
def data_validation_task(uid, beamline_acronym=BEAMLINE_OR_ENDSTATION):
13218
logger = get_run_logger()
219+
14220
api_key = Secret.load(f"tiled-{beamline_acronym}-api-key", _sync=True).get()
15221
tiled_client = from_profile("nsls2", api_key=api_key)
16-
run = tiled_client[beamline_acronym]["migration"][uid]
17-
logger.info(f"Validating uid {run.metadata['start']['uid']}")
18-
start_time = ttime.monotonic()
19-
for stream in run:
20-
logger.info(f"{stream}:")
21-
stream_start_time = ttime.monotonic()
22-
# stream_data = run[stream].read()
23-
stream_elapsed_time = ttime.monotonic() - stream_start_time
24-
logger.info(f"{stream} elapsed_time = {stream_elapsed_time}")
25-
logger.info(f"{stream} nbytes = {stream_data.nbytes:_}")
26-
elapsed_time = ttime.monotonic() - start_time
222+
run_client = tiled_client[beamline_acronym]["migration"][uid]
223+
logger.info(f"Validating uid {uid}")
224+
start_time = time.monotonic()
225+
validate(run_client, fix_errors=True, try_reading=True, raise_on_error=False)
226+
elapsed_time = time.monotonic() - start_time
27227
logger.info(f"{elapsed_time = }")
28228

29229

30230
@flow(log_prints=True)
31-
def data_validation(uid):
32-
read_all_streams(uid)
231+
def data_validation_flow(uid):
232+
data_validation_task(uid)

end_of_run_workflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from prefect import task, flow, get_run_logger
2-
from data_validation import data_validation
2+
from data_validation import data_validation_flow
33

44
@task
55
def log_completion():
@@ -9,5 +9,5 @@ def log_completion():
99
@flow(log_prints=True)
1010
def end_of_run_workflow(stop_doc):
1111
uid = stop_doc["run_start"]
12-
data_validation(uid)
12+
data_validation_flow(uid)
1313
log_completion()

pixi.lock

Lines changed: 30 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ platforms = ["linux-64"]
66
[dependencies]
77
prefect = "3.*"
88
python = "<3.13"
9-
tiled-client = ">=0.1.6"
9+
tiled-client = ">=0.2.1"
1010
prefect-docker = "*"
1111
databroker = "*"
12+
bluesky-tiled-plugins = "*"

0 commit comments

Comments
 (0)