diff --git a/zntrack/state.py b/zntrack/state.py index edf1ceb0..5786bcc4 100644 --- a/zntrack/state.py +++ b/zntrack/state.py @@ -87,6 +87,9 @@ def name(self) -> str: def nwd(self): if self.tmp_path is not None: return self.tmp_path + if "nwd" not in self.node.__dict__: + self.node.__dict__["nwd"] = get_nwd(self.node) + return get_nwd(self.node) @property diff --git a/zntrack/utils/node_wd.py b/zntrack/utils/node_wd.py index e9380cd4..239157f1 100644 --- a/zntrack/utils/node_wd.py +++ b/zntrack/utils/node_wd.py @@ -5,7 +5,6 @@ import pathlib import shutil import typing as t -import warnings import znflow.utils @@ -48,11 +47,30 @@ def get_nwd(node: "Node") -> pathlib.Path: try: return node.__dict__["nwd"] except KeyError: - warnings.warn( - "Using the NWD outside a project context" - " can not guarantee unique directories." - ) - return pathlib.Path(NWD_PATH, node.__class__.__name__) + if node.name is None: + raise ValueError("Unable to determine node name.") + if ( + node.state.remote is None + and node.state.rev is None + and node.state.state != NodeStatusEnum.FINISHED + ): + nwd = pathlib.Path(NWD_PATH, node.name) + else: + try: + with node.state.fs.open(ZNTRACK_FILE_PATH) as f: + zntrack_config = json.load(f) + nwd = zntrack_config[node.name]["nwd"] + nwd = json.loads(json.dumps(nwd), cls=znjson.ZnDecoder) + except (FileNotFoundError, KeyError): + nwd = pathlib.Path(NWD_PATH, node.name) + + if node.state.group is not None: + # strip the groups from node_name + to_replace = "_".join(node.state.group.name) + "_" + replacement = "/".join(node.state.group.name) + "/" + nwd = pathlib.Path(str(nwd).replace(to_replace, replacement)) + + return nwd class NWDReplaceHandler(znflow.utils.IterableHandler):