Skip to content
Open
3 changes: 3 additions & 0 deletions zntrack/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def name(self) -> str:
def nwd(self):
if self.tmp_path is not None:
return self.tmp_path
if "nwd" not in self.node.__dict__:
self.node.__dict__["nwd"] = get_nwd(self.node)

return get_nwd(self.node)

@property
Expand Down
30 changes: 24 additions & 6 deletions zntrack/utils/node_wd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pathlib
import shutil
import typing as t
import warnings

import znflow.utils

Expand Down Expand Up @@ -48,11 +47,30 @@ def get_nwd(node: "Node") -> pathlib.Path:
try:
return node.__dict__["nwd"]
except KeyError:
warnings.warn(
"Using the NWD outside a project context"
" can not guarantee unique directories."
)
return pathlib.Path(NWD_PATH, node.__class__.__name__)
if node.name is None:
raise ValueError("Unable to determine node name.")
if (
node.state.remote is None
and node.state.rev is None
and node.state.state != NodeStatusEnum.FINISHED
):
nwd = pathlib.Path(NWD_PATH, node.name)
else:
try:
with node.state.fs.open(ZNTRACK_FILE_PATH) as f:
zntrack_config = json.load(f)
nwd = zntrack_config[node.name]["nwd"]
nwd = json.loads(json.dumps(nwd), cls=znjson.ZnDecoder)
except (FileNotFoundError, KeyError):
nwd = pathlib.Path(NWD_PATH, node.name)

if node.state.group is not None:
# strip the groups from node_name
to_replace = "_".join(node.state.group.name) + "_"
replacement = "/".join(node.state.group.name) + "/"
nwd = pathlib.Path(str(nwd).replace(to_replace, replacement))

return nwd


class NWDReplaceHandler(znflow.utils.IterableHandler):
Expand Down