From cd70db841f0def31424e3ef62f31f1ed2b38dd72 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 14:25:55 +0200 Subject: [PATCH 01/23] Bugfix --- CHANGELOG.md | 53 ++ .../plans/2026-04-29-mermaid-export.md | 321 -------- .../specs/2026-04-29-mermaid-export-design.md | 174 ---- src/tensor_network_editor/_public_codegen.py | 6 +- src/tensor_network_editor/app/_protocol.py | 4 + .../app/_session_requests.py | 4 + src/tensor_network_editor/app/routes.py | 2 + src/tensor_network_editor/app/session.py | 4 + src/tensor_network_editor/app/static/app.css | 25 + .../app/static/index.html | 38 + .../js/actions/designMutationPipeline.js | 3 - .../app/static/js/core/dom.js | 20 +- .../js/properties/entityPropertiesMarkup.js | 13 +- .../js/properties/overviewPropertiesMarkup.js | 13 +- .../app/static/js/properties/properties.js | 12 + .../properties/propertiesRenderersEntities.js | 43 +- .../properties/propertiesRenderersOverview.js | 12 +- .../js/services/editorSessionService.js | 14 +- .../static/js/session/sessionEditorFlows.js | 2 + .../static/js/session/sessionTemplateFlows.js | 92 ++- .../static/js/shell/editorShellBindings.js | 42 +- .../app/static/js/state/editorStore.js | 6 + .../app/static/js/state/historySnapshots.js | 3 + .../app/static/js/state/state.js | 1 + .../app/static/js/utils/utilitiesGeometry.js | 28 +- .../app/static/js/utils/utilitiesLayout.js | 42 + .../js/utils/utilitiesUiToolbarActionState.js | 45 +- .../utils/utilitiesUiToolbarDerivedState.js | 28 +- .../codegen/modes/_linear_periodic/carry.py | 20 +- .../modes/_linear_periodic/graph_carry.py | 71 +- .../codegen/modes/grid_periodic.py | 13 +- .../codegen/modes/linear_periodic.py | 13 +- .../codegen/modes/tree_periodic.py | 13 +- src/tensor_network_editor/codegen/registry.py | 4 + .../codegen/shared/roundtrip.py | 5 +- .../validation/_validation_linear_periodic.py | 71 ++ src/tensor_network_editor/rendering.py | 680 +++++++++++++++- tests/codegen/test_generators.py | 619 ++++++++++++++ .../codegen/test_linear_periodic_internals.py | 205 +++++ tests/test_app_assets.py | 79 +- tests/test_frontend_architecture.py | 83 +- tests/test_frontend_runtime.py | 758 +++++++++++++++++- tests/test_models_validation.py | 33 + tests/test_rendering.py | 473 ++++++++++- 44 files changed, 3549 insertions(+), 641 deletions(-) delete mode 100644 docs/superpowers/plans/2026-04-29-mermaid-export.md delete mode 100644 docs/superpowers/specs/2026-04-29-mermaid-export-design.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 902720b..d0ec7c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,15 +10,68 @@ All notable changes to this project will be documented in this file. set (`PNG`, `SVG`, `PDF`, `TikZ/LaTeX`, `Graphviz/DOT`, and `Mermaid`) and clarifies that recommended startup flows can use built-in templates, session templates, and reusable subnetwork fragments. +- Browser-editor `For`-mode code generation now keeps the commented + `TNE_SPEC_B64` round-trip metadata at the end of the generated Python source + instead of the beginning, and the editor only includes it when the new + `Metadata` checkbox is enabled in the `Code` panel. +- Browser-editor `For` mode no longer disables template settings or + selection-based `Extract`, `To Library`, and `To Template` actions just for + being in a periodic editor view; those actions now stay available for normal + tensors and only reject virtual boundary cells such as `next`, `previous`, + grid side cells, or tree parent/child placeholders. ### Added - Static exports now include a `Mermaid` flowchart renderer for documentation workflows, with matching support in the Python API, CLI `render` subcommand, and browser editor export menu. +- The editor `Reflow` popover now offers simple horizontal and vertical + alignment controls plus a 90° clockwise rotation action that also rotates the + selected tensor ports to keep their orientation consistent. +- Static geometric exports now choose shape-aware directions for free indices in + linear, circular, and 2D-grid layouts, use a stable local fallback for + irregular layouts, and draw dangling stubs with a length of two tensor radii. ### Fixed +- Linear-periodic `For`-mode validation now rejects carry plans that the code + generator cannot realize, so the editor no longer reports some + multi-boundary manual schemes as valid during analysis only to fail later + when generating Python code. +- Linear-periodic `For`-mode carry code generation now keeps non-interface + labels from the previous payload distinct from the current cell's local + labels, so valid manual schemes with repeated index names across cells no + longer collapse accidentally during periodic carry simulation. +- Linear-periodic `For`-mode `tensorkrowch` carry generation now keeps local + open edges on stable current-cell edge objects while exporting repeated + carry interfaces from the materialized result node, so periodic helpers no + longer hand later loop iterations a stale leaf edge. +- Linear-periodic `For`-mode `tensorkrowch` carry helpers now reattach + intermediate contraction results before later manual steps reuse them, so + valid periodic plans no longer lose the shared inter-cell edge during + multi-step cell contractions. +- Linear-periodic `For`-mode `tensorkrowch` carry helpers now materialize + shared edges with `reattach_edges(override=True)` instead of relying on + `network.reset()`, so repeated periodic iterations keep their inter-cell + bond visible to later `contract_between` steps. +- Normal `tensorkrowch` manual code generation no longer injects + `reattach_edges(...)` between ordinary contraction steps, so non-`For` + exports keep the simpler node structure that the standard visualizer already + handles correctly. +- Contraction-scene tensor layering now follows the current visible operands + instead of only the base spec tensor list, so selecting or dragging derived + result tensors in `single`/`contract` keeps their free ports visible above + overlapping front tensors. +- Static exports now keep free-index directions aligned with the network's real + on-canvas orientation, so vertical, diagonal, and rotated-grid layouts no + longer get reinterpreted as axis-aligned during `SVG`, `PNG`, `PDF`, and + `TikZ/LaTeX` rendering. +- `Ctrl/Cmd+Enter` now closes the editor with info reliably from contraction + planner views, preview states, `For` mode, and benchmark mode by registering + the global shortcut listener in capture mode. +- Browser-editor academic exports now invalidate the serialized-spec cache + after layout moves and rotations, so exported figures follow the current + canvas geometry instead of occasionally reusing stale pre-reflow positions. - Mermaid export now renders free indices as labeled dangling-edge terminals instead of boxed open-index nodes, so the flowchart output reads more like a tensor-network leg. diff --git a/docs/superpowers/plans/2026-04-29-mermaid-export.md b/docs/superpowers/plans/2026-04-29-mermaid-export.md deleted file mode 100644 index 0091557..0000000 --- a/docs/superpowers/plans/2026-04-29-mermaid-export.md +++ /dev/null @@ -1,321 +0,0 @@ -# Mermaid Export Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Add a Mermaid export format that users can generate from Python, the CLI, and the browser editor for documentation-friendly tensor-network diagrams. - -**Architecture:** Extend the existing static rendering family with a new text renderer in `rendering.py`, then thread the new format through the `/api/render` backend route, the CLI `render` subcommand, and the browser export UI. Reuse `DotRenderOptions` for label toggles, keep the Mermaid output structure-oriented rather than geometry-oriented, and degrade gracefully for notes and groups. - -**Tech Stack:** Python 3.12, typed package exports, argparse CLI, browser editor JavaScript, pytest, pyright, ruff - ---- - -### Task 1: Add the core Mermaid renderer with TDD - -**Files:** -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\rendering.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\__init__.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_rendering.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_api.py` - -- [ ] **Step 1: Write the failing renderer tests** - -Add focused tests to `tests/test_rendering.py` for: - -```python -def test_render_spec_mermaid_returns_flowchart_for_normal_network() -> None: - mermaid = render_spec_mermaid(build_sample_spec()) - - assert mermaid.startswith("flowchart LR\n") - assert 'tensor_tensor_a["A"]' in mermaid - assert 'tensor_tensor_b["B"]' in mermaid - assert 'tensor_tensor_a <-->|"bond_x / x=3"| tensor_tensor_b' in mermaid - - -def test_render_spec_mermaid_can_hide_tensor_index_and_bond_labels() -> None: - mermaid = render_spec_mermaid( - build_sample_spec(), - options=DotRenderOptions( - show_tensor_labels=False, - show_index_labels=False, - show_edge_labels=False, - ), - ) - - assert 'tensor_tensor_a["tensor_a"]' in mermaid - assert 'tensor_tensor_b["tensor_b"]' in mermaid - assert 'bond_x' not in mermaid - assert 'x=3' not in mermaid - - -def test_render_spec_mermaid_includes_hyperedges_groups_and_notes() -> None: - spec = build_three_tensor_hyperedge_spec() - mermaid = render_spec_mermaid(spec) - - assert "subgraph group_demo [Demo Group]" in mermaid - assert 'hyperedge_h["shared_h"]' in mermaid - assert '%% Note: Check the contraction order' in mermaid - - -def test_render_spec_mermaid_writes_output_path(tmp_path: Path) -> None: - output_path = tmp_path / "network.mmd" - - mermaid = render_spec_mermaid(build_sample_spec(), output_path=output_path) - - assert output_path.read_text(encoding="utf-8") == mermaid -``` - -Add a public API coverage check in `tests/test_api.py` similar to the existing static render exports: - -```python -assert tensor_network_editor.render_spec_mermaid is render_spec_mermaid -``` - -- [ ] **Step 2: Run the renderer tests to verify they fail** - -Run: `.\.venv\Scripts\python -m pytest tests\test_rendering.py -k mermaid` - -Expected: FAIL because `render_spec_mermaid` does not exist yet. - -- [ ] **Step 3: Write the minimal Mermaid renderer** - -Implement in `src/tensor_network_editor/rendering.py`: - -```python -def render_spec_mermaid( - spec: NetworkSpec, - *, - options: DotRenderOptions | None = None, - output_path: StrPath | None = None, -) -> str: - ... -``` - -Add a small renderer class and helpers that: -- emit `flowchart LR` -- create safe Mermaid ids from existing stable ids -- render tensors, pairwise edges, open indices, hyperedge hubs, groups, and note comments -- reuse existing DOT label logic where possible -- write UTF-8 text when `output_path` is provided - -Then export it from `__all__` and the lazy exports in `src/tensor_network_editor/__init__.py`. - -- [ ] **Step 4: Run the renderer tests to verify they pass** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_rendering.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_api.py -k render_spec_mermaid` - -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/tensor_network_editor/rendering.py src/tensor_network_editor/__init__.py tests/test_rendering.py tests/test_api.py -git commit -m "Add Mermaid renderer" -``` - -### Task 2: Integrate Mermaid into the backend route and CLI with TDD - -**Files:** -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\cli.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\routes.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\internal\cli\_cli_handlers.py` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\internal\cli\_cli_parser.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_cli.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_app_routes.py` - -- [ ] **Step 1: Write the failing integration tests** - -Add route coverage in `tests/test_app_routes.py`: - -```python -def test_render_route_returns_mermaid_export(editor_server: EditorServer) -> None: - spec = build_sample_spec() - serialized_spec = {"schema_version": SCHEMA_VERSION, "network": spec.to_dict()} - - payload = request_json( - f"{editor_server.base_url}/api/render", - method="POST", - payload={"format": "mermaid", "spec": serialized_spec}, - ) - - assert payload["format"] == "mermaid" - assert payload["content_type"] == "text/plain;charset=utf-8" - assert payload["text"].startswith("flowchart LR\n") -``` - -Add CLI coverage in `tests/test_cli.py`: - -```python -def test_render_subcommand_writes_mermaid_output(sample_spec: NetworkSpec) -> None: - with ( - patch("tensor_network_editor.cli.load_spec", return_value=sample_spec), - patch("tensor_network_editor.cli.render_spec_mermaid", return_value="flowchart LR\n"), - ): - exit_code = main( - ["render", "saved-network.json", "--format", "mermaid", "--output", "graph.mmd"] - ) - - assert exit_code == 0 - - -def test_render_subcommand_prints_mermaid_when_no_output( - sample_spec: NetworkSpec, - capsys: pytest.CaptureFixture[str], -) -> None: - with ( - patch("tensor_network_editor.cli.load_spec", return_value=sample_spec), - patch("tensor_network_editor.cli.render_spec_mermaid", return_value="flowchart LR\n"), - ): - exit_code = main(["render", "saved-network.json", "--format", "mermaid"]) - - assert exit_code == 0 - assert capsys.readouterr().out == "flowchart LR\n\n" -``` - -- [ ] **Step 2: Run the integration tests to verify they fail** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_routes.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_cli.py -k mermaid` - -Expected: FAIL because the route and CLI do not accept `mermaid` yet. - -- [ ] **Step 3: Wire Mermaid into the route and CLI** - -Add `render_spec_mermaid` imports and format branches in: -- `src/tensor_network_editor/cli.py` -- `src/tensor_network_editor/internal/cli/_cli_handlers.py` -- `src/tensor_network_editor/internal/cli/_cli_parser.py` -- `src/tensor_network_editor/app/routes.py` - -Use: - -```python -content_type = "text/plain;charset=utf-8" -``` - -Use `.mmd` as the expected file extension in user-facing messages and downloads. - -- [ ] **Step 4: Run the integration tests to verify they pass** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_routes.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_cli.py -k mermaid` - -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/tensor_network_editor/cli.py src/tensor_network_editor/app/routes.py src/tensor_network_editor/internal/cli/_cli_handlers.py src/tensor_network_editor/internal/cli/_cli_parser.py tests/test_cli.py tests/test_app_routes.py -git commit -m "Integrate Mermaid export into CLI and API" -``` - -### Task 3: Add Mermaid to the browser editor and docs with TDD - -**Files:** -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\index.html` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\js\core\dom.js` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\js\shell\editorShellBindings.js` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\src\tensor_network_editor\app\static\js\session\sessionEditorFlows.js` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\README.md` -- Modify: `C:\Users\aleja\Documents\draw_to_tensor_network\CHANGELOG.md` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_app_assets.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_frontend_architecture.py` -- Test: `C:\Users\aleja\Documents\draw_to_tensor_network\tests\test_frontend_runtime.py` - -- [ ] **Step 1: Write the failing editor and docs tests** - -Add asset checks in `tests/test_app_assets.py` for: -- `id="export-mermaid-menu-item"` -- `` -- Mermaid listed in help text when export formats are enumerated -- DOM wiring for `exportMermaidMenuItem` - -Add runtime coverage in `tests/test_frontend_runtime.py` similar to the existing academic export flow: - -```javascript -await flows.downloadExportAs("mermaid"); -``` - -Then assert: -- one `renderSpec` call with `payload.format === "mermaid"` -- one text download for `draft_demo.mmd` -- `contentType === "text/plain;charset=utf-8"` - -- [ ] **Step 2: Run the editor tests to verify they fail** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_assets.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_frontend_runtime.py -k mermaid` - -Expected: FAIL because the editor does not expose Mermaid yet. - -- [ ] **Step 3: Implement the editor export wiring** - -Update: -- `index.html` to add a Mermaid export menu item and selector option -- `dom.js` to expose `exportMermaidMenuItem` -- `editorShellBindings.js` to bind `downloadExportAs("mermaid")` -- `sessionEditorFlows.js` to add Mermaid to `exportDetails`, use `.mmd`, and route it through text download - -Also update `README.md` and `CHANGELOG.md` to mention Mermaid export support. - -- [ ] **Step 4: Run the editor tests to verify they pass** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_app_assets.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_frontend_runtime.py -k mermaid` - -Expected: PASS - -- [ ] **Step 5: Commit** - -```bash -git add src/tensor_network_editor/app/static/index.html src/tensor_network_editor/app/static/js/core/dom.js src/tensor_network_editor/app/static/js/shell/editorShellBindings.js src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js README.md CHANGELOG.md tests/test_app_assets.py tests/test_frontend_architecture.py tests/test_frontend_runtime.py -git commit -m "Add Mermaid export to browser editor" -``` - -### Task 4: Final verification and cleanup - -**Files:** -- Verify only - -- [ ] **Step 1: Run focused pytest coverage** - -Run: -- `.\.venv\Scripts\python -m pytest tests\test_rendering.py` -- `.\.venv\Scripts\python -m pytest tests\test_api.py -k render_spec_mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_cli.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_app_routes.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_app_assets.py -k mermaid` -- `.\.venv\Scripts\python -m pytest tests\test_frontend_runtime.py -k mermaid` - -Expected: PASS - -- [ ] **Step 2: Run Python quality checks** - -Run: -- `.\.venv\Scripts\python -m ruff check . --fix` -- `.\.venv\Scripts\python -m ruff format .` - -Expected: PASS - -- [ ] **Step 3: Run type checking and note pre-existing failures separately** - -Run: -- `.\.venv\Scripts\python -m pyright` - -Expected: either PASS or only the already-known unrelated failures in: -- `tests/test_app_routes.py` -- `tests/test_app_server.py` -- `tests/test_session.py` - -- [ ] **Step 4: Commit any final cleanup** - -```bash -git add README.md CHANGELOG.md src tests -git commit -m "Polish Mermaid export integration" -``` diff --git a/docs/superpowers/specs/2026-04-29-mermaid-export-design.md b/docs/superpowers/specs/2026-04-29-mermaid-export-design.md deleted file mode 100644 index bae4812..0000000 --- a/docs/superpowers/specs/2026-04-29-mermaid-export-design.md +++ /dev/null @@ -1,174 +0,0 @@ -# Mermaid Export Design - -## Summary - -Add a new static export format, `mermaid`, for tensor-network diagrams. -The goal is to produce a portable text representation that users can paste -directly into GitHub, Markdown documents, or Mermaid-enabled documentation -tools. - -This export is documentation-oriented. It should preserve graph structure and -labels, not editor geometry or visual styling. - -## Goals - -- Add a first-class `mermaid` export alongside `svg`, `png`, `pdf`, `tikz`, - and `dot`. -- Keep the API, CLI, and browser editor export flows consistent. -- Reuse the existing export label toggles for tensor names, index names, and - bond names. -- Generate Mermaid that is robust and easy to paste into Markdown. - -## Non-Goals - -- Preserve canvas positions, exact layout, colors, or node sizes. -- Reproduce the editor appearance inside Mermaid. -- Add a separate `markdown` export format in v1. -- Fail the export because Mermaid cannot express some detail exactly. - -## Recommended Approach - -Implement a new renderer function: - -- `render_spec_mermaid(spec: NetworkSpec, *, options: DotRenderOptions | None = None, output_path: StrPath | None = None) -> str` - -`DotRenderOptions` is the best fit for v1 because Mermaid is also a -text-oriented graph export and needs the same label visibility controls as -`dot`. - -The renderer should emit a complete Mermaid diagram using: - -```text -flowchart LR -``` - -This direction matches the current left-to-right mental model already used in -`dot`. - -## Representation Rules - -### Tensors - -- Each tensor becomes one Mermaid node. -- If `show_tensor_labels` is true, use the tensor name as the visible label. -- If `show_tensor_labels` is false, keep the node but use a minimal fallback - label based on the tensor id so the graph remains valid and readable. - -### Pairwise edges - -- Each standard edge becomes one Mermaid connection between the two tensor - nodes. -- The edge label should follow the current `dot` behavior: - - show bond name and index label when both are enabled - - show only bond name when only bond names are enabled - - show only index name and dimension when only index names are enabled - - show no label when both are disabled - -### Open indices - -- Each open index becomes a terminal Mermaid node connected to its tensor. -- The node label should reuse the same label logic already used by `dot` for - open indices. - -### Hyperedges - -- Each hyperedge becomes a synthetic hub node connected to all endpoint - tensors. -- The hub label should use the hyperedge name when bond labels are enabled. -- Endpoint edge labels should reuse the current `dot` hyperedge endpoint label - logic when index labels are enabled. - -### Groups - -- Each group should become a Mermaid `subgraph`. -- The renderer should place the member tensor nodes inside that `subgraph`. -- If Mermaid cannot faithfully reflect complex overlap or crossing semantics, - the export should still succeed with a simple `subgraph` structure. - -### Notes - -- Notes should not become positioned visual nodes in v1. -- Export each note as a Mermaid comment line: - -```text -%% Note: Check the contraction order -``` - -This keeps note content available for documentation without forcing awkward -diagram geometry. - -## Escaping and Identifiers - -- Mermaid node ids must be generated from safe internal identifiers, not from - raw labels. -- Visible labels must be escaped conservatively so quotes, brackets, newlines, - and punctuation do not break the diagram. -- Reuse existing conservative escaping ideas from `dot` and `tikz`, but keep - Mermaid-specific syntax rules separate in small helper functions. - -## API and Integration - -### Python API - -- Export `render_spec_mermaid` from `tensor_network_editor.rendering`. -- Re-export it from `tensor_network_editor.__init__`. - -### CLI - -- Extend `render --format` with `mermaid`. -- Print to stdout when `--output` is omitted. -- Use `.mmd` as the recommended output extension. -- Label the success message as `Mermaid`. - -### Browser editor - -- Add `Mermaid` to the export menu and the export format selector. -- Route it through the same `/api/render` flow as `tikz` and `dot`. -- Download it as text with the `.mmd` extension. - -### Backend route - -- Extend `/api/render` to accept `format == "mermaid"`. -- Return: - - `format: "mermaid"` - - `text: ` - - `content_type: "text/plain;charset=utf-8"` - -## Error Handling - -- If the spec is valid, Mermaid export should succeed. -- Unsupported visual details must degrade gracefully instead of raising. -- Rendering should only fail for the same categories already used elsewhere, - such as invalid payloads or invalid specs. - -## Testing - -Add focused tests for: - -- `render_spec_mermaid` basic output for a normal network -- label toggle behavior for tensor, index, and bond labels -- open indices and hyperedges -- group and note emission -- escaping of special characters -- API route `/api/render` with `format="mermaid"` -- CLI `render --format mermaid` -- editor menu and export selector wiring -- frontend download flow and output filename extension - -## Documentation Updates - -- Add Mermaid export to `README.md`. -- Add Mermaid export to the editor help text if that text enumerates supported - export formats. -- Add a short `CHANGELOG.md` entry when implementation lands. - -## Rollout Notes - -The first version should stay intentionally simple: - -- structure first -- labels second -- visual fidelity out of scope - -This keeps the renderer predictable, testable, and useful for documentation -without turning Mermaid into a second layout engine. diff --git a/src/tensor_network_editor/_public_codegen.py b/src/tensor_network_editor/_public_codegen.py index 21f52ff..66f566f 100644 --- a/src/tensor_network_editor/_public_codegen.py +++ b/src/tensor_network_editor/_public_codegen.py @@ -27,6 +27,7 @@ def generate_code( *, engine: EngineIdentifier, collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST, + include_roundtrip_metadata: bool = True, output_path: StrPath | None = None, print_code: bool = False, external_data_base_path: StrPath | None = None, @@ -43,7 +44,10 @@ def generate_code( external_data_base_path=external_data_base_path, ) result = _generate_code( - codegen_spec, engine, collection_format=collection_format + codegen_spec, + engine, + collection_format=collection_format, + include_roundtrip_metadata=include_roundtrip_metadata, ) if print_code: log_branch(LOGGER, "Printing generated code to stdout") diff --git a/src/tensor_network_editor/app/_protocol.py b/src/tensor_network_editor/app/_protocol.py index 9e281a9..d623124 100644 --- a/src/tensor_network_editor/app/_protocol.py +++ b/src/tensor_network_editor/app/_protocol.py @@ -43,6 +43,7 @@ class CodegenRequest: serialized_spec: JsonDict engine: EngineIdentifier collection_format: TensorCollectionFormat + include_roundtrip_metadata: bool @dataclass(slots=True, frozen=True) @@ -269,6 +270,9 @@ def parse_codegen_request( serialized_spec=require_serialized_spec(payload), engine=resolve_engine(payload, default_engine), collection_format=resolve_collection_format(payload, default_collection_format), + include_roundtrip_metadata=require_boolean( + payload, "include_roundtrip_metadata", default=True + ), ) diff --git a/src/tensor_network_editor/app/_session_requests.py b/src/tensor_network_editor/app/_session_requests.py index f1a6ac4..01491da 100644 --- a/src/tensor_network_editor/app/_session_requests.py +++ b/src/tensor_network_editor/app/_session_requests.py @@ -30,6 +30,7 @@ def generate_session_request( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> CodegenResult: """Generate preview code for one editor request.""" with log_operation( @@ -47,6 +48,7 @@ def generate_session_request( spec, engine, collection_format=_resolve_collection_format(session, collection_format), + include_roundtrip_metadata=include_roundtrip_metadata, validate=False, ) @@ -56,6 +58,7 @@ def complete_session_request( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> EditorResult: """Finalize a session request and optionally print or save generated code.""" with log_operation( @@ -75,6 +78,7 @@ def complete_session_request( spec, engine, collection_format=_resolve_collection_format(session, collection_format), + include_roundtrip_metadata=include_roundtrip_metadata, validate=False, ) if session.print_code: diff --git a/src/tensor_network_editor/app/routes.py b/src/tensor_network_editor/app/routes.py index 69cfbd5..01254e6 100644 --- a/src/tensor_network_editor/app/routes.py +++ b/src/tensor_network_editor/app/routes.py @@ -833,6 +833,7 @@ def _handle_session_codegen_request( request.serialized_spec, request.engine, request.collection_format, + request.include_roundtrip_metadata, ) return ok_response(_serialize_generate_result(generate_result)) if operation == "complete": @@ -840,6 +841,7 @@ def _handle_session_codegen_request( request.serialized_spec, request.engine, request.collection_format, + request.include_roundtrip_metadata, ) return ok_response(_serialize_complete_result(complete_result)) raise ValueError(f"Unsupported code generation operation '{operation}'.") diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index 37f002f..2ffad11 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -290,6 +290,7 @@ def generate( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> CodegenResult: """Generate preview code without finalizing the session.""" with log_operation( @@ -306,6 +307,7 @@ def generate( serialized_spec, engine, collection_format, + include_roundtrip_metadata, ) def complete( @@ -313,6 +315,7 @@ def complete( serialized_spec: Mapping[str, object], engine: EngineIdentifier, collection_format: TensorCollectionFormat | None = None, + include_roundtrip_metadata: bool = True, ) -> EditorResult: """Finalize the session and store the resulting editor output.""" with log_operation( @@ -335,6 +338,7 @@ def complete( serialized_spec, engine, collection_format, + include_roundtrip_metadata, ) with self._lock: if self._finished_event.is_set() and self._result is not None: diff --git a/src/tensor_network_editor/app/static/app.css b/src/tensor_network_editor/app/static/app.css index 6d294f5..ca7a1ee 100644 --- a/src/tensor_network_editor/app/static/app.css +++ b/src/tensor_network_editor/app/static/app.css @@ -3369,6 +3369,31 @@ textarea[disabled] { align-items: center; } +.code-metadata-toggle { + display: inline-flex; + align-items: center; + gap: 0.45rem; + min-height: var(--canvas-control-height); + padding: 0 0.8rem; + border: 1px solid var(--border-subtle); + border-radius: 999px; + background: var(--surface-subtle); + color: var(--muted); + font-size: 0.78rem; + font-weight: 600; + line-height: 1.1; + cursor: pointer; + user-select: none; +} + +.code-metadata-toggle input { + margin: 0; +} + +.code-metadata-toggle[hidden] { + display: none; +} + .code-format-picker { position: relative; display: inline-flex; diff --git a/src/tensor_network_editor/app/static/index.html b/src/tensor_network_editor/app/static/index.html index 610db92..2e7b50a 100644 --- a/src/tensor_network_editor/app/static/index.html +++ b/src/tensor_network_editor/app/static/index.html @@ -788,6 +788,30 @@

> Auto layout + + +

/> +

${ - linearPeriodicMode - ? '

Subnetwork export and template promotion are not available in For mode yet.

' + subnetworkActionsMessage + ? `

${escapeHtml(subnetworkActionsMessage)}

` : "" }

Drag the group box on the canvas to move all tensors together.

diff --git a/src/tensor_network_editor/app/static/js/properties/overviewPropertiesMarkup.js b/src/tensor_network_editor/app/static/js/properties/overviewPropertiesMarkup.js index 3b71bfb..5a0db9a 100644 --- a/src/tensor_network_editor/app/static/js/properties/overviewPropertiesMarkup.js +++ b/src/tensor_network_editor/app/static/js/properties/overviewPropertiesMarkup.js @@ -78,7 +78,8 @@ export function buildMultiSelectionPropertiesMarkup({ multiIndexDimensionCandidate, showAddIndexAction, hyperedgeCreationCandidate, - linearPeriodicMode, + disableSubnetworkActions, + subnetworkActionsMessage, batchColor, totalElementCount, formatTotalElementCount, @@ -147,7 +148,7 @@ export function buildMultiSelectionPropertiesMarkup({ id="extract-selection-button" type="button" class="button-accent-positive" - ${linearPeriodicMode ? "disabled" : ""} + ${disableSubnetworkActions ? "disabled" : ""} ${buildTooltipAttributes( "Extract", "Extract the selected tensors as a reusable subnetwork.", @@ -160,7 +161,7 @@ export function buildMultiSelectionPropertiesMarkup({ id="save-selection-subnetwork-library-button" type="button" class="button-accent-template" - ${linearPeriodicMode ? "disabled" : ""} + ${disableSubnetworkActions ? "disabled" : ""} ${buildTooltipAttributes( "To Library", "Save the selected tensors to the subnetwork library." @@ -172,7 +173,7 @@ export function buildMultiSelectionPropertiesMarkup({ id="promote-selection-template-button" type="button" class="button-accent-template" - ${linearPeriodicMode ? "disabled" : ""} + ${disableSubnetworkActions ? "disabled" : ""} ${buildTooltipAttributes( "To Template", "Promote the selected tensors to a reusable template." @@ -194,8 +195,8 @@ export function buildMultiSelectionPropertiesMarkup({ ${ - linearPeriodicMode - ? '

Subnetwork export and template promotion are not available in For mode yet.

' + subnetworkActionsMessage + ? `

${escapeHtml(subnetworkActionsMessage)}

` : "" } ` diff --git a/src/tensor_network_editor/app/static/js/properties/properties.js b/src/tensor_network_editor/app/static/js/properties/properties.js index 3c22f52..a839e08 100644 --- a/src/tensor_network_editor/app/static/js/properties/properties.js +++ b/src/tensor_network_editor/app/static/js/properties/properties.js @@ -91,10 +91,22 @@ function createPropertyActions(ctx) { saveSelectionToSubnetworkLibrary: () => ctx.saveSelectionToSubnetworkLibrary(), promoteSelectedSubnetworkToTemplate: () => ctx.promoteSelectedSubnetworkToTemplate(), createGroupFromSelection: () => ctx.createGroupFromSelection(), + findTensorById: (tensorId) => ctx.findTensorById(tensorId), findGroupById: (groupId) => ctx.findGroupById(groupId), findEdgeById: (edgeId) => ctx.findEdgeById(edgeId), findHyperedgeById: (hyperedgeId) => ctx.findHyperedgeById(hyperedgeId), findNoteById: (noteId) => ctx.findNoteById(noteId), + isForBoundaryTensor: resolveContextAction(ctx, "isForBoundaryTensor", () => false), + isLinearPeriodicBoundaryTensor: resolveContextAction( + ctx, + "isLinearPeriodicBoundaryTensor", + () => false + ), + isTreePeriodicBoundaryTensor: resolveContextAction( + ctx, + "isTreePeriodicBoundaryTensor", + () => false + ), getMetadataColor: (metadata, fallbackColor) => ctx.getMetadataColor(metadata, fallbackColor), toggleGroupCollapse: (groupId) => ctx.toggleGroupCollapse(groupId), diff --git a/src/tensor_network_editor/app/static/js/properties/propertiesRenderersEntities.js b/src/tensor_network_editor/app/static/js/properties/propertiesRenderersEntities.js index 73999f6..12460f2 100644 --- a/src/tensor_network_editor/app/static/js/properties/propertiesRenderersEntities.js +++ b/src/tensor_network_editor/app/static/js/properties/propertiesRenderersEntities.js @@ -26,6 +26,28 @@ export function createEntityPropertiesRenderers({ actions, }); + function isStructuralBoundaryTensor(tensor) { + return Boolean( + tensor && + ( + (typeof actions.isForBoundaryTensor === "function" && + actions.isForBoundaryTensor(tensor)) || + (typeof actions.isLinearPeriodicBoundaryTensor === "function" && + actions.isLinearPeriodicBoundaryTensor(tensor)) || + (typeof actions.isTreePeriodicBoundaryTensor === "function" && + actions.isTreePeriodicBoundaryTensor(tensor)) || + tensor.linear_periodic_role === "previous" || + tensor.linear_periodic_role === "next" || + tensor.grid_periodic_role === "up" || + tensor.grid_periodic_role === "right" || + tensor.grid_periodic_role === "down" || + tensor.grid_periodic_role === "left" || + tensor.tree_periodic_role === "parent" || + tensor.tree_periodic_role === "child" + ) + ); + } + function renderGroupProperties(groupId) { const group = actions.findGroupById(groupId); if (!group) { @@ -36,10 +58,20 @@ export function createEntityPropertiesRenderers({ group.metadata, GRAPH_THEME.groupDefault ); - const linearPeriodicMode = - (typeof actions.isForMode === "function" && actions.isForMode()) || - (typeof actions.isLinearPeriodicMode === "function" && - actions.isLinearPeriodicMode()); + const exportableTensorCount = (Array.isArray(group.tensor_ids) + ? group.tensor_ids + : [] + ) + .map((tensorId) => + typeof actions.findTensorById === "function" + ? actions.findTensorById(tensorId) + : null + ) + .filter((tensor) => tensor && !isStructuralBoundaryTensor(tensor)).length; + const disableSubnetworkActions = exportableTensorCount === 0; + const subnetworkActionsMessage = disableSubnetworkActions + ? "Virtual For-mode boundary tensors cannot be exported or promoted as templates." + : ""; const totalElementCount = getTotalElementCountForTensorIds( Array.isArray(group.tensor_ids) ? group.tensor_ids : [] ); @@ -47,7 +79,8 @@ export function createEntityPropertiesRenderers({ propertiesPanel.innerHTML = buildGroupPropertiesMarkup({ group, groupColor, - linearPeriodicMode, + disableSubnetworkActions, + subnetworkActionsMessage, totalElementCount, formatTotalElementCount, renderTrashIcon, diff --git a/src/tensor_network_editor/app/static/js/properties/propertiesRenderersOverview.js b/src/tensor_network_editor/app/static/js/properties/propertiesRenderersOverview.js index 7c48502..a950aee 100644 --- a/src/tensor_network_editor/app/static/js/properties/propertiesRenderersOverview.js +++ b/src/tensor_network_editor/app/static/js/properties/propertiesRenderersOverview.js @@ -95,10 +95,11 @@ export function createOverviewPropertiesRenderers({ .map((entry) => entry.tensor || null) .filter((tensor) => tensor && !isStructuralBoundaryTensor(tensor)) .map((tensor) => tensor.id); - const linearPeriodicMode = - (typeof actions.isForMode === "function" && actions.isForMode()) || - (typeof actions.isLinearPeriodicMode === "function" && - actions.isLinearPeriodicMode()); + const exportableTensorCount = editableTensorIds.length; + const disableSubnetworkActions = hasMultipleTensors && exportableTensorCount === 0; + const subnetworkActionsMessage = disableSubnetworkActions + ? "Virtual For-mode boundary tensors cannot be exported or promoted as templates." + : ""; const batchColor = actions.getBatchColorValue(selectedEntries); const totalElementCount = getSelectionTotalElementCount(selectedEntries); const hyperedgeCreationCandidate = @@ -131,7 +132,8 @@ export function createOverviewPropertiesRenderers({ multiIndexDimensionCandidate, showAddIndexAction: editableTensorIds.length > 0, hyperedgeCreationCandidate, - linearPeriodicMode, + disableSubnetworkActions, + subnetworkActionsMessage, batchColor, totalElementCount, formatTotalElementCount, diff --git a/src/tensor_network_editor/app/static/js/services/editorSessionService.js b/src/tensor_network_editor/app/static/js/services/editorSessionService.js index fb2f618..b9747f7 100644 --- a/src/tensor_network_editor/app/static/js/services/editorSessionService.js +++ b/src/tensor_network_editor/app/static/js/services/editorSessionService.js @@ -1,9 +1,18 @@ -function buildCodegenPayload({ engine, collectionFormat, spec }) { - return { +function buildCodegenPayload({ + engine, + collectionFormat, + includeRoundtripMetadata, + spec, +}) { + const payload = { engine, collection_format: collectionFormat, spec, }; + if (typeof includeRoundtripMetadata === "boolean") { + payload.include_roundtrip_metadata = includeRoundtripMetadata; + } + return payload; } function summarizeSerializedSpec(serializedSpec) { @@ -34,6 +43,7 @@ function summarizeCodegenRequest(request) { return { engine: request.engine, collection_format: request.collectionFormat, + include_roundtrip_metadata: request.includeRoundtripMetadata, ...summarizeSerializedSpec(request.spec), }; } diff --git a/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js b/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js index bd89288..1b790f7 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js +++ b/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js @@ -61,6 +61,7 @@ export function createSessionEditorFlows({ return sessionService.generateCode({ engine: selectors.getSelectedEngine(), collectionFormat: selectors.getSelectedCollectionFormat(), + includeRoundtripMetadata: state.includeRoundtripMetadata === true, spec: actions.serializeCurrentSpec({ persistViewSnapshots: false }), }); } @@ -232,6 +233,7 @@ export function createSessionEditorFlows({ const payload = await sessionService.completeSession({ engine: selectors.getSelectedEngine(), collectionFormat: selectors.getSelectedCollectionFormat(), + includeRoundtripMetadata: state.includeRoundtripMetadata === true, spec: actions.serializeCurrentSpec({ persistViewSnapshots: true }), }); if (!payload.ok) { diff --git a/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js b/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js index e78cbb3..09b7582 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js +++ b/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js @@ -68,6 +68,45 @@ export function createSessionTemplateFlows({ ); } + function isStructuralBoundaryTensor(tensor) { + return Boolean( + tensor && + ( + (typeof actions.isForBoundaryTensor === "function" && + actions.isForBoundaryTensor(tensor)) || + (typeof actions.isLinearPeriodicBoundaryTensor === "function" && + actions.isLinearPeriodicBoundaryTensor(tensor)) || + (typeof actions.isTreePeriodicBoundaryTensor === "function" && + actions.isTreePeriodicBoundaryTensor(tensor)) || + tensor.linear_periodic_role === "previous" || + tensor.linear_periodic_role === "next" || + tensor.grid_periodic_role === "up" || + tensor.grid_periodic_role === "right" || + tensor.grid_periodic_role === "down" || + tensor.grid_periodic_role === "left" || + tensor.tree_periodic_role === "parent" || + tensor.tree_periodic_role === "child" + ) + ); + } + + function getTemplateEligibleTensorIds(tensorIds) { + const availableTensors = Array.isArray(state.spec?.tensors) ? state.spec.tensors : []; + const tensorById = new Map( + availableTensors.map((tensor) => [tensor.id, tensor]) + ); + return Array.from( + new Set( + (Array.isArray(tensorIds) ? tensorIds : []).filter( + (tensorId) => + typeof tensorId === "string" && + tensorById.has(tensorId) && + !isStructuralBoundaryTensor(tensorById.get(tensorId)) + ) + ) + ); + } + function isBenchmarkSchemeView() { return Boolean( state.benchmarkSession && @@ -167,23 +206,24 @@ export function createSessionTemplateFlows({ } async function extractTemplateSpecByTensorIds(tensorIds, emptySelectionMessage) { - if (isForModeActive()) { + if (!Array.isArray(tensorIds) || !tensorIds.length) { + actions.setStatus(emptySelectionMessage); + return null; + } + const eligibleTensorIds = getTemplateEligibleTensorIds(tensorIds); + if (!eligibleTensorIds.length) { actions.setStatus( - "Templates are only available in normal graph mode.", + "Virtual For-mode boundary tensors cannot be saved as templates.", "error" ); return null; } - if (!Array.isArray(tensorIds) || !tensorIds.length) { - actions.setStatus(emptySelectionMessage); - return null; - } try { const payload = await subnetworkService.extractSubnetwork({ serializedSpec: actions.serializeCurrentSpec({ persistViewSnapshots: false, }), - tensorIds, + tensorIds: eligibleTensorIds, }); if (!payload.ok) { actions.setStatus( @@ -200,23 +240,24 @@ export function createSessionTemplateFlows({ } async function exportSubnetworkByTensorIds(tensorIds, label = "subnetwork") { - if (isForModeActive()) { + if (!Array.isArray(tensorIds) || !tensorIds.length) { + actions.setStatus("Select one or more tensors to extract a subnetwork."); + return; + } + const eligibleTensorIds = getTemplateEligibleTensorIds(tensorIds); + if (!eligibleTensorIds.length) { actions.setStatus( - "Subnetwork export is only available in normal graph mode.", + "Virtual For-mode boundary tensors cannot be extracted as a subnetwork.", "error" ); return; } - if (!Array.isArray(tensorIds) || !tensorIds.length) { - actions.setStatus("Select one or more tensors to extract a subnetwork."); - return; - } try { const payload = await subnetworkService.extractSubnetwork({ serializedSpec: actions.serializeCurrentSpec({ persistViewSnapshots: false, }), - tensorIds, + tensorIds: eligibleTensorIds, }); if (!payload.ok) { actions.setStatus( @@ -608,16 +649,6 @@ export function createSessionTemplateFlows({ tensor_id_count: Array.isArray(tensorIds) ? tensorIds.length : 0, } ); - if (isForModeActive()) { - saveSubnetworkOperation?.finish({ - outcome: "blocked", - }); - actions.setStatus( - "The subnetwork library is only available in normal graph mode.", - "error" - ); - return; - } if (isBenchmarkSchemeView()) { saveSubnetworkOperation?.finish({ outcome: "blocked", @@ -635,6 +666,17 @@ export function createSessionTemplateFlows({ actions.setStatus("Select one or more tensors first."); return; } + const eligibleTensorIds = getTemplateEligibleTensorIds(tensorIds); + if (!eligibleTensorIds.length) { + saveSubnetworkOperation?.finish({ + outcome: "blocked", + }); + actions.setStatus( + "Virtual For-mode boundary tensors cannot be saved to the subnetwork library.", + "error" + ); + return; + } const naming = promptForLibrarySubnetworkName( baseDisplayName, "Subnetwork library save cancelled." @@ -658,7 +700,7 @@ export function createSessionTemplateFlows({ serializedSpec: actions.serializeCurrentSpec({ persistViewSnapshots: false, }), - tensorIds, + tensorIds: eligibleTensorIds, subnetworkName: naming.subnetworkName, tags, overwrite, diff --git a/src/tensor_network_editor/app/static/js/shell/editorShellBindings.js b/src/tensor_network_editor/app/static/js/shell/editorShellBindings.js index 184b59c..3cbe829 100644 --- a/src/tensor_network_editor/app/static/js/shell/editorShellBindings.js +++ b/src/tensor_network_editor/app/static/js/shell/editorShellBindings.js @@ -65,6 +65,8 @@ export function createEditorShellBindings({ expandGeneratedCodeButton, generatedCodeModalBackdrop, generatedCodeModalCloseButton, + codegenRoundtripMetadataField, + codegenRoundtripMetadataCheckbox, templateSelectField, templateSelect, engineSelectField, @@ -81,11 +83,9 @@ export function createEditorShellBindings({ editSessionTemplateMenuItem, openSubnetworkLibraryMenuItem, reflowImportedButton, - reflowAlignLeftButton, - reflowAlignRightButton, - reflowAlignTopButton, - reflowAlignMiddleButton, - reflowAlignBottomButton, + reflowAlignHorizontalButton, + reflowAlignVerticalButton, + reflowRotateSelectionButton, reflowIndicesLeftButton, reflowIndicesRightButton, reflowIndicesTopButton, @@ -368,11 +368,9 @@ export function createEditorShellBindings({ "linear-periodic-next-cell-button", "template-settings-button", "reflow-imported-button", - "reflow-align-left-button", - "reflow-align-right-button", - "reflow-align-top-button", - "reflow-align-middle-button", - "reflow-align-bottom-button", + "reflow-align-horizontal-button", + "reflow-align-vertical-button", + "reflow-rotate-selection-button", "reflow-arrange-chain-button", "reflow-arrange-tree-button", "reflow-arrange-grid-button", @@ -385,6 +383,7 @@ export function createEditorShellBindings({ "reflow-indices-bottom-button", "copy-code-button", "expand-generated-code-button", + "codegen-roundtrip-metadata-field", "template-manager-save-button", "template-manager-discard-button", ].forEach((controlId) => { @@ -398,6 +397,12 @@ export function createEditorShellBindings({ "", "Choose how generated code returns the tensors: list keeps an ordered sequence, matrix arranges row and column structures when the template supports them, and dict returns named entries." ); + shortcutTooltip.applyShortcutHint( + "codegen-roundtrip-metadata-field", + "Metadata", + "", + "Append commented editor metadata at the end of generated For-mode code so it can be reconstructed more faithfully when you import it back later." + ); shortcutTooltip.applyShortcutHint( "reflow-imported-button", "Reflow", @@ -644,6 +649,9 @@ export function createEditorShellBindings({ bindListener(templateSelect, "change", (event) => { setSelectChevronExpanded(templateSelectField, false); actions.handleTemplateSelectionChange(event); + if (typeof templateSelect?.blur === "function") { + templateSelect.blur(); + } }); bindListener( templateParameterPanel, @@ -696,11 +704,9 @@ export function createEditorShellBindings({ bindListener(reflowImportedButton, "click", () => { actions.toggleReflowLayoutPopover(); }); - bindReflowAction(reflowAlignLeftButton, "left"); - bindReflowAction(reflowAlignRightButton, "right"); - bindReflowAction(reflowAlignTopButton, "top"); - bindReflowAction(reflowAlignMiddleButton, "middle"); - bindReflowAction(reflowAlignBottomButton, "bottom"); + bindReflowAction(reflowAlignHorizontalButton, "align-horizontal"); + bindReflowAction(reflowAlignVerticalButton, "align-vertical"); + bindReflowAction(reflowRotateSelectionButton, "rotate-90"); bindReflowIndicesAction(reflowIndicesLeftButton, "left"); bindReflowIndicesAction(reflowIndicesRightButton, "right"); bindReflowIndicesAction(reflowIndicesTopButton, "top"); @@ -795,6 +801,10 @@ export function createEditorShellBindings({ actions.scheduleDraftAutosave(); } }); + bindListener(codegenRoundtripMetadataCheckbox, "change", (event) => { + store.setIncludeRoundtripMetadata(event.target.checked); + actions.updateToolbarState(); + }); bindListener(documentRef, "mousedown", (event) => { if (!isWithinTransientToolbarUi(event.target)) { actions.closeTransientToolbarUi(); @@ -803,7 +813,7 @@ export function createEditorShellBindings({ bindListener(loadInput, "change", actions.loadDesignFromFile); bindListener(subnetworkLoadInput, "change", actions.loadSubnetworkFromFile); bindListener(templateLoadInput, "change", actions.loadSessionTemplatesFromFile); - bindListener(windowRef, "keydown", actions.handleKeydown); + bindListener(windowRef, "keydown", actions.handleKeydown, true); bindListener(windowRef, "resize", actions.handleWindowResize); bindListener(windowRef, "mousemove", actions.handleGlobalMouseMove); bindListener(windowRef, "mouseup", actions.handleGlobalMouseUp); diff --git a/src/tensor_network_editor/app/static/js/state/editorStore.js b/src/tensor_network_editor/app/static/js/state/editorStore.js index ab8c01c..1d9bd12 100644 --- a/src/tensor_network_editor/app/static/js/state/editorStore.js +++ b/src/tensor_network_editor/app/static/js/state/editorStore.js @@ -44,6 +44,11 @@ export function createEditorStore(state) { return state.selectedCollectionFormat; } + function setIncludeRoundtripMetadata(includeRoundtripMetadata) { + state.includeRoundtripMetadata = Boolean(includeRoundtripMetadata); + return state.includeRoundtripMetadata; + } + function setSelectedTheme(themeName) { state.selectedTheme = typeof themeName === "string" ? themeName : state.selectedTheme; return state.selectedTheme; @@ -153,6 +158,7 @@ export function createEditorStore(state) { setAnnotationDefinitions, setSelectedEngine, setSelectedCollectionFormat, + setIncludeRoundtripMetadata, setSelectedTheme, setGeneratedCode, setEditorFinished, diff --git a/src/tensor_network_editor/app/static/js/state/historySnapshots.js b/src/tensor_network_editor/app/static/js/state/historySnapshots.js index fce6765..2d477c8 100644 --- a/src/tensor_network_editor/app/static/js/state/historySnapshots.js +++ b/src/tensor_network_editor/app/static/js/state/historySnapshots.js @@ -105,6 +105,9 @@ export function createHistorySnapshotSupport({ state.undoStack.shift(); } state.redoStack = []; + if (typeof bumpSpecRevision === "function") { + bumpSpecRevision(); + } state.lastMutationClearedCode = clearGeneratedCodePreview(); updateToolbarState(); return true; diff --git a/src/tensor_network_editor/app/static/js/state/state.js b/src/tensor_network_editor/app/static/js/state/state.js index bad90ab..6cef799 100644 --- a/src/tensor_network_editor/app/static/js/state/state.js +++ b/src/tensor_network_editor/app/static/js/state/state.js @@ -25,6 +25,7 @@ export function createInitialState() { availableCollectionFormats: [], selectedEngine: null, selectedCollectionFormat: null, + includeRoundtripMetadata: false, selectedTheme: DEFAULT_EDITOR_THEME_NAME, academicExportLabels: { tensor: true, diff --git a/src/tensor_network_editor/app/static/js/utils/utilitiesGeometry.js b/src/tensor_network_editor/app/static/js/utils/utilitiesGeometry.js index 54c6d3a..23fae54 100644 --- a/src/tensor_network_editor/app/static/js/utils/utilitiesGeometry.js +++ b/src/tensor_network_editor/app/static/js/utils/utilitiesGeometry.js @@ -41,6 +41,22 @@ export function createUtilityGeometryBindings({ ); } + function getVisibleLayeringTensors() { + if (typeof ctx.getVisibleTensors === "function") { + const visibleTensors = ctx.getVisibleTensors(); + if (Array.isArray(visibleTensors)) { + return visibleTensors; + } + } + if (typeof runtime.getVisibleTensors === "function") { + const visibleTensors = runtime.getVisibleTensors(); + if (Array.isArray(visibleTensors)) { + return visibleTensors; + } + } + return Array.isArray(state.spec?.tensors) ? state.spec.tensors : []; + } + function isIndexElevated(indexId, tensorId = null) { return Boolean( isIndexConnected(indexId) || @@ -443,7 +459,7 @@ export function createUtilityGeometryBindings({ } function reconcileTensorOrder() { - const tensorIds = state.spec ? state.spec.tensors.map((tensor) => tensor.id) : []; + const tensorIds = getVisibleLayeringTensors().map((tensor) => tensor.id); const activeTensorIds = new Set(tensorIds); const nextOrder = []; const seenTensorIds = new Set(); @@ -488,13 +504,21 @@ export function createUtilityGeometryBindings({ return; } reconcileTensorOrder(); + const visibleTensors = getVisibleLayeringTensors(); + const visibleTensorById = Object.fromEntries( + visibleTensors.map((tensor) => [tensor.id, tensor]) + ); state.tensorOrder.forEach((tensorId) => { const tensorRank = state.tensorRankById[tensorId] ?? 0; const tensorElement = state.cy.getElementById(tensorId); if (tensorElement && tensorElement.length) { tensorElement.data("zIndex", TENSOR_BASE_Z_INDEX + tensorRank); } - const tensor = runtime.findTensorById(tensorId); + const tensor = + visibleTensorById[tensorId] || + (typeof runtime.findTensorById === "function" + ? runtime.findTensorById(tensorId) + : null); if (!tensor) { return; } diff --git a/src/tensor_network_editor/app/static/js/utils/utilitiesLayout.js b/src/tensor_network_editor/app/static/js/utils/utilitiesLayout.js index 47740f2..f34c880 100644 --- a/src/tensor_network_editor/app/static/js/utils/utilitiesLayout.js +++ b/src/tensor_network_editor/app/static/js/utils/utilitiesLayout.js @@ -121,6 +121,37 @@ export function createUtilityLayoutBindings({ ctx, state, constants }) { }, "Snapped tensors to the grid."); } + function rotateSelectedTensorsClockwise() { + const tensors = getSelectedLayoutTensors(); + if (tensors.length < 2) { + ctx.setStatus("Select at least two tensors to rotate."); + return false; + } + + const bounds = computeTensorBounds(tensors); + const centerX = (bounds.left + bounds.right) / 2; + const centerY = (bounds.top + bounds.bottom) / 2; + + return applyTensorLayoutChange(() => { + tensors.forEach((tensor) => { + const deltaX = tensor.position.x - centerX; + const deltaY = tensor.position.y - centerY; + tensor.position.x = centerX - deltaY; + tensor.position.y = centerY + deltaX; + tensor.indices.forEach((index) => { + const rotatedOffset = { + x: -index.offset.y, + y: index.offset.x, + }; + index.offset = + typeof ctx.clampIndexOffset === "function" + ? ctx.clampIndexOffset(rotatedOffset, tensor) + : rotatedOffset; + }); + }); + }, "Rotated the selected tensors and ports 90° clockwise."); + } + function arrangeSelectedTensors(mode) { const tensorIds = getSelectedLayoutTensorIds(); if (tensorIds.length < 2) { @@ -172,8 +203,18 @@ export function createUtilityLayoutBindings({ ctx, state, constants }) { function applyReflowLayoutAction(layoutAction) { const action = typeof layoutAction === "string" ? layoutAction : ""; + if (action === "align-horizontal") { + return alignSelectedTensors("middle"); + } + if (action === "align-vertical") { + return alignSelectedTensors("center"); + } + if (action === "rotate-90") { + return rotateSelectedTensorsClockwise(); + } if ( action === "left" || + action === "center" || action === "right" || action === "top" || action === "middle" || @@ -269,6 +310,7 @@ export function createUtilityLayoutBindings({ ctx, state, constants }) { arrangeSelectedTensors, distributeSelectedTensors, reflowLastImportedTensors, + rotateSelectedTensorsClockwise, snapSelectedTensorsToGrid, }; } diff --git a/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarActionState.js b/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarActionState.js index d04b763..5f86e51 100644 --- a/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarActionState.js +++ b/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarActionState.js @@ -37,11 +37,9 @@ export function createUiToolbarActionStateSupport({ editSessionTemplateMenuItem, openSubnetworkLibraryMenuItem, reflowImportedButton, - reflowAlignLeftButton, - reflowAlignRightButton, - reflowAlignTopButton, - reflowAlignMiddleButton, - reflowAlignBottomButton, + reflowAlignHorizontalButton, + reflowAlignVerticalButton, + reflowRotateSelectionButton, reflowIndicesLeftButton, reflowIndicesRightButton, reflowIndicesTopButton, @@ -56,6 +54,8 @@ export function createUiToolbarActionStateSupport({ reflowSnapGridButton, createGroupButton, generateButton, + codegenRoundtripMetadataField, + codegenRoundtripMetadataCheckbox, benchmarkCompareModal, benchmarkCompareTableBody, } = dom; @@ -77,10 +77,14 @@ export function createUiToolbarActionStateSupport({ templateToolbarGroup, selectedTemplateValue, selectedTensorIds, + selectedExportableTensorIds, graphTensorCount, autoLayoutTensorCount, hasSelectedIndices, } = derivedState; + const exportableTensorIds = Array.isArray(selectedExportableTensorIds) + ? selectedExportableTensorIds + : []; if (undoButton) { undoButton.disabled = state.undoStack.length === 0; @@ -91,6 +95,14 @@ export function createUiToolbarActionStateSupport({ if (generateButton) { generateButton.disabled = !state.spec || !state.selectedEngine; } + if (codegenRoundtripMetadataCheckbox) { + codegenRoundtripMetadataCheckbox.checked = + state.includeRoundtripMetadata === true; + } + setElementHidden( + codegenRoundtripMetadataField, + benchmarkSchemeView || !forMode + ); if (exportMenuItem) { exportMenuItem.disabled = false; } @@ -141,18 +153,19 @@ export function createUiToolbarActionStateSupport({ setMenuItemChecked(themeShinyMenuItem, state.selectedTheme === "shiny"); } if (saveSessionTemplateMenuItem) { - saveSessionTemplateMenuItem.disabled = forMode || selectedTensorIds.length === 0; + saveSessionTemplateMenuItem.disabled = + exportableTensorIds.length === 0; } if (saveSubnetworkLibraryMenuItem) { saveSubnetworkLibraryMenuItem.disabled = - forMode || benchmarkSchemeView || selectedTensorIds.length === 0; + benchmarkSchemeView || exportableTensorIds.length === 0; } if (loadSessionTemplateMenuItem) { loadSessionTemplateMenuItem.disabled = false; } if (exportSessionTemplateMenuItem) { exportSessionTemplateMenuItem.disabled = - forMode || selectedTensorIds.length === 0; + exportableTensorIds.length === 0; } if (editSessionTemplateMenuItem) { editSessionTemplateMenuItem.disabled = state.availableTemplates.length === 0; @@ -190,11 +203,9 @@ export function createUiToolbarActionStateSupport({ } setButtonGroupDisabled( [ - reflowAlignLeftButton, - reflowAlignRightButton, - reflowAlignTopButton, - reflowAlignMiddleButton, - reflowAlignBottomButton, + reflowAlignHorizontalButton, + reflowAlignVerticalButton, + reflowRotateSelectionButton, reflowArrangeChainButton, reflowArrangeTreeButton, reflowArrangeGridButton, @@ -235,7 +246,7 @@ export function createUiToolbarActionStateSupport({ ); if (templateSettingsButton) { templateSettingsButton.disabled = - benchmarkSchemeView || !selectedTemplateValue || forMode; + benchmarkSchemeView || !selectedTemplateValue; setElementHidden( templateSettingsButton.parentElement || templateSettingsButton, benchmarkSchemeView @@ -246,13 +257,11 @@ export function createUiToolbarActionStateSupport({ ? "Template parameters are unavailable while viewing a benchmark scheme." : !selectedTemplateValue ? "Choose a template first." - : forMode - ? "Template parameters are not editable in For mode." - : "Edit template parameters." + : "Edit template parameters for the next template insertion." ); } if ( - (benchmarkSchemeView || !selectedTemplateValue || forMode) && + (benchmarkSchemeView || !selectedTemplateValue) && state.isTemplateSettingsOpen ) { state.isTemplateSettingsOpen = false; diff --git a/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarDerivedState.js b/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarDerivedState.js index fb42784..5ad2c3e 100644 --- a/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarDerivedState.js +++ b/src/tensor_network_editor/app/static/js/utils/utilitiesUiToolbarDerivedState.js @@ -36,7 +36,29 @@ export function createUiToolbarDerivedStateSupport({ || reflowImportedButton.parentElement : insertTemplateButton ? insertTemplateButton.parentElement || insertTemplateButton - : null; + : null; + } + + function isStructuralBoundaryTensor(tensor) { + return Boolean( + tensor && + ( + (typeof runtime.isForBoundaryTensor === "function" && + runtime.isForBoundaryTensor(tensor)) || + (typeof runtime.isLinearPeriodicBoundaryTensor === "function" && + runtime.isLinearPeriodicBoundaryTensor(tensor)) || + (typeof runtime.isTreePeriodicBoundaryTensor === "function" && + runtime.isTreePeriodicBoundaryTensor(tensor)) || + tensor.linear_periodic_role === "previous" || + tensor.linear_periodic_role === "next" || + tensor.grid_periodic_role === "up" || + tensor.grid_periodic_role === "right" || + tensor.grid_periodic_role === "down" || + tensor.grid_periodic_role === "left" || + tensor.tree_periodic_role === "parent" || + tensor.tree_periodic_role === "child" + ) + ); } function getToolbarDerivedState() { @@ -84,6 +106,9 @@ export function createUiToolbarDerivedStateSupport({ const selectedTensors = selectedTensorIds .map((tensorId) => findTensorById(tensorId)) .filter(Boolean); + const selectedExportableTensorIds = selectedTensors + .filter((tensor) => !isStructuralBoundaryTensor(tensor)) + .map((tensor) => tensor.id); const hasSelectedIndices = selectedTensors.some( (tensor) => Array.isArray(tensor.indices) && tensor.indices.length > 0 ); @@ -114,6 +139,7 @@ export function createUiToolbarDerivedStateSupport({ activeBenchmarkScheme, selectedTemplateValue, selectedTensorIds, + selectedExportableTensorIds, graphTensorCount, autoLayoutTensorCount, hasSelectedIndices, diff --git a/src/tensor_network_editor/codegen/modes/_linear_periodic/carry.py b/src/tensor_network_editor/codegen/modes/_linear_periodic/carry.py index 4e1886c..4a88669 100644 --- a/src/tensor_network_editor/codegen/modes/_linear_periodic/carry.py +++ b/src/tensor_network_editor/codegen/modes/_linear_periodic/carry.py @@ -579,12 +579,19 @@ def _resolve_previous_payload_operand_state( for index, operand_id in enumerate(previous_payload_state.interface_operand_ids) if operand_id == selected_operand_id } + renamed_labels = tuple( + incoming_label_by_payload_label.get( + label, + _previous_payload_local_label( + operand_id=selected_operand_id, + axis_index=axis_index, + ), + ) + for axis_index, label in enumerate(selected_state.labels) + ) return ( _CarryOperandState( - labels=tuple( - incoming_label_by_payload_label.get(label, label) - for label in selected_state.labels - ), + labels=renamed_labels, axis_names=selected_state.axis_names, dimensions=selected_state.dimensions, ), @@ -592,6 +599,11 @@ def _resolve_previous_payload_operand_state( ) +def _previous_payload_local_label(*, operand_id: str, axis_index: int) -> str: + """Return a collision-free simulation label for one carried local axis.""" + return f"__previous_payload_{operand_id}_{axis_index}" + + def _find_remaining_operand_id_for_label( *, label: str, diff --git a/src/tensor_network_editor/codegen/modes/_linear_periodic/graph_carry.py b/src/tensor_network_editor/codegen/modes/_linear_periodic/graph_carry.py index 647507d..9b7747b 100644 --- a/src/tensor_network_editor/codegen/modes/_linear_periodic/graph_carry.py +++ b/src/tensor_network_editor/codegen/modes/_linear_periodic/graph_carry.py @@ -8,6 +8,7 @@ from ...backends.tensornetwork import TensorNetworkCodeGenerator from ...shared._linear_periodic_expressions import ( _axis_name_for_engine, + _axis_names_for_engine, _build_remaining_label_expression_map, _carry_cell_key_prefix_expression, _operand_expression, @@ -43,6 +44,19 @@ def _render_carry_cell_helper( collection_format=collection_format, collection_name=collection_name, ) + tracked_export_lines, tracked_export_expressions = ( + _render_tensorkrowch_export_tracking_lines( + simulation=simulation, + collection_format=collection_format, + collection_name=collection_name, + ) + if engine is EngineName.TENSORKROWCH + else ([], {}) + ) + if tracked_export_lines: + if network_connection_lines: + network_connection_lines.append("") + network_connection_lines.extend(tracked_export_lines) previous_interface_lines = _render_carry_boundary_setup( simulation=simulation, engine=engine, @@ -55,6 +69,7 @@ def _render_carry_cell_helper( engine=engine, collection_format=collection_format, collection_name=collection_name, + tracked_export_expressions=tracked_export_expressions, ) return render_linear_periodic_helper( helper_name=helper_name, @@ -125,6 +140,49 @@ def _render_carry_boundary_setup( return lines +def _render_tensorkrowch_export_tracking_lines( + *, + simulation: _CarryPlanSimulation, + collection_format: TensorCollectionFormat, + collection_name: str, +) -> tuple[list[str], dict[str, str]]: + """Track only current-cell local-open edges before later contractions rename them.""" + if not simulation.local_open_labels: + return [], {} + + label_expression_by_label: dict[str, str] = {} + for tensor in simulation.prepared.tensors: + tensor_expression = tensor_collection_reference_by_id( + simulation.prepared, + tensor.spec.id, + collection_format, + collection_name, + ) + runtime_axis_names = _axis_names_for_engine( + EngineName.TENSORKROWCH, + tuple(index.spec.name for index in tensor.indices), + ) + for index, runtime_axis_name in zip( + tensor.indices, + runtime_axis_names, + strict=True, + ): + label_expression_by_label[index.label] = ( + f"{tensor_expression}[{runtime_axis_name!r}]" + ) + + lines = ["# Tracked current-cell edges"] + tracked_expressions: dict[str, str] = {} + for tracked_index, label in enumerate(simulation.local_open_labels): + label_expression = label_expression_by_label.get(label) + if label_expression is None: + continue + variable_name = f"tracked_edge_{tracked_index}" + lines.append(f"{variable_name} = {label_expression}") + tracked_expressions[label] = variable_name + return lines, tracked_expressions + + def _render_carry_plan_lines( *, simulation: _CarryPlanSimulation, @@ -132,6 +190,7 @@ def _render_carry_plan_lines( engine: EngineName, collection_format: TensorCollectionFormat, collection_name: str, + tracked_export_expressions: dict[str, str], ) -> tuple[list[str], list[str]]: """Render all carry-mode contractions and helper epilogue lines.""" if engine is EngineName.TENSORKROWCH and any( @@ -197,14 +256,18 @@ def _render_carry_plan_lines( "results_list.append(tk.contract_between(" f"{left_expression}, {right_expression}))" ) + if step_index < len(simulation.real_steps) - 1: + contraction_lines.append( + "results_list[-1].reattach_edges(override=True)" + ) contraction_lines.append("") final_result_index = ( len(simulation.real_steps) - 1 if simulation.real_steps else None ) output_lines: list[str] = [] - if engine is EngineName.TENSORKROWCH and simulation.outgoing_interface_operand_ids: - for operand_id in dict.fromkeys(simulation.outgoing_interface_operand_ids): + if engine is EngineName.TENSORKROWCH: + for operand_id in dict.fromkeys(simulation.remaining_operand_ids): if operand_id not in simulation.result_index_by_step_id: continue operand_expression = _operand_expression( @@ -213,7 +276,7 @@ def _render_carry_plan_lines( step_result_indexes=simulation.result_index_by_step_id, latest_result_index=final_result_index, ) - output_lines.append(f"{operand_expression}.reattach_edges()") + output_lines.append(f"{operand_expression}.reattach_edges(override=True)") label_expression_by_label = _build_remaining_label_expression_map( remaining_operand_ids=simulation.remaining_operand_ids, @@ -222,6 +285,8 @@ def _render_carry_plan_lines( step_result_indexes=simulation.result_index_by_step_id, latest_result_index=final_result_index, ) + if tracked_export_expressions: + label_expression_by_label.update(tracked_export_expressions) local_open_expressions = [ label_expression_by_label[label] for label in simulation.local_open_labels diff --git a/src/tensor_network_editor/codegen/modes/grid_periodic.py b/src/tensor_network_editor/codegen/modes/grid_periodic.py index fa711dd..9c780f6 100644 --- a/src/tensor_network_editor/codegen/modes/grid_periodic.py +++ b/src/tensor_network_editor/codegen/modes/grid_periodic.py @@ -14,6 +14,7 @@ def generate_grid_periodic_code( engine: EngineName, *, collection_format: TensorCollectionFormat, + include_roundtrip_metadata: bool = True, validate: bool = True, ) -> CodegenResult: """Generate helper-based Python code for the bidimensional periodic mode.""" @@ -34,7 +35,11 @@ def generate_grid_periodic_code( engine=engine, collection_format=collection_format, ) - return with_roundtrip_spec_marker(result, spec=spec) + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) if engine not in {EngineName.TENSORNETWORK, EngineName.TENSORKROWCH}: raise CodeGenerationError( f"The {engine.value} backend does not support grid periodic code generation." @@ -44,4 +49,8 @@ def generate_grid_periodic_code( engine=engine, collection_format=collection_format, ) - return with_roundtrip_spec_marker(result, spec=spec) + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) diff --git a/src/tensor_network_editor/codegen/modes/linear_periodic.py b/src/tensor_network_editor/codegen/modes/linear_periodic.py index 853a68b..686b8ad 100644 --- a/src/tensor_network_editor/codegen/modes/linear_periodic.py +++ b/src/tensor_network_editor/codegen/modes/linear_periodic.py @@ -15,6 +15,7 @@ def generate_linear_periodic_code( engine: EngineName, *, collection_format: TensorCollectionFormat, + include_roundtrip_metadata: bool = True, validate: bool = True, ) -> CodegenResult: """Generate helper-based Python code for the linear periodic-chain mode.""" @@ -37,7 +38,11 @@ def generate_linear_periodic_code( collection_format=collection_format, uses_carry_mode=uses_carry_mode, ) - return with_roundtrip_spec_marker(result, spec=spec) + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) if engine not in {EngineName.TENSORNETWORK, EngineName.TENSORKROWCH}: raise CodeGenerationError( f"The {engine.value} backend does not support linear periodic code generation." @@ -48,4 +53,8 @@ def generate_linear_periodic_code( collection_format=collection_format, uses_carry_mode=uses_carry_mode, ) - return with_roundtrip_spec_marker(result, spec=spec) + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) diff --git a/src/tensor_network_editor/codegen/modes/tree_periodic.py b/src/tensor_network_editor/codegen/modes/tree_periodic.py index 6e2a6e4..96da512 100644 --- a/src/tensor_network_editor/codegen/modes/tree_periodic.py +++ b/src/tensor_network_editor/codegen/modes/tree_periodic.py @@ -14,6 +14,7 @@ def generate_tree_periodic_code( engine: EngineName, *, collection_format: TensorCollectionFormat, + include_roundtrip_metadata: bool = True, validate: bool = True, ) -> CodegenResult: """Generate helper-based Python code for the tree periodic mode.""" @@ -34,7 +35,11 @@ def generate_tree_periodic_code( engine=engine, collection_format=collection_format, ) - return with_roundtrip_spec_marker(result, spec=spec) + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) if engine not in {EngineName.TENSORNETWORK, EngineName.TENSORKROWCH}: raise CodeGenerationError( f"The {engine.value} backend does not support tree periodic code generation." @@ -44,4 +49,8 @@ def generate_tree_periodic_code( engine=engine, collection_format=collection_format, ) - return with_roundtrip_spec_marker(result, spec=spec) + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) diff --git a/src/tensor_network_editor/codegen/registry.py b/src/tensor_network_editor/codegen/registry.py index a2d5903..454b96d 100644 --- a/src/tensor_network_editor/codegen/registry.py +++ b/src/tensor_network_editor/codegen/registry.py @@ -104,6 +104,7 @@ def generate_code( engine: EngineIdentifier, *, collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST, + include_roundtrip_metadata: bool = True, validate: bool = True, ) -> CodegenResult: """Generate Python code through the registered backend generator.""" @@ -115,6 +116,7 @@ def generate_code( spec, normalized_engine, collection_format=collection_format, + include_roundtrip_metadata=include_roundtrip_metadata, validate=validate, ) if spec.tree_periodic_tree is not None and isinstance( @@ -124,6 +126,7 @@ def generate_code( spec, normalized_engine, collection_format=collection_format, + include_roundtrip_metadata=include_roundtrip_metadata, validate=validate, ) if spec.linear_periodic_chain is not None and isinstance( @@ -133,6 +136,7 @@ def generate_code( spec, normalized_engine, collection_format=collection_format, + include_roundtrip_metadata=include_roundtrip_metadata, validate=validate, ) generator = get_generator(normalized_engine) diff --git a/src/tensor_network_editor/codegen/shared/roundtrip.py b/src/tensor_network_editor/codegen/shared/roundtrip.py index 45c130e..6f43fb7 100644 --- a/src/tensor_network_editor/codegen/shared/roundtrip.py +++ b/src/tensor_network_editor/codegen/shared/roundtrip.py @@ -31,11 +31,12 @@ def with_roundtrip_spec_marker( *, spec: NetworkSpec, ) -> CodegenResult: - """Return ``result`` with a leading serialized-spec comment marker.""" + """Return ``result`` with a trailing serialized-spec comment marker.""" marker = "\n".join(render_roundtrip_spec_marker_lines(spec)) + code = result.code.rstrip() return CodegenResult( engine=result.engine, - code=f"{marker}\n{result.code}", + code=f"{code}\n\n{marker}\n", warnings=list(result.warnings), artifacts=dict(result.artifacts), ) diff --git a/src/tensor_network_editor/internal/validation/_validation_linear_periodic.py b/src/tensor_network_editor/internal/validation/_validation_linear_periodic.py index d897279..d392b77 100644 --- a/src/tensor_network_editor/internal/validation/_validation_linear_periodic.py +++ b/src/tensor_network_editor/internal/validation/_validation_linear_periodic.py @@ -4,9 +4,12 @@ from dataclasses import dataclass +from ...codegen.modes._linear_periodic.carry import _build_carry_simulation_map from ...codegen.shared.common import prepare_network +from ...errors import CodeGenerationError from ...models import ( ContractionStepSpec, + EngineName, LinearPeriodicCellName, LinearPeriodicCellSpec, LinearPeriodicChainSpec, @@ -381,6 +384,74 @@ def _validate_linear_periodic_carry_mode( for cell_name, cell in iter_linear_periodic_cells(chain): _validate_linear_periodic_carry_cell(cell_name, cell, issues=issues) + _validate_linear_periodic_carry_codegen_compatibility(chain, issues=issues) + + +def _validate_linear_periodic_carry_codegen_compatibility( + chain: LinearPeriodicChainSpec, + *, + issues: list[ValidationIssue], +) -> None: + """Reject carry plans that the linear-periodic code generator cannot realize.""" + try: + _build_carry_simulation_map( + chain=chain, + engine=EngineName.TENSORNETWORK, + ) + except CodeGenerationError as exc: + append_issue( + issues, + code="linear-periodic-carry-codegen", + message=str(exc), + path=_linear_periodic_codegen_issue_path(chain, str(exc)), + ) + + +def _linear_periodic_codegen_issue_path( + chain: LinearPeriodicChainSpec, + message: str, +) -> str: + """Map one carry-codegen validation failure onto the closest spec path.""" + step_marker = "Carry step '" + cell_marker = "' in cell '" + if message.startswith(step_marker) and cell_marker in message: + step_start = len(step_marker) + step_end = message.find("'", step_start) + cell_start = message.find(cell_marker, step_end) + len(cell_marker) + cell_end = message.find("'", cell_start) + if step_end > step_start and cell_end > cell_start: + step_id = message[step_start:step_end] + cell_name_text = message[cell_start:cell_end] + cell_name = _linear_periodic_cell_name_from_text(cell_name_text) + if cell_name is not None: + return ( + f"{_linear_periodic_cell_prefix(cell_name)}" + f".contraction_plan.steps.{step_id}" + ) + + cell_prefix_marker = "Cell '" + if message.startswith(cell_prefix_marker): + cell_start = len(cell_prefix_marker) + cell_end = message.find("'", cell_start) + if cell_end > cell_start: + cell_name = _linear_periodic_cell_name_from_text( + message[cell_start:cell_end] + ) + if cell_name is not None: + return f"{_linear_periodic_cell_prefix(cell_name)}.contraction_plan" + + active_cell = chain.active_cell + return f"{_linear_periodic_cell_prefix(active_cell)}.contraction_plan" + + +def _linear_periodic_cell_name_from_text( + cell_name_text: str, +) -> LinearPeriodicCellName | None: + """Return the enum value for one serialized linear-periodic cell name.""" + try: + return LinearPeriodicCellName(cell_name_text) + except ValueError: + return None def _validate_linear_periodic_carry_cell( diff --git a/src/tensor_network_editor/rendering.py b/src/tensor_network_editor/rendering.py index ec529d3..7e4adb3 100644 --- a/src/tensor_network_editor/rendering.py +++ b/src/tensor_network_editor/rendering.py @@ -33,6 +33,33 @@ _LIGHT_TEXT_FILL = "#f5f9ff" _DARK_TEXT_FILL = "#091018" _READABLE_TEXT_LUMINANCE_THRESHOLD = 0.62 +_FREE_INDEX_CLEARANCE_MARGIN = 18.0 +_FREE_INDEX_DIAGONAL_DIRECTIONS = ( + CanvasPosition(x=1.0, y=1.0), + CanvasPosition(x=1.0, y=-1.0), + CanvasPosition(x=-1.0, y=1.0), + CanvasPosition(x=-1.0, y=-1.0), +) +_FREE_INDEX_CARDINAL_DIRECTIONS = ( + CanvasPosition(x=1.0, y=0.0), + CanvasPosition(x=-1.0, y=0.0), + CanvasPosition(x=0.0, y=1.0), + CanvasPosition(x=0.0, y=-1.0), +) +_FREE_INDEX_DIRECTION_HINTS = { + "left": CanvasPosition(x=-1.0, y=0.0), + "right": CanvasPosition(x=1.0, y=0.0), + "up": CanvasPosition(x=0.0, y=-1.0), + "down": CanvasPosition(x=0.0, y=1.0), + "north": CanvasPosition(x=0.0, y=-1.0), + "south": CanvasPosition(x=0.0, y=1.0), + "east": CanvasPosition(x=1.0, y=0.0), + "west": CanvasPosition(x=-1.0, y=0.0), + "xp": CanvasPosition(x=1.0, y=0.0), + "xm": CanvasPosition(x=-1.0, y=0.0), + "yp": CanvasPosition(x=0.0, y=-1.0), + "ym": CanvasPosition(x=0.0, y=1.0), +} LOGGER = logging.getLogger(__name__) @@ -288,6 +315,9 @@ def __init__(self, spec: NetworkSpec, options: SvgRenderOptions) -> None: for index_position, index in enumerate(tensor.indices) } self._connected_index_ids = _connected_index_ids(spec) + self._free_index_direction_by_id: dict[str, CanvasPosition] = {} + self._adjacency_by_tensor_id = self._build_tensor_adjacency() + self._component_tensor_ids_by_tensor_id = self._build_component_tensor_ids() def render(self) -> str: """Return the complete SVG document.""" @@ -657,19 +687,26 @@ def open_index_endpoint( tensor: TensorSpec, index: IndexSpec, *, - port_length: float = 34.0, + port_length: float | None = None, ) -> CanvasPosition: direction = self._index_direction(tensor, index) source = self.connection_point(tensor, index) + resolved_port_length = ( + self.open_index_port_length(tensor) if port_length is None else port_length + ) return CanvasPosition( - x=source.x + direction.x * port_length, - y=source.y + direction.y * port_length, + x=source.x + direction.x * resolved_port_length, + y=source.y + direction.y * resolved_port_length, ) def index_label_point(self, tensor: TensorSpec, index: IndexSpec) -> CanvasPosition: direction = self._index_direction(tensor, index) source = self.connection_point(tensor, index) - label_distance = 18.0 if self.is_index_connected(index.id) else 24.0 + label_distance = ( + 18.0 + if self.is_index_connected(index.id) + else self.open_index_port_length(tensor) + 14.0 + ) return CanvasPosition( x=source.x + direction.x * label_distance, y=source.y + direction.y * label_distance + 4.0, @@ -684,6 +721,16 @@ def _svg_text_anchor(self, tensor: TensorSpec, index: IndexSpec) -> str: return "middle" def _index_direction(self, tensor: TensorSpec, index: IndexSpec) -> CanvasPosition: + if not self.is_index_connected(index.id): + return self._free_index_direction(tensor, index) + return self._stored_index_direction(tensor, index) + + def open_index_port_length(self, tensor: TensorSpec) -> float: + return 2.0 * self.tensor_radius(tensor) + + def _stored_index_direction( + self, tensor: TensorSpec, index: IndexSpec + ) -> CanvasPosition: magnitude = hypot(index.offset.x, index.offset.y) if magnitude > 1e-6: return CanvasPosition( @@ -694,6 +741,568 @@ def _index_direction(self, tensor: TensorSpec, index: IndexSpec) -> CanvasPositi angle = -pi / 2 + (2 * pi * index_order / index_count) return CanvasPosition(x=cos(angle), y=sin(angle)) + def _free_index_direction( + self, tensor: TensorSpec, index: IndexSpec + ) -> CanvasPosition: + cached_direction = self._free_index_direction_by_id.get(index.id) + if cached_direction is not None: + return cached_direction + self._assign_free_index_directions(tensor) + cached_direction = self._free_index_direction_by_id.get(index.id) + if cached_direction is not None: + return cached_direction + return self._stored_index_direction(tensor, index) + + def _assign_free_index_directions(self, tensor: TensorSpec) -> None: + free_indices = [ + index for index in tensor.indices if not self.is_index_connected(index.id) + ] + if not free_indices: + return + component_tensors = self._component_tensors_for_tensor(tensor.id) + occupied_directions = self._occupied_directions(tensor) + assigned_directions: list[CanvasPosition] = [] + for index in free_indices: + candidate_directions = self._candidate_directions_for_free_index( + tensor, + index, + component_tensors, + ) + direction = self._choose_best_free_index_direction( + tensor, + component_tensors, + candidate_directions, + [*occupied_directions, *assigned_directions], + ) + self._free_index_direction_by_id[index.id] = direction + assigned_directions.append(direction) + + def _occupied_directions(self, tensor: TensorSpec) -> list[CanvasPosition]: + directions: list[CanvasPosition] = [] + for index in tensor.indices: + if not self.is_index_connected(index.id): + continue + direction = self._connected_index_direction(tensor, index) + if direction is not None: + directions.append(direction) + return directions + + def _connected_index_direction( + self, tensor: TensorSpec, index: IndexSpec + ) -> CanvasPosition | None: + for edge in self._spec.edges: + if ( + edge.left.index_id == index.id + and edge.right.tensor_id in self._tensor_by_id + ): + return _normalize_direction( + CanvasPosition( + x=self._tensor_by_id[edge.right.tensor_id].position.x + - tensor.position.x, + y=self._tensor_by_id[edge.right.tensor_id].position.y + - tensor.position.y, + ) + ) + if ( + edge.right.index_id == index.id + and edge.left.tensor_id in self._tensor_by_id + ): + return _normalize_direction( + CanvasPosition( + x=self._tensor_by_id[edge.left.tensor_id].position.x + - tensor.position.x, + y=self._tensor_by_id[edge.left.tensor_id].position.y + - tensor.position.y, + ) + ) + for hyperedge in self._spec.hyperedges: + endpoint_ids = {endpoint.index_id for endpoint in hyperedge.endpoints} + if index.id not in endpoint_ids: + continue + hub = self._hyperedge_hub_position(hyperedge) + return _normalize_direction( + CanvasPosition(x=hub.x - tensor.position.x, y=hub.y - tensor.position.y) + ) + return None + + def _candidate_directions_for_free_index( + self, + tensor: TensorSpec, + index: IndexSpec, + component_tensors: Sequence[TensorSpec], + ) -> list[CanvasPosition]: + directional_hint = self._directional_index_hint(index) + candidates: list[CanvasPosition] = [] + component_kind = self._classify_component_shape(component_tensors) + if component_kind == "linear": + candidates.extend(self._linear_component_candidates(component_tensors)) + elif component_kind == "circular": + candidates.extend( + self._circular_component_candidates(tensor, component_tensors) + ) + elif component_kind == "grid2d": + candidates.extend( + self._grid_component_candidates(tensor, component_tensors) + ) + if directional_hint is not None: + if component_kind == "generic": + candidates.insert(0, directional_hint) + else: + candidates.append(directional_hint) + candidates.extend(self._generic_component_candidates(tensor, component_tensors)) + candidates.extend(_FREE_INDEX_CARDINAL_DIRECTIONS) + candidates.extend(_FREE_INDEX_DIAGONAL_DIRECTIONS) + return _deduplicate_directions(candidates) + + def _choose_best_free_index_direction( + self, + tensor: TensorSpec, + component_tensors: Sequence[TensorSpec], + candidate_directions: Sequence[CanvasPosition], + occupied_directions: Sequence[CanvasPosition], + ) -> CanvasPosition: + best_direction: CanvasPosition | None = None + best_score = float("-inf") + for candidate_index, candidate_direction in enumerate(candidate_directions): + score = 1000.0 - candidate_index * 10.0 + score -= self._occupied_direction_penalty( + candidate_direction, occupied_directions + ) + score -= self._tensor_clearance_penalty( + tensor, component_tensors, candidate_direction + ) + if score > best_score: + best_score = score + best_direction = candidate_direction + return best_direction or CanvasPosition(x=0.0, y=-1.0) + + def _occupied_direction_penalty( + self, + candidate_direction: CanvasPosition, + occupied_directions: Sequence[CanvasPosition], + ) -> float: + if not occupied_directions: + return 0.0 + closest_alignment = max( + _dot_product(candidate_direction, occupied_direction) + for occupied_direction in occupied_directions + ) + if closest_alignment <= 0.2: + return 0.0 + return (closest_alignment - 0.2) * 500.0 + + def _tensor_clearance_penalty( + self, + tensor: TensorSpec, + component_tensors: Sequence[TensorSpec], + candidate_direction: CanvasPosition, + ) -> float: + radius = 0.0 if self.is_port_tensor(tensor) else self.tensor_radius(tensor) + segment_start = CanvasPosition( + x=tensor.position.x + candidate_direction.x * radius, + y=tensor.position.y + candidate_direction.y * radius, + ) + segment_end = CanvasPosition( + x=segment_start.x + + candidate_direction.x * self.open_index_port_length(tensor), + y=segment_start.y + + candidate_direction.y * self.open_index_port_length(tensor), + ) + penalty = 0.0 + for other_tensor in component_tensors: + if other_tensor.id == tensor.id: + continue + distance = _point_to_segment_distance( + other_tensor.position, segment_start, segment_end + ) + clearance = distance - self.tensor_radius(other_tensor) + if clearance < _FREE_INDEX_CLEARANCE_MARGIN: + penalty += (_FREE_INDEX_CLEARANCE_MARGIN - clearance) * 8.0 + return penalty + + def _directional_index_hint(self, index: IndexSpec) -> CanvasPosition | None: + normalized_name = index.name.strip().lower() + return _FREE_INDEX_DIRECTION_HINTS.get(normalized_name) + + def _linear_component_candidates( + self, component_tensors: Sequence[TensorSpec] + ) -> list[CanvasPosition]: + axis_direction = self._component_primary_axis(component_tensors) + perpendicular_direction = CanvasPosition( + x=-axis_direction.y, y=axis_direction.x + ) + return [ + perpendicular_direction, + CanvasPosition(x=-perpendicular_direction.x, y=-perpendicular_direction.y), + axis_direction, + CanvasPosition(x=-axis_direction.x, y=-axis_direction.y), + ] + + def _circular_component_candidates( + self, tensor: TensorSpec, component_tensors: Sequence[TensorSpec] + ) -> list[CanvasPosition]: + center = _average_position( + [component_tensor.position for component_tensor in component_tensors] + ) + radial_direction = _normalize_direction( + CanvasPosition( + x=tensor.position.x - center.x, + y=tensor.position.y - center.y, + ) + ) + perpendicular_direction = CanvasPosition( + x=-radial_direction.y, y=radial_direction.x + ) + return [ + radial_direction, + perpendicular_direction, + CanvasPosition(x=-perpendicular_direction.x, y=-perpendicular_direction.y), + CanvasPosition(x=-radial_direction.x, y=-radial_direction.y), + ] + + def _grid_component_candidates( + self, tensor: TensorSpec, component_tensors: Sequence[TensorSpec] + ) -> list[CanvasPosition]: + basis_directions = self._grid_component_basis(component_tensors) + if basis_directions is None: + return [] + primary_axis, secondary_axis = basis_directions + primary_projections = sorted( + { + round( + _dot_product(component_tensor.position, primary_axis), + 6, + ) + for component_tensor in component_tensors + } + ) + secondary_projections = sorted( + { + round( + _dot_product(component_tensor.position, secondary_axis), + 6, + ) + for component_tensor in component_tensors + } + ) + primary_projection = round(_dot_product(tensor.position, primary_axis), 6) + secondary_projection = round(_dot_product(tensor.position, secondary_axis), 6) + candidates: list[CanvasPosition] = [] + if ( + secondary_projections + and abs(secondary_projection - secondary_projections[0]) < 1e-6 + ): + candidates.append(CanvasPosition(x=-secondary_axis.x, y=-secondary_axis.y)) + if ( + secondary_projections + and abs(secondary_projection - secondary_projections[-1]) < 1e-6 + ): + candidates.append(secondary_axis) + if ( + primary_projections + and abs(primary_projection - primary_projections[0]) < 1e-6 + ): + candidates.append(CanvasPosition(x=-primary_axis.x, y=-primary_axis.y)) + if ( + primary_projections + and abs(primary_projection - primary_projections[-1]) < 1e-6 + ): + candidates.append(primary_axis) + if len(candidates) >= 2: + diagonal = _normalize_direction( + CanvasPosition( + x=sum(candidate.x for candidate in candidates), + y=sum(candidate.y for candidate in candidates), + ) + ) + candidates.insert(0, diagonal) + return candidates + + def _generic_component_candidates( + self, tensor: TensorSpec, component_tensors: Sequence[TensorSpec] + ) -> list[CanvasPosition]: + neighbor_tensor_ids = self._adjacency_by_tensor_id.get(tensor.id, set()) + neighbor_tensors = [ + self._tensor_by_id[neighbor_tensor_id] + for neighbor_tensor_id in neighbor_tensor_ids + if neighbor_tensor_id in self._tensor_by_id + ] + candidates: list[CanvasPosition] = [] + if neighbor_tensors: + neighbor_center = _average_position( + [neighbor_tensor.position for neighbor_tensor in neighbor_tensors] + ) + candidates.append( + _normalize_direction( + CanvasPosition( + x=tensor.position.x - neighbor_center.x, + y=tensor.position.y - neighbor_center.y, + ) + ) + ) + component_center = _average_position( + [component_tensor.position for component_tensor in component_tensors] + ) + if ( + abs(component_center.x - tensor.position.x) > 1e-6 + or abs(component_center.y - tensor.position.y) > 1e-6 + ): + candidates.append( + _normalize_direction( + CanvasPosition( + x=tensor.position.x - component_center.x, + y=tensor.position.y - component_center.y, + ) + ) + ) + candidates.extend( + sorted( + [*_FREE_INDEX_CARDINAL_DIRECTIONS, *_FREE_INDEX_DIAGONAL_DIRECTIONS], + key=lambda direction: self._generic_direction_sort_key( + tensor, + component_tensors, + direction, + ), + ) + ) + return candidates + + def _generic_direction_sort_key( + self, + tensor: TensorSpec, + component_tensors: Sequence[TensorSpec], + direction: CanvasPosition, + ) -> tuple[float, float]: + penalty = self._tensor_clearance_penalty(tensor, component_tensors, direction) + return (penalty, -abs(direction.x) - abs(direction.y)) + + def _classify_component_shape(self, component_tensors: Sequence[TensorSpec]) -> str: + tensor_ids = {tensor.id for tensor in component_tensors} + if len(tensor_ids) >= 2 and self._is_linear_component(tensor_ids): + return "linear" + if len(tensor_ids) >= 3 and self._is_circular_component(tensor_ids): + return "circular" + if len(tensor_ids) >= 4 and self._is_grid_component(component_tensors): + return "grid2d" + return "generic" + + def _is_linear_component(self, tensor_ids: set[str]) -> bool: + component_tensors = [ + self._tensor_by_id[tensor_id] + for tensor_id in tensor_ids + if tensor_id in self._tensor_by_id + ] + degree_by_tensor_id = { + tensor_id: len(self._adjacency_by_tensor_id.get(tensor_id, set())) + for tensor_id in tensor_ids + } + edge_count = sum(degree_by_tensor_id.values()) // 2 + degree_one_count = sum( + 1 for degree in degree_by_tensor_id.values() if degree == 1 + ) + axis_direction = self._component_primary_axis(component_tensors) + axis_origin = _average_position( + [component_tensor.position for component_tensor in component_tensors] + ) + projections = [ + _dot_product( + CanvasPosition( + x=component_tensor.position.x - axis_origin.x, + y=component_tensor.position.y - axis_origin.y, + ), + axis_direction, + ) + for component_tensor in component_tensors + ] + dominant_span = max(projections, default=0.0) - min(projections, default=0.0) + minor_span = max( + ( + abs( + _cross_product( + axis_direction, + CanvasPosition( + x=component_tensor.position.x - axis_origin.x, + y=component_tensor.position.y - axis_origin.y, + ), + ) + ) + for component_tensor in component_tensors + ), + default=0.0, + ) + return ( + edge_count == len(tensor_ids) - 1 + and max(degree_by_tensor_id.values(), default=0) <= 2 + and degree_one_count == 2 + and (dominant_span <= 1e-6 or minor_span <= dominant_span * 0.18) + ) + + def _is_circular_component(self, tensor_ids: set[str]) -> bool: + degree_by_tensor_id = { + tensor_id: len(self._adjacency_by_tensor_id.get(tensor_id, set())) + for tensor_id in tensor_ids + } + edge_count = sum(degree_by_tensor_id.values()) // 2 + return edge_count == len(tensor_ids) and all( + degree == 2 for degree in degree_by_tensor_id.values() + ) + + def _is_grid_component(self, component_tensors: Sequence[TensorSpec]) -> bool: + basis_directions = self._grid_component_basis(component_tensors) + if basis_directions is None: + return False + primary_axis, secondary_axis = basis_directions + primary_projections = { + round(_dot_product(tensor.position, primary_axis), 6) + for tensor in component_tensors + } + secondary_projections = { + round(_dot_product(tensor.position, secondary_axis), 6) + for tensor in component_tensors + } + if len(primary_projections) < 2 or len(secondary_projections) < 2: + return False + tensor_ids = {tensor.id for tensor in component_tensors} + for edge in self._spec.edges: + if ( + edge.left.tensor_id not in tensor_ids + or edge.right.tensor_id not in tensor_ids + ): + continue + left_tensor = self._tensor_by_id[edge.left.tensor_id] + right_tensor = self._tensor_by_id[edge.right.tensor_id] + edge_direction = _normalize_direction( + CanvasPosition( + x=right_tensor.position.x - left_tensor.position.x, + y=right_tensor.position.y - left_tensor.position.y, + ) + ) + if ( + abs(_dot_product(edge_direction, primary_axis)) < 0.94 + and abs(_dot_product(edge_direction, secondary_axis)) < 0.94 + ): + return False + return True + + def _component_primary_axis( + self, component_tensors: Sequence[TensorSpec] + ) -> CanvasPosition: + if len(component_tensors) < 2: + return CanvasPosition(x=0.0, y=-1.0) + widest_pair: tuple[TensorSpec, TensorSpec] | None = None + widest_distance = float("-inf") + for left_tensor in component_tensors: + for right_tensor in component_tensors: + if left_tensor.id == right_tensor.id: + continue + distance = hypot( + right_tensor.position.x - left_tensor.position.x, + right_tensor.position.y - left_tensor.position.y, + ) + if distance > widest_distance: + widest_distance = distance + widest_pair = (left_tensor, right_tensor) + if widest_pair is None: + return CanvasPosition(x=0.0, y=-1.0) + return _normalize_direction( + CanvasPosition( + x=widest_pair[1].position.x - widest_pair[0].position.x, + y=widest_pair[1].position.y - widest_pair[0].position.y, + ) + ) + + def _grid_component_basis( + self, component_tensors: Sequence[TensorSpec] + ) -> tuple[CanvasPosition, CanvasPosition] | None: + tensor_ids = {tensor.id for tensor in component_tensors} + basis_directions: list[CanvasPosition] = [] + for edge in self._spec.edges: + if ( + edge.left.tensor_id not in tensor_ids + or edge.right.tensor_id not in tensor_ids + ): + continue + left_tensor = self._tensor_by_id[edge.left.tensor_id] + right_tensor = self._tensor_by_id[edge.right.tensor_id] + direction = _canonicalize_undirected_direction( + _normalize_direction( + CanvasPosition( + x=right_tensor.position.x - left_tensor.position.x, + y=right_tensor.position.y - left_tensor.position.y, + ) + ) + ) + if all( + abs(_dot_product(direction, basis)) < 0.94 for basis in basis_directions + ): + basis_directions.append(direction) + if len(basis_directions) != 2: + return None + if abs(_dot_product(basis_directions[0], basis_directions[1])) > 0.3: + return None + return basis_directions[0], basis_directions[1] + + def _build_tensor_adjacency(self) -> dict[str, set[str]]: + adjacency_by_tensor_id: dict[str, set[str]] = { + tensor.id: set() for tensor in self._spec.tensors + } + for edge in self._spec.edges: + adjacency_by_tensor_id.setdefault(edge.left.tensor_id, set()).add( + edge.right.tensor_id + ) + adjacency_by_tensor_id.setdefault(edge.right.tensor_id, set()).add( + edge.left.tensor_id + ) + for hyperedge in self._spec.hyperedges: + endpoint_tensor_ids = [ + endpoint.tensor_id + for endpoint in hyperedge.endpoints + if endpoint.tensor_id in self._tensor_by_id + ] + for left_tensor_id in endpoint_tensor_ids: + for right_tensor_id in endpoint_tensor_ids: + if left_tensor_id == right_tensor_id: + continue + adjacency_by_tensor_id.setdefault(left_tensor_id, set()).add( + right_tensor_id + ) + return adjacency_by_tensor_id + + def _build_component_tensor_ids(self) -> dict[str, list[str]]: + component_tensor_ids_by_tensor_id: dict[str, list[str]] = {} + visited_tensor_ids: set[str] = set() + for tensor in self._spec.tensors: + if tensor.id in visited_tensor_ids: + continue + queue = [tensor.id] + component_tensor_ids: list[str] = [] + visited_tensor_ids.add(tensor.id) + while queue: + current_tensor_id = queue.pop(0) + component_tensor_ids.append(current_tensor_id) + for neighbor_tensor_id in self._adjacency_by_tensor_id.get( + current_tensor_id, set() + ): + if neighbor_tensor_id in visited_tensor_ids: + continue + visited_tensor_ids.add(neighbor_tensor_id) + queue.append(neighbor_tensor_id) + for component_tensor_id in component_tensor_ids: + component_tensor_ids_by_tensor_id[component_tensor_id] = ( + component_tensor_ids + ) + return component_tensor_ids_by_tensor_id + + def _component_tensors_for_tensor(self, tensor_id: str) -> list[TensorSpec]: + component_tensor_ids = self._component_tensor_ids_by_tensor_id.get( + tensor_id, [tensor_id] + ) + return [ + self._tensor_by_id[component_tensor_id] + for component_tensor_id in component_tensor_ids + if component_tensor_id in self._tensor_by_id + ] + class _TikzRenderer: """Small deterministic TikZ renderer for one validated network spec.""" @@ -1684,6 +2293,69 @@ def _parallel_edge_control_point( ) +def _normalize_direction(direction: CanvasPosition) -> CanvasPosition: + magnitude = hypot(direction.x, direction.y) + if magnitude <= 1e-6: + return CanvasPosition(x=0.0, y=-1.0) + return CanvasPosition(x=direction.x / magnitude, y=direction.y / magnitude) + + +def _dot_product(left: CanvasPosition, right: CanvasPosition) -> float: + return left.x * right.x + left.y * right.y + + +def _cross_product(left: CanvasPosition, right: CanvasPosition) -> float: + return left.x * right.y - left.y * right.x + + +def _canonicalize_undirected_direction(direction: CanvasPosition) -> CanvasPosition: + if direction.x < -1e-6: + return CanvasPosition(x=-direction.x, y=-direction.y) + if abs(direction.x) <= 1e-6 and direction.y < -1e-6: + return CanvasPosition(x=-direction.x, y=-direction.y) + return direction + + +def _deduplicate_directions( + directions: Iterable[CanvasPosition], +) -> list[CanvasPosition]: + unique_directions: list[CanvasPosition] = [] + seen_keys: set[tuple[int, int]] = set() + for direction in directions: + normalized_direction = _normalize_direction(direction) + direction_key = ( + round(normalized_direction.x * 1000), + round(normalized_direction.y * 1000), + ) + if direction_key in seen_keys: + continue + seen_keys.add(direction_key) + unique_directions.append(normalized_direction) + return unique_directions + + +def _point_to_segment_distance( + point: CanvasPosition, + segment_start: CanvasPosition, + segment_end: CanvasPosition, +) -> float: + segment_dx = segment_end.x - segment_start.x + segment_dy = segment_end.y - segment_start.y + segment_length_squared = segment_dx * segment_dx + segment_dy * segment_dy + if segment_length_squared <= 1e-6: + return hypot(point.x - segment_start.x, point.y - segment_start.y) + projection = ( + ((point.x - segment_start.x) * segment_dx) + + ((point.y - segment_start.y) * segment_dy) + ) / segment_length_squared + clamped_projection = min(1.0, max(0.0, projection)) + closest_point = CanvasPosition( + x=segment_start.x + segment_dx * clamped_projection, + y=segment_start.y + segment_dy * clamped_projection, + ) + return hypot(point.x - closest_point.x, point.y - closest_point.y) + + def _sample_quadratic_points( source: CanvasPosition, control: CanvasPosition, diff --git a/tests/codegen/test_generators.py b/tests/codegen/test_generators.py index 5876d2b..de1aa60 100644 --- a/tests/codegen/test_generators.py +++ b/tests/codegen/test_generators.py @@ -1,6 +1,8 @@ from __future__ import annotations +import sys from collections.abc import Callable +from types import ModuleType, SimpleNamespace from unittest.mock import patch import pytest @@ -13,6 +15,7 @@ from tensor_network_editor.errors import CodeGenerationError from tensor_network_editor.models import ( CanvasPosition, + ContractionStepSpec, EdgeEndpointRef, EdgeSpec, EngineName, @@ -34,6 +37,7 @@ build_outer_product_plan_spec, build_sample_spec, build_sample_spec_without_plan, + build_three_tensor_complete_plan_spec, build_three_tensor_hyperedge_spec, build_three_tensor_spec, build_three_tensor_spec_without_plan, @@ -263,6 +267,481 @@ def _execute_generated_code( return namespace +class _FakeTensorKrowchEdge: + """Minimal edge object for generated-code regression tests.""" + + def __init__( + self, + node: _FakeTensorKrowchNode, + axis_name: str, + *, + origin: tuple[str, str] | None = None, + ) -> None: + self.node1 = node + self.axis1 = SimpleNamespace(name=axis_name) + self.node2: _FakeTensorKrowchNode | None = None + self.axis2: SimpleNamespace | None = None + self.origin = origin or (node.name, axis_name) + + @classmethod + def from_endpoints( + cls, + *, + node1: _FakeTensorKrowchNode, + axis1_name: str, + node2: _FakeTensorKrowchNode | None = None, + axis2_name: str | None = None, + origin: tuple[str, str] | None = None, + ) -> _FakeTensorKrowchEdge: + """Build one edge with explicit endpoint ownership.""" + edge = cls(node1, axis1_name, origin=origin) + if node2 is not None and axis2_name is not None: + edge.attach_second(node2, axis2_name) + return edge + + def attach_second( + self, + node: _FakeTensorKrowchNode, + axis_name: str, + ) -> None: + self.node2 = node + self.axis2 = SimpleNamespace(name=axis_name) + + def replace_endpoint( + self, + old_node: _FakeTensorKrowchNode, + new_node: _FakeTensorKrowchNode, + new_axis_name: str, + ) -> None: + if self.node1 is old_node: + self.node1 = new_node + self.axis1 = SimpleNamespace(name=new_axis_name) + return + if self.node2 is old_node: + self.node2 = new_node + self.axis2 = SimpleNamespace(name=new_axis_name) + + def is_dangling(self) -> bool: + return self.node2 is None + + def axis_name_for_node( + self, + node: _FakeTensorKrowchNode, + ) -> SimpleNamespace: + """Return the endpoint axis metadata for ``node``.""" + if self.node1 is node: + return self.axis1 + assert self.node2 is node + assert self.axis2 is not None + return self.axis2 + + +class _FakeTensorKrowchNode: + """Minimal node object for generated-code regression tests.""" + + def __init__( + self, + *, + tensor: object, + axes_names: tuple[str, ...], + name: str, + network: object, + ) -> None: + del tensor, network + self.name = name + self.edges_by_axis_name = { + axis_name: _FakeTensorKrowchEdge(self, axis_name) + for axis_name in axes_names + } + self.pending_edges_by_axis_name: dict[str, _FakeTensorKrowchEdge] = {} + self.pending_edge_owner_by_axis_name: dict[str, _FakeTensorKrowchNode] = {} + self.axis_is_node1_by_axis_name = {axis_name: True for axis_name in axes_names} + + def __getitem__(self, axis_name: str) -> _FakeTensorKrowchEdge: + if axis_name in self.edges_by_axis_name: + return self.edges_by_axis_name[axis_name] + return self.pending_edges_by_axis_name[axis_name] + + def reattach_edges(self, override: bool = False) -> None: + for axis_name, edge in list(self.pending_edges_by_axis_name.items()): + owner = self.pending_edge_owner_by_axis_name.pop(axis_name) + owner_is_node1 = edge.node1 is owner + if owner_is_node1: + other_node = edge.node2 + other_axis_name = None if edge.axis2 is None else edge.axis2.name + else: + other_node = edge.node1 + other_axis_name = edge.axis1.name + if override: + if owner_is_node1: + edge.node1 = self + edge.axis1 = SimpleNamespace(name=axis_name) + else: + edge.node2 = self + edge.axis2 = SimpleNamespace(name=axis_name) + self.edges_by_axis_name[axis_name] = edge + else: + self.edges_by_axis_name[axis_name] = ( + _FakeTensorKrowchEdge.from_endpoints( + node1=self if owner_is_node1 else other_node, + axis1_name=axis_name if owner_is_node1 else other_axis_name, + node2=other_node if owner_is_node1 else self, + axis2_name=other_axis_name if owner_is_node1 else axis_name, + origin=edge.origin, + ) + ) + self.axis_is_node1_by_axis_name[axis_name] = owner_is_node1 + self.pending_edges_by_axis_name = {} + + +class _FakeTensorKrowchModule(ModuleType): + """Tiny ``tensorkrowch`` double that exposes fragile axis ordering.""" + + def __init__(self) -> None: + super().__init__("tensorkrowch") + self.Node = _FakeTensorKrowchNode + self.TensorNetwork = _fake_tensorkrowch_network_factory + + @staticmethod + def connect( + left_edge: _FakeTensorKrowchEdge, + right_edge: _FakeTensorKrowchEdge, + ) -> _FakeTensorKrowchEdge: + left_edge.attach_second(right_edge.node1, right_edge.axis1.name) + right_edge.node1.edges_by_axis_name[right_edge.axis1.name] = left_edge + right_edge.node1.axis_is_node1_by_axis_name[right_edge.axis1.name] = False + return left_edge + + @staticmethod + def contract_between( + left_node: _FakeTensorKrowchNode, + right_node: _FakeTensorKrowchNode, + ) -> _FakeTensorKrowchNode: + left_edges = set(left_node.edges_by_axis_name.values()) + right_edges = set(right_node.edges_by_axis_name.values()) + if not left_edges.intersection(right_edges): + raise ValueError( + f"No batch edges or shared edges between nodes {left_node.name} and {right_node.name} found" + ) + shared_edges = left_edges.intersection(right_edges) + surviving_edges_with_owner = [ + (edge, left_node) + for edge in left_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + [ + (edge, right_node) + for edge in right_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + surviving_axis_names = _deduplicate_fake_tensorkrowch_axis_names( + tuple( + edge.axis_name_for_node(owner).name + for edge, owner in surviving_edges_with_owner + ) + ) + result = _FakeTensorKrowchNode( + tensor=None, + axes_names=surviving_axis_names, + name=f"{left_node.name}_{right_node.name}", + network=None, + ) + result.edges_by_axis_name = {} + result.pending_edges_by_axis_name = {} + result.pending_edge_owner_by_axis_name = {} + result.axis_is_node1_by_axis_name = {} + for axis_name, (edge, owner) in zip( + surviving_axis_names, + surviving_edges_with_owner, + strict=True, + ): + result.pending_edges_by_axis_name[axis_name] = edge + result.pending_edge_owner_by_axis_name[axis_name] = owner + result.axis_is_node1_by_axis_name[axis_name] = edge.node1 is owner + return result + + +def _deduplicate_fake_tensorkrowch_axis_names( + axis_names: tuple[str, ...], +) -> tuple[str, ...]: + """Mirror TensorKrowch suffixing for exact duplicate surviving axes.""" + base_names = [ + axis_name.rsplit("_", 1)[0] + if axis_name.rsplit("_", 1)[-1].isdigit() + else axis_name + for axis_name in axis_names + ] + result: list[str] = [] + counts: dict[str, int] = {} + for axis_name in base_names: + index = counts.get(axis_name, 0) + counts[axis_name] = index + 1 + if base_names.count(axis_name) == 1: + result.append(axis_name) + else: + result.append(f"{axis_name}_{index}") + return tuple(result) + + +def _fake_tensorkrowch_network_factory() -> object: + """Return a placeholder TensorNetwork instance for generated code.""" + return SimpleNamespace(reset=lambda: None) + + +class _ResetAwareFakeTensorKrowchNetwork: + """Minimal network object that can resync inherited resultant edges.""" + + def __init__(self) -> None: + self.nodes: list[_ResetAwareFakeTensorKrowchNode] = [] + + def register(self, node: _ResetAwareFakeTensorKrowchNode) -> None: + self.nodes.append(node) + + def reset(self) -> None: + for node in self.nodes: + node.reset_inherited_edges() + + +class _ResetAwareFakeTensorKrowchEdge: + """Edge double that hides inherited-result connections until reset.""" + + def __init__( + self, + node: _ResetAwareFakeTensorKrowchNode, + axis_name: str, + ) -> None: + self.node1 = node + self.axis1 = SimpleNamespace(name=axis_name) + self.node2: _ResetAwareFakeTensorKrowchNode | None = None + self.axis2: SimpleNamespace | None = None + self.origin = (node.name, axis_name) + self.inherited_source_by_result_node: dict[ + _ResetAwareFakeTensorKrowchNode, + tuple[_ResetAwareFakeTensorKrowchNode, str], + ] = {} + + def attach_second( + self, + node: _ResetAwareFakeTensorKrowchNode, + axis_name: str, + ) -> None: + self.node2 = node + self.axis2 = SimpleNamespace(name=axis_name) + + def replace_endpoint( + self, + old_node: _ResetAwareFakeTensorKrowchNode, + new_node: _ResetAwareFakeTensorKrowchNode, + new_axis_name: str, + ) -> None: + if self.node1 is old_node: + self.inherited_source_by_result_node[new_node] = ( + old_node, + self.axis1.name, + ) + self.node1 = new_node + self.axis1 = SimpleNamespace(name=new_axis_name) + self._stale_other_resultant_endpoints(excluded_result_node=new_node) + return + if self.node2 is old_node: + assert self.axis2 is not None + self.inherited_source_by_result_node[new_node] = ( + old_node, + self.axis2.name, + ) + self.node2 = new_node + self.axis2 = SimpleNamespace(name=new_axis_name) + self._stale_other_resultant_endpoints(excluded_result_node=new_node) + + def materialize_leaf_endpoint_for_resultant( + self, + node: _ResetAwareFakeTensorKrowchNode, + ) -> None: + source = self.inherited_source_by_result_node.get(node) + if source is None: + return + source_node, source_axis_name = source + if self.node1 is node: + self.node1 = source_node + self.axis1 = SimpleNamespace(name=source_axis_name) + return + if self.node2 is node: + self.node2 = source_node + self.axis2 = SimpleNamespace(name=source_axis_name) + + def restore_resultant_endpoint( + self, + node: _ResetAwareFakeTensorKrowchNode, + axis_name: str, + ) -> None: + source = self.inherited_source_by_result_node.get(node) + if source is None: + return + source_node, source_axis_name = source + if self.node1 is source_node and self.axis1.name == source_axis_name: + self.node1 = node + self.axis1 = SimpleNamespace(name=axis_name) + return + if ( + self.node2 is source_node + and self.axis2 is not None + and self.axis2.name == source_axis_name + ): + self.node2 = node + self.axis2 = SimpleNamespace(name=axis_name) + + def _stale_other_resultant_endpoints( + self, + *, + excluded_result_node: _ResetAwareFakeTensorKrowchNode, + ) -> None: + """Hide this edge from other inherited-result views until reset.""" + for result_node in tuple(self.inherited_source_by_result_node): + if result_node is excluded_result_node: + continue + if self.node1 is result_node or self.node2 is result_node: + self.materialize_leaf_endpoint_for_resultant(result_node) + + def is_dangling(self) -> bool: + return self.node2 is None + + def axis_name_for_node( + self, + node: _ResetAwareFakeTensorKrowchNode, + ) -> SimpleNamespace: + if self.node1 is node: + return self.axis1 + assert self.node2 is node + assert self.axis2 is not None + return self.axis2 + + def connects_nodes( + self, + left_node: _ResetAwareFakeTensorKrowchNode, + right_node: _ResetAwareFakeTensorKrowchNode, + ) -> bool: + return (self.node1 is left_node and self.node2 is right_node) or ( + self.node1 is right_node and self.node2 is left_node + ) + + +class _ResetAwareFakeTensorKrowchNode: + """Node double that tracks resultant-edge visibility across resets.""" + + def __init__( + self, + *, + tensor: object, + axes_names: tuple[str, ...], + name: str, + network: _ResetAwareFakeTensorKrowchNetwork | None, + ) -> None: + del tensor + self.name = name + self.network = network + self.is_resultant = False + self.edges_by_axis_name = { + axis_name: _ResetAwareFakeTensorKrowchEdge(self, axis_name) + for axis_name in axes_names + } + self.pending_edges_by_axis_name: dict[ + str, + _ResetAwareFakeTensorKrowchEdge, + ] = {} + if network is not None: + network.register(self) + + def __getitem__(self, axis_name: str) -> _ResetAwareFakeTensorKrowchEdge: + if axis_name in self.edges_by_axis_name: + return self.edges_by_axis_name[axis_name] + return self.pending_edges_by_axis_name[axis_name] + + def reattach_edges(self) -> None: + self.edges_by_axis_name.update(self.pending_edges_by_axis_name) + self.pending_edges_by_axis_name = {} + + def reset_inherited_edges(self) -> None: + for axis_name, edge in self.edges_by_axis_name.items(): + edge.restore_resultant_endpoint(self, axis_name) + for axis_name, edge in self.pending_edges_by_axis_name.items(): + edge.restore_resultant_endpoint(self, axis_name) + + +class _ResetAwareFakeTensorKrowchModule(ModuleType): + """TensorKrowch double that requires ``network.reset()`` for inherited edges.""" + + def __init__(self) -> None: + super().__init__("tensorkrowch") + self.Node = _ResetAwareFakeTensorKrowchNode + self.TensorNetwork = _reset_aware_fake_tensorkrowch_network_factory + + @staticmethod + def connect( + left_edge: _ResetAwareFakeTensorKrowchEdge, + right_edge: _ResetAwareFakeTensorKrowchEdge, + ) -> _ResetAwareFakeTensorKrowchEdge: + if left_edge.is_dangling() and left_edge.node1.is_resultant: + left_edge.materialize_leaf_endpoint_for_resultant(left_edge.node1) + left_edge.attach_second(right_edge.node1, right_edge.axis1.name) + right_edge.node1.edges_by_axis_name[right_edge.axis1.name] = left_edge + return left_edge + + @staticmethod + def contract_between( + left_node: _ResetAwareFakeTensorKrowchNode, + right_node: _ResetAwareFakeTensorKrowchNode, + ) -> _ResetAwareFakeTensorKrowchNode: + left_edges = set(left_node.edges_by_axis_name.values()) + right_edges = set(right_node.edges_by_axis_name.values()) + shared_edges = { + edge + for edge in left_edges.intersection(right_edges) + if edge.connects_nodes(left_node, right_node) + } + if not shared_edges: + raise ValueError( + f"No batch edges or shared edges between nodes {left_node.name} and {right_node.name} found" + ) + surviving_edges_with_owner = [ + (edge, right_node) + for edge in right_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + [ + (edge, left_node) + for edge in left_node.edges_by_axis_name.values() + if edge not in shared_edges + ] + surviving_axis_names = _deduplicate_fake_tensorkrowch_axis_names( + tuple( + edge.axis_name_for_node(owner).name + for edge, owner in surviving_edges_with_owner + ) + ) + result = _ResetAwareFakeTensorKrowchNode( + tensor=None, + axes_names=surviving_axis_names, + name=f"{left_node.name}_{right_node.name}", + network=left_node.network, + ) + result.is_resultant = True + result.edges_by_axis_name = {} + result.pending_edges_by_axis_name = {} + for axis_name, (edge, owner) in zip( + surviving_axis_names, + surviving_edges_with_owner, + strict=True, + ): + edge.replace_endpoint(owner, result, axis_name) + result.pending_edges_by_axis_name[axis_name] = edge + return result + + +def _reset_aware_fake_tensorkrowch_network_factory() -> ( + _ResetAwareFakeTensorKrowchNetwork +): + """Return a fake network that models inherited-edge reset semantics.""" + return _ResetAwareFakeTensorKrowchNetwork() + + @pytest.mark.parametrize( ("engine", "expected_snippets"), [ @@ -813,6 +1292,20 @@ def test_periodic_generate_code_emits_roundtrip_metadata_marker() -> None: assert "# TNE_SPEC_B64:" in result.code assert "# Tensor Network Editor linear periodic mode" in result.code + assert result.code.index( + "# Tensor Network Editor linear periodic mode" + ) < result.code.index("# TNE_SPEC_B64:") + + +def test_periodic_generate_code_can_skip_roundtrip_metadata_marker() -> None: + result = generate_code( + build_linear_periodic_chain_spec(), + engine=EngineName.EINSUM_NUMPY, + include_roundtrip_metadata=False, + ) + + assert "# TNE_SPEC_B64:" not in result.code + assert "# Tensor Network Editor linear periodic mode" in result.code @pytest.mark.parametrize("engine", list(EngineName)) @@ -856,6 +1349,16 @@ def test_generate_code_respects_manual_plan_steps( assert "result = results_list[-1]" in result.code +def test_tensorkrowch_normal_codegen_does_not_emit_reattach_edges() -> None: + result = generate_code( + build_three_tensor_complete_plan_spec(), + engine=EngineName.TENSORKROWCH, + ) + + assert "results_list.append(tk.contract_between(" in result.code + assert "reattach_edges(" not in result.code + + @pytest.mark.parametrize("engine", list(EngineName)) def test_generate_code_keeps_partial_manual_plan_as_prefix( engine: EngineName, @@ -1119,6 +1622,122 @@ def test_linear_periodic_carry_codegen_labels_shared_for_sections( assert "previous_payload: dict[str, object]" in result.code +def test_linear_periodic_carry_tensorkrowch_codegen_tracks_boundary_edges_without_axis_order_assumptions() -> ( + None +): + result = generate_code( + build_linear_periodic_carry_chain_spec(), + engine=EngineName.TENSORKROWCH, + ) + fake_torch = ModuleType("torch") + fake_torch.float32 = object() + fake_torch.zeros = lambda shape, dtype=None: (shape, dtype) + fake_tensorkrowch = _FakeTensorKrowchModule() + + with patch.dict( + sys.modules, + { + "torch": fake_torch, + "tensorkrowch": fake_tensorkrowch, + }, + ): + namespace = _execute_generated_code(result.code, n=3) + + open_edges = namespace["open_edges"] + assert isinstance(open_edges, list) + assert len(open_edges) == 4 + assert [edge.origin for edge in open_edges] == [ + ("Initial", "phys"), + ("PeriodicLeft", "phys_l"), + ("PeriodicRight", "phys_r"), + ("Final", "phys"), + ] + + +def test_linear_periodic_carry_tensorkrowch_codegen_executes_when_periodic_cell_contracts_local_pair_before_previous_payload() -> ( + None +): + spec = build_linear_periodic_carry_chain_spec() + spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [ + ContractionStepSpec( + id="periodic_contract_internal_first", + left_operand_id="periodic_left_tensor", + right_operand_id="periodic_right_tensor", + ), + ContractionStepSpec( + id="periodic_consume_previous_second", + left_operand_id="periodic_contract_internal_first", + right_operand_id="__linear_previous__", + ), + ContractionStepSpec( + id="periodic_carry_last", + left_operand_id="periodic_consume_previous_second", + right_operand_id="__linear_next__", + ), + ] + result = generate_code(spec, engine=EngineName.TENSORKROWCH) + fake_torch = ModuleType("torch") + fake_torch.float32 = object() + fake_torch.zeros = lambda shape, dtype=None: (shape, dtype) + fake_tensorkrowch = _FakeTensorKrowchModule() + + with patch.dict( + sys.modules, + { + "torch": fake_torch, + "tensorkrowch": fake_tensorkrowch, + }, + ): + namespace = _execute_generated_code(result.code, n=5) + + assert "result" in namespace + assert "open_edges" in namespace + + +def test_linear_periodic_carry_tensorkrowch_codegen_materializes_result_edges_with_override() -> ( + None +): + spec = build_linear_periodic_carry_chain_spec() + spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [ + ContractionStepSpec( + id="periodic_contract_internal_first", + left_operand_id="periodic_left_tensor", + right_operand_id="periodic_right_tensor", + ), + ContractionStepSpec( + id="periodic_consume_previous_second", + left_operand_id="periodic_contract_internal_first", + right_operand_id="__linear_previous__", + ), + ContractionStepSpec( + id="periodic_carry_last", + left_operand_id="periodic_consume_previous_second", + right_operand_id="__linear_next__", + ), + ] + result = generate_code(spec, engine=EngineName.TENSORKROWCH) + fake_torch = ModuleType("torch") + fake_torch.float32 = object() + fake_torch.zeros = lambda shape, dtype=None: (shape, dtype) + fake_tensorkrowch = _FakeTensorKrowchModule() + + with patch.dict( + sys.modules, + { + "torch": fake_torch, + "tensorkrowch": fake_tensorkrowch, + }, + ): + namespace = _execute_generated_code(result.code, n=5) + + assert "reattach_edges(override=True)" in result.code + assert "network.reset()" not in result.code + assert "open_edges.extend([tracked_edge_0, tracked_edge_1])" in result.code + assert "outgoing_interface = [results_list[-1]['right']]" in result.code + assert "result" in namespace + assert "open_edges" in namespace + + @pytest.mark.parametrize("engine", list(EngineName)) def test_linear_periodic_codegen_does_not_stringify_manual_blocks( engine: EngineName, diff --git a/tests/codegen/test_linear_periodic_internals.py b/tests/codegen/test_linear_periodic_internals.py index 60846ab..8de0540 100644 --- a/tests/codegen/test_linear_periodic_internals.py +++ b/tests/codegen/test_linear_periodic_internals.py @@ -4,9 +4,17 @@ from tensor_network_editor.errors import CodeGenerationError from tensor_network_editor.models import ( + CanvasPosition, + ContractionPlanSpec, ContractionStepSpec, + EdgeEndpointRef, + EdgeSpec, EngineName, + IndexSpec, LinearPeriodicCellName, + LinearPeriodicCellSpec, + LinearPeriodicTensorRole, + TensorSpec, ) from tests.factories import build_linear_periodic_carry_chain_spec @@ -175,6 +183,203 @@ def test_simulate_carry_cell_rejects_next_step_that_is_not_final() -> None: ) +def test_simulate_carry_cell_accepts_previous_payload_labels_that_only_collide_by_name() -> ( + None +): + from tensor_network_editor.codegen.modes._linear_periodic.carry import ( + _CarryOperandState, + _CarryPayloadState, + _simulate_carry_cell, + ) + + periodic_cell = LinearPeriodicCellSpec( + tensors=[ + TensorSpec( + id="periodic_previous_boundary", + name="Previous cell", + position=CanvasPosition(x=-100.0, y=140.0), + linear_periodic_role=LinearPeriodicTensorRole.PREVIOUS, + indices=[ + IndexSpec( + id="periodic_previous_slot_1", name="slot_1", dimension=2 + ), + IndexSpec( + id="periodic_previous_slot_2", name="slot_2", dimension=2 + ), + ], + ), + TensorSpec( + id="periodic_next_boundary", + name="Next cell", + position=CanvasPosition(x=540.0, y=140.0), + linear_periodic_role=LinearPeriodicTensorRole.NEXT, + indices=[ + IndexSpec(id="periodic_next_slot_1", name="slot_1", dimension=2), + IndexSpec(id="periodic_next_slot_2", name="slot_2", dimension=2), + ], + ), + TensorSpec( + id="tensor_a1", + name="A1", + position=CanvasPosition(x=-255.0, y=363.0), + indices=[ + IndexSpec(id="a1_right", name="right", dimension=3), + IndexSpec(id="a1_phys", name="phys", dimension=2), + ], + ), + TensorSpec( + id="tensor_a2", + name="A2", + position=CanvasPosition(x=65.0, y=363.0), + indices=[ + IndexSpec(id="a2_left", name="left", dimension=3), + IndexSpec(id="a2_right", name="right", dimension=3), + IndexSpec(id="a2_phys", name="phys", dimension=2), + ], + ), + TensorSpec( + id="tensor_a3", + name="A3", + position=CanvasPosition(x=385.0, y=363.0), + indices=[ + IndexSpec(id="a3_left", name="left", dimension=3), + IndexSpec(id="a3_right", name="right", dimension=3), + IndexSpec(id="a3_phys", name="phys", dimension=2), + ], + ), + TensorSpec( + id="tensor_a4", + name="A4", + position=CanvasPosition(x=705.0, y=363.0), + indices=[ + IndexSpec(id="a4_left", name="left", dimension=3), + IndexSpec(id="a4_phys", name="phys", dimension=2), + ], + ), + ], + edges=[ + EdgeSpec( + id="edge_a1_a2", + name="edge-0-1", + left=EdgeEndpointRef(tensor_id="tensor_a1", index_id="a1_right"), + right=EdgeEndpointRef(tensor_id="tensor_a2", index_id="a2_left"), + ), + EdgeSpec( + id="edge_a2_a3", + name="edge-1-2", + left=EdgeEndpointRef(tensor_id="tensor_a2", index_id="a2_right"), + right=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_left"), + ), + EdgeSpec( + id="edge_a3_a4", + name="edge-2-3", + left=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_right"), + right=EdgeEndpointRef(tensor_id="tensor_a4", index_id="a4_left"), + ), + EdgeSpec( + id="edge_previous_a1", + name="bond1", + left=EdgeEndpointRef( + tensor_id="tensor_a1", + index_id="a1_phys", + ), + right=EdgeEndpointRef( + tensor_id="periodic_previous_boundary", + index_id="periodic_previous_slot_1", + ), + ), + EdgeSpec( + id="edge_previous_a2", + name="bond2", + left=EdgeEndpointRef( + tensor_id="periodic_previous_boundary", + index_id="periodic_previous_slot_2", + ), + right=EdgeEndpointRef( + tensor_id="tensor_a2", + index_id="a2_phys", + ), + ), + EdgeSpec( + id="edge_a3_next", + name="bond3", + left=EdgeEndpointRef(tensor_id="tensor_a3", index_id="a3_phys"), + right=EdgeEndpointRef( + tensor_id="periodic_next_boundary", + index_id="periodic_next_slot_1", + ), + ), + EdgeSpec( + id="edge_a4_next", + name="bond4", + left=EdgeEndpointRef( + tensor_id="periodic_next_boundary", + index_id="periodic_next_slot_2", + ), + right=EdgeEndpointRef(tensor_id="tensor_a4", index_id="a4_phys"), + ), + ], + contraction_plan=ContractionPlanSpec( + id="periodic_plan", + name="Manual path", + steps=[ + ContractionStepSpec( + id="step_contract_right", + left_operand_id="tensor_a4", + right_operand_id="tensor_a3", + ), + ContractionStepSpec( + id="step_from_previous", + left_operand_id="__linear_previous__", + right_operand_id="tensor_a2", + ), + ContractionStepSpec( + id="step_merge", + left_operand_id="step_from_previous", + right_operand_id="step_contract_right", + ), + ContractionStepSpec( + id="step_to_next", + left_operand_id="step_merge", + right_operand_id="__linear_next__", + ), + ], + ), + ) + previous_payload_state = _CarryPayloadState( + interface_operand_ids=("payload_left", "payload_right"), + interface_labels=("a1_phys", "a2_phys"), + operand_states={ + "payload_left": _CarryOperandState( + labels=("payload_edge", "a1_phys"), + axis_names=("left_payload", "slot_1"), + dimensions=(3, 2), + ), + "payload_right": _CarryOperandState( + labels=("a4_phys", "a3_phys", "payload_edge", "a2_phys"), + axis_names=("carry_0", "carry_1", "bridge", "slot_2"), + dimensions=(2, 2, 3, 2), + ), + }, + ) + + simulation = _simulate_carry_cell( + cell=periodic_cell, + cell_name=LinearPeriodicCellName.PERIODIC, + previous_payload_state=previous_payload_state, + engine=EngineName.TENSORKROWCH, + ) + + assert simulation.carry_operand_id == "step_merge" + assert simulation.outgoing_interface_operand_ids == ("step_merge", "step_merge") + assert ( + simulation.remaining_operand_states["step_merge"].labels.count("a3_phys") == 1 + ) + assert ( + simulation.remaining_operand_states["step_merge"].labels.count("a4_phys") == 1 + ) + + def test_build_carry_simulation_context_collects_interface_state() -> None: from tensor_network_editor.codegen.modes._linear_periodic.carry import ( _build_carry_simulation_context, diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py index 4c47bd1..d4c384f 100644 --- a/tests/test_app_assets.py +++ b/tests/test_app_assets.py @@ -235,6 +235,7 @@ def test_root_groups_export_actions_and_code_generation_controls_as_requested( code_pane_index = html.index('id="sidebar-pane-code"') engine_index = html.index('id="engine-select"') collection_index = html.index('id="collection-format-select"') + metadata_index = html.index('id="codegen-roundtrip-metadata-checkbox"') generate_index = html.index('id="generate-button"') copy_index = html.index('id="copy-code-button"') expand_index = html.index('id="expand-generated-code-button"') @@ -247,8 +248,12 @@ def test_root_groups_export_actions_and_code_generation_controls_as_requested( < generate_index < copy_index < expand_index + < metadata_index < warning_index ) + assert 'id="codegen-roundtrip-metadata-field"' in html + assert 'id="codegen-roundtrip-metadata-checkbox"' in html + assert ">Metadata<" in html assert 'id="copy-code-button"' in html assert 'id="expand-generated-code-button"' in html assert 'id="generated-code-view"' in html @@ -787,6 +792,8 @@ def test_css_asset_styles_grouped_export_and_code_generation_controls( assert "background: #0e639c;" not in body assert ".code-header-controls {" in body assert ".code-header-controls .code-format-picker {" in body + assert ".code-metadata-toggle {" in body + assert ".code-metadata-toggle input {" in body assert ".code-format-picker.select-chevron-field::after {" in body assert ".code-format-picker select {" in body assert "appearance: none;" in body @@ -2103,6 +2110,8 @@ def test_toolbar_assets_route_file_and_template_actions_through_cursor_style_men 'exportFormatSelect: document.getElementById("export-format-select")' in dom_body ) + assert "codegenRoundtripMetadataField: document.getElementById(" in dom_body + assert "codegenRoundtripMetadataCheckbox: document.getElementById(" in dom_body assert ( 'codeGenerationWarning: document.getElementById("code-generation-warning")' in dom_body @@ -2476,7 +2485,9 @@ def test_template_management_assets_expose_toolbar_controls_and_routes( assert 'aria-haspopup="dialog"' in html assert 'id="reflow-layout-popover"' in html assert 'id="reflow-auto-layout-button"' in html - assert 'id="reflow-align-left-button"' not in html + assert 'id="reflow-align-horizontal-button"' in html + assert 'id="reflow-align-vertical-button"' in html + assert 'id="reflow-rotate-selection-button"' in html assert 'id="reflow-indices-left-button"' in html assert 'id="reflow-indices-reset-button"' in html assert 'id="reflow-arrange-chain-button"' in html @@ -2513,19 +2524,40 @@ def test_template_management_assets_expose_toolbar_controls_and_routes( assert 'class="help-close-icon"' in html assert 'Template' not in html assert re.search( - r'
[\s\S]*id="reflow-auto-layout-button"[\s\S]*>\s*Auto layout\s*<[\s\S]*id="reflow-arrange-chain-button"[\s\S]*>\s*Chain\s*<[\s\S]*id="reflow-arrange-tree-button"[\s\S]*>\s*Tree\s*<[\s\S]*id="reflow-arrange-grid-button"[\s\S]*>\s*Grid\s*<[\s\S]*id="reflow-snap-grid-button"[\s\S]*>\s*Snap to Grid\s*<', + r'
[\s\S]*id="reflow-auto-layout-button"[\s\S]*>\s*Auto layout\s*<[\s\S]*id="reflow-align-horizontal-button"[\s\S]*>\s*(?:↔|↔)\s*<[\s\S]*id="reflow-align-vertical-button"[\s\S]*>\s*(?:↕|↕)\s*<[\s\S]*id="reflow-rotate-selection-button"[\s\S]*>\s*(?:↻|↻)\s*<[\s\S]*id="reflow-arrange-chain-button"[\s\S]*>\s*Chain\s*<[\s\S]*id="reflow-arrange-tree-button"[\s\S]*>\s*Tree\s*<[\s\S]*id="reflow-arrange-grid-button"[\s\S]*>\s*Grid\s*<[\s\S]*id="reflow-snap-grid-button"[\s\S]*>\s*Snap to Grid\s*<', html, ) assert re.search( r'Indices[\s\S]*
[\s\S]*id="reflow-indices-left-button"[\s\S]*id="reflow-indices-right-button"[\s\S]*id="reflow-indices-top-button"[\s\S]*id="reflow-indices-bottom-button"[\s\S]*id="reflow-indices-reset-button"', html, ) - assert 'aria-label="Align left"' not in html - assert 'aria-label="Align middle"' not in html + assert 'aria-label="Align horizontally"' in html + assert 'aria-label="Align vertically"' in html + assert 'aria-label="Rotate 90 degrees clockwise"' in html assert 'aria-label="Move indices left"' in html assert ( - 'title="Align left: place selected tensors on the same left edge while keeping them separated."' - not in html + 'title="Align horizontal: line up the selected tensors on the same horizontal axis while keeping their left-to-right spacing."' + in html + ) + assert ( + 'title="Align vertical: line up the selected tensors on the same vertical axis while keeping their top-to-bottom spacing."' + in html + ) + assert ( + 'title="Rotate 90° clockwise: rotate the selected tensors and their ports around the selection center."' + in html + ) + assert not re.search( + r'id="reflow-align-horizontal-button"[\s\S]*>\s*Align horizontal\s*<', + html, + ) + assert not re.search( + r'id="reflow-align-vertical-button"[\s\S]*>\s*Align vertical\s*<', + html, + ) + assert not re.search( + r'id="reflow-rotate-selection-button"[\s\S]*>\s*Rotate 90(?:°|°)\s*<', + html, ) assert ( 'title="Chain: place selected tensors in one ordered row, following bonds when present."' @@ -2622,7 +2654,15 @@ def test_template_management_assets_expose_toolbar_controls_and_routes( in dom_body ) assert ( - 'reflowAlignLeftButton: document.getElementById("reflow-align-left-button")' + 'reflowAlignHorizontalButton: document.getElementById("reflow-align-horizontal-button")' + in dom_body + ) + assert ( + 'reflowAlignVerticalButton: document.getElementById("reflow-align-vertical-button")' + in dom_body + ) + assert ( + 'reflowRotateSelectionButton: document.getElementById("reflow-rotate-selection-button")' in dom_body ) assert ( @@ -2712,10 +2752,33 @@ def test_template_management_assets_expose_toolbar_controls_and_routes( 'bindListener(templateSettingsButton, "click", () => {' in shell_bindings_body ) assert "toggleReflowLayoutPopover" in shell_bindings_body - assert "reflowAlignLeftButton" in shell_bindings_body + assert "reflowAlignHorizontalButton" in shell_bindings_body + assert "reflowAlignVerticalButton" in shell_bindings_body + assert "reflowRotateSelectionButton" in shell_bindings_body assert "reflowIndicesLeftButton" in shell_bindings_body assert "reflowArrangeGridButton" in shell_bindings_body + assert "codegenRoundtripMetadataCheckbox" in shell_bindings_body assert 'bindReflowAction(reflowAutoLayoutButton, "auto");' in shell_bindings_body + assert ( + 'bindReflowAction(reflowAlignHorizontalButton, "align-horizontal");' + in shell_bindings_body + ) + assert ( + 'bindReflowAction(reflowAlignVerticalButton, "align-vertical");' + in shell_bindings_body + ) + assert ( + 'bindReflowAction(reflowRotateSelectionButton, "rotate-90");' + in shell_bindings_body + ) + assert ( + 'bindListener(codegenRoundtripMetadataCheckbox, "change", (event) => {' + in shell_bindings_body + ) + assert ( + "store.setIncludeRoundtripMetadata(event.target.checked);" + in shell_bindings_body + ) assert "applyReflowIndicesAction" in shell_bindings_body assert ( 'bindListener(loadSessionTemplateMenuItem, "click", () => {' diff --git a/tests/test_frontend_architecture.py b/tests/test_frontend_architecture.py index 3c4f190..9d77af9 100644 --- a/tests/test_frontend_architecture.py +++ b/tests/test_frontend_architecture.py @@ -1197,6 +1197,7 @@ def test_editor_services_route_session_requests_through_explicit_dependencies( await sessionService.generateCode({{ engine: "quimb", collectionFormat: "dict", + includeRoundtripMetadata: true, spec: {{ schema_version: 4, network: {{ id: "network_demo" }} }}, }}); await sessionService.renderSpec({{ @@ -1218,6 +1219,9 @@ def test_editor_services_route_session_requests_through_explicit_dependencies( if (calls[1].payload.collection_format !== "dict") {{ throw new Error(`Expected collection_format=dict, received ${{calls[1].payload.collection_format}}.`); }} + if (calls[1].payload.include_roundtrip_metadata !== true) {{ + throw new Error(`Expected include_roundtrip_metadata=true, received ${{calls[1].payload.include_roundtrip_metadata}}.`); + }} if (calls[2].path !== "/api/render") {{ throw new Error(`Unexpected render path: ${{calls[2].path}}`); }} @@ -3817,11 +3821,18 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( state.selectedEngine = engine; storeCalls.push({{ step: "setSelectedEngine", engine }}); }}, - setSelectedCollectionFormat(collectionFormat) {{ - state.selectedCollectionFormat = collectionFormat; - storeCalls.push({{ step: "setSelectedCollectionFormat", collectionFormat }}); - }}, - }}; + setSelectedCollectionFormat(collectionFormat) {{ + state.selectedCollectionFormat = collectionFormat; + storeCalls.push({{ step: "setSelectedCollectionFormat", collectionFormat }}); + }}, + setIncludeRoundtripMetadata(includeRoundtripMetadata) {{ + state.includeRoundtripMetadata = Boolean(includeRoundtripMetadata); + storeCalls.push({{ + step: "setIncludeRoundtripMetadata", + includeRoundtripMetadata: state.includeRoundtripMetadata, + }}); + }}, + }}; const flowEvents = []; const bootstrapFlow = bootstrapFlowModule.createEditorBootstrapFlow({{ state, @@ -3921,8 +3932,8 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( windowRef: {{ innerWidth: 800, innerHeight: 600, - addEventListener(type, handler) {{ - windowListeners.push(type); + addEventListener(type, handler, options) {{ + windowListeners.push({{ type, options }}); }}, }}, }}); @@ -4007,12 +4018,29 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( generatedCodeModal: getButton("generated-code-modal"), generatedCodeModalBackdrop: getButton("generated-code-modal-backdrop"), generatedCodeModalCloseButton: getButton("generated-code-modal-close-button"), + codegenRoundtripMetadataField: getButton("codegen-roundtrip-metadata-field"), + codegenRoundtripMetadataCheckbox: {{ + checked: false, + listeners: {{}}, + addEventListener(type, handler) {{ this.listeners[type] = handler; }}, + change(event) {{ + this.checked = Boolean(event?.target?.checked); + this.listeners.change?.(event); + }}, + }}, templateSelectField: getButton("template-select-field"), engineSelectField: getButton("engine-select-field"), collectionFormatSelectField: getButton("collection-format-select-field"), templateSelect: {{ value: "mps", - addEventListener(type, handler) {{ this[type] = handler; }}, + listeners: {{}}, + addEventListener(type, handler) {{ this.listeners[type] = handler; }}, + mousedown(event) {{ this.listeners.mousedown?.(event); }}, + change(event) {{ this.listeners.change?.(event); }}, + blur() {{ + flowEvents.push("templateSelect.blur"); + this.listeners.blur?.({{ target: this }}); + }}, }}, templateSettingsButton: getButton("template-settings-button"), templateSettingsPopover: getButton("template-settings-popover"), @@ -4026,11 +4054,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( editSessionTemplateMenuItem: getButton("edit-session-template-menu-item"), openSubnetworkLibraryMenuItem: getButton("open-subnetwork-library-menu-item"), reflowImportedButton: getButton("reflow-imported-button"), - reflowAlignLeftButton: getButton("reflow-align-left-button"), - reflowAlignRightButton: getButton("reflow-align-right-button"), - reflowAlignTopButton: getButton("reflow-align-top-button"), - reflowAlignMiddleButton: getButton("reflow-align-middle-button"), - reflowAlignBottomButton: getButton("reflow-align-bottom-button"), + reflowAlignHorizontalButton: getButton("reflow-align-horizontal-button"), + reflowAlignVerticalButton: getButton("reflow-align-vertical-button"), + reflowRotateSelectionButton: getButton("reflow-rotate-selection-button"), reflowIndicesLeftButton: getButton("reflow-indices-left-button"), reflowIndicesRightButton: getButton("reflow-indices-right-button"), reflowIndicesTopButton: getButton("reflow-indices-top-button"), @@ -4198,8 +4224,8 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( dom, documentRef: tooltipDocument, windowRef: {{ - addEventListener(type, handler) {{ - windowListeners.push(type); + addEventListener(type, handler, options) {{ + windowListeners.push({{ type, options }}); }}, }}, actions: shellActions, @@ -4211,6 +4237,7 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( getButton("expand-generated-code-button").click(); dom.generatedCodeModalBackdrop.click(); dom.generatedCodeModalCloseButton.click(); + dom.codegenRoundtripMetadataCheckbox.change({{ target: {{ checked: true }} }}); dom.engineSelect.change({{ target: {{ value: "cotengra" }} }}); dom.fileMenuButton.click(); dom.exportSubmenuShell.mouseenter(); @@ -4230,6 +4257,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( dom.templateSettingsButton.click(); dom.reflowImportedButton.click(); dom.reflowAutoLayoutButton.click(); + dom.reflowAlignHorizontalButton.click(); + dom.reflowAlignVerticalButton.click(); + dom.reflowRotateSelectionButton.click(); dom.reflowArrangeGridButton.click(); dom.reflowIndicesResetButton.click(); dom.templateManagerCloseButton.click(); @@ -4249,6 +4279,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( if (dom.templateSelectField.attributes["data-expanded"] !== "false") {{ throw new Error("Expected template select change to collapse the disclosure indicator."); }} + if (!flowEvents.includes("templateSelect.blur")) {{ + throw new Error("Expected template selection changes to blur the dropdown so keyboard shortcuts do not stay trapped in the select."); + }} dom.engineSelect.mousedown({{ target: dom.engineSelect }}); if (dom.engineSelectField.attributes["data-expanded"] !== "true") {{ throw new Error("Expected engine select mouse down to mark the disclosure as expanded."); @@ -4265,6 +4298,9 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( if (dom.collectionFormatSelectField.attributes["data-expanded"] !== "false") {{ throw new Error("Expected collection format select change to collapse the disclosure indicator."); }} + if (state.includeRoundtripMetadata !== true) {{ + throw new Error(`Expected metadata checkbox changes to update state, received ${{state.includeRoundtripMetadata}}.`); + }} if (!flowEvents.includes("generateCode")) {{ throw new Error(`Expected toolbar generate binding to invoke the injected action, received ${{JSON.stringify(flowEvents)}}.`); }} @@ -4330,6 +4366,15 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( if (!flowEvents.includes("applyReflowLayoutAction:auto")) {{ throw new Error(`Expected the Auto layout action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); }} + if (!flowEvents.includes("applyReflowLayoutAction:align-horizontal")) {{ + throw new Error(`Expected the horizontal alignment action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); + }} + if (!flowEvents.includes("applyReflowLayoutAction:align-vertical")) {{ + throw new Error(`Expected the vertical alignment action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); + }} + if (!flowEvents.includes("applyReflowLayoutAction:rotate-90")) {{ + throw new Error(`Expected the rotate action to dispatch through the Reflow popover, received ${{JSON.stringify(flowEvents)}}.`); + }} if (!flowEvents.includes("applyReflowLayoutAction:grid")) {{ throw new Error(`Expected the Reflow popover actions to dispatch the requested layout, received ${{JSON.stringify(flowEvents)}}.`); }} @@ -4415,6 +4460,14 @@ def test_shell_modules_expose_explicit_bootstrap_flow_and_toolbar_bindings( ) {{ throw new Error("Expected the Code tab to expose its tooltip description."); }} + const keydownBinding = windowListeners.find( + (entry) => entry && entry.type === "keydown" + ); + if (!keydownBinding || keydownBinding.options !== true) {{ + throw new Error( + `Expected the global keydown shortcut listener to register in capture mode, received ${{JSON.stringify(windowListeners)}}.` + ); + }} """, ) diff --git a/tests/test_frontend_runtime.py b/tests/test_frontend_runtime.py index 6a3f9dd..8cdab50 100644 --- a/tests/test_frontend_runtime.py +++ b/tests/test_frontend_runtime.py @@ -7056,6 +7056,604 @@ def _write_port_layering_runtime_regression_script(tmp_path: Path) -> Path: return script_path +def _write_contraction_scene_port_layering_runtime_regression_script( + tmp_path: Path, +) -> Path: + script_path = tmp_path / "contraction_scene_port_layering_runtime_regression.mjs" + geometry_module_path = ( + REPO_ROOT + / "src" + / "tensor_network_editor" + / "app" + / "static" + / "js" + / "utils/utilitiesGeometry.js" + ) + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const geometryUrl = pathToFileURL({str(geometry_module_path)!r}).href; + const {{ createUtilityGeometryBindings }} = await import(geometryUrl); + + function createFakeElement(id, initialZIndex) {{ + let zIndex = initialZIndex; + return {{ + length: 1, + id() {{ + return id; + }}, + data(name, value) {{ + if (value === undefined) {{ + return name === "zIndex" ? zIndex : undefined; + }} + if (name === "zIndex") {{ + zIndex = value; + }} + return undefined; + }}, + }}; + }} + + const visibleDerivedTensor = {{ + id: "scene-step-ab", + name: "A-B", + position: {{ x: 100, y: 100 }}, + size: {{ width: 160, height: 84 }}, + indices: [ + {{ + id: "scene-step-ab_open", + name: "open", + dimension: 2, + offset: {{ x: 38, y: 0 }}, + metadata: {{}}, + }}, + ], + isDerived: true, + sourceTensorIds: ["tensor_a", "tensor_b"], + metadata: {{}}, + }}; + const visibleFrontTensor = {{ + id: "tensor_front", + name: "Front", + position: {{ x: 130, y: 100 }}, + size: {{ width: 140, height: 84 }}, + indices: [ + {{ + id: "front_open", + name: "front", + dimension: 2, + offset: {{ x: -38, y: 0 }}, + metadata: {{}}, + }}, + ], + isDerived: false, + sourceTensorIds: ["tensor_front"], + metadata: {{}}, + }}; + + const elementMap = new Map([ + ["scene-step-ab", createFakeElement("scene-step-ab", 10)], + ["scene-step-ab_open", createFakeElement("scene-step-ab_open", 10.2)], + ["tensor_front", createFakeElement("tensor_front", 11)], + ["front_open", createFakeElement("front_open", 11.2)], + ]); + + const state = {{ + activeTensorDrag: null, + cy: {{ + getElementById(id) {{ + return elementMap.get(id) || {{ length: 0, data() {{ return undefined; }} }}; + }}, + edges() {{ + return []; + }}, + }}, + pendingIndexId: null, + selectionIds: ["scene-step-ab"], + spec: {{ + tensors: [ + {{ + id: "tensor_a", + indices: [], + position: {{ x: 0, y: 0 }}, + }}, + {{ + id: "tensor_b", + indices: [], + position: {{ x: 0, y: 0 }}, + }}, + {{ + id: "tensor_front", + indices: visibleFrontTensor.indices, + position: visibleFrontTensor.position, + }}, + ], + }}, + tensorOrder: [], + tensorRankById: {{}}, + }}; + const runtime = {{ + asFiniteNumber(value, fallbackValue) {{ + return Number.isFinite(value) ? value : fallbackValue; + }}, + findConnectionByIndexId() {{ + return null; + }}, + findEdgeByIndexId() {{ + return null; + }}, + findHyperedgeByIndexId() {{ + return null; + }}, + findTensorById(tensorId) {{ + return ( + state.spec.tensors.find((tensor) => tensor.id === tensorId) || null + ); + }}, + getVisibleTensors() {{ + return [visibleDerivedTensor, visibleFrontTensor]; + }}, + indexLabelNodeId(indexId) {{ + return `${{indexId}}__label`; + }}, + }}; + + const geometry = createUtilityGeometryBindings({{ + ctx: {{ state }}, + state, + constants: {{ + TENSOR_WIDTH: 140, + TENSOR_HEIGHT: 84, + MIN_TENSOR_WIDTH: 96, + MIN_TENSOR_HEIGHT: 60, + INDEX_RADIUS: 10, + INDEX_PADDING: 6, + }}, + runtime, + }}); + Object.assign(runtime, geometry); + + geometry.applyTensorLayerData(); + + const selectedOpenZIndex = elementMap.get("scene-step-ab_open").data("zIndex"); + const frontTensorZIndex = elementMap.get("tensor_front").data("zIndex"); + + if (!(selectedOpenZIndex > frontTensorZIndex)) {{ + throw new Error( + `A selected derived contraction tensor should keep its open ports visible above front tensors: open=${{selectedOpenZIndex}}, front=${{frontTensorZIndex}}.` + ); + }} + """ + ), + encoding="utf-8", + ) + return script_path + + +def _write_contraction_scene_base_tensor_port_layering_runtime_regression_script( + tmp_path: Path, +) -> Path: + script_path = ( + tmp_path / "contraction_scene_base_tensor_port_layering_runtime_regression.mjs" + ) + _copy_runtime_bundle( + tmp_path, + { + "state.runtime.mjs": "state/state.js", + "utilities.runtime.mjs": "utils/utilities.js", + "historySelection.runtime.mjs": "graph/historySelection.js", + "contractionScene.runtime.mjs": "graph/contractionScene.js", + }, + _RUNTIME_EDITOR_SUPPORT_MODULES, + ) + script_path.write_text( + textwrap.dedent( + """ + import { pathToFileURL } from "node:url"; + + function createClassList() { + return { + add() {}, + remove() {}, + toggle() {}, + }; + } + + function createButton() { + return { + disabled: false, + classList: createClassList(), + addEventListener() {}, + focus() {}, + }; + } + + function createSpec() { + return { + id: "network_manual_anchor", + name: "manual-anchor", + tensors: [ + { + id: "tensor_a", + name: "A", + position: { x: 120, y: 140 }, + size: { width: 140, height: 84 }, + metadata: {}, + indices: [ + { + id: "tensor_a_left", + name: "left", + dimension: 2, + offset: { x: -38, y: 0 }, + metadata: {}, + }, + { + id: "tensor_a_bond", + name: "bond", + dimension: 3, + offset: { x: 38, y: 0 }, + metadata: {}, + }, + ], + }, + { + id: "tensor_b", + name: "B", + position: { x: 360, y: 220 }, + size: { width: 140, height: 84 }, + metadata: {}, + indices: [ + { + id: "tensor_b_bond", + name: "bond", + dimension: 3, + offset: { x: -38, y: 0 }, + metadata: {}, + }, + { + id: "tensor_b_right", + name: "carry", + dimension: 5, + offset: { x: 38, y: 0 }, + metadata: {}, + }, + ], + }, + { + id: "tensor_c", + name: "C", + position: { x: 620, y: 300 }, + size: { width: 140, height: 84 }, + metadata: {}, + indices: [ + { + id: "tensor_c_left", + name: "carry", + dimension: 5, + offset: { x: -38, y: 0 }, + metadata: {}, + }, + { + id: "tensor_c_right", + name: "right", + dimension: 7, + offset: { x: 38, y: 0 }, + metadata: {}, + }, + ], + }, + ], + groups: [], + edges: [ + { + id: "edge_ab", + name: "bond_ab", + left: { tensor_id: "tensor_a", index_id: "tensor_a_bond" }, + right: { tensor_id: "tensor_b", index_id: "tensor_b_bond" }, + metadata: {}, + }, + { + id: "edge_bc", + name: "bond_bc", + left: { tensor_id: "tensor_b", index_id: "tensor_b_right" }, + right: { tensor_id: "tensor_c", index_id: "tensor_c_left" }, + metadata: {}, + }, + ], + notes: [], + contraction_plan: { + id: "plan_chain", + name: "Chain path", + steps: [ + { + id: "step_ab", + left_operand_id: "tensor_a", + right_operand_id: "tensor_b", + }, + ], + }, + metadata: {}, + }; + } + + function createFakeElement(id) { + let zIndex = null; + const classes = new Set(); + let selected = false; + return { + length: 1, + id() { + return id; + }, + data(name, value) { + if (value === undefined) { + return name === "zIndex" ? zIndex : undefined; + } + if (name === "zIndex") { + zIndex = value; + } + return undefined; + }, + select() { + selected = true; + }, + unselect() { + selected = false; + }, + addClass(className) { + classes.add(className); + }, + removeClass(className) { + classes.delete(className); + }, + hasClass(className) { + return classes.has(className); + }, + isSelected() { + return selected; + }, + position() {}, + selectable() {}, + grabbable() {}, + }; + } + + const baseUrl = new URL("./", import.meta.url); + const [stateModule, utilitiesModule, historyModule, contractionSceneModule] = + await Promise.all([ + import(new URL("./state.runtime.mjs", baseUrl).href), + import(new URL("./utilities.runtime.mjs", baseUrl).href), + import(new URL("./historySelection.runtime.mjs", baseUrl).href), + import(new URL("./contractionScene.runtime.mjs", baseUrl).href), + ]); + + const { createInitialState } = stateModule; + const { registerUtilities } = utilitiesModule; + const { registerHistorySelection } = historyModule; + const { registerContractionScene } = contractionSceneModule; + + const ctx = { + state: createInitialState(), + constants: { + TENSOR_WIDTH: 140, + TENSOR_HEIGHT: 84, + MIN_TENSOR_WIDTH: 96, + MIN_TENSOR_HEIGHT: 60, + INDEX_RADIUS: 10, + INDEX_PADDING: 6, + NOTE_WIDTH: 220, + NOTE_HEIGHT: 120, + NOTE_MIN_WIDTH: 120, + NOTE_MIN_HEIGHT: 90, + HISTORY_LIMIT: 100, + REDO_SHORTCUT_LABEL: "Ctrl+Shift+Z", + DEFAULT_INDEX_SLOTS: [ + { x: -38, y: 0 }, + { x: 38, y: 0 }, + { x: 0, y: -24 }, + { x: 0, y: 24 }, + ], + }, + dom: { + workspace: {}, + statusMessage: { textContent: "", classList: createClassList() }, + propertiesPanel: { innerHTML: "" }, + generatedCode: { value: "" }, + engineSelect: { options: [], value: "tensornetwork" }, + collectionFormatSelect: { options: [], value: "list" }, + exportFormatSelect: { value: "py" }, + addNoteButton: createButton(), + connectButton: { classList: createClassList() }, + loadInput: {}, + undoButton: createButton(), + redoButton: createButton(), + exportButton: createButton(), + toggleLinearPeriodicButton: { classList: createClassList() }, + linearPeriodicPreviousCellButton: createButton(), + linearPeriodicCellLabel: { textContent: "" }, + linearPeriodicNextCellButton: createButton(), + templateSelect: { value: "" }, + templateParameterPanel: { hidden: true }, + templateGraphSizeLabel: { textContent: "" }, + templateGraphSizeInput: { value: "2", min: "1" }, + templateBondDimensionInput: { value: "3", min: "1" }, + templatePhysicalDimensionInput: { value: "2", min: "1" }, + insertTemplateButton: createButton(), + createGroupButton: createButton(), + helpButton: createButton(), + helpModal: { classList: createClassList() }, + helpBackdrop: createButton(), + helpCloseButton: createButton(), + canvasShell: { + getBoundingClientRect() { + return { left: 0, top: 0, width: 1000, height: 800 }; + }, + }, + groupLayer: {}, + resizeLayer: {}, + notesLayer: {}, + selectionBox: {}, + minimapCanvas: {}, + sidebar: {}, + plannerPanel: { + innerHTML: "", + querySelectorAll() { + return []; + }, + }, + generateButton: createButton(), + }, + apiGet: async () => null, + apiPost: async () => null, + window: { + structuredClone: globalThis.structuredClone, + crypto: globalThis.crypto, + setTimeout, + clearTimeout, + confirm: () => true, + }, + document: { + activeElement: null, + createElement() { + return { + value: "", + textContent: "", + selected: false, + appendChild() {}, + click() {}, + }; + }, + getElementById() { + return createButton(); + }, + querySelectorAll() { + return []; + }, + }, + cytoscape: null, + tensorWidth: (tensor) => tensor?.size?.width ?? 140, + tensorHeight: (tensor) => tensor?.size?.height ?? 84, + render: () => {}, + renderOverlayDecorations: () => {}, + renderMinimap: () => {}, + renderPlanner: () => {}, + renderSidebarTabs: () => {}, + refreshContractionAnalysis: () => {}, + syncPendingInteractionClasses: () => {}, + setActiveSidebarTab: () => {}, + updateToolbarState: () => {}, + captureEditableFocus: () => null, + restoreEditableFocus: () => {}, + }; + + registerUtilities(ctx); + registerContractionScene(ctx); + registerHistorySelection(ctx); + + ctx.state.selectedEngine = "tensornetwork"; + ctx.state.selectedCollectionFormat = "list"; + ctx.state.spec = ctx.normalizeSpec(createSpec()); + + const scene = ctx.buildContractionScene(); + if (!scene) { + throw new Error("Expected a contraction scene after the manual step."); + } + const elementMap = new Map(); + scene.tensors.forEach((tensor) => { + elementMap.set(tensor.id, createFakeElement(tensor.id)); + tensor.indices.forEach((index) => { + elementMap.set(index.id, createFakeElement(index.id)); + elementMap.set(`${index.id}__label`, createFakeElement(`${index.id}__label`)); + }); + }); + + ctx.state.cy = { + batch(action) { + action(); + }, + getElementById(id) { + return ( + elementMap.get(id) || { + length: 0, + data() { + return undefined; + }, + select() {}, + unselect() {}, + addClass() {}, + removeClass() {}, + position() {}, + selectable() {}, + grabbable() {}, + } + ); + }, + edges() { + return { + forEach() {}, + }; + }, + $(selector) { + if (selector === ":selected") { + return { + forEach(callback) { + elementMap.forEach((element) => { + if (element.isSelected()) { + callback(element); + } + }); + }, + }; + } + if (selector === ".is-selection-highlight") { + return { + forEach(callback) { + elementMap.forEach((element) => { + if (element.hasClass("is-selection-highlight")) { + callback(element); + } + }); + }, + }; + } + return { + forEach() {}, + }; + }, + }; + + ctx.bringTensorToFront("tensor_c"); + ctx.setSelection(["tensor_c"], { primaryId: "tensor_c" }); + + const selectedBaseTensor = scene.operandMap.tensor_c; + const selectedOpenPort = selectedBaseTensor.indices.find( + (index) => index.name === "right" + ); + const derivedTensor = scene.tensors.find((tensor) => tensor.isDerived); + + if (!ctx.state.tensorOrder.includes(derivedTensor.id)) { + throw new Error( + `Expected tensor layering order to track visible contraction operands, received ${JSON.stringify(ctx.state.tensorOrder)}.` + ); + } + const selectedOpenPortZIndex = elementMap + .get(selectedOpenPort.id) + .data("zIndex"); + const derivedTensorZIndex = elementMap.get(derivedTensor.id).data("zIndex"); + if (!(selectedOpenPortZIndex > derivedTensorZIndex)) { + throw new Error( + `A selected base tensor in contraction view should keep its free port visible above derived front tensors: open=${selectedOpenPortZIndex}, derived=${derivedTensorZIndex}.` + ); + } + """ + ), + encoding="utf-8", + ) + return script_path + + def _write_planner_auto_shortcut_runtime_regression_script(tmp_path: Path) -> Path: script_path = tmp_path / "planner_auto_shortcut_runtime_regression.mjs" _copy_js_modules(tmp_path, _SHORTCUT_RUNTIME_DEPENDENCY_MODULES) @@ -9574,6 +10172,15 @@ def _write_metadata_properties_runtime_regression_script(tmp_path: Path) -> Path if (!propertiesPanel.innerHTML.includes('id="add-index-to-selection-button"')) { throw new Error("Mixed selections should keep the bulk Add index action when editable tensors remain."); } + if (/id="extract-selection-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) { + throw new Error("Mixed selections with editable tensors should keep Extract enabled."); + } + if (/id="save-selection-subnetwork-library-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) { + throw new Error("Mixed selections with editable tensors should keep To Library enabled."); + } + if (/id="promote-selection-template-button"[^>]*disabled/.test(propertiesPanel.innerHTML)) { + throw new Error("Mixed selections with editable tensors should keep To Template enabled."); + } document.getElementById("add-index-to-selection-button").click(); const editableTensorAfter = ctx.state.spec.tensors.find( (candidate) => candidate.id === "tensor_a" @@ -11602,6 +12209,52 @@ def test_graph_model_layers_open_ports_below_front_tensors( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_contraction_scene_selection_keeps_derived_open_ports_visible( + tmp_path: Path, +) -> None: + script_path = _write_contraction_scene_port_layering_runtime_regression_script( + tmp_path + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The contraction-scene port layering runtime regression script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_contraction_scene_selection_keeps_base_tensor_open_ports_visible( + tmp_path: Path, +) -> None: + script_path = ( + _write_contraction_scene_base_tensor_port_layering_runtime_regression_script( + tmp_path + ) + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The contraction-scene base-tensor port layering runtime regression script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_copy_shortcut_prefers_native_text_selection_over_graph_copy( tmp_path: Path, @@ -12120,11 +12773,9 @@ def _write_utility_runtime_contract_script(tmp_path: Path) -> Path: }), parentElement: reflowLayoutShell, }, - reflowAlignLeftButton: createButton(), - reflowAlignRightButton: createButton(), - reflowAlignTopButton: createButton(), - reflowAlignMiddleButton: createButton(), - reflowAlignBottomButton: createButton(), + reflowAlignHorizontalButton: createButton(), + reflowAlignVerticalButton: createButton(), + reflowRotateSelectionButton: createButton(), reflowIndicesLeftButton: createButton(), reflowIndicesRightButton: createButton(), reflowIndicesTopButton: createButton(), @@ -12361,6 +13012,12 @@ def _write_utility_runtime_contract_script(tmp_path: Path) -> Path: if (ctx.dom.reflowAutoLayoutButton.disabled) { throw new Error("Auto layout should stay enabled when the whole graph can be arranged."); } + runtime.isLinearPeriodicMode = () => true; + runtime.updateToolbarState(); + if (ctx.dom.templateSettingsButton.disabled) { + throw new Error("Template settings should stay enabled in For mode because they only affect future insertions."); + } + runtime.isLinearPeriodicMode = () => false; ctx.state.selectionIds = ["tensor_a"]; runtime.isBenchmarkMode = () => true; runtime.getBenchmarkSession = () => ({ @@ -13236,6 +13893,7 @@ def _write_interaction_session_dependency_injection_runtime_script( generatedCode: "", selectedEngine: "quimb", selectedCollectionFormat: "dict", + includeRoundtripMetadata: true, templateDefinitions: {}, availableTemplates: [], templateCatalogWarnings: [], @@ -13332,6 +13990,9 @@ def _write_interaction_session_dependency_injection_runtime_script( if (generateCall.payload.engine !== "quimb" || generateCall.payload.collectionFormat !== "dict") { throw new Error(`Unexpected generate payload: ${JSON.stringify(generateCall.payload)}.`); } + if (generateCall.payload.includeRoundtripMetadata !== true) { + throw new Error(`Expected includeRoundtripMetadata=true in the injected generate payload, received ${JSON.stringify(generateCall.payload)}.`); + } if (dom.generatedCode.value.trim() !== "result = 1") { throw new Error(`Expected injected preview sync to receive stripped code, received ${dom.generatedCode.value}.`); } @@ -15069,6 +15730,93 @@ def _write_layout_subnetwork_runtime_regression_script(tmp_path: Path) -> Path: } } + ctx.state.selectionIds = ["tensor_a", "tensor_b"]; + ctx.state.primarySelectionId = "tensor_b"; + ctx.state.spec.tensors[0].position = { x: 100, y: 100 }; + ctx.state.spec.tensors[1].position = { x: 260, y: 220 }; + ctx.state.spec.tensors[0].indices[0].offset = { x: 20, y: -10 }; + ctx.state.spec.tensors[1].indices[0].offset = { x: 16, y: -8 }; + ctx.state.spec.tensors[1].indices[1].offset = { x: 20, y: 10 }; + ctx.applyReflowLayoutAction("align-horizontal"); + const horizontalAlignmentYs = ctx.state.spec.tensors + .slice(0, 2) + .map((tensor) => tensor.position.y); + if (!horizontalAlignmentYs.every((value) => value === horizontalAlignmentYs[0])) { + throw new Error( + `Horizontal alignment should align tensor centers on the y axis, received ${horizontalAlignmentYs.join(", ")}.` + ); + } + ctx.applyReflowLayoutAction("align-vertical"); + const verticalAlignmentXs = ctx.state.spec.tensors + .slice(0, 2) + .map((tensor) => tensor.position.x); + if (!verticalAlignmentXs.every((value) => value === verticalAlignmentXs[0])) { + throw new Error( + `Vertical alignment should align tensor centers on the x axis, received ${verticalAlignmentXs.join(", ")}.` + ); + } + + ctx.state.spec.tensors[0].position = { x: 100, y: 100 }; + ctx.state.spec.tensors[1].position = { x: 260, y: 220 }; + ctx.state.spec.tensors[0].indices[0].offset = { x: 20, y: -10 }; + ctx.state.spec.tensors[1].indices[0].offset = { x: 16, y: -8 }; + ctx.state.spec.tensors[1].indices[1].offset = { x: 20, y: 10 }; + ctx.serializeCurrentSpec(); + ctx.applyReflowLayoutAction("rotate-90"); + const tensorARotated = ctx.findTensorById("tensor_a"); + const tensorBRotated = ctx.findTensorById("tensor_b"); + if (tensorARotated.position.x !== 240 || tensorARotated.position.y !== 80) { + throw new Error( + `Rotate 90 should move tensor A clockwise around the selection center, received ${JSON.stringify(tensorARotated.position)}.` + ); + } + if (tensorBRotated.position.x !== 120 || tensorBRotated.position.y !== 240) { + throw new Error( + `Rotate 90 should move tensor B clockwise around the selection center, received ${JSON.stringify(tensorBRotated.position)}.` + ); + } + if (JSON.stringify(tensorARotated.indices[0].offset) !== JSON.stringify({ x: 10, y: 20 })) { + throw new Error( + `Rotate 90 should rotate tensor A ports, received ${JSON.stringify(tensorARotated.indices[0].offset)}.` + ); + } + if (JSON.stringify(tensorBRotated.indices[0].offset) !== JSON.stringify({ x: 8, y: 16 })) { + throw new Error( + `Rotate 90 should rotate tensor B first port, received ${JSON.stringify(tensorBRotated.indices[0].offset)}.` + ); + } + if (JSON.stringify(tensorBRotated.indices[1].offset) !== JSON.stringify({ x: -10, y: 20 })) { + throw new Error( + `Rotate 90 should rotate tensor B second port, received ${JSON.stringify(tensorBRotated.indices[1].offset)}.` + ); + } + if (ctx.state.selectionIds.join(",") !== "tensor_a,tensor_b") { + throw new Error("Rotate 90 should preserve the selected tensors."); + } + const serializedAfterRotate = ctx.serializeCurrentSpec(); + const serializedTensorA = serializedAfterRotate.network.tensors.find( + (tensor) => tensor.id === "tensor_a" + ); + const serializedTensorB = serializedAfterRotate.network.tensors.find( + (tensor) => tensor.id === "tensor_b" + ); + if ( + serializedTensorA.position.x !== 240 || serializedTensorA.position.y !== 80 + ) { + throw new Error( + `serializeCurrentSpec should invalidate its cache after layout changes for tensor A, received ${JSON.stringify(serializedTensorA.position)}.` + ); + } + if ( + serializedTensorB.position.x !== 120 || serializedTensorB.position.y !== 240 + ) { + throw new Error( + `serializeCurrentSpec should invalidate its cache after layout changes for tensor B, received ${JSON.stringify(serializedTensorB.position)}.` + ); + } + + ctx.state.selectionIds = ["tensor_a", "tensor_b", "tensor_c"]; + ctx.state.primarySelectionId = "tensor_c"; ctx.state.spec.tensors[0].position.x = 100; ctx.state.spec.tensors[1].position.x = 260; ctx.state.spec.tensors[2].position.x = 460; diff --git a/tests/test_models_validation.py b/tests/test_models_validation.py index ee76551..183270d 100644 --- a/tests/test_models_validation.py +++ b/tests/test_models_validation.py @@ -836,6 +836,39 @@ def test_validate_spec_accepts_linear_periodic_partial_carry_chain() -> None: assert validate_spec(build_linear_periodic_partial_carry_chain_spec()) == [] +def test_validate_spec_rejects_linear_periodic_previous_step_that_merges_multiple_payload_operands() -> ( + None +): + spec = build_linear_periodic_partial_carry_chain_spec() + assert spec.linear_periodic_chain is not None + periodic_cell = spec.linear_periodic_chain.periodic_cell + assert periodic_cell.contraction_plan is not None + periodic_cell.contraction_plan.steps = [ + ContractionStepSpec( + id="merge_previous_locals", + left_operand_id="periodic_previous_left_tensor", + right_operand_id="periodic_previous_right_tensor", + ), + ContractionStepSpec( + id="consume_previous_payload", + left_operand_id="__linear_previous__", + right_operand_id="merge_previous_locals", + ), + ContractionStepSpec( + id="carry_next_left", + left_operand_id="periodic_next_left_tensor", + right_operand_id="__linear_next__", + ), + ] + + issue = find_issue(validate_spec(spec), "linear-periodic-carry-codegen") + + assert issue.path == ( + "linear_periodic_chain.periodic_cell.contraction_plan.steps.consume_previous_payload" + ) + assert "one previous carry operand per step" in issue.message + + def test_build_carry_validation_context_internal_helper_collects_interface_state() -> ( None ): diff --git a/tests/test_rendering.py b/tests/test_rendering.py index c684740..dac1776 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -1,17 +1,26 @@ from __future__ import annotations import re +from math import hypot from pathlib import Path from typing import Any from xml.etree import ElementTree as ET import pytest -from tensor_network_editor.models import NetworkSpec +from tensor_network_editor.models import ( + CanvasPosition, + EdgeEndpointRef, + EdgeSpec, + IndexSpec, + NetworkSpec, + TensorSpec, +) from tensor_network_editor.rendering import ( DotRenderOptions, SvgRenderOptions, TikzRenderOptions, + _number, _SvgRenderer, render_spec_dot, render_spec_mermaid, @@ -19,7 +28,12 @@ render_spec_svg, render_spec_tikz, ) -from tests.factories import build_sample_spec, build_three_tensor_hyperedge_spec +from tensor_network_editor.templates import TemplateParameters, build_template_spec +from tests.factories import ( + build_sample_spec, + build_three_tensor_hyperedge_spec, + build_three_tensor_spec, +) def _build_colored_parallel_edge_spec() -> NetworkSpec: @@ -107,6 +121,309 @@ def _build_three_parallel_edge_spec() -> NetworkSpec: return spec +def _build_cycle_spec() -> NetworkSpec: + return NetworkSpec( + id="network_cycle", + name="cycle", + tensors=[ + TensorSpec( + id="tensor_a", + name="A", + position=CanvasPosition(x=120.0, y=120.0), + indices=[ + IndexSpec(id="tensor_a_free", name="fa", dimension=2), + IndexSpec(id="tensor_a_ab", name="ab", dimension=3), + IndexSpec(id="tensor_a_da", name="da", dimension=5), + ], + ), + TensorSpec( + id="tensor_b", + name="B", + position=CanvasPosition(x=280.0, y=120.0), + indices=[ + IndexSpec(id="tensor_b_free", name="fb", dimension=2), + IndexSpec(id="tensor_b_ab", name="ab", dimension=3), + IndexSpec(id="tensor_b_bc", name="bc", dimension=7), + ], + ), + TensorSpec( + id="tensor_c", + name="C", + position=CanvasPosition(x=280.0, y=280.0), + indices=[ + IndexSpec(id="tensor_c_free", name="fc", dimension=2), + IndexSpec(id="tensor_c_bc", name="bc", dimension=7), + IndexSpec(id="tensor_c_cd", name="cd", dimension=11), + ], + ), + TensorSpec( + id="tensor_d", + name="D", + position=CanvasPosition(x=120.0, y=280.0), + indices=[ + IndexSpec(id="tensor_d_free", name="fd", dimension=2), + IndexSpec(id="tensor_d_cd", name="cd", dimension=11), + IndexSpec(id="tensor_d_da", name="da", dimension=5), + ], + ), + ], + edges=[ + EdgeSpec( + id="edge_ab", + name="ab", + left=EdgeEndpointRef(tensor_id="tensor_a", index_id="tensor_a_ab"), + right=EdgeEndpointRef(tensor_id="tensor_b", index_id="tensor_b_ab"), + ), + EdgeSpec( + id="edge_bc", + name="bc", + left=EdgeEndpointRef(tensor_id="tensor_b", index_id="tensor_b_bc"), + right=EdgeEndpointRef(tensor_id="tensor_c", index_id="tensor_c_bc"), + ), + EdgeSpec( + id="edge_cd", + name="cd", + left=EdgeEndpointRef(tensor_id="tensor_c", index_id="tensor_c_cd"), + right=EdgeEndpointRef(tensor_id="tensor_d", index_id="tensor_d_cd"), + ), + EdgeSpec( + id="edge_da", + name="da", + left=EdgeEndpointRef(tensor_id="tensor_d", index_id="tensor_d_da"), + right=EdgeEndpointRef(tensor_id="tensor_a", index_id="tensor_a_da"), + ), + ], + ) + + +def _build_grid_export_spec() -> NetworkSpec: + tensors: list[TensorSpec] = [] + edges: list[EdgeSpec] = [] + for row_index in range(3): + for column_index in range(3): + tensor_id = f"tensor_{row_index}_{column_index}" + indices = [ + IndexSpec( + id=f"{tensor_id}_free", + name=f"f_{row_index}_{column_index}", + dimension=2, + ) + ] + if column_index < 2: + indices.append( + IndexSpec( + id=f"{tensor_id}_right", + name=f"h_{row_index}_{column_index}", + dimension=3, + ) + ) + if column_index > 0: + indices.append( + IndexSpec( + id=f"{tensor_id}_left", + name=f"h_{row_index}_{column_index - 1}", + dimension=3, + ) + ) + if row_index < 2: + indices.append( + IndexSpec( + id=f"{tensor_id}_down", + name=f"v_{row_index}_{column_index}", + dimension=5, + ) + ) + if row_index > 0: + indices.append( + IndexSpec( + id=f"{tensor_id}_up", + name=f"v_{row_index - 1}_{column_index}", + dimension=5, + ) + ) + tensors.append( + TensorSpec( + id=tensor_id, + name=f"T{row_index}{column_index}", + position=CanvasPosition( + x=120.0 + 140.0 * column_index, + y=120.0 + 140.0 * row_index, + ), + indices=indices, + ) + ) + for row_index in range(3): + for column_index in range(2): + left_tensor_id = f"tensor_{row_index}_{column_index}" + right_tensor_id = f"tensor_{row_index}_{column_index + 1}" + edge_name = f"h_{row_index}_{column_index}" + edges.append( + EdgeSpec( + id=f"edge_{edge_name}", + name=edge_name, + left=EdgeEndpointRef( + tensor_id=left_tensor_id, + index_id=f"{left_tensor_id}_right", + ), + right=EdgeEndpointRef( + tensor_id=right_tensor_id, + index_id=f"{right_tensor_id}_left", + ), + ) + ) + for row_index in range(2): + for column_index in range(3): + top_tensor_id = f"tensor_{row_index}_{column_index}" + bottom_tensor_id = f"tensor_{row_index + 1}_{column_index}" + edge_name = f"v_{row_index}_{column_index}" + edges.append( + EdgeSpec( + id=f"edge_{edge_name}", + name=edge_name, + left=EdgeEndpointRef( + tensor_id=top_tensor_id, + index_id=f"{top_tensor_id}_down", + ), + right=EdgeEndpointRef( + tensor_id=bottom_tensor_id, + index_id=f"{bottom_tensor_id}_up", + ), + ) + ) + return NetworkSpec( + id="network_grid_export", + name="grid-export", + tensors=tensors, + edges=edges, + ) + + +def _build_vertical_three_tensor_spec() -> NetworkSpec: + spec = build_three_tensor_spec() + spec.tensors[0].position = CanvasPosition(x=240.0, y=80.0) + spec.tensors[1].position = CanvasPosition(x=240.0, y=240.0) + spec.tensors[2].position = CanvasPosition(x=240.0, y=400.0) + return spec + + +def _build_vertical_three_tensor_named_hint_spec() -> NetworkSpec: + spec = _build_vertical_three_tensor_spec() + spec.tensors[0].indices[0].name = "up" + return spec + + +def _build_diagonal_three_tensor_spec() -> NetworkSpec: + spec = build_three_tensor_spec() + spec.tensors[0].position = CanvasPosition(x=80.0, y=80.0) + spec.tensors[1].position = CanvasPosition(x=240.0, y=240.0) + spec.tensors[2].position = CanvasPosition(x=400.0, y=400.0) + return spec + + +def _build_rotated_grid_export_spec() -> NetworkSpec: + spec = _build_grid_export_spec() + center = CanvasPosition(x=240.0, y=240.0) + column_step = CanvasPosition(x=100.0, y=100.0) + row_step = CanvasPosition(x=-100.0, y=100.0) + for tensor in spec.tensors: + _, row_text, column_text = tensor.id.split("_") + row_index = int(row_text) + column_index = int(column_text) + tensor.position = CanvasPosition( + x=center.x + + (column_index - 1) * column_step.x + + (row_index - 1) * row_step.x, + y=center.y + + (column_index - 1) * column_step.y + + (row_index - 1) * row_step.y, + ) + return spec + + +def _build_vertical_mpo_export_spec() -> NetworkSpec: + spec = build_template_spec( + "mpo", + TemplateParameters( + graph_size=4, + bond_dimension=3, + physical_dimension=2, + boundary_condition="open", + j=1.0, + h=1.0, + ), + ) + for tensor_index, tensor in enumerate(spec.tensors): + tensor.position = CanvasPosition(x=240.0, y=80.0 + tensor_index * 160.0) + return spec + + +def _build_generic_export_spec() -> NetworkSpec: + return NetworkSpec( + id="network_generic_export", + name="generic-export", + tensors=[ + TensorSpec( + id="tensor_center", + name="Center", + position=CanvasPosition(x=220.0, y=200.0), + indices=[ + IndexSpec(id="tensor_center_free", name="free", dimension=2), + IndexSpec(id="tensor_center_right", name="r", dimension=3), + IndexSpec(id="tensor_center_down", name="d", dimension=5), + ], + ), + TensorSpec( + id="tensor_right", + name="Right", + position=CanvasPosition(x=360.0, y=180.0), + indices=[ + IndexSpec(id="tensor_right_left", name="r", dimension=3), + ], + ), + TensorSpec( + id="tensor_down", + name="Down", + position=CanvasPosition(x=260.0, y=340.0), + indices=[ + IndexSpec(id="tensor_down_up", name="d", dimension=5), + ], + ), + ], + edges=[ + EdgeSpec( + id="edge_center_right", + name="r", + left=EdgeEndpointRef( + tensor_id="tensor_center", index_id="tensor_center_right" + ), + right=EdgeEndpointRef( + tensor_id="tensor_right", index_id="tensor_right_left" + ), + ), + EdgeSpec( + id="edge_center_down", + name="d", + left=EdgeEndpointRef( + tensor_id="tensor_center", index_id="tensor_center_down" + ), + right=EdgeEndpointRef( + tensor_id="tensor_down", index_id="tensor_down_up" + ), + ), + ], + ) + + +def _dot(left: CanvasPosition, right: CanvasPosition) -> float: + return left.x * right.x + left.y * right.y + + +def _normalize(vector: CanvasPosition) -> CanvasPosition: + magnitude = hypot(vector.x, vector.y) + assert magnitude > 1e-9 + return CanvasPosition(x=vector.x / magnitude, y=vector.y / magnitude) + + def _svg_text_content(svg: str) -> list[str]: root = ET.fromstring(svg) text_nodes = root.findall(".//{http://www.w3.org/2000/svg}text") @@ -166,6 +483,146 @@ def test_academic_svg_and_tikz_exports_use_tensor_circles_and_dangling_ports() - assert r"\draw[tne open index]" in tikz +def test_export_geometry_prefers_perpendicular_free_index_directions_for_linear_chain() -> ( + None +): + spec = build_three_tensor_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + left_tensor = spec.tensors[0] + left_index = left_tensor.indices[0] + direction = renderer._index_direction(left_tensor, left_index) + source = renderer.connection_point(left_tensor, left_index) + target = renderer.open_index_endpoint(left_tensor, left_index) + + assert abs(direction.x) < 0.25 + assert abs(direction.y) > 0.9 + assert hypot(target.x - source.x, target.y - source.y) == pytest.approx( + 2.0 * renderer.tensor_radius(left_tensor) + ) + + +def test_export_geometry_respects_vertical_linear_chain_orientation() -> None: + spec = _build_vertical_three_tensor_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + first_tensor = spec.tensors[0] + free_index = first_tensor.indices[0] + direction = renderer._index_direction(first_tensor, free_index) + + assert abs(direction.x) > 0.9 + assert abs(direction.y) < 0.25 + + +def test_export_geometry_prefers_linear_component_orientation_over_named_hints() -> ( + None +): + spec = _build_vertical_three_tensor_named_hint_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + first_tensor = spec.tensors[0] + free_index = first_tensor.indices[0] + direction = renderer._index_direction(first_tensor, free_index) + + assert abs(direction.x) > 0.9 + assert abs(direction.y) < 0.25 + + +def test_export_geometry_respects_diagonal_linear_chain_orientation() -> None: + spec = _build_diagonal_three_tensor_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + + first_tensor = spec.tensors[0] + free_index = first_tensor.indices[0] + direction = renderer._index_direction(first_tensor, free_index) + chain_axis = _normalize(CanvasPosition(x=1.0, y=1.0)) + diagonal_perpendicular = _normalize(CanvasPosition(x=-1.0, y=1.0)) + + assert abs(_dot(direction, chain_axis)) < 0.25 + assert abs(_dot(direction, diagonal_perpendicular)) > 0.9 + + +def test_export_geometry_prefers_vertical_mpo_component_orientation_over_index_names() -> ( + None +): + spec = _build_vertical_mpo_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + first_tensor = spec.tensors[0] + bra_index = next(index for index in first_tensor.indices if index.name == "bra") + ket_index = next(index for index in first_tensor.indices if index.name == "ket") + bra_direction = renderer._index_direction(first_tensor, bra_index) + ket_direction = renderer._index_direction(first_tensor, ket_index) + + assert abs(bra_direction.x) > 0.9 + assert abs(ket_direction.x) > 0.9 + assert abs(bra_direction.y) < 0.25 + assert abs(ket_direction.y) < 0.25 + assert _dot(bra_direction, ket_direction) < -0.85 + + +def test_export_geometry_points_cycle_free_indices_outward() -> None: + spec = _build_cycle_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + cycle_center = CanvasPosition(x=200.0, y=200.0) + + for tensor in spec.tensors: + free_index = tensor.indices[0] + direction = renderer._index_direction(tensor, free_index) + radial = _normalize( + CanvasPosition( + x=tensor.position.x - cycle_center.x, + y=tensor.position.y - cycle_center.y, + ) + ) + assert _dot(direction, radial) > 0.85 + + +def test_export_geometry_points_grid_boundary_free_indices_outward() -> None: + spec = _build_grid_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + expectations = { + "tensor_0_1": CanvasPosition(x=0.0, y=-1.0), + "tensor_1_0": CanvasPosition(x=-1.0, y=0.0), + "tensor_1_2": CanvasPosition(x=1.0, y=0.0), + "tensor_2_1": CanvasPosition(x=0.0, y=1.0), + } + + for tensor_id, expected_direction in expectations.items(): + tensor = next(tensor for tensor in spec.tensors if tensor.id == tensor_id) + free_index = tensor.indices[0] + direction = renderer._index_direction(tensor, free_index) + assert _dot(direction, expected_direction) > 0.85 + + +def test_export_geometry_points_rotated_grid_boundary_free_indices_outward() -> None: + spec = _build_rotated_grid_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + expectations = { + "tensor_0_1": _normalize(CanvasPosition(x=1.0, y=-1.0)), + "tensor_1_0": _normalize(CanvasPosition(x=-1.0, y=-1.0)), + "tensor_1_2": _normalize(CanvasPosition(x=1.0, y=1.0)), + "tensor_2_1": _normalize(CanvasPosition(x=-1.0, y=1.0)), + } + + for tensor_id, expected_direction in expectations.items(): + tensor = next(tensor for tensor in spec.tensors if tensor.id == tensor_id) + free_index = tensor.indices[0] + direction = renderer._index_direction(tensor, free_index) + assert _dot(direction, expected_direction) > 0.85 + + +def test_export_geometry_generic_free_indices_point_away_from_local_neighbors() -> None: + spec = _build_generic_export_spec() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + center_tensor = spec.tensors[0] + free_index = center_tensor.indices[0] + + direction = renderer._index_direction(center_tensor, free_index) + away_from_neighbors = _normalize(CanvasPosition(x=-180.0, y=-120.0)) + + assert _dot(direction, away_from_neighbors) > 0.75 + + def test_academic_svg_tikz_and_dot_preserve_entity_colors_and_parallel_edges() -> None: spec = _build_colored_parallel_edge_spec() @@ -283,12 +740,20 @@ def test_academic_parallel_edges_curve_far_enough_to_separate_three_bonds() -> N def test_academic_edges_reach_tensor_centers_in_svg_and_tikz() -> None: spec = _assign_demo_index_offsets() - edge_render_infos = _SvgRenderer(spec, SvgRenderOptions())._edge_render_infos() + renderer = _SvgRenderer(spec, SvgRenderOptions()) + edge_render_infos = renderer._edge_render_infos() + bounds = renderer._compute_bounds(edge_render_infos) tikz = render_spec_tikz(spec) assert edge_render_infos[0].source == spec.tensors[0].position assert edge_render_infos[0].target == spec.tensors[1].position - assert "(150, 116) -- (390, 116)" in tikz + expected_segment = ( + f"({_number(edge_render_infos[0].source.x - bounds.x1)}, " + f"{_number(bounds.y2 - edge_render_infos[0].source.y)}) -- " + f"({_number(edge_render_infos[0].target.x - bounds.x1)}, " + f"{_number(bounds.y2 - edge_render_infos[0].target.y)})" + ) + assert expected_segment in tikz def test_academic_svg_renderer_can_hide_tensor_index_and_bond_labels() -> None: From 7dd05407377be021a68c0d7f09a07e368b2355de Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 14:28:23 +0200 Subject: [PATCH 02/23] Version update --- CHANGELOG.md | 2 ++ CITATION.cff | 4 ++-- src/tensor_network_editor/_version.py | 2 +- tests/test_app_routes.py | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0ec7c6..f740f21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +## [0.5.0] - 2026-04-30 + ### Changed - The browser editor's `Info` help panel now mentions the full current export diff --git a/CITATION.cff b/CITATION.cff index 9e7717d..061f699 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,8 +5,8 @@ type: software authors: - family-names: "Mata Ali" given-names: "Alejandro" -version: "0.4.0" -date-released: "2026-04-25" +version: "0.5.0" +date-released: "2026-04-30" repository-code: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" url: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" license: "MIT" diff --git a/src/tensor_network_editor/_version.py b/src/tensor_network_editor/_version.py index 3a58b17..bcfe245 100644 --- a/src/tensor_network_editor/_version.py +++ b/src/tensor_network_editor/_version.py @@ -4,4 +4,4 @@ from typing import Final -__version__: Final[str] = "0.4.0" +__version__: Final[str] = "0.5.0" diff --git a/tests/test_app_routes.py b/tests/test_app_routes.py index d36d7e8..ed18170 100644 --- a/tests/test_app_routes.py +++ b/tests/test_app_routes.py @@ -12,7 +12,7 @@ import pytest -from tensor_network_editor import generate_code +from tensor_network_editor import __version__, generate_code from tensor_network_editor.analysis import analyze_contraction from tensor_network_editor.app._protocol import JsonDict from tensor_network_editor.app.routes import handle_bootstrap @@ -63,7 +63,7 @@ def test_bootstrap_returns_session_contract( } assert payload["app_metadata"] == { "repository_url": "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor", - "version": "0.4.0", + "version": __version__, "license_name": "MIT", "author_name": "Alejandro Mata Ali", } From 8a659da16c32b6dce2a31a1200f7ebc73022ffac Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 14:30:51 +0200 Subject: [PATCH 03/23] Fix pyright typing in codegen test doubles --- tests/codegen/test_generators.py | 62 +++++++++++++++++++++++--------- tests/test_protocol.py | 2 ++ 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/tests/codegen/test_generators.py b/tests/codegen/test_generators.py index de1aa60..c3832d6 100644 --- a/tests/codegen/test_generators.py +++ b/tests/codegen/test_generators.py @@ -381,15 +381,28 @@ def reattach_edges(self, override: bool = False) -> None: edge.axis2 = SimpleNamespace(name=axis_name) self.edges_by_axis_name[axis_name] = edge else: - self.edges_by_axis_name[axis_name] = ( - _FakeTensorKrowchEdge.from_endpoints( - node1=self if owner_is_node1 else other_node, - axis1_name=axis_name if owner_is_node1 else other_axis_name, - node2=other_node if owner_is_node1 else self, - axis2_name=other_axis_name if owner_is_node1 else axis_name, - origin=edge.origin, + if owner_is_node1: + self.edges_by_axis_name[axis_name] = ( + _FakeTensorKrowchEdge.from_endpoints( + node1=self, + axis1_name=axis_name, + node2=other_node, + axis2_name=other_axis_name, + origin=edge.origin, + ) + ) + else: + assert other_node is not None + assert other_axis_name is not None + self.edges_by_axis_name[axis_name] = ( + _FakeTensorKrowchEdge.from_endpoints( + node1=other_node, + axis1_name=other_axis_name, + node2=self, + axis2_name=axis_name, + origin=edge.origin, + ) ) - ) self.axis_is_node1_by_axis_name[axis_name] = owner_is_node1 self.pending_edges_by_axis_name = {} @@ -460,6 +473,23 @@ def contract_between( return result +class _FakeTorchModule(ModuleType): + """Tiny ``torch`` double for generated-code regression tests.""" + + float32: object + + def __init__(self) -> None: + super().__init__("torch") + self.float32 = object() + + @staticmethod + def zeros( + shape: tuple[int, ...], + dtype: object | None = None, + ) -> tuple[tuple[int, ...], object | None]: + return (shape, dtype) + + def _deduplicate_fake_tensorkrowch_axis_names( axis_names: tuple[str, ...], ) -> tuple[str, ...]: @@ -1629,9 +1659,7 @@ def test_linear_periodic_carry_tensorkrowch_codegen_tracks_boundary_edges_withou build_linear_periodic_carry_chain_spec(), engine=EngineName.TENSORKROWCH, ) - fake_torch = ModuleType("torch") - fake_torch.float32 = object() - fake_torch.zeros = lambda shape, dtype=None: (shape, dtype) + fake_torch = _FakeTorchModule() fake_tensorkrowch = _FakeTensorKrowchModule() with patch.dict( @@ -1658,6 +1686,8 @@ def test_linear_periodic_carry_tensorkrowch_codegen_executes_when_periodic_cell_ None ): spec = build_linear_periodic_carry_chain_spec() + assert spec.linear_periodic_chain is not None + assert spec.linear_periodic_chain.periodic_cell.contraction_plan is not None spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [ ContractionStepSpec( id="periodic_contract_internal_first", @@ -1676,9 +1706,7 @@ def test_linear_periodic_carry_tensorkrowch_codegen_executes_when_periodic_cell_ ), ] result = generate_code(spec, engine=EngineName.TENSORKROWCH) - fake_torch = ModuleType("torch") - fake_torch.float32 = object() - fake_torch.zeros = lambda shape, dtype=None: (shape, dtype) + fake_torch = _FakeTorchModule() fake_tensorkrowch = _FakeTensorKrowchModule() with patch.dict( @@ -1698,6 +1726,8 @@ def test_linear_periodic_carry_tensorkrowch_codegen_materializes_result_edges_wi None ): spec = build_linear_periodic_carry_chain_spec() + assert spec.linear_periodic_chain is not None + assert spec.linear_periodic_chain.periodic_cell.contraction_plan is not None spec.linear_periodic_chain.periodic_cell.contraction_plan.steps = [ ContractionStepSpec( id="periodic_contract_internal_first", @@ -1716,9 +1746,7 @@ def test_linear_periodic_carry_tensorkrowch_codegen_materializes_result_edges_wi ), ] result = generate_code(spec, engine=EngineName.TENSORKROWCH) - fake_torch = ModuleType("torch") - fake_torch.float32 = object() - fake_torch.zeros = lambda shape, dtype=None: (shape, dtype) + fake_torch = _FakeTorchModule() fake_tensorkrowch = _FakeTensorKrowchModule() with patch.dict( diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f0daa36..cd64cec 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -88,6 +88,7 @@ def test_parse_codegen_request_uses_defaults_when_optional_fields_are_missing( serialized_spec=serialized_sample_spec, engine=EngineName.EINSUM_TORCH, collection_format=TensorCollectionFormat.DICT, + include_roundtrip_metadata=True, ) @@ -111,6 +112,7 @@ def test_parse_codegen_request_honors_explicit_engine_and_collection_format( serialized_spec=serialized_sample_spec, engine=EngineName.QUIMB, collection_format=TensorCollectionFormat.MATRIX, + include_roundtrip_metadata=True, ) From 87fdfaf746e96ea063ecf504d88a2a63a7385add Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 14:36:04 +0200 Subject: [PATCH 04/23] Stabilize app asset DOM ref assertions --- session.log.1 | 14 ++++++ session.log.2 | 10 ++++ tests/test_app_assets.py | 106 +++++++++++++++++++++++++++------------ 3 files changed, 99 insertions(+), 31 deletions(-) create mode 100644 session.log.1 create mode 100644 session.log.2 diff --git a/session.log.1 b/session.log.1 new file mode 100644 index 0000000..057d98a --- /dev/null +++ b/session.log.1 @@ -0,0 +1,14 @@ +2026-04-30 14:34:29,668 DEBUG tensor_network_editor: Runtime diagnostics: python=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\.venv\Scripts\python.exe cwd=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor package=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\src\tensor_network_editor version=0.5.0 current_checkout=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor editable_install=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor +2026-04-30 14:34:29,669 DEBUG tensor_network_editor.cli: Resolved CLI command arguments command=edit engine=tensorkrowch python_import_mode=static python_reconstruction_level=auto +2026-04-30 14:34:29,669 DEBUG tensor_network_editor.cli: CLI command finished command=edit outcome=success elapsed_ms=2 +2026-04-30 14:34:44,092 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 +2026-04-30 14:34:44,092 DEBUG tensor_network_editor.editor: Editor launch started engine=tensorkrowch mode=dark +2026-04-30 14:34:44,092 DEBUG tensor_network_editor.editor: Editor launch finished engine=tensorkrowch mode=dark outcome=success elapsed_ms=0 +2026-04-30 14:35:01,638 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 +2026-04-30 14:35:01,638 DEBUG tensor_network_editor.editor: Editor launch started engine=einsum_numpy mode=colorblind spec_mode=normal edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 +2026-04-30 14:35:01,638 DEBUG tensor_network_editor.editor: Editor launch finished engine=einsum_numpy mode=colorblind spec_mode=normal outcome=success elapsed_ms=0 edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 +2026-04-30 14:35:24,584 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=10485760 backup_count=5 +2026-04-30 14:35:24,584 DEBUG tensor_network_editor.cli: CLI command started command=edit +2026-04-30 14:35:24,585 DEBUG tensor_network_editor: Runtime diagnostics: python=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\.venv\Scripts\python.exe cwd=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor package=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\src\tensor_network_editor version=0.5.0 current_checkout=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor editable_install=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor +2026-04-30 14:35:24,585 DEBUG tensor_network_editor.cli: Resolved CLI command arguments command=edit engine=tensorkrowch python_import_mode=static python_reconstruction_level=auto +2026-04-30 14:35:24,585 DEBUG tensor_network_editor.cli: CLI command finished command=edit outcome=success elapsed_ms=1 diff --git a/session.log.2 b/session.log.2 new file mode 100644 index 0000000..33b0bd7 --- /dev/null +++ b/session.log.2 @@ -0,0 +1,10 @@ +2026-04-30 14:34:02,852 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 +2026-04-30 14:34:02,852 DEBUG tensor_network_editor.editor: Editor launch started engine=einsum_numpy mode=colorblind spec_mode=normal edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 +2026-04-30 14:34:02,852 DEBUG tensor_network_editor.editor: Editor launch finished engine=einsum_numpy mode=colorblind spec_mode=normal outcome=success elapsed_ms=0 edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 +2026-04-30 14:34:29,663 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=10485760 backup_count=5 +2026-04-30 14:34:29,664 DEBUG tensor_network_editor.cli: CLI command started command=edit +2026-04-30 14:34:29,665 DEBUG tensor_network_editor: Runtime diagnostics: python=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\.venv\Scripts\python.exe cwd=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor package=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\src\tensor_network_editor version=0.5.0 current_checkout=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor editable_install=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor +2026-04-30 14:34:29,665 DEBUG tensor_network_editor.cli: Resolved CLI command arguments command=edit engine=tensorkrowch python_import_mode=static python_reconstruction_level=auto +2026-04-30 14:34:29,665 DEBUG tensor_network_editor.cli: CLI command finished command=edit outcome=success elapsed_ms=1 +2026-04-30 14:34:29,667 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 +2026-04-30 14:34:29,668 DEBUG tensor_network_editor.cli: CLI command started command=edit diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py index d4c384f..cafc22b 100644 --- a/tests/test_app_assets.py +++ b/tests/test_app_assets.py @@ -2469,6 +2469,7 @@ def test_template_management_assets_expose_toolbar_controls_and_routes( session_template_manager_body = request_text( f"{editor_server.base_url}/js/session/sessionTemplateManager.js" ) + dom_body_without_whitespace = "".join(dom_body.split()) assert re.search( r'', @@ -2635,64 +2636,107 @@ def test_template_management_assets_expose_toolbar_controls_and_routes( assert "Support on YouTube" in about_section.group("body") assert 'href="https://www.youtube.com/@whenphysics"' in about_section.group("body") assert ( - 'templateSettingsButton: document.getElementById("template-settings-button")' - in dom_body + 'templateSettingsButton: document.getElementById("template-settings-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'templateSettingsPopover: document.getElementById("template-settings-popover")' - in dom_body + 'templateSettingsPopover: document.getElementById("template-settings-popover")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'templateParameterPanel: document.getElementById("template-parameter-panel")' - in dom_body + 'templateParameterPanel: document.getElementById("template-parameter-panel")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'templateLoadInput: document.getElementById("template-load-input")' in dom_body + 'templateLoadInput: document.getElementById("template-load-input")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowLayoutPopover: document.getElementById("reflow-layout-popover")' - in dom_body + 'reflowLayoutPopover: document.getElementById("reflow-layout-popover")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowAlignHorizontalButton: document.getElementById("reflow-align-horizontal-button")' - in dom_body + 'reflowAlignHorizontalButton: document.getElementById("reflow-align-horizontal-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowAlignVerticalButton: document.getElementById("reflow-align-vertical-button")' - in dom_body + 'reflowAlignVerticalButton: document.getElementById("reflow-align-vertical-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowRotateSelectionButton: document.getElementById("reflow-rotate-selection-button")' - in dom_body + 'reflowRotateSelectionButton: document.getElementById("reflow-rotate-selection-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowIndicesLeftButton: document.getElementById("reflow-indices-left-button")' - in dom_body + 'reflowIndicesLeftButton: document.getElementById("reflow-indices-left-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowArrangeChainButton: document.getElementById("reflow-arrange-chain-button")' - in dom_body + 'reflowArrangeChainButton: document.getElementById("reflow-arrange-chain-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'reflowSnapGridButton: document.getElementById("reflow-snap-grid-button")' - in dom_body + 'reflowSnapGridButton: document.getElementById("reflow-snap-grid-button")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'aboutSchemaVersion: document.getElementById("about-schema-version")' - in dom_body + 'aboutSchemaVersion: document.getElementById("about-schema-version")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) assert ( - 'templateManagerModal: document.getElementById("template-manager-modal")' - in dom_body + 'templateManagerModal: document.getElementById("template-manager-modal")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) - assert "templateManagerSaveButton: document.getElementById(" in dom_body - assert "templateManagerCloseButton: document.getElementById(" in dom_body - assert "templateManagerDiscardButton: document.getElementById(" in dom_body assert ( - 'templateCatalogWarning: document.getElementById("template-catalog-warning")' - in dom_body + "templateManagerSaveButton: document.getElementById(".replace(" ", "") + in dom_body_without_whitespace + ) + assert ( + "templateManagerCloseButton: document.getElementById(".replace(" ", "") + in dom_body_without_whitespace + ) + assert ( + "templateManagerDiscardButton: document.getElementById(".replace(" ", "") + in dom_body_without_whitespace + ) + assert ( + 'templateCatalogWarning: document.getElementById("template-catalog-warning")'.replace( + " ", "" + ) + in dom_body_without_whitespace + ) + assert ( + 'helpSharedHeader: document.getElementById("help-shared-header")'.replace( + " ", "" + ) + in dom_body_without_whitespace ) - assert 'helpSharedHeader: document.getElementById("help-shared-header")' in dom_body assert ".help-about-grid {" in body assert "grid-template-columns: repeat(3, minmax(0, 1fr));" in body assert ".help-dialog-close {" in body From fbb79e14e15848b152f36dd18e87c0e137e56066 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 14:47:09 +0200 Subject: [PATCH 05/23] Removed logs --- session.log.1 | 14 -------------- session.log.2 | 10 ---------- 2 files changed, 24 deletions(-) delete mode 100644 session.log.1 delete mode 100644 session.log.2 diff --git a/session.log.1 b/session.log.1 deleted file mode 100644 index 057d98a..0000000 --- a/session.log.1 +++ /dev/null @@ -1,14 +0,0 @@ -2026-04-30 14:34:29,668 DEBUG tensor_network_editor: Runtime diagnostics: python=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\.venv\Scripts\python.exe cwd=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor package=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\src\tensor_network_editor version=0.5.0 current_checkout=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor editable_install=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor -2026-04-30 14:34:29,669 DEBUG tensor_network_editor.cli: Resolved CLI command arguments command=edit engine=tensorkrowch python_import_mode=static python_reconstruction_level=auto -2026-04-30 14:34:29,669 DEBUG tensor_network_editor.cli: CLI command finished command=edit outcome=success elapsed_ms=2 -2026-04-30 14:34:44,092 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 -2026-04-30 14:34:44,092 DEBUG tensor_network_editor.editor: Editor launch started engine=tensorkrowch mode=dark -2026-04-30 14:34:44,092 DEBUG tensor_network_editor.editor: Editor launch finished engine=tensorkrowch mode=dark outcome=success elapsed_ms=0 -2026-04-30 14:35:01,638 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 -2026-04-30 14:35:01,638 DEBUG tensor_network_editor.editor: Editor launch started engine=einsum_numpy mode=colorblind spec_mode=normal edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 -2026-04-30 14:35:01,638 DEBUG tensor_network_editor.editor: Editor launch finished engine=einsum_numpy mode=colorblind spec_mode=normal outcome=success elapsed_ms=0 edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 -2026-04-30 14:35:24,584 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=10485760 backup_count=5 -2026-04-30 14:35:24,584 DEBUG tensor_network_editor.cli: CLI command started command=edit -2026-04-30 14:35:24,585 DEBUG tensor_network_editor: Runtime diagnostics: python=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\.venv\Scripts\python.exe cwd=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor package=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\src\tensor_network_editor version=0.5.0 current_checkout=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor editable_install=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor -2026-04-30 14:35:24,585 DEBUG tensor_network_editor.cli: Resolved CLI command arguments command=edit engine=tensorkrowch python_import_mode=static python_reconstruction_level=auto -2026-04-30 14:35:24,585 DEBUG tensor_network_editor.cli: CLI command finished command=edit outcome=success elapsed_ms=1 diff --git a/session.log.2 b/session.log.2 deleted file mode 100644 index 33b0bd7..0000000 --- a/session.log.2 +++ /dev/null @@ -1,10 +0,0 @@ -2026-04-30 14:34:02,852 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 -2026-04-30 14:34:02,852 DEBUG tensor_network_editor.editor: Editor launch started engine=einsum_numpy mode=colorblind spec_mode=normal edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 -2026-04-30 14:34:02,852 DEBUG tensor_network_editor.editor: Editor launch finished engine=einsum_numpy mode=colorblind spec_mode=normal outcome=success elapsed_ms=0 edge_count=1 group_count=1 hyperedge_count=0 note_count=1 tensor_count=2 -2026-04-30 14:34:29,663 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=10485760 backup_count=5 -2026-04-30 14:34:29,664 DEBUG tensor_network_editor.cli: CLI command started command=edit -2026-04-30 14:34:29,665 DEBUG tensor_network_editor: Runtime diagnostics: python=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\.venv\Scripts\python.exe cwd=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor package=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\src\tensor_network_editor version=0.5.0 current_checkout=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor editable_install=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor -2026-04-30 14:34:29,665 DEBUG tensor_network_editor.cli: Resolved CLI command arguments command=edit engine=tensorkrowch python_import_mode=static python_reconstruction_level=auto -2026-04-30 14:34:29,665 DEBUG tensor_network_editor.cli: CLI command finished command=edit outcome=success elapsed_ms=1 -2026-04-30 14:34:29,667 INFO tensor_network_editor: Configured persistent file logging path=C:\Users\alejandro.mata\Documents\Tensor-Network-Editor\session.log max_bytes=2048 backup_count=7 -2026-04-30 14:34:29,668 DEBUG tensor_network_editor.cli: CLI command started command=edit From ef0606dc2433b3f6c3d264cdc4e333a8a4498240 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 14:58:53 +0200 Subject: [PATCH 06/23] Removed dead code --- CHANGELOG.md | 7 + .../codegen/backends/einsum.py | 15 -- .../internal/_logging.py | 29 +--- .../internal/modes/_grid_periodic.py | 8 - .../internal/modes/_tree_periodic.py | 8 - .../internal/templates/_template_catalog.py | 145 +++++++----------- src/tensor_network_editor/rendering.py | 22 --- 7 files changed, 66 insertions(+), 168 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f740f21..d22cc68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Changed + +- Removed a few unused internal helpers from logging, periodic-mode utilities, + rendering, and einsum code generation, and deduplicated built-in template + defaults so the catalog now keeps each template's default parameters in one + shared definition. + ## [0.5.0] - 2026-04-30 ### Changed diff --git a/src/tensor_network_editor/codegen/backends/einsum.py b/src/tensor_network_editor/codegen/backends/einsum.py index 39a3b69..98063c4 100644 --- a/src/tensor_network_editor/codegen/backends/einsum.py +++ b/src/tensor_network_editor/codegen/backends/einsum.py @@ -27,7 +27,6 @@ from ..shared.common import ( CodeSection, PreparedNetwork, - PreparedTensor, container_name_for_format, prepare_network, render_code_sections, @@ -704,17 +703,3 @@ def _render_remaining_label_sequence( if use_string_labels: return [symbol_map[label] for label in labels] return [f"label_{label_to_int[label]}" for label in labels] - - @staticmethod - def _build_equation( - tensors: list[PreparedTensor], - output_labels: list[str], - symbol_map: dict[str, str], - ) -> str: - """Build a standard einsum equation string for the prepared tensors.""" - input_terms = [ - "".join(symbol_map[index.label] for index in tensor.indices) - for tensor in tensors - ] - output_term = "".join(symbol_map[label] for label in output_labels) - return ",".join(input_terms) + "->" + output_term diff --git a/src/tensor_network_editor/internal/_logging.py b/src/tensor_network_editor/internal/_logging.py index 54d08ed..f0f28f0 100644 --- a/src/tensor_network_editor/internal/_logging.py +++ b/src/tensor_network_editor/internal/_logging.py @@ -116,13 +116,6 @@ def frontend_payload(self) -> dict[str, object]: } -class ContextFormatter(logging.Formatter): - """Formatter used by package handlers.""" - - def format(self, record: logging.LogRecord) -> str: - return super().format(record) - - @contextmanager def package_logging_scope( requested_level_name: str | None, @@ -401,24 +394,6 @@ def validate_positive_log_setting(value: int, *, name: str) -> int: return value -def _merge_context_fields(record: logging.LogRecord) -> dict[str, str]: - merged: dict[str, str] = dict(_LOG_CONTEXT.get()) - record_context = getattr(record, "tne_context", None) - if isinstance(record_context, dict): - for key, value in record_context.items(): - normalized = _normalize_context_value(value) - if normalized is not None: - merged[key] = normalized - ordered: dict[str, str] = {} - for key in _CONTEXT_FIELD_ORDER: - value = merged.pop(key, None) - if isinstance(value, str): - ordered[key] = value - for key in sorted(merged): - ordered[key] = merged[key] - return ordered - - def _resolve_effective_log_level_name( requested_level_name: str | None, *, @@ -443,7 +418,7 @@ def _create_stream_handler(level_name: str) -> logging.Handler: stream_handler = logging.StreamHandler() stream_handler.name = _STREAM_HANDLER_NAME stream_handler.setLevel(LOG_LEVEL_VALUES[level_name]) - stream_handler.setFormatter(ContextFormatter(_STREAM_FORMAT)) + stream_handler.setFormatter(logging.Formatter(_STREAM_FORMAT)) return stream_handler @@ -462,7 +437,7 @@ def _create_file_handler( ) file_handler.name = f"{_FILE_HANDLER_NAME_PREFIX}{path}" file_handler.setLevel(LOG_LEVEL_VALUES[level_name]) - file_handler.setFormatter(ContextFormatter(_FILE_FORMAT)) + file_handler.setFormatter(logging.Formatter(_FILE_FORMAT)) return file_handler diff --git a/src/tensor_network_editor/internal/modes/_grid_periodic.py b/src/tensor_network_editor/internal/modes/_grid_periodic.py index 51644c3..afa41db 100644 --- a/src/tensor_network_editor/internal/modes/_grid_periodic.py +++ b/src/tensor_network_editor/internal/modes/_grid_periodic.py @@ -29,9 +29,6 @@ GridPeriodicTensorRole.DOWN: GRID_PERIODIC_DOWN_OPERAND_ID, GridPeriodicTensorRole.LEFT: GRID_PERIODIC_LEFT_OPERAND_ID, } -GRID_PERIODIC_RESERVED_OPERAND_IDS = frozenset( - GRID_PERIODIC_RESERVED_OPERAND_ID_BY_ROLE.values() -) @dataclass(slots=True, frozen=True) @@ -319,11 +316,6 @@ def grid_periodic_reserved_operand_id_for_role( return GRID_PERIODIC_RESERVED_OPERAND_ID_BY_ROLE[role] -def is_grid_periodic_reserved_operand_id(operand_id: str) -> bool: - """Return ``True`` when ``operand_id`` is a reserved 2D boundary operand.""" - return operand_id in GRID_PERIODIC_RESERVED_OPERAND_IDS - - def _analysis_tensor_id(tensor: TensorSpec) -> str: """Return the analysis operand id for a grid-periodic cell tensor.""" if tensor.grid_periodic_role is None: diff --git a/src/tensor_network_editor/internal/modes/_tree_periodic.py b/src/tensor_network_editor/internal/modes/_tree_periodic.py index ed7aa41..845217a 100644 --- a/src/tensor_network_editor/internal/modes/_tree_periodic.py +++ b/src/tensor_network_editor/internal/modes/_tree_periodic.py @@ -329,14 +329,6 @@ def tree_periodic_reserved_operand_id_for_tensor(tensor: TensorSpec) -> str | No return None -def is_tree_periodic_reserved_operand_id(operand_id: str) -> bool: - """Return ``True`` when ``operand_id`` is a reserved tree boundary operand.""" - return operand_id == TREE_PERIODIC_PARENT_OPERAND_ID or ( - operand_id.startswith(TREE_PERIODIC_CHILD_OPERAND_ID_PREFIX) - and operand_id.endswith("__") - ) - - def _analysis_tensor_id(tensor: TensorSpec) -> str: """Return the analysis operand id for a tree-periodic cell tensor.""" return tree_periodic_reserved_operand_id_for_tensor(tensor) or tensor.id diff --git a/src/tensor_network_editor/internal/templates/_template_catalog.py b/src/tensor_network_editor/internal/templates/_template_catalog.py index 9887bb7..3ecf795 100644 --- a/src/tensor_network_editor/internal/templates/_template_catalog.py +++ b/src/tensor_network_editor/internal/templates/_template_catalog.py @@ -330,144 +330,113 @@ def _serialize_template_parameter_payload( return defaults, minimums +_MPS_DEFAULTS = TemplateParameters( + graph_size=4, + bond_dimension=3, + physical_dimension=2, +) +_MPO_DEFAULTS = TemplateParameters( + graph_size=4, + bond_dimension=3, + physical_dimension=2, + boundary_condition="open", + j=1.0, + h=1.0, +) +_PEPS_DEFAULTS = TemplateParameters( + graph_size=3, + bond_dimension=3, + physical_dimension=2, +) +_MERA_DEFAULTS = TemplateParameters( + graph_size=3, + bond_dimension=3, + physical_dimension=2, +) +_TTN_DEFAULTS = TemplateParameters( + depth=3, + bond_dimension=3, + physical_dimension=2, + leaf_physical_legs=True, + root_open_leg=False, + isometric=False, +) +_PEPO_DEFAULTS = TemplateParameters( + graph_size=3, + bond_dimension=3, + physical_dimension=2, +) +_TEBD_GATE_LAYER_DEFAULTS = TemplateParameters( + graph_size=4, + bond_dimension=3, + physical_dimension=2, +) + + TEMPLATE_DEFINITIONS: dict[str, TemplateDefinition] = { "mps": TemplateDefinition( name="mps", display_name="MPS", graph_size_label="Sites", - defaults=TemplateParameters( - graph_size=4, - bond_dimension=3, - physical_dimension=2, - ), - parameter_fields=_build_mps_parameter_fields( - TemplateParameters( - graph_size=4, - bond_dimension=3, - physical_dimension=2, - ) - ), + defaults=_MPS_DEFAULTS, + parameter_fields=_build_mps_parameter_fields(_MPS_DEFAULTS), ), "mpo": TemplateDefinition( name="mpo", display_name="MPO", graph_size_label="Sites", - defaults=TemplateParameters( - graph_size=4, - bond_dimension=3, - physical_dimension=2, - boundary_condition="open", - j=1.0, - h=1.0, - ), - parameter_fields=_build_mpo_parameter_fields( - TemplateParameters( - graph_size=4, - bond_dimension=3, - physical_dimension=2, - boundary_condition="open", - j=1.0, - h=1.0, - ) - ), + defaults=_MPO_DEFAULTS, + parameter_fields=_build_mpo_parameter_fields(_MPO_DEFAULTS), ), "peps_2x2": TemplateDefinition( name="peps_2x2", display_name="PEPS", graph_size_label="Side length", - defaults=TemplateParameters( - graph_size=3, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_PEPS_DEFAULTS, parameter_fields=_build_standard_parameter_fields( size_field_name="graph_size", size_field_label="Graph size (Side length)", - defaults=TemplateParameters( - graph_size=3, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_PEPS_DEFAULTS, ), ), "mera": TemplateDefinition( name="mera", display_name="MERA", graph_size_label="Depth", - defaults=TemplateParameters( - graph_size=3, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_MERA_DEFAULTS, parameter_fields=_build_standard_parameter_fields( size_field_name="graph_size", size_field_label="Graph size (Depth)", - defaults=TemplateParameters( - graph_size=3, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_MERA_DEFAULTS, ), ), "ttn": TemplateDefinition( name="ttn", display_name="TTN", graph_size_label="Depth", - defaults=TemplateParameters( - depth=3, - bond_dimension=3, - physical_dimension=2, - leaf_physical_legs=True, - root_open_leg=False, - isometric=False, - ), - parameter_fields=_build_ttn_parameter_fields( - TemplateParameters( - depth=3, - bond_dimension=3, - physical_dimension=2, - leaf_physical_legs=True, - root_open_leg=False, - isometric=False, - ) - ), + defaults=_TTN_DEFAULTS, + parameter_fields=_build_ttn_parameter_fields(_TTN_DEFAULTS), ), "pepo": TemplateDefinition( name="pepo", display_name="PEPO", graph_size_label="Side length", - defaults=TemplateParameters( - graph_size=3, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_PEPO_DEFAULTS, parameter_fields=_build_standard_parameter_fields( size_field_name="graph_size", size_field_label="Graph size (Side length)", - defaults=TemplateParameters( - graph_size=3, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_PEPO_DEFAULTS, ), ), "tebd_gate_layer": TemplateDefinition( name="tebd_gate_layer", display_name="TEBD Gate Layer", graph_size_label="Sites", - defaults=TemplateParameters( - graph_size=4, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_TEBD_GATE_LAYER_DEFAULTS, parameter_fields=_build_standard_parameter_fields( size_field_name="graph_size", size_field_label="Graph size (Sites)", - defaults=TemplateParameters( - graph_size=4, - bond_dimension=3, - physical_dimension=2, - ), + defaults=_TEBD_GATE_LAYER_DEFAULTS, ), ), } diff --git a/src/tensor_network_editor/rendering.py b/src/tensor_network_editor/rendering.py index 7e4adb3..de5f2c0 100644 --- a/src/tensor_network_editor/rendering.py +++ b/src/tensor_network_editor/rendering.py @@ -25,7 +25,6 @@ from .types import StrPath from .validation import ensure_valid_spec -_INDEX_RADIUS = 10.0 _GROUP_PADDING = 28.0 _NOTE_WIDTH = 210.0 _NOTE_HEIGHT = 82.0 @@ -2356,27 +2355,6 @@ def _point_to_segment_distance( return hypot(point.x - closest_point.x, point.y - closest_point.y) -def _sample_quadratic_points( - source: CanvasPosition, - control: CanvasPosition, - target: CanvasPosition, - *, - segment_count: int = 24, -) -> list[CanvasPosition]: - return [ - CanvasPosition( - x=((1 - t) ** 2) * source.x - + 2 * (1 - t) * t * control.x - + (t**2) * target.x, - y=((1 - t) ** 2) * source.y - + 2 * (1 - t) * t * control.y - + (t**2) * target.y, - ) - for step in range(segment_count + 1) - for t in [step / segment_count] - ] - - def _wrap_text(text: str, *, max_chars: int) -> list[str]: words = text.split() if not words: From 8047c4223b7c0d631011fb961f42363d3f88aa20 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 16:58:02 +0200 Subject: [PATCH 07/23] Clean up periodic codegen internals --- CHANGELOG.md | 3 + .../modes/_grid_periodic/array_einsum.py | 63 ++++------- .../modes/_grid_periodic/array_quimb.py | 70 +++++------- .../modes/_grid_periodic/array_shared.py | 83 +++++++++++++++ .../codegen/modes/_periodic_codegen.py | 57 ++++++++++ .../modes/_tree_periodic/array_einsum.py | 80 ++++---------- .../modes/_tree_periodic/array_quimb.py | 84 +++++---------- .../modes/_tree_periodic/array_shared.py | 100 ++++++++++++++++++ .../codegen/modes/grid_periodic.py | 58 +++++----- .../codegen/modes/linear_periodic.py | 63 +++++------ .../codegen/modes/tree_periodic.py | 58 +++++----- tests/codegen/test_common.py | 70 ++++++++++++ tests/codegen/test_grid_periodic_internals.py | 37 ++++++- tests/codegen/test_tree_periodic_internals.py | 33 ++++++ 14 files changed, 564 insertions(+), 295 deletions(-) create mode 100644 src/tensor_network_editor/codegen/modes/_grid_periodic/array_shared.py create mode 100644 src/tensor_network_editor/codegen/modes/_periodic_codegen.py create mode 100644 src/tensor_network_editor/codegen/modes/_tree_periodic/array_shared.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d22cc68..5b11e2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ All notable changes to this project will be documented in this file. rendering, and einsum code generation, and deduplicated built-in template defaults so the catalog now keeps each template's default parameters in one shared definition. +- Periodic code generation now routes linear, grid, and tree modes through one + shared internal dispatcher, and the grid/tree array helpers reuse shared cell + preparation utilities instead of repeating the same setup in each backend. ## [0.5.0] - 2026-04-30 diff --git a/src/tensor_network_editor/codegen/modes/_grid_periodic/array_einsum.py b/src/tensor_network_editor/codegen/modes/_grid_periodic/array_einsum.py index b4fcb9a..2ca7d29 100644 --- a/src/tensor_network_editor/codegen/modes/_grid_periodic/array_einsum.py +++ b/src/tensor_network_editor/codegen/modes/_grid_periodic/array_einsum.py @@ -2,10 +2,7 @@ from __future__ import annotations -from ....internal.modes._grid_periodic import ( - GridPeriodicInterfacePort, - build_internal_grid_periodic_cell_network, -) +from ....internal.modes._grid_periodic import GridPeriodicInterfacePort from ....models import ( EngineName, GridPeriodicCellName, @@ -17,21 +14,19 @@ from ...shared.common import ( CodeSection, PreparedNetwork, - container_name_for_format, - prepare_network, - render_tensor_collection_assignment, - render_tensor_collection_initialization, tensor_collection_reference_by_id, ) from .array_helpers import ( _GRID_CELL_KIND_OFFSET, _build_interface_slot_by_label, _build_local_label_offsets, - _build_ports_by_role, _einsum_interface_expression, _runtime_cell_coordinate_expressions, ) -from .common import _cell_from_grid +from .array_shared import ( + build_grid_array_cell_context, + render_grid_array_tensor_sections, +) from .shared import _RenderedCellHelper, render_grid_periodic_helper @@ -45,22 +40,15 @@ def _render_einsum_cell_helper( collection_format: TensorCollectionFormat, ) -> _RenderedCellHelper: """Render one grid cell helper for an einsum backend.""" - cell = _cell_from_grid(grid, cell_name) - internal_spec = build_internal_grid_periodic_cell_network( - cell, + context = build_grid_array_cell_context( + grid=grid, cell_name=cell_name, - include_contraction_plan=False, + collection_format=collection_format, ) - prepared = prepare_network(internal_spec) - collection_name = container_name_for_format(collection_format) - ports_by_role = _build_ports_by_role(cell=cell, cell_name=cell_name) - interface_index_ids = { - port.internal_index_id for ports in ports_by_role.values() for port in ports - } label_expression_by_label = _build_einsum_label_expression_map( - prepared=prepared, + prepared=context.prepared, cell_name=cell_name, - ports_by_role=ports_by_role, + ports_by_role=context.ports_by_role, ) module_alias = "np" if engine is EngineName.EINSUM_NUMPY else "torch" zero_suffix = ( @@ -68,29 +56,24 @@ def _render_einsum_cell_helper( if engine is EngineName.EINSUM_TORCH else ", dtype=float)" ) - tensor_collection_lines = render_tensor_collection_initialization( - collection_name, - collection_format, - ) - tensor_construction_lines = render_tensor_collection_assignment( - collection_name=collection_name, - collection_format=collection_format, - prepared=prepared, - tensor_value_by_id={ - tensor.spec.id: f"{module_alias}.zeros({tensor.spec.shape!r}{zero_suffix}" - for tensor in prepared.tensors - }, - include_initialization=False, + tensor_collection_lines, tensor_construction_lines = ( + render_grid_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: f"{module_alias}.zeros({tensor.spec.shape!r}{zero_suffix}" + for tensor in context.prepared.tensors + }, + ) ) output_lines = ["cell_operands = []", "cell_operand_labels = []"] - for tensor in prepared.tensors: + for tensor in context.prepared.tensors: output_lines.append( "cell_operands.append(" + tensor_collection_reference_by_id( - prepared, + context.prepared, tensor.spec.id, collection_format, - collection_name, + context.collection_name, ) + ")" ) @@ -107,8 +90,8 @@ def _render_einsum_cell_helper( + _render_python_list_expression( [ label_expression_by_label[index.label] - for index in prepared.open_indices - if index.spec.id not in interface_index_ids + for index in context.prepared.open_indices + if index.spec.id not in context.interface_index_ids ] ), "return {", diff --git a/src/tensor_network_editor/codegen/modes/_grid_periodic/array_quimb.py b/src/tensor_network_editor/codegen/modes/_grid_periodic/array_quimb.py index 17cb37a..b434fc1 100644 --- a/src/tensor_network_editor/codegen/modes/_grid_periodic/array_quimb.py +++ b/src/tensor_network_editor/codegen/modes/_grid_periodic/array_quimb.py @@ -2,10 +2,7 @@ from __future__ import annotations -from ....internal.modes._grid_periodic import ( - GridPeriodicInterfacePort, - build_internal_grid_periodic_cell_network, -) +from ....internal.modes._grid_periodic import GridPeriodicInterfacePort from ....models import ( GridPeriodicCellName, GridPeriodicGridSpec, @@ -16,20 +13,18 @@ from ...shared.common import ( CodeSection, PreparedNetwork, - container_name_for_format, flattened_tensor_collection_expression, - prepare_network, - render_tensor_collection_assignment, - render_tensor_collection_initialization, ) from .array_helpers import ( _build_interface_slot_by_label, _build_local_label_offsets, - _build_ports_by_role, _quimb_interface_expression, _runtime_cell_coordinate_expressions, ) -from .common import _cell_from_grid +from .array_shared import ( + build_grid_array_cell_context, + render_grid_array_tensor_sections, +) from .shared import _RenderedCellHelper, render_grid_periodic_helper @@ -42,50 +37,41 @@ def _render_quimb_cell_helper( collection_format: TensorCollectionFormat, ) -> _RenderedCellHelper: """Render one grid cell helper for the ``quimb`` backend.""" - cell = _cell_from_grid(grid, cell_name) - internal_spec = build_internal_grid_periodic_cell_network( - cell, + context = build_grid_array_cell_context( + grid=grid, cell_name=cell_name, - include_contraction_plan=False, + collection_format=collection_format, ) - prepared = prepare_network(internal_spec) - collection_name = container_name_for_format(collection_format) - ports_by_role = _build_ports_by_role(cell=cell, cell_name=cell_name) - interface_index_ids = { - port.internal_index_id for ports in ports_by_role.values() for port in ports - } label_expression_by_label = _build_quimb_label_expression_map( - prepared=prepared, + prepared=context.prepared, cell_name=cell_name, - ports_by_role=ports_by_role, - ) - tensor_collection_lines = render_tensor_collection_initialization( - collection_name, - collection_format, + ports_by_role=context.ports_by_role, ) - tensor_construction_lines = render_tensor_collection_assignment( - collection_name=collection_name, - collection_format=collection_format, - prepared=prepared, - tensor_value_by_id={ - tensor.spec.id: ( - f"qtn.Tensor(data=np.zeros({tensor.spec.shape!r}, dtype=float), " - f"inds={_render_python_tuple_expression([label_expression_by_label[index.label] for index in tensor.indices])}, " - f"tags={(tensor.spec.name,)!r})" - ) - for tensor in prepared.tensors - }, - include_initialization=False, + tensor_collection_lines, tensor_construction_lines = ( + render_grid_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: ( + f"qtn.Tensor(data=np.zeros({tensor.spec.shape!r}, dtype=float), " + f"inds={_render_python_tuple_expression([label_expression_by_label[index.label] for index in tensor.indices])}, " + f"tags={(tensor.spec.name,)!r})" + ) + for tensor in context.prepared.tensors + }, + ) ) output_lines = [ "cell_tensors = " - + flattened_tensor_collection_expression(collection_format, collection_name), + + flattened_tensor_collection_expression( + collection_format, + context.collection_name, + ), "open_inds = " + _render_python_tuple_expression( [ label_expression_by_label[index.label] - for index in prepared.open_indices - if index.spec.id not in interface_index_ids + for index in context.prepared.open_indices + if index.spec.id not in context.interface_index_ids ] ), "return {", diff --git a/src/tensor_network_editor/codegen/modes/_grid_periodic/array_shared.py b/src/tensor_network_editor/codegen/modes/_grid_periodic/array_shared.py new file mode 100644 index 0000000..3a64f74 --- /dev/null +++ b/src/tensor_network_editor/codegen/modes/_grid_periodic/array_shared.py @@ -0,0 +1,83 @@ +"""Shared cell-preparation helpers for grid-periodic array codegen.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from ....internal.modes._grid_periodic import ( + GridPeriodicInterfacePort, + build_internal_grid_periodic_cell_network, +) +from ....models import ( + GridPeriodicCellName, + GridPeriodicGridSpec, + GridPeriodicTensorRole, + TensorCollectionFormat, +) +from ...shared.common import ( + PreparedNetwork, + container_name_for_format, + prepare_network, + render_tensor_collection_assignment, + render_tensor_collection_initialization, +) +from .array_helpers import _build_ports_by_role +from .common import _cell_from_grid + + +@dataclass(slots=True, frozen=True) +class GridArrayCellContext: + """Prepared render context shared by grid-periodic array helper builders.""" + + prepared: PreparedNetwork + collection_format: TensorCollectionFormat + collection_name: str + ports_by_role: dict[GridPeriodicTensorRole, tuple[GridPeriodicInterfacePort, ...]] + interface_index_ids: frozenset[str] + + +def build_grid_array_cell_context( + *, + grid: GridPeriodicGridSpec, + cell_name: GridPeriodicCellName, + collection_format: TensorCollectionFormat, +) -> GridArrayCellContext: + """Build the shared prepared context for one array-backed grid cell helper.""" + cell = _cell_from_grid(grid, cell_name) + prepared = prepare_network( + build_internal_grid_periodic_cell_network( + cell, + cell_name=cell_name, + include_contraction_plan=False, + ) + ) + ports_by_role = _build_ports_by_role(cell=cell, cell_name=cell_name) + return GridArrayCellContext( + prepared=prepared, + collection_format=collection_format, + collection_name=container_name_for_format(collection_format), + ports_by_role=ports_by_role, + interface_index_ids=frozenset( + port.internal_index_id for ports in ports_by_role.values() for port in ports + ), + ) + + +def render_grid_array_tensor_sections( + *, + context: GridArrayCellContext, + tensor_value_by_id: dict[str, str], +) -> tuple[list[str], list[str]]: + """Render the shared tensor collection sections for one grid cell helper.""" + tensor_collection_lines = render_tensor_collection_initialization( + context.collection_name, + context.collection_format, + ) + tensor_construction_lines = render_tensor_collection_assignment( + collection_name=context.collection_name, + collection_format=context.collection_format, + prepared=context.prepared, + tensor_value_by_id=tensor_value_by_id, + include_initialization=False, + ) + return tensor_collection_lines, tensor_construction_lines diff --git a/src/tensor_network_editor/codegen/modes/_periodic_codegen.py b/src/tensor_network_editor/codegen/modes/_periodic_codegen.py new file mode 100644 index 0000000..d78cf97 --- /dev/null +++ b/src/tensor_network_editor/codegen/modes/_periodic_codegen.py @@ -0,0 +1,57 @@ +"""Shared dispatch helpers for periodic code-generation entrypoints.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TypeVar + +from ...errors import CodeGenerationError +from ...models import CodegenResult, EngineName, NetworkSpec +from ..shared.roundtrip import with_roundtrip_spec_marker + +_PeriodicPayloadT = TypeVar("_PeriodicPayloadT") +_ARRAY_ENGINES: frozenset[EngineName] = frozenset( + { + EngineName.QUIMB, + EngineName.EINSUM_NUMPY, + EngineName.EINSUM_TORCH, + } +) +_GRAPH_ENGINES: frozenset[EngineName] = frozenset( + { + EngineName.TENSORNETWORK, + EngineName.TENSORKROWCH, + } +) + + +def dispatch_periodic_codegen( + *, + spec: NetworkSpec, + payload: _PeriodicPayloadT | None, + missing_payload_message: str, + unsupported_backend_label: str, + engine: EngineName, + include_roundtrip_metadata: bool, + array_renderer: Callable[[_PeriodicPayloadT], CodegenResult], + graph_renderer: Callable[[_PeriodicPayloadT], CodegenResult], +) -> CodegenResult: + """Route periodic code generation to the array or graph backend family.""" + if payload is None: + raise CodeGenerationError(missing_payload_message) + + if engine in _ARRAY_ENGINES: + result = array_renderer(payload) + elif engine in _GRAPH_ENGINES: + result = graph_renderer(payload) + else: + raise CodeGenerationError( + f"The {engine.value} backend does not support " + f"{unsupported_backend_label} code generation." + ) + + return ( + with_roundtrip_spec_marker(result, spec=spec) + if include_roundtrip_metadata + else result + ) diff --git a/src/tensor_network_editor/codegen/modes/_tree_periodic/array_einsum.py b/src/tensor_network_editor/codegen/modes/_tree_periodic/array_einsum.py index 679a9cb..759bcf2 100644 --- a/src/tensor_network_editor/codegen/modes/_tree_periodic/array_einsum.py +++ b/src/tensor_network_editor/codegen/modes/_tree_periodic/array_einsum.py @@ -2,35 +2,26 @@ from __future__ import annotations -from ....internal.modes._tree_periodic import ( - build_internal_tree_periodic_cell_network, - build_tree_periodic_interface_ports, -) from ....models import ( EngineName, TensorCollectionFormat, TreePeriodicCellName, - TreePeriodicTensorRole, TreePeriodicTreeSpec, ) from ...shared._linear_periodic_expressions import _render_python_list_expression from ...shared.common import ( CodeSection, - container_name_for_format, flattened_tensor_collection_expression, - prepare_network, - render_tensor_collection_assignment, - render_tensor_collection_initialization, ) from .array_helpers import ( _build_einsum_label_expression_map, _render_child_interface_lines, ) -from .common import ( - _build_child_ports_by_index, - _cell_from_tree, - _render_parent_interface_validation, +from .array_shared import ( + build_tree_array_cell_context, + render_tree_array_tensor_sections, ) +from .common import _render_parent_interface_validation from .shared import _RenderedTreeCellHelper, render_tree_periodic_helper @@ -44,63 +35,38 @@ def _render_einsum_cell_helper( collection_format: TensorCollectionFormat, ) -> _RenderedTreeCellHelper: """Render one tree cell helper for an einsum backend.""" - cell = _cell_from_tree(tree, cell_name) - prepared = prepare_network( - build_internal_tree_periodic_cell_network( - cell, - cell_name=cell_name, - include_contraction_plan=False, - ) - ) - collection_name = container_name_for_format(collection_format) - parent_ports = build_tree_periodic_interface_ports( - cell, - cell_name=cell_name, - role=TreePeriodicTensorRole.PARENT, - ) - child_ports_by_index = _build_child_ports_by_index( + context = build_tree_array_cell_context( tree=tree, - cell=cell, cell_name=cell_name, + collection_format=collection_format, ) label_expression_by_label = _build_einsum_label_expression_map( - prepared=prepared, + prepared=context.prepared, cell_name=cell_name, - parent_ports=parent_ports, - child_ports_by_index=child_ports_by_index, + parent_ports=context.parent_ports, + child_ports_by_index=context.child_ports_by_index, ) - interface_index_ids = {port.internal_index_id for port in parent_ports} | { - port.internal_index_id - for ports in child_ports_by_index.values() - for port in ports - } module_alias = "np" if engine is EngineName.EINSUM_NUMPY else "torch" zero_suffix = ( ", dtype=torch.float32)" if engine is EngineName.EINSUM_TORCH else ", dtype=float)" ) - tensor_collection_lines = render_tensor_collection_initialization( - collection_name, - collection_format, - ) - tensor_value_by_id = { - tensor.spec.id: f"{module_alias}.zeros({tensor.spec.shape!r}{zero_suffix}" - for tensor in prepared.tensors - } - tensor_construction_lines = render_tensor_collection_assignment( - collection_name=collection_name, - collection_format=collection_format, - prepared=prepared, - tensor_value_by_id=tensor_value_by_id, - include_initialization=False, + tensor_collection_lines, tensor_construction_lines = ( + render_tree_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: f"{module_alias}.zeros({tensor.spec.shape!r}{zero_suffix}" + for tensor in context.prepared.tensors + }, + ) ) - output_lines = _render_parent_interface_validation(parent_ports) + output_lines = _render_parent_interface_validation(context.parent_ports) output_lines.extend( [ "cell_operands = " + flattened_tensor_collection_expression( - collection_format, collection_name + collection_format, context.collection_name ), "cell_operand_labels = [", *[ @@ -109,19 +75,19 @@ def _render_einsum_cell_helper( [label_expression_by_label[index.label] for index in tensor.indices] ) + "," - for tensor in prepared.tensors + for tensor in context.prepared.tensors ], "]", *_render_child_interface_lines( cell_name=cell_name, - child_ports_by_index=child_ports_by_index, + child_ports_by_index=context.child_ports_by_index, ), "open_labels = " + _render_python_list_expression( [ label_expression_by_label[index.label] - for index in prepared.open_indices - if index.spec.id not in interface_index_ids + for index in context.prepared.open_indices + if index.spec.id not in context.interface_index_ids ] ), "return {", diff --git a/src/tensor_network_editor/codegen/modes/_tree_periodic/array_quimb.py b/src/tensor_network_editor/codegen/modes/_tree_periodic/array_quimb.py index 62f8ca9..8633ce6 100644 --- a/src/tensor_network_editor/codegen/modes/_tree_periodic/array_quimb.py +++ b/src/tensor_network_editor/codegen/modes/_tree_periodic/array_quimb.py @@ -2,34 +2,25 @@ from __future__ import annotations -from ....internal.modes._tree_periodic import ( - build_internal_tree_periodic_cell_network, - build_tree_periodic_interface_ports, -) from ....models import ( TensorCollectionFormat, TreePeriodicCellName, - TreePeriodicTensorRole, TreePeriodicTreeSpec, ) from ...shared._linear_periodic_expressions import _render_python_tuple_expression from ...shared.common import ( CodeSection, - container_name_for_format, flattened_tensor_collection_expression, - prepare_network, - render_tensor_collection_assignment, - render_tensor_collection_initialization, ) from .array_helpers import ( _build_quimb_label_expression_map, _render_child_interface_lines, ) -from .common import ( - _build_child_ports_by_index, - _cell_from_tree, - _render_parent_interface_validation, +from .array_shared import ( + build_tree_array_cell_context, + render_tree_array_tensor_sections, ) +from .common import _render_parent_interface_validation from .shared import _RenderedTreeCellHelper, render_tree_periodic_helper @@ -42,72 +33,47 @@ def _render_quimb_cell_helper( collection_format: TensorCollectionFormat, ) -> _RenderedTreeCellHelper: """Render one tree cell helper for the ``quimb`` backend.""" - cell = _cell_from_tree(tree, cell_name) - prepared = prepare_network( - build_internal_tree_periodic_cell_network( - cell, - cell_name=cell_name, - include_contraction_plan=False, - ) - ) - collection_name = container_name_for_format(collection_format) - parent_ports = build_tree_periodic_interface_ports( - cell, - cell_name=cell_name, - role=TreePeriodicTensorRole.PARENT, - ) - child_ports_by_index = _build_child_ports_by_index( + context = build_tree_array_cell_context( tree=tree, - cell=cell, cell_name=cell_name, + collection_format=collection_format, ) label_expression_by_label = _build_quimb_label_expression_map( - prepared=prepared, + prepared=context.prepared, cell_name=cell_name, - parent_ports=parent_ports, - child_ports_by_index=child_ports_by_index, + parent_ports=context.parent_ports, + child_ports_by_index=context.child_ports_by_index, ) - interface_index_ids = {port.internal_index_id for port in parent_ports} | { - port.internal_index_id - for ports in child_ports_by_index.values() - for port in ports - } - tensor_collection_lines = render_tensor_collection_initialization( - collection_name, - collection_format, - ) - tensor_value_by_id = { - tensor.spec.id: ( - f"qtn.Tensor(data=np.zeros({tensor.spec.shape!r}, dtype=float), " - f"inds={_render_python_tuple_expression([label_expression_by_label[index.label] for index in tensor.indices])}, " - f"tags={(tensor.spec.name, tensor.spec.id)!r})" + tensor_collection_lines, tensor_construction_lines = ( + render_tree_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: ( + f"qtn.Tensor(data=np.zeros({tensor.spec.shape!r}, dtype=float), " + f"inds={_render_python_tuple_expression([label_expression_by_label[index.label] for index in tensor.indices])}, " + f"tags={(tensor.spec.name, tensor.spec.id)!r})" + ) + for tensor in context.prepared.tensors + }, ) - for tensor in prepared.tensors - } - tensor_construction_lines = render_tensor_collection_assignment( - collection_name=collection_name, - collection_format=collection_format, - prepared=prepared, - tensor_value_by_id=tensor_value_by_id, - include_initialization=False, ) - output_lines = _render_parent_interface_validation(parent_ports) + output_lines = _render_parent_interface_validation(context.parent_ports) output_lines.extend( [ "network_tensors = " + flattened_tensor_collection_expression( - collection_format, collection_name + collection_format, context.collection_name ), *_render_child_interface_lines( cell_name=cell_name, - child_ports_by_index=child_ports_by_index, + child_ports_by_index=context.child_ports_by_index, ), "open_inds = " + _render_python_tuple_expression( [ label_expression_by_label[index.label] - for index in prepared.open_indices - if index.spec.id not in interface_index_ids + for index in context.prepared.open_indices + if index.spec.id not in context.interface_index_ids ] ), "return {", diff --git a/src/tensor_network_editor/codegen/modes/_tree_periodic/array_shared.py b/src/tensor_network_editor/codegen/modes/_tree_periodic/array_shared.py new file mode 100644 index 0000000..3c473d5 --- /dev/null +++ b/src/tensor_network_editor/codegen/modes/_tree_periodic/array_shared.py @@ -0,0 +1,100 @@ +"""Shared cell-preparation helpers for tree-periodic array codegen.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from ....internal.modes._tree_periodic import ( + TreePeriodicInterfacePort, + build_internal_tree_periodic_cell_network, + build_tree_periodic_interface_ports, +) +from ....models import ( + TensorCollectionFormat, + TreePeriodicCellName, + TreePeriodicTensorRole, + TreePeriodicTreeSpec, +) +from ...shared.common import ( + PreparedNetwork, + container_name_for_format, + prepare_network, + render_tensor_collection_assignment, + render_tensor_collection_initialization, +) +from .common import _build_child_ports_by_index, _cell_from_tree + + +@dataclass(slots=True, frozen=True) +class TreeArrayCellContext: + """Prepared render context shared by tree-periodic array helper builders.""" + + prepared: PreparedNetwork + collection_format: TensorCollectionFormat + collection_name: str + parent_ports: tuple[TreePeriodicInterfacePort, ...] + child_ports_by_index: dict[int, tuple[TreePeriodicInterfacePort, ...]] + interface_index_ids: frozenset[str] + + +def build_tree_array_cell_context( + *, + tree: TreePeriodicTreeSpec, + cell_name: TreePeriodicCellName, + collection_format: TensorCollectionFormat, +) -> TreeArrayCellContext: + """Build the shared prepared context for one array-backed tree cell helper.""" + cell = _cell_from_tree(tree, cell_name) + prepared = prepare_network( + build_internal_tree_periodic_cell_network( + cell, + cell_name=cell_name, + include_contraction_plan=False, + ) + ) + parent_ports = build_tree_periodic_interface_ports( + cell, + cell_name=cell_name, + role=TreePeriodicTensorRole.PARENT, + ) + child_ports_by_index = _build_child_ports_by_index( + tree=tree, + cell=cell, + cell_name=cell_name, + ) + interface_index_ids = frozenset( + {port.internal_index_id for port in parent_ports} + | { + port.internal_index_id + for ports in child_ports_by_index.values() + for port in ports + } + ) + return TreeArrayCellContext( + prepared=prepared, + collection_format=collection_format, + collection_name=container_name_for_format(collection_format), + parent_ports=parent_ports, + child_ports_by_index=child_ports_by_index, + interface_index_ids=interface_index_ids, + ) + + +def render_tree_array_tensor_sections( + *, + context: TreeArrayCellContext, + tensor_value_by_id: dict[str, str], +) -> tuple[list[str], list[str]]: + """Render the shared tensor collection sections for one tree cell helper.""" + tensor_collection_lines = render_tensor_collection_initialization( + context.collection_name, + context.collection_format, + ) + tensor_construction_lines = render_tensor_collection_assignment( + collection_name=context.collection_name, + collection_format=context.collection_format, + prepared=context.prepared, + tensor_value_by_id=tensor_value_by_id, + include_initialization=False, + ) + return tensor_collection_lines, tensor_construction_lines diff --git a/src/tensor_network_editor/codegen/modes/grid_periodic.py b/src/tensor_network_editor/codegen/modes/grid_periodic.py index 9c780f6..a9d5a8c 100644 --- a/src/tensor_network_editor/codegen/modes/grid_periodic.py +++ b/src/tensor_network_editor/codegen/modes/grid_periodic.py @@ -2,11 +2,16 @@ from __future__ import annotations -from ...errors import CodeGenerationError -from ...models import CodegenResult, EngineName, NetworkSpec, TensorCollectionFormat -from ..shared.roundtrip import with_roundtrip_spec_marker +from ...models import ( + CodegenResult, + EngineName, + GridPeriodicGridSpec, + NetworkSpec, + TensorCollectionFormat, +) from ._grid_periodic_array_renderers import generate_array_grid_periodic_code from ._grid_periodic_graph_renderers import generate_graph_grid_periodic_code +from ._periodic_codegen import dispatch_periodic_codegen def generate_grid_periodic_code( @@ -19,38 +24,29 @@ def generate_grid_periodic_code( ) -> CodegenResult: """Generate helper-based Python code for the bidimensional periodic mode.""" del validate - if spec.grid_periodic_grid is None: - raise CodeGenerationError( - "Grid periodic code generation requires a grid payload." - ) - grid = spec.grid_periodic_grid - if engine in { - EngineName.QUIMB, - EngineName.EINSUM_NUMPY, - EngineName.EINSUM_TORCH, - }: - result = generate_array_grid_periodic_code( - grid=grid, + + def render_array(resolved_grid: GridPeriodicGridSpec) -> CodegenResult: + return generate_array_grid_periodic_code( + grid=resolved_grid, engine=engine, collection_format=collection_format, ) - return ( - with_roundtrip_spec_marker(result, spec=spec) - if include_roundtrip_metadata - else result - ) - if engine not in {EngineName.TENSORNETWORK, EngineName.TENSORKROWCH}: - raise CodeGenerationError( - f"The {engine.value} backend does not support grid periodic code generation." + + def render_graph(resolved_grid: GridPeriodicGridSpec) -> CodegenResult: + return generate_graph_grid_periodic_code( + grid=resolved_grid, + engine=engine, + collection_format=collection_format, ) - result = generate_graph_grid_periodic_code( - grid=grid, + + return dispatch_periodic_codegen( + spec=spec, + payload=grid, + missing_payload_message="Grid periodic code generation requires a grid payload.", + unsupported_backend_label="grid periodic", engine=engine, - collection_format=collection_format, - ) - return ( - with_roundtrip_spec_marker(result, spec=spec) - if include_roundtrip_metadata - else result + include_roundtrip_metadata=include_roundtrip_metadata, + array_renderer=render_array, + graph_renderer=render_graph, ) diff --git a/src/tensor_network_editor/codegen/modes/linear_periodic.py b/src/tensor_network_editor/codegen/modes/linear_periodic.py index 686b8ad..cb143fe 100644 --- a/src/tensor_network_editor/codegen/modes/linear_periodic.py +++ b/src/tensor_network_editor/codegen/modes/linear_periodic.py @@ -2,12 +2,17 @@ from __future__ import annotations -from ...errors import CodeGenerationError from ...internal.modes._linear_periodic import linear_periodic_chain_uses_carry_mode -from ...models import CodegenResult, EngineName, NetworkSpec, TensorCollectionFormat -from ..shared.roundtrip import with_roundtrip_spec_marker +from ...models import ( + CodegenResult, + EngineName, + LinearPeriodicChainSpec, + NetworkSpec, + TensorCollectionFormat, +) from ._linear_periodic_array_renderers import generate_array_linear_periodic_code from ._linear_periodic_graph_renderers import generate_graph_linear_periodic_code +from ._periodic_codegen import dispatch_periodic_codegen def generate_linear_periodic_code( @@ -20,41 +25,31 @@ def generate_linear_periodic_code( ) -> CodegenResult: """Generate helper-based Python code for the linear periodic-chain mode.""" del validate - if spec.linear_periodic_chain is None: - raise CodeGenerationError( - "Linear periodic code generation requires a chain payload." - ) - chain = spec.linear_periodic_chain - uses_carry_mode = linear_periodic_chain_uses_carry_mode(chain) - if engine in { - EngineName.QUIMB, - EngineName.EINSUM_NUMPY, - EngineName.EINSUM_TORCH, - }: - result = generate_array_linear_periodic_code( - chain=chain, + + def render_array(resolved_chain: LinearPeriodicChainSpec) -> CodegenResult: + return generate_array_linear_periodic_code( + chain=resolved_chain, engine=engine, collection_format=collection_format, - uses_carry_mode=uses_carry_mode, + uses_carry_mode=linear_periodic_chain_uses_carry_mode(resolved_chain), ) - return ( - with_roundtrip_spec_marker(result, spec=spec) - if include_roundtrip_metadata - else result - ) - if engine not in {EngineName.TENSORNETWORK, EngineName.TENSORKROWCH}: - raise CodeGenerationError( - f"The {engine.value} backend does not support linear periodic code generation." + + def render_graph(resolved_chain: LinearPeriodicChainSpec) -> CodegenResult: + return generate_graph_linear_periodic_code( + chain=resolved_chain, + engine=engine, + collection_format=collection_format, + uses_carry_mode=linear_periodic_chain_uses_carry_mode(resolved_chain), ) - result = generate_graph_linear_periodic_code( - chain=chain, + + return dispatch_periodic_codegen( + spec=spec, + payload=chain, + missing_payload_message="Linear periodic code generation requires a chain payload.", + unsupported_backend_label="linear periodic", engine=engine, - collection_format=collection_format, - uses_carry_mode=uses_carry_mode, - ) - return ( - with_roundtrip_spec_marker(result, spec=spec) - if include_roundtrip_metadata - else result + include_roundtrip_metadata=include_roundtrip_metadata, + array_renderer=render_array, + graph_renderer=render_graph, ) diff --git a/src/tensor_network_editor/codegen/modes/tree_periodic.py b/src/tensor_network_editor/codegen/modes/tree_periodic.py index 96da512..6a82195 100644 --- a/src/tensor_network_editor/codegen/modes/tree_periodic.py +++ b/src/tensor_network_editor/codegen/modes/tree_periodic.py @@ -2,9 +2,14 @@ from __future__ import annotations -from ...errors import CodeGenerationError -from ...models import CodegenResult, EngineName, NetworkSpec, TensorCollectionFormat -from ..shared.roundtrip import with_roundtrip_spec_marker +from ...models import ( + CodegenResult, + EngineName, + NetworkSpec, + TensorCollectionFormat, + TreePeriodicTreeSpec, +) +from ._periodic_codegen import dispatch_periodic_codegen from ._tree_periodic_array_renderers import generate_array_tree_periodic_code from ._tree_periodic_graph_renderers import generate_graph_tree_periodic_code @@ -19,38 +24,29 @@ def generate_tree_periodic_code( ) -> CodegenResult: """Generate helper-based Python code for the tree periodic mode.""" del validate - if spec.tree_periodic_tree is None: - raise CodeGenerationError( - "Tree periodic code generation requires a tree payload." - ) - tree = spec.tree_periodic_tree - if engine in { - EngineName.QUIMB, - EngineName.EINSUM_NUMPY, - EngineName.EINSUM_TORCH, - }: - result = generate_array_tree_periodic_code( - tree=tree, + + def render_array(resolved_tree: TreePeriodicTreeSpec) -> CodegenResult: + return generate_array_tree_periodic_code( + tree=resolved_tree, engine=engine, collection_format=collection_format, ) - return ( - with_roundtrip_spec_marker(result, spec=spec) - if include_roundtrip_metadata - else result - ) - if engine not in {EngineName.TENSORNETWORK, EngineName.TENSORKROWCH}: - raise CodeGenerationError( - f"The {engine.value} backend does not support tree periodic code generation." + + def render_graph(resolved_tree: TreePeriodicTreeSpec) -> CodegenResult: + return generate_graph_tree_periodic_code( + tree=resolved_tree, + engine=engine, + collection_format=collection_format, ) - result = generate_graph_tree_periodic_code( - tree=tree, + + return dispatch_periodic_codegen( + spec=spec, + payload=tree, + missing_payload_message="Tree periodic code generation requires a tree payload.", + unsupported_backend_label="tree periodic", engine=engine, - collection_format=collection_format, - ) - return ( - with_roundtrip_spec_marker(result, spec=spec) - if include_roundtrip_metadata - else result + include_roundtrip_metadata=include_roundtrip_metadata, + array_renderer=render_array, + graph_renderer=render_graph, ) diff --git a/tests/codegen/test_common.py b/tests/codegen/test_common.py index 5179f30..0dc4e62 100644 --- a/tests/codegen/test_common.py +++ b/tests/codegen/test_common.py @@ -16,6 +16,8 @@ ) from tensor_network_editor.models import ( CanvasPosition, + CodegenResult, + EngineName, NetworkSpec, TensorCollectionFormat, TensorSpec, @@ -158,3 +160,71 @@ def test_render_helper_function_lines_indents_rendered_sections() -> None: ) assert helper_lines == ["def build_cell(slot_index: int) -> dict[str, object]:"] + + +def test_dispatch_periodic_codegen_routes_supported_backends_and_roundtrip() -> None: + from tensor_network_editor.codegen.modes._periodic_codegen import ( + dispatch_periodic_codegen, + ) + + seen_calls: list[tuple[str, str]] = [] + + def render_array(payload: str) -> CodegenResult: + seen_calls.append(("array", payload)) + return CodegenResult(engine=EngineName.EINSUM_NUMPY, code="array_result = 1\n") + + def render_graph(payload: str) -> CodegenResult: + seen_calls.append(("graph", payload)) + return CodegenResult(engine=EngineName.TENSORNETWORK, code="graph_result = 1\n") + + spec = build_three_tensor_hyperedge_spec() + + array_result = dispatch_periodic_codegen( + spec=spec, + payload="array-payload", + missing_payload_message="missing payload", + unsupported_backend_label="periodic", + engine=EngineName.EINSUM_NUMPY, + include_roundtrip_metadata=True, + array_renderer=render_array, + graph_renderer=render_graph, + ) + graph_result = dispatch_periodic_codegen( + spec=spec, + payload="graph-payload", + missing_payload_message="missing payload", + unsupported_backend_label="periodic", + engine=EngineName.TENSORNETWORK, + include_roundtrip_metadata=False, + array_renderer=render_array, + graph_renderer=render_graph, + ) + + assert seen_calls == [("array", "array-payload"), ("graph", "graph-payload")] + assert "# TNE_SPEC_B64:" in array_result.code + assert graph_result.code == "graph_result = 1\n" + + +def test_dispatch_periodic_codegen_rejects_missing_payload() -> None: + from tensor_network_editor.codegen.modes._periodic_codegen import ( + dispatch_periodic_codegen, + ) + from tensor_network_editor.errors import CodeGenerationError + + with pytest.raises(CodeGenerationError, match="grid payload"): + dispatch_periodic_codegen( + spec=NetworkSpec(name="missing payload"), + payload=None, + missing_payload_message="Grid periodic code generation requires a grid payload.", + unsupported_backend_label="grid periodic", + engine=EngineName.EINSUM_NUMPY, + include_roundtrip_metadata=False, + array_renderer=lambda payload: CodegenResult( + engine=EngineName.EINSUM_NUMPY, + code=f"{payload}\n", + ), + graph_renderer=lambda payload: CodegenResult( + engine=EngineName.TENSORNETWORK, + code=f"{payload}\n", + ), + ) diff --git a/tests/codegen/test_grid_periodic_internals.py b/tests/codegen/test_grid_periodic_internals.py index 6b31c71..1856e30 100644 --- a/tests/codegen/test_grid_periodic_internals.py +++ b/tests/codegen/test_grid_periodic_internals.py @@ -1,6 +1,9 @@ from __future__ import annotations -from tensor_network_editor.models import GridPeriodicCellName +from tensor_network_editor.models import ( + GridPeriodicCellName, + TensorCollectionFormat, +) from tests.factories import build_grid_periodic_grid_spec @@ -72,3 +75,35 @@ def test_grid_periodic_internal_helpers_keep_shared_labels_and_main_flow() -> No "result = network_nodes[0] if len(network_nodes) == 1 else None" ] assert "output_labels.extend(bottom_right_cell['open_labels'])" in einsum_main_lines + + +def test_grid_periodic_array_shared_helpers_build_context_and_sections() -> None: + from tensor_network_editor.codegen.modes._grid_periodic.array_shared import ( + build_grid_array_cell_context, + render_grid_array_tensor_sections, + ) + + grid = build_grid_periodic_grid_spec().grid_periodic_grid + assert grid is not None + + context = build_grid_array_cell_context( + grid=grid, + cell_name=GridPeriodicCellName.TOP_LEFT, + collection_format=TensorCollectionFormat.LIST, + ) + tensor_collection_lines, tensor_construction_lines = ( + render_grid_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: f"value_{tensor.variable_name}" + for tensor in context.prepared.tensors + }, + ) + ) + + assert context.collection_name == "tensors" + assert context.prepared.tensors + assert context.interface_index_ids + assert tensor_collection_lines == ["tensors = []"] + assert any(line.startswith("# Tensor ") for line in tensor_construction_lines) + assert any("tensors.append(value_" in line for line in tensor_construction_lines) diff --git a/tests/codegen/test_tree_periodic_internals.py b/tests/codegen/test_tree_periodic_internals.py index 164efbd..384beec 100644 --- a/tests/codegen/test_tree_periodic_internals.py +++ b/tests/codegen/test_tree_periodic_internals.py @@ -124,3 +124,36 @@ def test_tree_periodic_array_helpers_keep_child_interfaces_and_backend_tensor_bu assert "np.zeros(" in numpy_helper_body assert "torch.zeros(" in torch_helper_body assert "np.zeros(" not in torch_helper_body + + +def test_tree_periodic_array_shared_helpers_build_context_and_sections() -> None: + from tensor_network_editor.codegen.modes._tree_periodic.array_shared import ( + build_tree_array_cell_context, + render_tree_array_tensor_sections, + ) + + tree = build_tree_periodic_tree_spec().tree_periodic_tree + assert tree is not None + + context = build_tree_array_cell_context( + tree=tree, + cell_name=TreePeriodicCellName.ROOT, + collection_format=TensorCollectionFormat.LIST, + ) + tensor_collection_lines, tensor_construction_lines = ( + render_tree_array_tensor_sections( + context=context, + tensor_value_by_id={ + tensor.spec.id: f"value_{tensor.variable_name}" + for tensor in context.prepared.tensors + }, + ) + ) + + assert context.collection_name == "tensors" + assert context.parent_ports == () + assert tuple(context.child_ports_by_index) == tuple(range(tree.branching_factor)) + assert context.interface_index_ids + assert tensor_collection_lines == ["tensors = []"] + assert any(line.startswith("# Tensor ") for line in tensor_construction_lines) + assert any("tensors.append(value_" in line for line in tensor_construction_lines) From b2f7949cf88b055bb1781e33947fcc7cb278bef5 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 18:06:34 +0200 Subject: [PATCH 08/23] Refactor render helpers and route assembly --- CHANGELOG.md | 3 + src/tensor_network_editor/app/routes.py | 225 +++++++++++--------- src/tensor_network_editor/rendering.py | 261 +++++++++++++++--------- tests/test_app_routes.py | 71 +++++++ tests/test_rendering.py | 72 +++++++ 5 files changed, 436 insertions(+), 196 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b11e2e..ed365f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ All notable changes to this project will be documented in this file. - Periodic code generation now routes linear, grid, and tree modes through one shared internal dispatcher, and the grid/tree array helpers reuse shared cell preparation utilities instead of repeating the same setup in each backend. +- Static rendering helpers and the `/api/render` route now share more of their + internal option parsing, validation, and response assembly logic instead of + repeating the same flow per export format. ## [0.5.0] - 2026-04-30 diff --git a/src/tensor_network_editor/app/routes.py b/src/tensor_network_editor/app/routes.py index 01254e6..f6ae51c 100644 --- a/src/tensor_network_editor/app/routes.py +++ b/src/tensor_network_editor/app/routes.py @@ -91,6 +91,7 @@ _MAX_FRONTEND_CLIENT_LOG_EVENTS = 200 _MAX_FRONTEND_CLIENT_LOG_MESSAGE_LENGTH = 400 _MAX_FRONTEND_CLIENT_LOG_CONTEXT_VALUE_LENGTH = 200 +_RenderFormat = Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"] @dataclass(slots=True, frozen=True) @@ -102,6 +103,15 @@ class _FrontendClientLogEvent: context: dict[str, object] +@dataclass(slots=True, frozen=True) +class _RenderLabelOptions: + """Shared label-visibility flags for academic render routes.""" + + show_tensor_labels: bool + show_index_labels: bool + show_edge_labels: bool + + def _route_context( session: EditorSession | None, route: str, @@ -327,109 +337,20 @@ def handle_render(session: EditorSession, payload: JsonDict) -> JsonResponse: """Render the current editor payload to an academic text format.""" del session with log_operation( - LOGGER, "Render route", context={"route": "/api/render"} + LOGGER, "Render route", context=_route_context(None, "/api/render") ) as success_context: try: render_format = _resolve_render_format(payload) serialized_spec = require_serialized_spec(payload) spec = deserialize_spec(serialized_spec, validate=False) + label_options = _resolve_render_label_options(payload) success_context["format"] = render_format success_context.update(summarize_spec_counts(spec)) - svg_options = SvgRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), + response_payload = _build_render_response( + render_format, + spec, + label_options, ) - if render_format == "tikz": - text = render_spec_tikz( - spec, - options=TikzRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), - ), - ) - content_type = "text/x-tex;charset=utf-8" - response_payload: JsonDict = { - "format": render_format, - "text": text, - "content_type": content_type, - } - elif render_format == "dot": - text = render_spec_dot( - spec, - options=DotRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), - ), - ) - content_type = "text/vnd.graphviz;charset=utf-8" - response_payload = { - "format": render_format, - "text": text, - "content_type": content_type, - } - elif render_format == "mermaid": - text = render_spec_mermaid( - spec, - options=DotRenderOptions( - show_tensor_labels=require_boolean( - payload, "show_tensor_names", default=True - ), - show_index_labels=require_boolean( - payload, "show_index_names", default=True - ), - show_edge_labels=require_boolean( - payload, "show_bond_names", default=True - ), - ), - ) - content_type = "text/plain;charset=utf-8" - response_payload = { - "format": render_format, - "text": text, - "content_type": content_type, - } - elif render_format == "svg": - text = render_spec_svg(spec, options=svg_options) - response_payload = { - "format": render_format, - "text": text, - "content_type": "image/svg+xml;charset=utf-8", - } - elif render_format == "png": - binary = render_spec_png(spec, options=svg_options) - response_payload = { - "format": render_format, - "base64": base64.b64encode(binary).decode("ascii"), - "content_type": "image/png", - } - else: - binary = render_spec_pdf(spec, options=svg_options) - response_payload = { - "format": render_format, - "base64": base64.b64encode(binary).decode("ascii"), - "content_type": "application/pdf", - } except ValueError as exc: return bad_request_response(str(exc)) except SerializationError as exc: @@ -790,22 +711,126 @@ def _serialize_generate_result(result: CodegenResult) -> JsonDict: def _resolve_render_format( payload: JsonDict, -) -> Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"]: +) -> _RenderFormat: raw_format = payload.get("format") if not isinstance(raw_format, str) or not raw_format.strip(): raise ValueError("Missing 'format' payload.") normalized_format = raw_format.strip().lower() if normalized_format in {"tikz", "dot", "mermaid", "svg", "png", "pdf"}: - return cast( - Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"], - normalized_format, - ) + return cast(_RenderFormat, normalized_format) raise ValueError( "Unsupported render format " f"'{raw_format}'. Expected 'tikz', 'dot', 'mermaid', 'svg', 'png', or 'pdf'." ) +def _resolve_render_label_options(payload: JsonDict) -> _RenderLabelOptions: + """Return shared render-label visibility flags for one request payload.""" + return _RenderLabelOptions( + show_tensor_labels=require_boolean(payload, "show_tensor_names", default=True), + show_index_labels=require_boolean(payload, "show_index_names", default=True), + show_edge_labels=require_boolean(payload, "show_bond_names", default=True), + ) + + +def _svg_render_options(label_options: _RenderLabelOptions) -> SvgRenderOptions: + """Return SVG/PNG/PDF render options derived from shared label flags.""" + return SvgRenderOptions( + show_tensor_labels=label_options.show_tensor_labels, + show_index_labels=label_options.show_index_labels, + show_edge_labels=label_options.show_edge_labels, + ) + + +def _tikz_render_options(label_options: _RenderLabelOptions) -> TikzRenderOptions: + """Return TikZ render options derived from shared label flags.""" + return TikzRenderOptions( + show_tensor_labels=label_options.show_tensor_labels, + show_index_labels=label_options.show_index_labels, + show_edge_labels=label_options.show_edge_labels, + ) + + +def _dot_render_options(label_options: _RenderLabelOptions) -> DotRenderOptions: + """Return DOT/Mermaid render options derived from shared label flags.""" + return DotRenderOptions( + show_tensor_labels=label_options.show_tensor_labels, + show_index_labels=label_options.show_index_labels, + show_edge_labels=label_options.show_edge_labels, + ) + + +def _build_text_render_response( + render_format: _RenderFormat, + text: str, + *, + content_type: str, +) -> JsonDict: + """Return one text-based render response payload.""" + return { + "format": render_format, + "text": text, + "content_type": content_type, + } + + +def _build_binary_render_response( + render_format: _RenderFormat, + binary: bytes, + *, + content_type: str, +) -> JsonDict: + """Return one binary render response payload encoded for JSON transport.""" + return { + "format": render_format, + "base64": base64.b64encode(binary).decode("ascii"), + "content_type": content_type, + } + + +def _build_render_response( + render_format: _RenderFormat, + spec: NetworkSpec, + label_options: _RenderLabelOptions, +) -> JsonDict: + """Return the serialized academic render payload for one format request.""" + if render_format == "tikz": + return _build_text_render_response( + render_format, + render_spec_tikz(spec, options=_tikz_render_options(label_options)), + content_type="text/x-tex;charset=utf-8", + ) + if render_format == "dot": + return _build_text_render_response( + render_format, + render_spec_dot(spec, options=_dot_render_options(label_options)), + content_type="text/vnd.graphviz;charset=utf-8", + ) + if render_format == "mermaid": + return _build_text_render_response( + render_format, + render_spec_mermaid(spec, options=_dot_render_options(label_options)), + content_type="text/plain;charset=utf-8", + ) + if render_format == "svg": + return _build_text_render_response( + render_format, + render_spec_svg(spec, options=_svg_render_options(label_options)), + content_type="image/svg+xml;charset=utf-8", + ) + if render_format == "png": + return _build_binary_render_response( + render_format, + render_spec_png(spec, options=_svg_render_options(label_options)), + content_type="image/png", + ) + return _build_binary_render_response( + render_format, + render_spec_pdf(spec, options=_svg_render_options(label_options)), + content_type="application/pdf", + ) + + def _serialize_complete_result(result: EditorResult) -> JsonDict: """Serialize a complete-route editor result.""" return serialize_editor_result(result) diff --git a/src/tensor_network_editor/rendering.py b/src/tensor_network_editor/rendering.py index de5f2c0..8cae178 100644 --- a/src/tensor_network_editor/rendering.py +++ b/src/tensor_network_editor/rendering.py @@ -3,13 +3,13 @@ from __future__ import annotations import logging -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass from html import escape from importlib import import_module from io import BytesIO from math import ceil, cos, hypot, isfinite, pi, sin -from typing import Any +from typing import Any, TypeVar from xml.sax.saxutils import quoteattr from .internal._logging import log_operation, summarize_spec_counts @@ -61,6 +61,8 @@ } LOGGER = logging.getLogger(__name__) +_RenderOptionsT = TypeVar("_RenderOptionsT") +_RenderedOutputT = TypeVar("_RenderedOutputT", str, bytes) @dataclass(slots=True, frozen=True) @@ -143,25 +145,78 @@ class _RenderedEdge: stroke: str -def render_spec_svg( +def _render_context( spec: NetworkSpec, *, - options: SvgRenderOptions | None = None, - output_path: StrPath | None = None, -) -> str: - """Render one tensor-network specification as a standalone SVG string.""" - context = { - "format": "svg", + format_name: str, + output_path: StrPath | None, +) -> dict[str, object]: + """Return shared logging context for one static render operation.""" + return { + "format": format_name, "output_path": output_path, **summarize_spec_counts(spec), } - with log_operation(LOGGER, "Render spec", context=context): - resolved_options = options or SvgRenderOptions() + + +def _validate_positive_render_scale( + scale: object, + *, + description: str, +) -> float: + """Normalize one render scale after validating positivity and finiteness.""" + if isinstance(scale, bool) or not isinstance(scale, (int, float)): + raise ValueError(f"{description} must be a positive finite number.") + resolved_scale = float(scale) + if not isfinite(resolved_scale) or resolved_scale <= 0: + raise ValueError(f"{description} must be a positive finite number.") + return resolved_scale + + +def _render_spec_output( + spec: NetworkSpec, + *, + format_name: str, + options: _RenderOptionsT, + output_path: StrPath | None, + description: str, + render: Callable[[NetworkSpec, _RenderOptionsT], _RenderedOutputT], + writer: Callable[..., None], +) -> _RenderedOutputT: + """Validate, render, and optionally persist one static export.""" + with log_operation( + LOGGER, + "Render spec", + context=_render_context(spec, format_name=format_name, output_path=output_path), + ): validated_spec = ensure_valid_spec(spec) - svg = _MatplotlibRenderer(validated_spec, resolved_options).render_svg() + rendered = render(validated_spec, options) if output_path is not None: - write_utf8_text(output_path, svg, description="SVG network rendering") - return svg + writer(output_path, rendered, description=description) + return rendered + + +def render_spec_svg( + spec: NetworkSpec, + *, + options: SvgRenderOptions | None = None, + output_path: StrPath | None = None, +) -> str: + """Render one tensor-network specification as a standalone SVG string.""" + resolved_options = options or SvgRenderOptions() + + def _render(validated_spec: NetworkSpec, options: SvgRenderOptions) -> str: + return _MatplotlibRenderer(validated_spec, options).render_svg() + + return _render_spec_output( + spec, + format_name="svg", + options=resolved_options, + output_path=output_path, + description="SVG network rendering", + render=_render, + writer=write_utf8_text, + ) def render_spec_png( @@ -172,26 +227,28 @@ def render_spec_png( output_path: StrPath | None = None, ) -> bytes: """Render one tensor-network specification as PNG bytes using Matplotlib.""" - context = { - "format": "png", - "output_path": output_path, - **summarize_spec_counts(spec), - } - with log_operation(LOGGER, "Render spec", context=context): - if isinstance(scale, bool) or not isinstance(scale, (int, float)): - raise ValueError("PNG render scale must be a positive finite number.") - if not isfinite(float(scale)) or scale <= 0: - raise ValueError("PNG render scale must be a positive finite number.") - resolved_options = options or SvgRenderOptions() - validated_spec = ensure_valid_spec(spec) - png = _MatplotlibRenderer( + resolved_scale = _validate_positive_render_scale( + scale, + description="PNG render scale", + ) + resolved_options = options or SvgRenderOptions() + + def _render(validated_spec: NetworkSpec, options: SvgRenderOptions) -> bytes: + return _MatplotlibRenderer( validated_spec, - resolved_options, - scale=float(scale), + options, + scale=resolved_scale, ).render_png() - if output_path is not None: - write_binary(output_path, png, description="PNG network rendering") - return png + + return _render_spec_output( + spec, + format_name="png", + options=resolved_options, + output_path=output_path, + description="PNG network rendering", + render=_render, + writer=write_binary, + ) def render_spec_pdf( @@ -202,24 +259,28 @@ def render_spec_pdf( output_path: StrPath | None = None, ) -> bytes: """Render one tensor-network specification as PDF bytes using Matplotlib.""" - context = { - "format": "pdf", - "output_path": output_path, - **summarize_spec_counts(spec), - } - with log_operation(LOGGER, "Render spec", context=context): - if isinstance(scale, bool) or not isinstance(scale, (int, float)): - raise ValueError("PDF render scale must be a positive finite number.") - if not isfinite(float(scale)) or scale <= 0: - raise ValueError("PDF render scale must be a positive finite number.") - resolved_options = options or SvgRenderOptions() - validated_spec = ensure_valid_spec(spec) - pdf = _MatplotlibRenderer( - validated_spec, resolved_options, scale=float(scale) + resolved_scale = _validate_positive_render_scale( + scale, + description="PDF render scale", + ) + resolved_options = options or SvgRenderOptions() + + def _render(validated_spec: NetworkSpec, options: SvgRenderOptions) -> bytes: + return _MatplotlibRenderer( + validated_spec, + options, + scale=resolved_scale, ).render_pdf() - if output_path is not None: - write_binary(output_path, pdf, description="PDF network rendering") - return pdf + + return _render_spec_output( + spec, + format_name="pdf", + options=resolved_options, + output_path=output_path, + description="PDF network rendering", + render=_render, + writer=write_binary, + ) def render_spec_tikz( @@ -229,25 +290,33 @@ def render_spec_tikz( output_path: StrPath | None = None, ) -> str: """Render one tensor-network specification as a standalone TikZ picture.""" - context = { - "format": "tikz", - "output_path": output_path, - **summarize_spec_counts(spec), - } - with log_operation(LOGGER, "Render spec", context=context): - resolved_options = options or TikzRenderOptions() - if ( - isinstance(resolved_options.scale, bool) - or not isinstance(resolved_options.scale, (int, float)) - or not isfinite(float(resolved_options.scale)) - or resolved_options.scale <= 0 - ): - raise ValueError("TikZ render scale must be a positive finite number.") - validated_spec = ensure_valid_spec(spec) - tikz = _TikzRenderer(validated_spec, resolved_options).render() - if output_path is not None: - write_utf8_text(output_path, tikz, description="TikZ network rendering") - return tikz + resolved_options = options or TikzRenderOptions() + validated_scale = _validate_positive_render_scale( + resolved_options.scale, + description="TikZ render scale", + ) + resolved_options = TikzRenderOptions( + scale=validated_scale, + global_width=resolved_options.global_width, + show_index_labels=resolved_options.show_index_labels, + show_edge_labels=resolved_options.show_edge_labels, + include_groups=resolved_options.include_groups, + include_notes=resolved_options.include_notes, + show_tensor_labels=resolved_options.show_tensor_labels, + ) + + def _render(validated_spec: NetworkSpec, options: TikzRenderOptions) -> str: + return _TikzRenderer(validated_spec, options).render() + + return _render_spec_output( + spec, + format_name="tikz", + options=resolved_options, + output_path=output_path, + description="TikZ network rendering", + render=_render, + writer=write_utf8_text, + ) def render_spec_dot( @@ -257,20 +326,20 @@ def render_spec_dot( output_path: StrPath | None = None, ) -> str: """Render one tensor-network specification as a Graphviz/DOT graph.""" - context = { - "format": "dot", - "output_path": output_path, - **summarize_spec_counts(spec), - } - with log_operation(LOGGER, "Render spec", context=context): - resolved_options = options or DotRenderOptions() - validated_spec = ensure_valid_spec(spec) - dot = _DotRenderer(validated_spec, resolved_options).render() - if output_path is not None: - write_utf8_text( - output_path, dot, description="Graphviz/DOT network rendering" - ) - return dot + resolved_options = options or DotRenderOptions() + + def _render(validated_spec: NetworkSpec, options: DotRenderOptions) -> str: + return _DotRenderer(validated_spec, options).render() + + return _render_spec_output( + spec, + format_name="dot", + options=resolved_options, + output_path=output_path, + description="Graphviz/DOT network rendering", + render=_render, + writer=write_utf8_text, + ) def render_spec_mermaid( @@ -280,20 +349,20 @@ def render_spec_mermaid( output_path: StrPath | None = None, ) -> str: """Render one tensor-network specification as a Mermaid flowchart.""" - context = { - "format": "mermaid", - "output_path": output_path, - **summarize_spec_counts(spec), - } - with log_operation(LOGGER, "Render spec", context=context): - resolved_options = options or DotRenderOptions() - validated_spec = ensure_valid_spec(spec) - mermaid = _MermaidRenderer(validated_spec, resolved_options).render() - if output_path is not None: - write_utf8_text( - output_path, mermaid, description="Mermaid network rendering" - ) - return mermaid + resolved_options = options or DotRenderOptions() + + def _render(validated_spec: NetworkSpec, options: DotRenderOptions) -> str: + return _MermaidRenderer(validated_spec, options).render() + + return _render_spec_output( + spec, + format_name="mermaid", + options=resolved_options, + output_path=output_path, + description="Mermaid network rendering", + render=_render, + writer=write_utf8_text, + ) class _SvgRenderer: diff --git a/tests/test_app_routes.py b/tests/test_app_routes.py index ed18170..c626672 100644 --- a/tests/test_app_routes.py +++ b/tests/test_app_routes.py @@ -448,6 +448,77 @@ def test_render_route_applies_label_options_to_svg_png_and_pdf( assert options.show_edge_labels is False +def test_resolve_render_label_options_reads_shared_payload_flags() -> None: + import tensor_network_editor.app.routes as routes_module + + options = routes_module._resolve_render_label_options( + { + "show_tensor_names": False, + "show_index_names": True, + "show_bond_names": False, + } + ) + + assert options.show_tensor_labels is False + assert options.show_index_labels is True + assert options.show_edge_labels is False + + +def test_build_render_response_preserves_text_and_binary_payload_shapes() -> None: + import tensor_network_editor.app.routes as routes_module + + spec = build_sample_spec() + label_options = routes_module._RenderLabelOptions( + show_tensor_labels=False, + show_index_labels=False, + show_edge_labels=False, + ) + + with ( + patch.object( + routes_module, + "render_spec_tikz", + return_value=r"\begin{tikzpicture}", + ) as render_tikz_mock, + patch.object( + routes_module, + "render_spec_pdf", + return_value=b"%PDF-1.4", + ) as render_pdf_mock, + ): + tikz_payload = routes_module._build_render_response( + "tikz", + spec, + label_options, + ) + pdf_payload = routes_module._build_render_response( + "pdf", + spec, + label_options, + ) + + tikz_options = render_tikz_mock.call_args.kwargs["options"] + pdf_options = render_pdf_mock.call_args.kwargs["options"] + + assert tikz_payload == { + "format": "tikz", + "text": r"\begin{tikzpicture}", + "content_type": "text/x-tex;charset=utf-8", + } + assert tikz_options.show_tensor_labels is False + assert tikz_options.show_index_labels is False + assert tikz_options.show_edge_labels is False + + assert pdf_payload == { + "format": "pdf", + "base64": base64.b64encode(b"%PDF-1.4").decode("ascii"), + "content_type": "application/pdf", + } + assert pdf_options.show_tensor_labels is False + assert pdf_options.show_index_labels is False + assert pdf_options.show_edge_labels is False + + def test_render_route_rejects_unsupported_academic_format( editor_server: EditorServer, ) -> None: diff --git a/tests/test_rendering.py b/tests/test_rendering.py index dac1776..1db1bd7 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -1014,6 +1014,78 @@ def reject_matplotlib_modules() -> tuple[object, object, object, object]: rendering_module.render_spec_pdf(build_sample_spec()) +def test_validate_positive_render_scale_normalizes_and_rejects_invalid_values() -> None: + import tensor_network_editor.rendering as rendering_module + + assert rendering_module._validate_positive_render_scale( + 2, + description="PNG render scale", + ) == pytest.approx(2.0) + assert rendering_module._validate_positive_render_scale( + 1.5, + description="TikZ render scale", + ) == pytest.approx(1.5) + + for invalid_scale in (True, 0, -1, float("inf"), float("nan"), "2"): + with pytest.raises( + ValueError, + match="PNG render scale must be a positive finite number.", + ): + rendering_module._validate_positive_render_scale( + invalid_scale, + description="PNG render scale", + ) + + +def test_render_spec_output_validates_renders_and_writes_output( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + import tensor_network_editor.rendering as rendering_module + + spec = build_sample_spec() + validated_spec = build_three_tensor_spec() + output_path = tmp_path / "network.svg" + calls: dict[str, Any] = {} + + def fake_validate(received_spec: NetworkSpec) -> NetworkSpec: + calls["validate"] = received_spec + return validated_spec + + def fake_render( + received_spec: NetworkSpec, + received_options: SvgRenderOptions, + ) -> str: + calls["render"] = (received_spec, received_options) + return "" + + def fake_write( + path: Path, + content: str, + *, + description: str, + ) -> None: + calls["write"] = (path, content, description) + + monkeypatch.setattr(rendering_module, "ensure_valid_spec", fake_validate) + options = SvgRenderOptions(show_tensor_labels=False) + + rendered = rendering_module._render_spec_output( + spec, + format_name="svg", + options=options, + output_path=output_path, + description="SVG network rendering", + render=fake_render, + writer=fake_write, + ) + + assert rendered == "" + assert calls["validate"] is spec + assert calls["render"] == (validated_spec, options) + assert calls["write"] == (output_path, "", "SVG network rendering") + + def test_render_spec_png_returns_png_bytes_and_writes_output_path( tmp_path: Path, ) -> None: From 6427657da06584bca0c70a3c40c253f75b1709bd Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 18:17:59 +0200 Subject: [PATCH 09/23] Modularize internal template builders --- CHANGELOG.md | 3 + .../internal/templates/_template_builders.py | 931 +----------------- .../templates/_template_builders_common.py | 132 +++ .../templates/_template_builders_grid.py | 333 +++++++ .../templates/_template_builders_linear.py | 245 +++++ .../templates/_template_builders_tree.py | 277 ++++++ tests/test_template_catalog_internal.py | 64 ++ 7 files changed, 1076 insertions(+), 909 deletions(-) create mode 100644 src/tensor_network_editor/internal/templates/_template_builders_common.py create mode 100644 src/tensor_network_editor/internal/templates/_template_builders_grid.py create mode 100644 src/tensor_network_editor/internal/templates/_template_builders_linear.py create mode 100644 src/tensor_network_editor/internal/templates/_template_builders_tree.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ed365f5..eeef722 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ All notable changes to this project will be documented in this file. - Static rendering helpers and the `/api/render` route now share more of their internal option parsing, validation, and response assembly logic instead of repeating the same flow per export format. +- Internal built-in template builders are now split by family with shared + construction primitives, while the existing template catalog and public + template APIs keep the same behavior and registration order. ## [0.5.0] - 2026-04-30 diff --git a/src/tensor_network_editor/internal/templates/_template_builders.py b/src/tensor_network_editor/internal/templates/_template_builders.py index 5032f59..75ad6d9 100644 --- a/src/tensor_network_editor/internal/templates/_template_builders.py +++ b/src/tensor_network_editor/internal/templates/_template_builders.py @@ -1,22 +1,20 @@ -"""Builders for the package's built-in tensor-network templates.""" +"""Facade and registry entrypoints for internal template builders.""" from __future__ import annotations -from collections.abc import Callable -from typing import cast - -from ...models import ( - CanvasPosition, - EdgeEndpointRef, - EdgeSpec, - IndexSpec, - NetworkSpec, - TensorDataMode, - TensorDataSpec, - TensorSpec, -) +from ...models import NetworkSpec from ...validation import ensure_valid_spec -from ..models._model_tensor_data import TensorNumericLiteral +from ._template_builders_grid import ( + _build_pepo_template, + _build_peps_template, + _build_tebd_gate_layer_template, +) +from ._template_builders_linear import ( + _build_linear_chain_template, + _build_mpo_template, + _build_mps_template, +) +from ._template_builders_tree import _build_mera_template, _build_ttn_template from ._template_catalog import ( TEMPLATE_DEFINITIONS, TemplateParameters, @@ -26,30 +24,10 @@ validate_template_parameters, ) -HORIZONTAL_SPACING = 320.0 -VERTICAL_SPACING = 280.0 -LAYER_SPACING = 210.0 -TREE_LEAF_SPACING = 220.0 -LEFT_OFFSET = (-58.0, 0.0) -RIGHT_OFFSET = (58.0, 0.0) -UP_OFFSET = (0.0, -28.0) -DOWN_OFFSET = (0.0, 28.0) -LOWER_LEFT_OFFSET = (-24.0, 34.0) -LOWER_RIGHT_OFFSET = (24.0, 34.0) -UPPER_PHYSICAL_OFFSET = (-26.0, -54.0) -LOWER_PHYSICAL_OFFSET = (26.0, 42.0) -GATE_UPPER_LEFT_OFFSET = (-36.0, -38.0) -GATE_UPPER_RIGHT_OFFSET = (36.0, -38.0) -GATE_LOWER_LEFT_OFFSET = (-36.0, 38.0) -GATE_LOWER_RIGHT_OFFSET = (36.0, 38.0) -TemplateIndexConfig = tuple[str, int | None, tuple[float, float]] -LinearChainSiteIndexBuilder = Callable[ - [int, int, TemplateParameters], list[TemplateIndexConfig] -] - def build_template( - template_name: str, parameters: TemplateParameters | None = None + template_name: str, + parameters: TemplateParameters | None = None, ) -> NetworkSpec: """Build and validate the named built-in template.""" definition = get_template_definition(template_name) @@ -65,878 +43,6 @@ def build_template( return ensure_valid_spec(builder(resolved_parameters)) -def _build_mps_template(parameters: TemplateParameters) -> NetworkSpec: - """Build an MPS template with the requested site count and dimensions.""" - spec = _build_linear_chain_template( - "mps", - parameters, - tensor_name_prefix="A", - spacing=HORIZONTAL_SPACING, - site_index_builder=_build_mps_site_indices, - periodic=parameters.boundary_condition == "periodic", - ) - return _apply_mps_template_configuration(spec, parameters) - - -def _build_mpo_template(parameters: TemplateParameters) -> NetworkSpec: - """Build an MPO template with the requested site count and dimensions.""" - spec = _build_linear_chain_template( - "mpo", - parameters, - tensor_name_prefix="W", - spacing=330.0, - site_index_builder=_build_mpo_site_indices, - periodic=parameters.boundary_condition == "periodic", - ) - return _apply_mpo_template_configuration(spec, parameters) - - -def _build_linear_chain_template( - template_name: str, - parameters: TemplateParameters, - *, - tensor_name_prefix: str, - spacing: float, - site_index_builder: LinearChainSiteIndexBuilder, - periodic: bool = False, -) -> NetworkSpec: - """Build one left-to-right chain template from a per-site index factory.""" - length = _resolve_graph_size(parameters) - tensors = [ - _make_tensor( - f"tensor_{site_index}", - f"{tensor_name_prefix}{site_index + 1}", - spacing * site_index, - 0.0, - site_index_builder(site_index, length, parameters), - ) - for site_index in range(length) - ] - definition = TEMPLATE_DEFINITIONS[template_name] - spec_name = ( - definition.display_name - if length == definition.defaults.graph_size - else f"{definition.display_name} ({length} {definition.graph_size_label.lower()})" - ) - return NetworkSpec( - id=f"template_{template_name}_{length}", - name=spec_name, - tensors=tensors, - edges=_make_linear_chain_edges(tensors, periodic=periodic), - ) - - -def _build_mps_site_indices( - site_index: int, - length: int, - parameters: TemplateParameters, -) -> list[TemplateIndexConfig]: - """Return the named index layout for one MPS site.""" - tensor_indices: list[TemplateIndexConfig] = [] - if parameters.boundary_condition == "periodic" or site_index > 0: - tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) - if parameters.boundary_condition == "periodic" or site_index < length - 1: - tensor_indices.append(("right", parameters.bond_dimension, RIGHT_OFFSET)) - tensor_indices.append(("phys", parameters.physical_dimension, DOWN_OFFSET)) - return tensor_indices - - -def _build_mpo_site_indices( - site_index: int, - length: int, - parameters: TemplateParameters, -) -> list[TemplateIndexConfig]: - """Return the named index layout for one MPO site.""" - tensor_indices: list[TemplateIndexConfig] = [] - if parameters.boundary_condition == "periodic" or site_index > 0: - tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) - if parameters.boundary_condition == "periodic" or site_index < length - 1: - tensor_indices.append(("right", parameters.bond_dimension, RIGHT_OFFSET)) - tensor_indices.extend( - [ - ("bra", parameters.physical_dimension, UP_OFFSET), - ("ket", parameters.physical_dimension, DOWN_OFFSET), - ] - ) - return tensor_indices - - -def _make_linear_chain_edges( - tensors: list[TensorSpec], - *, - periodic: bool = False, -) -> list[EdgeSpec]: - """Return the right-to-left bonds between adjacent chain tensors.""" - edges = [ - _make_edge( - f"edge_{site_index}_{site_index + 1}", - tensors[site_index], - "right", - tensors[site_index + 1], - "left", - ) - for site_index in range(len(tensors) - 1) - ] - if periodic and len(tensors) > 1: - edges.append( - _make_edge( - f"edge_{len(tensors)}_1", - tensors[-1], - "right", - tensors[0], - "left", - ) - ) - return edges - - -def _build_peps_template(parameters: TemplateParameters) -> NetworkSpec: - """Build the requested PEPS template variant.""" - if _resolve_graph_size(parameters) == 2: - return _build_default_peps_template(parameters) - return _build_generic_peps_template(parameters) - - -def _build_default_peps_template(parameters: TemplateParameters) -> NetworkSpec: - """Build the default 2x2 PEPS layout.""" - tensors = [ - _make_tensor( - "tensor_a", - "A", - 0.0, - 0.0, - [ - ("right", parameters.bond_dimension, RIGHT_OFFSET), - ("down", parameters.bond_dimension, DOWN_OFFSET), - ("phys", parameters.physical_dimension, LOWER_LEFT_OFFSET), - ], - ), - _make_tensor( - "tensor_b", - "B", - 340.0, - 0.0, - [ - ("left", parameters.bond_dimension, LEFT_OFFSET), - ("down", parameters.bond_dimension, DOWN_OFFSET), - ("phys", parameters.physical_dimension, LOWER_RIGHT_OFFSET), - ], - ), - _make_tensor( - "tensor_c", - "C", - 0.0, - VERTICAL_SPACING, - [ - ("right", parameters.bond_dimension, RIGHT_OFFSET), - ("up", parameters.bond_dimension, UP_OFFSET), - ("phys", parameters.physical_dimension, LOWER_LEFT_OFFSET), - ], - ), - _make_tensor( - "tensor_d", - "D", - 340.0, - VERTICAL_SPACING, - [ - ("left", parameters.bond_dimension, LEFT_OFFSET), - ("up", parameters.bond_dimension, UP_OFFSET), - ("phys", parameters.physical_dimension, LOWER_RIGHT_OFFSET), - ], - ), - ] - edges = [ - _make_edge("edge_ab", tensors[0], "right", tensors[1], "left"), - _make_edge("edge_cd", tensors[2], "right", tensors[3], "left"), - _make_edge("edge_ac", tensors[0], "down", tensors[2], "up"), - _make_edge("edge_bd", tensors[1], "down", tensors[3], "up"), - ] - return NetworkSpec( - id="template_peps_2", - name="PEPS 2x2", - tensors=tensors, - edges=edges, - ) - - -def _build_generic_peps_template(parameters: TemplateParameters) -> NetworkSpec: - """Build a square PEPS grid larger than the default 2x2 layout.""" - size = _resolve_graph_size(parameters) - tensors: list[TensorSpec] = [] - tensor_lookup: dict[tuple[int, int], TensorSpec] = {} - for row_index in range(size): - for column_index in range(size): - tensor_indices: list[TemplateIndexConfig] = [] - if column_index > 0: - tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) - if column_index < size - 1: - tensor_indices.append( - ("right", parameters.bond_dimension, RIGHT_OFFSET) - ) - if row_index > 0: - tensor_indices.append(("up", parameters.bond_dimension, UP_OFFSET)) - if row_index < size - 1: - tensor_indices.append(("down", parameters.bond_dimension, DOWN_OFFSET)) - tensor_indices.append( - ( - "phys", - parameters.physical_dimension, - LOWER_LEFT_OFFSET if column_index % 2 == 0 else LOWER_RIGHT_OFFSET, - ) - ) - tensor = _make_tensor( - f"tensor_r{row_index + 1}_c{column_index + 1}", - _grid_tensor_name(row_index, column_index), - 340.0 * column_index, - VERTICAL_SPACING * row_index, - tensor_indices, - ) - tensors.append(tensor) - tensor_lookup[(row_index, column_index)] = tensor - edges = [] - for row_index in range(size): - for column_index in range(size): - current_tensor = tensor_lookup[(row_index, column_index)] - if column_index + 1 < size: - edges.append( - _make_edge( - f"edge_r{row_index + 1}_c{column_index + 1}_right", - current_tensor, - "right", - tensor_lookup[(row_index, column_index + 1)], - "left", - ) - ) - if row_index + 1 < size: - edges.append( - _make_edge( - f"edge_r{row_index + 1}_c{column_index + 1}_down", - current_tensor, - "down", - tensor_lookup[(row_index + 1, column_index)], - "up", - ) - ) - return NetworkSpec( - id=f"template_peps_{size}", - name=f"PEPS {size}x{size}", - tensors=tensors, - edges=edges, - ) - - -def _build_mera_template(parameters: TemplateParameters) -> NetworkSpec: - """Build the requested MERA template variant.""" - if ( - _resolve_graph_size(parameters) - == TEMPLATE_DEFINITIONS["mera"].defaults.graph_size - ): - return _build_default_mera_template(parameters) - return _build_generic_mera_template(parameters) - - -def _build_default_mera_template(parameters: TemplateParameters) -> NetworkSpec: - """Build the default depth-3 MERA layout.""" - tensors = [ - _make_tensor( - "tensor_top", - "Top", - 320.0, - 0.0, - [ - ("left", parameters.bond_dimension, LEFT_OFFSET), - ("right", parameters.bond_dimension, RIGHT_OFFSET), - ], - ), - _make_tensor( - "tensor_mid_left", - "Mid L", - 120.0, - 210.0, - [ - ("up", parameters.bond_dimension, UP_OFFSET), - ("left", parameters.bond_dimension, LEFT_OFFSET), - ("down", parameters.bond_dimension, DOWN_OFFSET), - ], - ), - _make_tensor( - "tensor_mid_right", - "Mid R", - 520.0, - 210.0, - [ - ("up", parameters.bond_dimension, UP_OFFSET), - ("down", parameters.bond_dimension, DOWN_OFFSET), - ("right", parameters.bond_dimension, RIGHT_OFFSET), - ], - ), - _make_tensor( - "tensor_leaf_left", - "Leaf L", - 0.0, - 420.0, - [ - ("up", parameters.bond_dimension, UP_OFFSET), - ("phys", parameters.physical_dimension, DOWN_OFFSET), - ], - ), - _make_tensor( - "tensor_leaf_mid", - "Leaf M", - 320.0, - 420.0, - [ - ("left", parameters.bond_dimension, LEFT_OFFSET), - ("right", parameters.bond_dimension, RIGHT_OFFSET), - ("phys", parameters.physical_dimension, DOWN_OFFSET), - ], - ), - _make_tensor( - "tensor_leaf_right", - "Leaf R", - 640.0, - 420.0, - [ - ("up", parameters.bond_dimension, UP_OFFSET), - ("phys", parameters.physical_dimension, DOWN_OFFSET), - ], - ), - ] - edges = [ - _make_edge("edge_top_left", tensors[0], "left", tensors[1], "up"), - _make_edge("edge_top_right", tensors[0], "right", tensors[2], "up"), - _make_edge("edge_left_leaf", tensors[1], "left", tensors[3], "up"), - _make_edge("edge_center_leaf", tensors[1], "down", tensors[4], "left"), - _make_edge("edge_right_center", tensors[2], "down", tensors[4], "right"), - _make_edge("edge_right_leaf", tensors[2], "right", tensors[5], "up"), - ] - return NetworkSpec( - id="template_mera_3", - name="MERA", - tensors=tensors, - edges=edges, - ) - - -def _build_generic_mera_template(parameters: TemplateParameters) -> NetworkSpec: - """Build a generic MERA layout with the requested depth.""" - depth = _resolve_graph_size(parameters) - levels: list[list[TensorSpec]] = [] - for level_index in range(depth): - level_tensors = [] - for position_index in range(level_index + 1): - tensor_indices: list[TemplateIndexConfig] = [] - if position_index > 0: - tensor_indices.append( - ("up_left", parameters.bond_dimension, LEFT_OFFSET) - ) - if position_index < level_index: - tensor_indices.append( - ("up_right", parameters.bond_dimension, UP_OFFSET) - ) - if level_index < depth - 1: - tensor_indices.append( - ("down_left", parameters.bond_dimension, LOWER_LEFT_OFFSET) - ) - tensor_indices.append( - ("down_right", parameters.bond_dimension, LOWER_RIGHT_OFFSET) - ) - if level_index == depth - 1: - tensor_indices.append( - ("phys", parameters.physical_dimension, DOWN_OFFSET) - ) - tensor = _make_tensor( - f"tensor_l{level_index + 1}_{position_index + 1}", - f"L{level_index + 1}-{position_index + 1}", - position_index * HORIZONTAL_SPACING - + ((depth - level_index - 1) * HORIZONTAL_SPACING) / 2, - level_index * LAYER_SPACING, - tensor_indices, - ) - level_tensors.append(tensor) - levels.append(level_tensors) - edges = [] - for level_index in range(depth - 1): - for position_index, tensor in enumerate(levels[level_index]): - left_child = levels[level_index + 1][position_index] - right_child = levels[level_index + 1][position_index + 1] - edges.append( - _make_edge( - f"edge_l{level_index + 1}_{position_index + 1}_left", - tensor, - "down_left", - left_child, - "up_right", - ) - ) - edges.append( - _make_edge( - f"edge_l{level_index + 1}_{position_index + 1}_right", - tensor, - "down_right", - right_child, - "up_left", - ) - ) - return NetworkSpec( - id=f"template_mera_{depth}", - name=f"MERA depth {depth}", - tensors=[tensor for level in levels for tensor in level], - edges=edges, - ) - - -def _build_ttn_template(parameters: TemplateParameters) -> NetworkSpec: - """Build the canonical TTN layout.""" - depth = _resolve_ttn_depth(parameters) - spec = _build_generic_ttn_template(parameters, depth=depth) - spec.id = f"template_ttn_{depth}" - spec.name = f"TTN depth {depth}" - spec.metadata = { - "template_name": "ttn", - "depth": depth, - "leaf_physical_legs": parameters.leaf_physical_legs, - "root_open_leg": parameters.root_open_leg, - "isometric": parameters.isometric, - } - for tensor in spec.tensors: - tensor.metadata = { - "role": "isometry" if parameters.isometric else "tensor", - "family": "ttn", - "isometric": parameters.isometric, - } - _annotate_physics_1d_indices(tensor, symmetry="none") - return spec - - -def _build_generic_ttn_template( - parameters: TemplateParameters, - *, - depth: int, -) -> NetworkSpec: - """Build a generic TTN with the requested depth.""" - levels: list[list[TensorSpec]] = [] - for level_index in range(depth): - level_tensors = [] - node_count = 2**level_index - for position_index in range(node_count): - tensor_indices: list[TemplateIndexConfig] = [] - if level_index > 0: - tensor_indices.append(("up", parameters.bond_dimension, UP_OFFSET)) - if level_index < depth - 1: - tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) - tensor_indices.append( - ("right", parameters.bond_dimension, RIGHT_OFFSET) - ) - if level_index == 0 and parameters.root_open_leg: - tensor_indices.append(("out", parameters.bond_dimension, UP_OFFSET)) - if level_index == depth - 1 and parameters.leaf_physical_legs: - tensor_indices.append( - ("phys", parameters.physical_dimension, DOWN_OFFSET) - ) - x_position = ( - ((2 * position_index + 1) * (2 ** (depth - level_index - 1)) - 1) - * TREE_LEAF_SPACING - / 2 - ) - tensor = _make_tensor( - f"tensor_l{level_index + 1}_{position_index + 1}", - f"L{level_index + 1}-{position_index + 1}", - x_position, - level_index * LAYER_SPACING, - tensor_indices, - ) - level_tensors.append(tensor) - levels.append(level_tensors) - edges = [] - for level_index in range(depth - 1): - for position_index, tensor in enumerate(levels[level_index]): - left_child = levels[level_index + 1][position_index * 2] - right_child = levels[level_index + 1][position_index * 2 + 1] - edges.append( - _make_edge( - f"edge_l{level_index + 1}_{position_index + 1}_left", - tensor, - "left", - left_child, - "up", - ) - ) - edges.append( - _make_edge( - f"edge_l{level_index + 1}_{position_index + 1}_right", - tensor, - "right", - right_child, - "up", - ) - ) - return NetworkSpec( - id=f"template_ttn_{depth}", - name=f"TTN depth {depth}", - tensors=[tensor for level in levels for tensor in level], - edges=edges, - ) - - -def _grid_tensor_name(row_index: int, column_index: int) -> str: - """Return a readable tensor name for a PEPS grid position.""" - if row_index < 26: - return f"{chr(ord('A') + row_index)}{column_index + 1}" - return f"R{row_index + 1}C{column_index + 1}" - - -def _build_pepo_template(parameters: TemplateParameters) -> NetworkSpec: - """Build a square PEPO operator grid with bra and ket physical legs.""" - size = _resolve_graph_size(parameters) - tensors: list[TensorSpec] = [] - tensor_lookup: dict[tuple[int, int], TensorSpec] = {} - for row_index in range(size): - for column_index in range(size): - tensor_indices: list[TemplateIndexConfig] = [] - if column_index > 0: - tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) - if column_index < size - 1: - tensor_indices.append( - ("right", parameters.bond_dimension, RIGHT_OFFSET) - ) - if row_index > 0: - tensor_indices.append(("up", parameters.bond_dimension, UP_OFFSET)) - if row_index < size - 1: - tensor_indices.append(("down", parameters.bond_dimension, DOWN_OFFSET)) - tensor_indices.extend( - [ - ("bra", parameters.physical_dimension, UPPER_PHYSICAL_OFFSET), - ("ket", parameters.physical_dimension, LOWER_PHYSICAL_OFFSET), - ] - ) - tensor = _make_tensor( - f"tensor_r{row_index + 1}_c{column_index + 1}", - _grid_tensor_name(row_index, column_index), - 340.0 * column_index, - VERTICAL_SPACING * row_index, - tensor_indices, - ) - tensors.append(tensor) - tensor_lookup[(row_index, column_index)] = tensor - edges = [] - for row_index in range(size): - for column_index in range(size): - current_tensor = tensor_lookup[(row_index, column_index)] - if column_index + 1 < size: - edges.append( - _make_edge( - f"edge_r{row_index + 1}_c{column_index + 1}_right", - current_tensor, - "right", - tensor_lookup[(row_index, column_index + 1)], - "left", - ) - ) - if row_index + 1 < size: - edges.append( - _make_edge( - f"edge_r{row_index + 1}_c{column_index + 1}_down", - current_tensor, - "down", - tensor_lookup[(row_index + 1, column_index)], - "up", - ) - ) - return NetworkSpec( - id=f"template_pepo_{size}", - name=f"PEPO {size}x{size}", - tensors=tensors, - edges=edges, - ) - - -def _build_tebd_gate_layer_template(parameters: TemplateParameters) -> NetworkSpec: - """Build an MPS chain with an even TEBD two-site gate layer.""" - site_count = _resolve_graph_size(parameters) - site_tensors = [ - _make_tensor( - f"tensor_site_{site_index + 1}", - f"A{site_index + 1}", - HORIZONTAL_SPACING * site_index, - 0.0, - _build_mps_site_indices(site_index, site_count, parameters), - ) - for site_index in range(site_count) - ] - gate_tensors = [ - _make_tensor( - f"tensor_gate_{site_index + 1}_{site_index + 2}", - f"G{site_index + 1}-{site_index + 2}", - HORIZONTAL_SPACING * (site_index + 0.5), - 220.0, - [ - ( - "out_left", - parameters.physical_dimension, - GATE_UPPER_LEFT_OFFSET, - ), - ( - "out_right", - parameters.physical_dimension, - GATE_UPPER_RIGHT_OFFSET, - ), - ("in_left", parameters.physical_dimension, GATE_LOWER_LEFT_OFFSET), - ("in_right", parameters.physical_dimension, GATE_LOWER_RIGHT_OFFSET), - ], - ) - for site_index in range(0, site_count - 1, 2) - ] - edges = _make_linear_chain_edges(site_tensors) - for gate_index, gate_tensor in enumerate(gate_tensors): - left_site_index = gate_index * 2 - right_site_index = left_site_index + 1 - edges.extend( - [ - _make_edge( - f"edge_gate_{left_site_index + 1}_{right_site_index + 1}_left", - site_tensors[left_site_index], - "phys", - gate_tensor, - "in_left", - ), - _make_edge( - f"edge_gate_{left_site_index + 1}_{right_site_index + 1}_right", - site_tensors[right_site_index], - "phys", - gate_tensor, - "in_right", - ), - ] - ) - for tensor in site_tensors: - tensor.metadata = { - "role": "state", - "state": "tebd_input", - "symmetry": "z2", - "tags": "tebd mps site", - } - _annotate_physics_1d_indices(tensor, symmetry="z2") - for tensor in gate_tensors: - tensor.metadata = { - "role": "gate", - "symmetry": "z2", - "tags": "tebd even layer", - } - _annotate_physics_1d_indices(tensor, symmetry="z2") - definition = TEMPLATE_DEFINITIONS["tebd_gate_layer"] - spec_name = ( - definition.display_name - if site_count == definition.defaults.graph_size - else f"{definition.display_name} ({site_count} {definition.graph_size_label.lower()})" - ) - return NetworkSpec( - id=f"template_tebd_gate_layer_{site_count}", - name=spec_name, - tensors=[*site_tensors, *gate_tensors], - edges=edges, - ) - - -def _annotate_physics_1d_indices(tensor: TensorSpec, *, symmetry: str) -> None: - """Add guided metadata to standard 1D physics template indices.""" - for index in tensor.indices: - if index.name in {"left", "right"}: - index.metadata = {"leg_kind": "bond", "symmetry": symmetry} - else: - index.metadata = {"leg_kind": "physical", "symmetry": symmetry} - - -def _apply_mps_template_configuration( - spec: NetworkSpec, - parameters: TemplateParameters, -) -> NetworkSpec: - """Attach metadata and tensor initialization presets to the built MPS.""" - spec.metadata = { - "template_name": "mps", - "role": "state", - "boundary_condition": parameters.boundary_condition, - "symmetry": parameters.symmetry, - "initial_state": parameters.initial_state, - } - for tensor_index, tensor in enumerate(spec.tensors): - tensor.metadata = { - "role": "state", - "family": "mps", - "symmetry": parameters.symmetry, - "initial_state": parameters.initial_state, - } - _annotate_physics_1d_indices(tensor, symmetry=parameters.symmetry) - tensor.tensor_data = _build_mps_tensor_data( - tensor, - tensor_index=tensor_index, - parameters=parameters, - ) - return spec - - -def _apply_mpo_template_configuration( - spec: NetworkSpec, - parameters: TemplateParameters, -) -> NetworkSpec: - """Attach semantic MPO metadata to the built operator chain.""" - spec.metadata = { - "template_name": "mpo", - "role": "operator", - "boundary_condition": parameters.boundary_condition, - "j": parameters.j, - "h": parameters.h, - } - for tensor in spec.tensors: - tensor.metadata = { - "role": "operator", - "family": "mpo", - "boundary_condition": parameters.boundary_condition, - "j": parameters.j, - "h": parameters.h, - } - _annotate_physics_1d_indices(tensor, symmetry="none") - return spec - - -def _build_mps_tensor_data( - tensor: TensorSpec, - *, - tensor_index: int, - parameters: TemplateParameters, -) -> TensorDataSpec: - """Return the tensor-data initializer matching the chosen MPS preset.""" - if parameters.initial_state == "zeros": - return TensorDataSpec(mode=TensorDataMode.ZEROS) - if parameters.initial_state == "random": - return TensorDataSpec( - mode=TensorDataMode.RANDOM, - seed=tensor_index, - ) - if parameters.initial_state == "all_up": - return _build_mps_literal_state_tensor_data(tensor, basis_index=0) - if parameters.initial_state == "all_down": - return _build_mps_literal_state_tensor_data(tensor, basis_index=1) - return _build_mps_literal_state_tensor_data( - tensor, - basis_index=tensor_index % 2, - ) - - -def _build_mps_literal_state_tensor_data( - tensor: TensorSpec, - *, - basis_index: int, -) -> TensorDataSpec: - """Build one explicit basis-state tensor embedded into the current shape.""" - values = _build_zero_literal(list(tensor.shape)) - _set_nested_literal_value( - values, - [0] * (len(tensor.shape) - 1) + [basis_index], - 1.0, - ) - return TensorDataSpec( - mode=TensorDataMode.LITERAL, - values=values, - ) - - -def _build_zero_literal(shape: list[int]) -> TensorNumericLiteral: - """Build one nested zero-filled literal matching the provided shape.""" - if not shape: - return 0.0 - return [_build_zero_literal(shape[1:]) for _ in range(shape[0])] - - -def _set_nested_literal_value( - values: TensorNumericLiteral, - index_path: list[int], - value: float, -) -> None: - """Assign one scalar inside a nested tensor literal structure.""" - current_values = cast(list[TensorNumericLiteral], values) - for index in index_path[:-1]: - current_values = cast(list[TensorNumericLiteral], current_values[index]) - current_values[index_path[-1]] = value - - -def _make_tensor( - tensor_id: str, - name: str, - x: float, - y: float, - indices: list[TemplateIndexConfig], -) -> TensorSpec: - """Create one template tensor with named indices and canvas placement.""" - return TensorSpec( - id=tensor_id, - name=name, - position=CanvasPosition(x=x, y=y), - indices=[ - _make_named_index(tensor_id, suffix, dimension, offset) - for suffix, dimension, offset in indices - ], - ) - - -def _make_named_index( - tensor_id: str, - suffix: str, - dimension: int | None, - offset: tuple[float, float], -) -> IndexSpec: - """Create one named index for a template tensor.""" - return IndexSpec( - id=f"{tensor_id}_{suffix}", - name=suffix, - dimension=_resolve_required_dimension(dimension), - offset=CanvasPosition(x=offset[0], y=offset[1]), - ) - - -def _resolve_required_dimension(dimension: int | None) -> int: - """Return one validated template index dimension.""" - if dimension is None: - raise ValueError("Template index dimensions must be resolved before building.") - return dimension - - -def _make_edge( - edge_id: str, - left_tensor: TensorSpec, - left_index_suffix: str, - right_tensor: TensorSpec, - right_index_suffix: str, -) -> EdgeSpec: - """Create one template edge between two named tensor indices.""" - return EdgeSpec( - id=edge_id, - name=edge_id.replace("_", "-"), - left=EdgeEndpointRef( - tensor_id=left_tensor.id, - index_id=f"{left_tensor.id}_{left_index_suffix}", - ), - right=EdgeEndpointRef( - tensor_id=right_tensor.id, - index_id=f"{right_tensor.id}_{right_index_suffix}", - ), - ) - - -def _resolve_graph_size(parameters: TemplateParameters) -> int: - """Return the validated graph-size parameter for size-based templates.""" - if parameters.graph_size is None: - raise ValueError("Template parameter 'graph_size' is required.") - return parameters.graph_size - - -def _resolve_ttn_depth(parameters: TemplateParameters) -> int: - """Return the validated depth parameter for the TTN template.""" - if parameters.depth is None: - raise ValueError("Template parameter 'depth' is required.") - return parameters.depth - - def register_builtin_templates() -> None: """Register the built-in templates in their stable display order.""" register_template( @@ -981,3 +87,10 @@ def register_builtin_templates() -> None: _build_tebd_gate_layer_template, overwrite=True, ) + + +__all__ = [ + "build_template", + "register_builtin_templates", + "_build_linear_chain_template", +] diff --git a/src/tensor_network_editor/internal/templates/_template_builders_common.py b/src/tensor_network_editor/internal/templates/_template_builders_common.py new file mode 100644 index 0000000..4940206 --- /dev/null +++ b/src/tensor_network_editor/internal/templates/_template_builders_common.py @@ -0,0 +1,132 @@ +"""Shared primitives and constants for internal template builders.""" + +from __future__ import annotations + +from typing import cast + +from ...models import CanvasPosition, EdgeEndpointRef, EdgeSpec, IndexSpec, TensorSpec +from ..models._model_tensor_data import TensorNumericLiteral +from ._template_catalog import TemplateParameters + +HORIZONTAL_SPACING = 320.0 +VERTICAL_SPACING = 280.0 +LAYER_SPACING = 210.0 +TREE_LEAF_SPACING = 220.0 +LEFT_OFFSET = (-58.0, 0.0) +RIGHT_OFFSET = (58.0, 0.0) +UP_OFFSET = (0.0, -28.0) +DOWN_OFFSET = (0.0, 28.0) +LOWER_LEFT_OFFSET = (-24.0, 34.0) +LOWER_RIGHT_OFFSET = (24.0, 34.0) +UPPER_PHYSICAL_OFFSET = (-26.0, -54.0) +LOWER_PHYSICAL_OFFSET = (26.0, 42.0) +GATE_UPPER_LEFT_OFFSET = (-36.0, -38.0) +GATE_UPPER_RIGHT_OFFSET = (36.0, -38.0) +GATE_LOWER_LEFT_OFFSET = (-36.0, 38.0) +GATE_LOWER_RIGHT_OFFSET = (36.0, 38.0) +TemplateIndexConfig = tuple[str, int | None, tuple[float, float]] + + +def _annotate_physics_1d_indices(tensor: TensorSpec, *, symmetry: str) -> None: + """Add guided metadata to standard 1D physics template indices.""" + for index in tensor.indices: + if index.name in {"left", "right"}: + index.metadata = {"leg_kind": "bond", "symmetry": symmetry} + else: + index.metadata = {"leg_kind": "physical", "symmetry": symmetry} + + +def _build_zero_literal(shape: list[int]) -> TensorNumericLiteral: + """Build one nested zero-filled literal matching the provided shape.""" + if not shape: + return 0.0 + return [_build_zero_literal(shape[1:]) for _ in range(shape[0])] + + +def _set_nested_literal_value( + values: TensorNumericLiteral, + index_path: list[int], + value: float, +) -> None: + """Assign one scalar inside a nested tensor literal structure.""" + current_values = cast(list[TensorNumericLiteral], values) + for index in index_path[:-1]: + current_values = cast(list[TensorNumericLiteral], current_values[index]) + current_values[index_path[-1]] = value + + +def _make_tensor( + tensor_id: str, + name: str, + x: float, + y: float, + indices: list[TemplateIndexConfig], +) -> TensorSpec: + """Create one template tensor with named indices and canvas placement.""" + return TensorSpec( + id=tensor_id, + name=name, + position=CanvasPosition(x=x, y=y), + indices=[ + _make_named_index(tensor_id, suffix, dimension, offset) + for suffix, dimension, offset in indices + ], + ) + + +def _make_named_index( + tensor_id: str, + suffix: str, + dimension: int | None, + offset: tuple[float, float], +) -> IndexSpec: + """Create one named index for a template tensor.""" + return IndexSpec( + id=f"{tensor_id}_{suffix}", + name=suffix, + dimension=_resolve_required_dimension(dimension), + offset=CanvasPosition(x=offset[0], y=offset[1]), + ) + + +def _resolve_required_dimension(dimension: int | None) -> int: + """Return one validated template index dimension.""" + if dimension is None: + raise ValueError("Template index dimensions must be resolved before building.") + return dimension + + +def _make_edge( + edge_id: str, + left_tensor: TensorSpec, + left_index_suffix: str, + right_tensor: TensorSpec, + right_index_suffix: str, +) -> EdgeSpec: + """Create one template edge between two named tensor indices.""" + return EdgeSpec( + id=edge_id, + name=edge_id.replace("_", "-"), + left=EdgeEndpointRef( + tensor_id=left_tensor.id, + index_id=f"{left_tensor.id}_{left_index_suffix}", + ), + right=EdgeEndpointRef( + tensor_id=right_tensor.id, + index_id=f"{right_tensor.id}_{right_index_suffix}", + ), + ) + + +def _resolve_graph_size(parameters: TemplateParameters) -> int: + """Return the validated graph-size parameter for size-based templates.""" + if parameters.graph_size is None: + raise ValueError("Template parameter 'graph_size' is required.") + return parameters.graph_size + + +def _resolve_ttn_depth(parameters: TemplateParameters) -> int: + """Return the validated depth parameter for the TTN template.""" + if parameters.depth is None: + raise ValueError("Template parameter 'depth' is required.") + return parameters.depth diff --git a/src/tensor_network_editor/internal/templates/_template_builders_grid.py b/src/tensor_network_editor/internal/templates/_template_builders_grid.py new file mode 100644 index 0000000..e8d06e9 --- /dev/null +++ b/src/tensor_network_editor/internal/templates/_template_builders_grid.py @@ -0,0 +1,333 @@ +"""Grid and layer-oriented internal template builders.""" + +from __future__ import annotations + +from collections.abc import Callable + +from ...models import EdgeSpec, NetworkSpec, TensorSpec +from ._template_builders_common import ( + DOWN_OFFSET, + GATE_LOWER_LEFT_OFFSET, + GATE_LOWER_RIGHT_OFFSET, + GATE_UPPER_LEFT_OFFSET, + GATE_UPPER_RIGHT_OFFSET, + HORIZONTAL_SPACING, + LEFT_OFFSET, + LOWER_LEFT_OFFSET, + LOWER_PHYSICAL_OFFSET, + LOWER_RIGHT_OFFSET, + RIGHT_OFFSET, + UP_OFFSET, + UPPER_PHYSICAL_OFFSET, + VERTICAL_SPACING, + TemplateIndexConfig, + _annotate_physics_1d_indices, + _make_edge, + _make_tensor, + _resolve_graph_size, +) +from ._template_builders_linear import _build_mps_site_indices, _make_linear_chain_edges +from ._template_catalog import TEMPLATE_DEFINITIONS, TemplateParameters + +GridSiteIndexBuilder = Callable[ + [int, int, int, TemplateParameters], list[TemplateIndexConfig] +] + + +def _build_peps_template(parameters: TemplateParameters) -> NetworkSpec: + """Build the requested PEPS template variant.""" + if _resolve_graph_size(parameters) == 2: + return _build_default_peps_template(parameters) + return _build_generic_peps_template(parameters) + + +def _build_default_peps_template(parameters: TemplateParameters) -> NetworkSpec: + """Build the default 2x2 PEPS layout.""" + tensors = [ + _make_tensor( + "tensor_a", + "A", + 0.0, + 0.0, + [ + ("right", parameters.bond_dimension, RIGHT_OFFSET), + ("down", parameters.bond_dimension, DOWN_OFFSET), + ("phys", parameters.physical_dimension, LOWER_LEFT_OFFSET), + ], + ), + _make_tensor( + "tensor_b", + "B", + 340.0, + 0.0, + [ + ("left", parameters.bond_dimension, LEFT_OFFSET), + ("down", parameters.bond_dimension, DOWN_OFFSET), + ("phys", parameters.physical_dimension, LOWER_RIGHT_OFFSET), + ], + ), + _make_tensor( + "tensor_c", + "C", + 0.0, + VERTICAL_SPACING, + [ + ("right", parameters.bond_dimension, RIGHT_OFFSET), + ("up", parameters.bond_dimension, UP_OFFSET), + ("phys", parameters.physical_dimension, LOWER_LEFT_OFFSET), + ], + ), + _make_tensor( + "tensor_d", + "D", + 340.0, + VERTICAL_SPACING, + [ + ("left", parameters.bond_dimension, LEFT_OFFSET), + ("up", parameters.bond_dimension, UP_OFFSET), + ("phys", parameters.physical_dimension, LOWER_RIGHT_OFFSET), + ], + ), + ] + edges = [ + _make_edge("edge_ab", tensors[0], "right", tensors[1], "left"), + _make_edge("edge_cd", tensors[2], "right", tensors[3], "left"), + _make_edge("edge_ac", tensors[0], "down", tensors[2], "up"), + _make_edge("edge_bd", tensors[1], "down", tensors[3], "up"), + ] + return NetworkSpec( + id="template_peps_2", + name="PEPS 2x2", + tensors=tensors, + edges=edges, + ) + + +def _build_generic_peps_template(parameters: TemplateParameters) -> NetworkSpec: + """Build a square PEPS grid larger than the default 2x2 layout.""" + size = _resolve_graph_size(parameters) + tensors, edges = _build_square_grid_tensors_and_edges( + size=size, + parameters=parameters, + site_index_builder=_build_peps_grid_site_indices, + ) + return NetworkSpec( + id=f"template_peps_{size}", + name=f"PEPS {size}x{size}", + tensors=tensors, + edges=edges, + ) + + +def _build_pepo_template(parameters: TemplateParameters) -> NetworkSpec: + """Build a square PEPO operator grid with bra and ket physical legs.""" + size = _resolve_graph_size(parameters) + tensors, edges = _build_square_grid_tensors_and_edges( + size=size, + parameters=parameters, + site_index_builder=_build_pepo_grid_site_indices, + ) + return NetworkSpec( + id=f"template_pepo_{size}", + name=f"PEPO {size}x{size}", + tensors=tensors, + edges=edges, + ) + + +def _build_square_grid_tensors_and_edges( + *, + size: int, + parameters: TemplateParameters, + site_index_builder: GridSiteIndexBuilder, +) -> tuple[list[TensorSpec], list[EdgeSpec]]: + """Build the tensors and nearest-neighbor edges for one square grid.""" + tensors: list[TensorSpec] = [] + tensor_lookup: dict[tuple[int, int], TensorSpec] = {} + for row_index in range(size): + for column_index in range(size): + tensor = _make_tensor( + f"tensor_r{row_index + 1}_c{column_index + 1}", + _grid_tensor_name(row_index, column_index), + 340.0 * column_index, + VERTICAL_SPACING * row_index, + site_index_builder(row_index, column_index, size, parameters), + ) + tensors.append(tensor) + tensor_lookup[(row_index, column_index)] = tensor + edges: list[EdgeSpec] = [] + for row_index in range(size): + for column_index in range(size): + current_tensor = tensor_lookup[(row_index, column_index)] + if column_index + 1 < size: + edges.append( + _make_edge( + f"edge_r{row_index + 1}_c{column_index + 1}_right", + current_tensor, + "right", + tensor_lookup[(row_index, column_index + 1)], + "left", + ) + ) + if row_index + 1 < size: + edges.append( + _make_edge( + f"edge_r{row_index + 1}_c{column_index + 1}_down", + current_tensor, + "down", + tensor_lookup[(row_index + 1, column_index)], + "up", + ) + ) + return tensors, edges + + +def _build_peps_grid_site_indices( + row_index: int, + column_index: int, + size: int, + parameters: TemplateParameters, +) -> list[TemplateIndexConfig]: + """Return the index layout for one PEPS grid tensor.""" + tensor_indices = _build_grid_neighbor_indices( + row_index=row_index, + column_index=column_index, + size=size, + parameters=parameters, + ) + tensor_indices.append( + ( + "phys", + parameters.physical_dimension, + LOWER_LEFT_OFFSET if column_index % 2 == 0 else LOWER_RIGHT_OFFSET, + ) + ) + return tensor_indices + + +def _build_pepo_grid_site_indices( + row_index: int, + column_index: int, + size: int, + parameters: TemplateParameters, +) -> list[TemplateIndexConfig]: + """Return the index layout for one PEPO grid tensor.""" + tensor_indices = _build_grid_neighbor_indices( + row_index=row_index, + column_index=column_index, + size=size, + parameters=parameters, + ) + tensor_indices.extend( + [ + ("bra", parameters.physical_dimension, UPPER_PHYSICAL_OFFSET), + ("ket", parameters.physical_dimension, LOWER_PHYSICAL_OFFSET), + ] + ) + return tensor_indices + + +def _build_grid_neighbor_indices( + *, + row_index: int, + column_index: int, + size: int, + parameters: TemplateParameters, +) -> list[TemplateIndexConfig]: + """Return the horizontal and vertical bond indices for one grid tensor.""" + tensor_indices: list[TemplateIndexConfig] = [] + if column_index > 0: + tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) + if column_index < size - 1: + tensor_indices.append(("right", parameters.bond_dimension, RIGHT_OFFSET)) + if row_index > 0: + tensor_indices.append(("up", parameters.bond_dimension, UP_OFFSET)) + if row_index < size - 1: + tensor_indices.append(("down", parameters.bond_dimension, DOWN_OFFSET)) + return tensor_indices + + +def _grid_tensor_name(row_index: int, column_index: int) -> str: + """Return a readable tensor name for a square-grid position.""" + if row_index < 26: + return f"{chr(ord('A') + row_index)}{column_index + 1}" + return f"R{row_index + 1}C{column_index + 1}" + + +def _build_tebd_gate_layer_template(parameters: TemplateParameters) -> NetworkSpec: + """Build an MPS chain with an even TEBD two-site gate layer.""" + site_count = _resolve_graph_size(parameters) + site_tensors = [ + _make_tensor( + f"tensor_site_{site_index + 1}", + f"A{site_index + 1}", + HORIZONTAL_SPACING * site_index, + 0.0, + _build_mps_site_indices(site_index, site_count, parameters), + ) + for site_index in range(site_count) + ] + gate_tensors = [ + _make_tensor( + f"tensor_gate_{site_index + 1}_{site_index + 2}", + f"G{site_index + 1}-{site_index + 2}", + HORIZONTAL_SPACING * (site_index + 0.5), + 220.0, + [ + ("out_left", parameters.physical_dimension, GATE_UPPER_LEFT_OFFSET), + ("out_right", parameters.physical_dimension, GATE_UPPER_RIGHT_OFFSET), + ("in_left", parameters.physical_dimension, GATE_LOWER_LEFT_OFFSET), + ("in_right", parameters.physical_dimension, GATE_LOWER_RIGHT_OFFSET), + ], + ) + for site_index in range(0, site_count - 1, 2) + ] + edges = _make_linear_chain_edges(site_tensors) + for gate_index, gate_tensor in enumerate(gate_tensors): + left_site_index = gate_index * 2 + right_site_index = left_site_index + 1 + edges.extend( + [ + _make_edge( + f"edge_gate_{left_site_index + 1}_{right_site_index + 1}_left", + site_tensors[left_site_index], + "phys", + gate_tensor, + "in_left", + ), + _make_edge( + f"edge_gate_{left_site_index + 1}_{right_site_index + 1}_right", + site_tensors[right_site_index], + "phys", + gate_tensor, + "in_right", + ), + ] + ) + for tensor in site_tensors: + tensor.metadata = { + "role": "state", + "state": "tebd_input", + "symmetry": "z2", + "tags": "tebd mps site", + } + _annotate_physics_1d_indices(tensor, symmetry="z2") + for tensor in gate_tensors: + tensor.metadata = { + "role": "gate", + "symmetry": "z2", + "tags": "tebd even layer", + } + _annotate_physics_1d_indices(tensor, symmetry="z2") + definition = TEMPLATE_DEFINITIONS["tebd_gate_layer"] + spec_name = ( + definition.display_name + if site_count == definition.defaults.graph_size + else f"{definition.display_name} ({site_count} {definition.graph_size_label.lower()})" + ) + return NetworkSpec( + id=f"template_tebd_gate_layer_{site_count}", + name=spec_name, + tensors=[*site_tensors, *gate_tensors], + edges=edges, + ) diff --git a/src/tensor_network_editor/internal/templates/_template_builders_linear.py b/src/tensor_network_editor/internal/templates/_template_builders_linear.py new file mode 100644 index 0000000..f34812d --- /dev/null +++ b/src/tensor_network_editor/internal/templates/_template_builders_linear.py @@ -0,0 +1,245 @@ +"""Linear-family internal template builders.""" + +from __future__ import annotations + +from collections.abc import Callable + +from ...models import EdgeSpec, NetworkSpec, TensorDataMode, TensorDataSpec, TensorSpec +from ._template_builders_common import ( + DOWN_OFFSET, + HORIZONTAL_SPACING, + LEFT_OFFSET, + RIGHT_OFFSET, + UP_OFFSET, + TemplateIndexConfig, + _annotate_physics_1d_indices, + _build_zero_literal, + _make_edge, + _make_tensor, + _resolve_graph_size, + _set_nested_literal_value, +) +from ._template_catalog import TEMPLATE_DEFINITIONS, TemplateParameters + +LinearChainSiteIndexBuilder = Callable[ + [int, int, TemplateParameters], list[TemplateIndexConfig] +] + + +def _build_mps_template(parameters: TemplateParameters) -> NetworkSpec: + """Build an MPS template with the requested site count and dimensions.""" + spec = _build_linear_chain_template( + "mps", + parameters, + tensor_name_prefix="A", + spacing=HORIZONTAL_SPACING, + site_index_builder=_build_mps_site_indices, + periodic=parameters.boundary_condition == "periodic", + ) + return _apply_mps_template_configuration(spec, parameters) + + +def _build_mpo_template(parameters: TemplateParameters) -> NetworkSpec: + """Build an MPO template with the requested site count and dimensions.""" + spec = _build_linear_chain_template( + "mpo", + parameters, + tensor_name_prefix="W", + spacing=330.0, + site_index_builder=_build_mpo_site_indices, + periodic=parameters.boundary_condition == "periodic", + ) + return _apply_mpo_template_configuration(spec, parameters) + + +def _build_linear_chain_template( + template_name: str, + parameters: TemplateParameters, + *, + tensor_name_prefix: str, + spacing: float, + site_index_builder: LinearChainSiteIndexBuilder, + periodic: bool = False, +) -> NetworkSpec: + """Build one left-to-right chain template from a per-site index factory.""" + length = _resolve_graph_size(parameters) + tensors = [ + _make_tensor( + f"tensor_{site_index}", + f"{tensor_name_prefix}{site_index + 1}", + spacing * site_index, + 0.0, + site_index_builder(site_index, length, parameters), + ) + for site_index in range(length) + ] + definition = TEMPLATE_DEFINITIONS[template_name] + spec_name = ( + definition.display_name + if length == definition.defaults.graph_size + else f"{definition.display_name} ({length} {definition.graph_size_label.lower()})" + ) + return NetworkSpec( + id=f"template_{template_name}_{length}", + name=spec_name, + tensors=tensors, + edges=_make_linear_chain_edges(tensors, periodic=periodic), + ) + + +def _build_mps_site_indices( + site_index: int, + length: int, + parameters: TemplateParameters, +) -> list[TemplateIndexConfig]: + """Return the named index layout for one MPS site.""" + tensor_indices: list[TemplateIndexConfig] = [] + if parameters.boundary_condition == "periodic" or site_index > 0: + tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) + if parameters.boundary_condition == "periodic" or site_index < length - 1: + tensor_indices.append(("right", parameters.bond_dimension, RIGHT_OFFSET)) + tensor_indices.append(("phys", parameters.physical_dimension, DOWN_OFFSET)) + return tensor_indices + + +def _build_mpo_site_indices( + site_index: int, + length: int, + parameters: TemplateParameters, +) -> list[TemplateIndexConfig]: + """Return the named index layout for one MPO site.""" + tensor_indices: list[TemplateIndexConfig] = [] + if parameters.boundary_condition == "periodic" or site_index > 0: + tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) + if parameters.boundary_condition == "periodic" or site_index < length - 1: + tensor_indices.append(("right", parameters.bond_dimension, RIGHT_OFFSET)) + tensor_indices.extend( + [ + ("bra", parameters.physical_dimension, UP_OFFSET), + ("ket", parameters.physical_dimension, DOWN_OFFSET), + ] + ) + return tensor_indices + + +def _make_linear_chain_edges( + tensors: list[TensorSpec], + *, + periodic: bool = False, +) -> list[EdgeSpec]: + """Return the right-to-left bonds between adjacent chain tensors.""" + edges = [ + _make_edge( + f"edge_{site_index}_{site_index + 1}", + tensors[site_index], + "right", + tensors[site_index + 1], + "left", + ) + for site_index in range(len(tensors) - 1) + ] + if periodic and len(tensors) > 1: + edges.append( + _make_edge( + f"edge_{len(tensors)}_1", + tensors[-1], + "right", + tensors[0], + "left", + ) + ) + return edges + + +def _apply_mps_template_configuration( + spec: NetworkSpec, + parameters: TemplateParameters, +) -> NetworkSpec: + """Attach metadata and tensor initialization presets to the built MPS.""" + spec.metadata = { + "template_name": "mps", + "role": "state", + "boundary_condition": parameters.boundary_condition, + "symmetry": parameters.symmetry, + "initial_state": parameters.initial_state, + } + for tensor_index, tensor in enumerate(spec.tensors): + tensor.metadata = { + "role": "state", + "family": "mps", + "symmetry": parameters.symmetry, + "initial_state": parameters.initial_state, + } + _annotate_physics_1d_indices(tensor, symmetry=parameters.symmetry) + tensor.tensor_data = _build_mps_tensor_data( + tensor, + tensor_index=tensor_index, + parameters=parameters, + ) + return spec + + +def _apply_mpo_template_configuration( + spec: NetworkSpec, + parameters: TemplateParameters, +) -> NetworkSpec: + """Attach semantic MPO metadata to the built operator chain.""" + spec.metadata = { + "template_name": "mpo", + "role": "operator", + "boundary_condition": parameters.boundary_condition, + "j": parameters.j, + "h": parameters.h, + } + for tensor in spec.tensors: + tensor.metadata = { + "role": "operator", + "family": "mpo", + "boundary_condition": parameters.boundary_condition, + "j": parameters.j, + "h": parameters.h, + } + _annotate_physics_1d_indices(tensor, symmetry="none") + return spec + + +def _build_mps_tensor_data( + tensor: TensorSpec, + *, + tensor_index: int, + parameters: TemplateParameters, +) -> TensorDataSpec: + """Return the tensor-data initializer matching the chosen MPS preset.""" + if parameters.initial_state == "zeros": + return TensorDataSpec(mode=TensorDataMode.ZEROS) + if parameters.initial_state == "random": + return TensorDataSpec( + mode=TensorDataMode.RANDOM, + seed=tensor_index, + ) + if parameters.initial_state == "all_up": + return _build_mps_literal_state_tensor_data(tensor, basis_index=0) + if parameters.initial_state == "all_down": + return _build_mps_literal_state_tensor_data(tensor, basis_index=1) + return _build_mps_literal_state_tensor_data( + tensor, + basis_index=tensor_index % 2, + ) + + +def _build_mps_literal_state_tensor_data( + tensor: TensorSpec, + *, + basis_index: int, +) -> TensorDataSpec: + """Build one explicit basis-state tensor embedded into the current shape.""" + values = _build_zero_literal(list(tensor.shape)) + _set_nested_literal_value( + values, + [0] * (len(tensor.shape) - 1) + [basis_index], + 1.0, + ) + return TensorDataSpec( + mode=TensorDataMode.LITERAL, + values=values, + ) diff --git a/src/tensor_network_editor/internal/templates/_template_builders_tree.py b/src/tensor_network_editor/internal/templates/_template_builders_tree.py new file mode 100644 index 0000000..404388f --- /dev/null +++ b/src/tensor_network_editor/internal/templates/_template_builders_tree.py @@ -0,0 +1,277 @@ +"""Tree-family internal template builders.""" + +from __future__ import annotations + +from ...models import NetworkSpec, TensorSpec +from ._template_builders_common import ( + DOWN_OFFSET, + HORIZONTAL_SPACING, + LAYER_SPACING, + LEFT_OFFSET, + LOWER_LEFT_OFFSET, + LOWER_RIGHT_OFFSET, + RIGHT_OFFSET, + TREE_LEAF_SPACING, + UP_OFFSET, + TemplateIndexConfig, + _annotate_physics_1d_indices, + _make_edge, + _make_tensor, + _resolve_graph_size, + _resolve_ttn_depth, +) +from ._template_catalog import TEMPLATE_DEFINITIONS, TemplateParameters + + +def _build_mera_template(parameters: TemplateParameters) -> NetworkSpec: + """Build the requested MERA template variant.""" + if ( + _resolve_graph_size(parameters) + == TEMPLATE_DEFINITIONS["mera"].defaults.graph_size + ): + return _build_default_mera_template(parameters) + return _build_generic_mera_template(parameters) + + +def _build_default_mera_template(parameters: TemplateParameters) -> NetworkSpec: + """Build the default depth-3 MERA layout.""" + tensors = [ + _make_tensor( + "tensor_top", + "Top", + 320.0, + 0.0, + [ + ("left", parameters.bond_dimension, LEFT_OFFSET), + ("right", parameters.bond_dimension, RIGHT_OFFSET), + ], + ), + _make_tensor( + "tensor_mid_left", + "Mid L", + 120.0, + 210.0, + [ + ("up", parameters.bond_dimension, UP_OFFSET), + ("left", parameters.bond_dimension, LEFT_OFFSET), + ("down", parameters.bond_dimension, DOWN_OFFSET), + ], + ), + _make_tensor( + "tensor_mid_right", + "Mid R", + 520.0, + 210.0, + [ + ("up", parameters.bond_dimension, UP_OFFSET), + ("down", parameters.bond_dimension, DOWN_OFFSET), + ("right", parameters.bond_dimension, RIGHT_OFFSET), + ], + ), + _make_tensor( + "tensor_leaf_left", + "Leaf L", + 0.0, + 420.0, + [ + ("up", parameters.bond_dimension, UP_OFFSET), + ("phys", parameters.physical_dimension, DOWN_OFFSET), + ], + ), + _make_tensor( + "tensor_leaf_mid", + "Leaf M", + 320.0, + 420.0, + [ + ("left", parameters.bond_dimension, LEFT_OFFSET), + ("right", parameters.bond_dimension, RIGHT_OFFSET), + ("phys", parameters.physical_dimension, DOWN_OFFSET), + ], + ), + _make_tensor( + "tensor_leaf_right", + "Leaf R", + 640.0, + 420.0, + [ + ("up", parameters.bond_dimension, UP_OFFSET), + ("phys", parameters.physical_dimension, DOWN_OFFSET), + ], + ), + ] + edges = [ + _make_edge("edge_top_left", tensors[0], "left", tensors[1], "up"), + _make_edge("edge_top_right", tensors[0], "right", tensors[2], "up"), + _make_edge("edge_left_leaf", tensors[1], "left", tensors[3], "up"), + _make_edge("edge_center_leaf", tensors[1], "down", tensors[4], "left"), + _make_edge("edge_right_center", tensors[2], "down", tensors[4], "right"), + _make_edge("edge_right_leaf", tensors[2], "right", tensors[5], "up"), + ] + return NetworkSpec( + id="template_mera_3", + name="MERA", + tensors=tensors, + edges=edges, + ) + + +def _build_generic_mera_template(parameters: TemplateParameters) -> NetworkSpec: + """Build a generic MERA layout with the requested depth.""" + depth = _resolve_graph_size(parameters) + levels: list[list[TensorSpec]] = [] + for level_index in range(depth): + level_tensors: list[TensorSpec] = [] + for position_index in range(level_index + 1): + tensor_indices: list[TemplateIndexConfig] = [] + if position_index > 0: + tensor_indices.append( + ("up_left", parameters.bond_dimension, LEFT_OFFSET) + ) + if position_index < level_index: + tensor_indices.append( + ("up_right", parameters.bond_dimension, UP_OFFSET) + ) + if level_index < depth - 1: + tensor_indices.append( + ("down_left", parameters.bond_dimension, LOWER_LEFT_OFFSET) + ) + tensor_indices.append( + ("down_right", parameters.bond_dimension, LOWER_RIGHT_OFFSET) + ) + if level_index == depth - 1: + tensor_indices.append( + ("phys", parameters.physical_dimension, DOWN_OFFSET) + ) + tensor = _make_tensor( + f"tensor_l{level_index + 1}_{position_index + 1}", + f"L{level_index + 1}-{position_index + 1}", + position_index * HORIZONTAL_SPACING + + ((depth - level_index - 1) * HORIZONTAL_SPACING) / 2, + level_index * LAYER_SPACING, + tensor_indices, + ) + level_tensors.append(tensor) + levels.append(level_tensors) + edges = [] + for level_index in range(depth - 1): + for position_index, tensor in enumerate(levels[level_index]): + left_child = levels[level_index + 1][position_index] + right_child = levels[level_index + 1][position_index + 1] + edges.append( + _make_edge( + f"edge_l{level_index + 1}_{position_index + 1}_left", + tensor, + "down_left", + left_child, + "up_right", + ) + ) + edges.append( + _make_edge( + f"edge_l{level_index + 1}_{position_index + 1}_right", + tensor, + "down_right", + right_child, + "up_left", + ) + ) + return NetworkSpec( + id=f"template_mera_{depth}", + name=f"MERA depth {depth}", + tensors=[tensor for level in levels for tensor in level], + edges=edges, + ) + + +def _build_ttn_template(parameters: TemplateParameters) -> NetworkSpec: + """Build the canonical TTN layout.""" + depth = _resolve_ttn_depth(parameters) + spec = _build_generic_ttn_template(parameters, depth=depth) + spec.id = f"template_ttn_{depth}" + spec.name = f"TTN depth {depth}" + spec.metadata = { + "template_name": "ttn", + "depth": depth, + "leaf_physical_legs": parameters.leaf_physical_legs, + "root_open_leg": parameters.root_open_leg, + "isometric": parameters.isometric, + } + for tensor in spec.tensors: + tensor.metadata = { + "role": "isometry" if parameters.isometric else "tensor", + "family": "ttn", + "isometric": parameters.isometric, + } + _annotate_physics_1d_indices(tensor, symmetry="none") + return spec + + +def _build_generic_ttn_template( + parameters: TemplateParameters, + *, + depth: int, +) -> NetworkSpec: + """Build a generic TTN with the requested depth.""" + levels: list[list[TensorSpec]] = [] + for level_index in range(depth): + level_tensors: list[TensorSpec] = [] + node_count = 2**level_index + for position_index in range(node_count): + tensor_indices: list[TemplateIndexConfig] = [] + if level_index > 0: + tensor_indices.append(("up", parameters.bond_dimension, UP_OFFSET)) + if level_index < depth - 1: + tensor_indices.append(("left", parameters.bond_dimension, LEFT_OFFSET)) + tensor_indices.append( + ("right", parameters.bond_dimension, RIGHT_OFFSET) + ) + if level_index == 0 and parameters.root_open_leg: + tensor_indices.append(("out", parameters.bond_dimension, UP_OFFSET)) + if level_index == depth - 1 and parameters.leaf_physical_legs: + tensor_indices.append( + ("phys", parameters.physical_dimension, DOWN_OFFSET) + ) + x_position = ( + ((2 * position_index + 1) * (2 ** (depth - level_index - 1)) - 1) + * TREE_LEAF_SPACING + / 2 + ) + tensor = _make_tensor( + f"tensor_l{level_index + 1}_{position_index + 1}", + f"L{level_index + 1}-{position_index + 1}", + x_position, + level_index * LAYER_SPACING, + tensor_indices, + ) + level_tensors.append(tensor) + levels.append(level_tensors) + edges = [] + for level_index in range(depth - 1): + for position_index, tensor in enumerate(levels[level_index]): + left_child = levels[level_index + 1][position_index * 2] + right_child = levels[level_index + 1][position_index * 2 + 1] + edges.append( + _make_edge( + f"edge_l{level_index + 1}_{position_index + 1}_left", + tensor, + "left", + left_child, + "up", + ) + ) + edges.append( + _make_edge( + f"edge_l{level_index + 1}_{position_index + 1}_right", + tensor, + "right", + right_child, + "up", + ) + ) + return NetworkSpec( + id=f"template_ttn_{depth}", + name=f"TTN depth {depth}", + tensors=[tensor for level in levels for tensor in level], + edges=edges, + ) diff --git a/tests/test_template_catalog_internal.py b/tests/test_template_catalog_internal.py index 1967cb3..e1e31b5 100644 --- a/tests/test_template_catalog_internal.py +++ b/tests/test_template_catalog_internal.py @@ -1,5 +1,7 @@ from __future__ import annotations +import importlib + import pytest from tensor_network_editor.internal.models._model_tensor_data import TensorDataMode @@ -8,6 +10,8 @@ build_template, ) from tensor_network_editor.internal.templates._template_catalog import ( + _reset_template_registry_for_tests, + get_template_builder, get_template_definition, list_template_names, serialize_template_definitions, @@ -108,6 +112,66 @@ def test_template_builders_internal_dispatches_to_specific_builder() -> None: assert len(spec.tensors) == 5 +def test_template_builder_facade_reexports_family_modules() -> None: + try: + linear_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_linear" + ) + grid_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_grid" + ) + tree_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_tree" + ) + except ModuleNotFoundError as exc: + pytest.fail(f"Expected split template-builder modules to exist: {exc}") + + _reset_template_registry_for_tests() + + assert _build_linear_chain_template is linear_module._build_linear_chain_template + assert get_template_builder("mps").__module__ == linear_module.__name__ + assert get_template_builder("peps_2x2").__module__ == grid_module.__name__ + assert get_template_builder("mera").__module__ == tree_module.__name__ + + +def test_template_builder_common_module_exposes_shared_primitives() -> None: + try: + common_module = importlib.import_module( + "tensor_network_editor.internal.templates._template_builders_common" + ) + except ModuleNotFoundError as exc: + pytest.fail(f"Expected shared template-builder primitives module: {exc}") + + left_tensor = common_module._make_tensor( + "tensor_left", + "Left", + 10.0, + 20.0, + [("right", 3, (58.0, 0.0))], + ) + right_tensor = common_module._make_tensor( + "tensor_right", + "Right", + 40.0, + 20.0, + [("left", 3, (-58.0, 0.0))], + ) + edge = common_module._make_edge( + "edge_left_right", + left_tensor, + "right", + right_tensor, + "left", + ) + + assert left_tensor.indices[0].id == "tensor_left_right" + assert right_tensor.indices[0].id == "tensor_right_left" + assert edge.left.tensor_id == "tensor_left" + assert edge.left.index_id == "tensor_left_right" + assert edge.right.tensor_id == "tensor_right" + assert edge.right.index_id == "tensor_right_left" + + def test_linear_chain_template_helper_reuses_catalog_metadata() -> None: spec = _build_linear_chain_template( "mpo", From 37ee8470e6a1444fd082e35497a3b348e5b06c60 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 30 Apr 2026 20:15:00 +0200 Subject: [PATCH 10/23] Optimize large static render geometry --- CHANGELOG.md | 3 + src/tensor_network_editor/rendering.py | 369 +++++++++++++++++-------- tests/test_rendering.py | 72 +++++ 3 files changed, 324 insertions(+), 120 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eeef722..33d259e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ All notable changes to this project will be documented in this file. - Internal built-in template builders are now split by family with shared construction primitives, while the existing template catalog and public template APIs keep the same behavior and registration order. +- Large static renders now reuse connected-component geometry and connected + index direction lookups instead of recomputing the same layout heuristics for + every free index, which substantially reduces hot-path SVG/PNG/TikZ latency. ## [0.5.0] - 2026-04-30 diff --git a/src/tensor_network_editor/rendering.py b/src/tensor_network_editor/rendering.py index 8cae178..0b2378c 100644 --- a/src/tensor_network_editor/rendering.py +++ b/src/tensor_network_editor/rendering.py @@ -3,13 +3,15 @@ from __future__ import annotations import logging +from collections import deque from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass +from functools import lru_cache from html import escape from importlib import import_module from io import BytesIO from math import ceil, cos, hypot, isfinite, pi, sin -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar from xml.sax.saxutils import quoteattr from .internal._logging import log_operation, summarize_spec_counts @@ -145,6 +147,22 @@ class _RenderedEdge: stroke: str +_ComponentKind = Literal["linear", "circular", "grid2d", "generic"] + + +@dataclass(slots=True, frozen=True) +class _ComponentGeometryProfile: + """Reusable geometry derived once for one connected tensor component.""" + + tensors: tuple[TensorSpec, ...] + center: CanvasPosition + kind: _ComponentKind + primary_axis: CanvasPosition + grid_basis: tuple[CanvasPosition, CanvasPosition] | None + primary_projections: tuple[float, ...] = () + secondary_projections: tuple[float, ...] = () + + def _render_context( spec: NetworkSpec, *, @@ -372,6 +390,10 @@ def __init__(self, spec: NetworkSpec, options: SvgRenderOptions) -> None: self._spec = spec self._options = options self._tensor_by_id = {tensor.id: tensor for tensor in spec.tensors} + self._tensor_radius_by_id = { + tensor.id: max(24.0, min(tensor.size.width, tensor.size.height) / 2.0) + for tensor in spec.tensors + } self._index_by_id = { index.id: (tensor, index) for tensor in spec.tensors @@ -386,6 +408,8 @@ def __init__(self, spec: NetworkSpec, options: SvgRenderOptions) -> None: self._free_index_direction_by_id: dict[str, CanvasPosition] = {} self._adjacency_by_tensor_id = self._build_tensor_adjacency() self._component_tensor_ids_by_tensor_id = self._build_component_tensor_ids() + self._connected_index_direction_by_id = self._build_connected_index_directions() + self._component_profile_by_tensor_id = self._build_component_profiles() def render(self) -> str: """Return the complete SVG document.""" @@ -730,7 +754,10 @@ def _index_position(tensor: TensorSpec, index: IndexSpec) -> CanvasPosition: ) def tensor_radius(self, tensor: TensorSpec) -> float: - return max(24.0, min(tensor.size.width, tensor.size.height) / 2.0) + return self._tensor_radius_by_id.get( + tensor.id, + max(24.0, min(tensor.size.width, tensor.size.height) / 2.0), + ) def is_port_tensor(self, tensor: TensorSpec) -> bool: return ( @@ -856,68 +883,36 @@ def _occupied_directions(self, tensor: TensorSpec) -> list[CanvasPosition]: return directions def _connected_index_direction( - self, tensor: TensorSpec, index: IndexSpec + self, _tensor: TensorSpec, index: IndexSpec ) -> CanvasPosition | None: - for edge in self._spec.edges: - if ( - edge.left.index_id == index.id - and edge.right.tensor_id in self._tensor_by_id - ): - return _normalize_direction( - CanvasPosition( - x=self._tensor_by_id[edge.right.tensor_id].position.x - - tensor.position.x, - y=self._tensor_by_id[edge.right.tensor_id].position.y - - tensor.position.y, - ) - ) - if ( - edge.right.index_id == index.id - and edge.left.tensor_id in self._tensor_by_id - ): - return _normalize_direction( - CanvasPosition( - x=self._tensor_by_id[edge.left.tensor_id].position.x - - tensor.position.x, - y=self._tensor_by_id[edge.left.tensor_id].position.y - - tensor.position.y, - ) - ) - for hyperedge in self._spec.hyperedges: - endpoint_ids = {endpoint.index_id for endpoint in hyperedge.endpoints} - if index.id not in endpoint_ids: - continue - hub = self._hyperedge_hub_position(hyperedge) - return _normalize_direction( - CanvasPosition(x=hub.x - tensor.position.x, y=hub.y - tensor.position.y) - ) - return None + return self._connected_index_direction_by_id.get(index.id) def _candidate_directions_for_free_index( self, tensor: TensorSpec, index: IndexSpec, - component_tensors: Sequence[TensorSpec], + _component_tensors: Sequence[TensorSpec], ) -> list[CanvasPosition]: directional_hint = self._directional_index_hint(index) candidates: list[CanvasPosition] = [] - component_kind = self._classify_component_shape(component_tensors) + component_profile = self._component_profile_for_tensor(tensor.id) + component_kind = component_profile.kind if component_kind == "linear": - candidates.extend(self._linear_component_candidates(component_tensors)) + candidates.extend(self._linear_component_candidates(component_profile)) elif component_kind == "circular": candidates.extend( - self._circular_component_candidates(tensor, component_tensors) + self._circular_component_candidates(tensor, component_profile) ) elif component_kind == "grid2d": candidates.extend( - self._grid_component_candidates(tensor, component_tensors) + self._grid_component_candidates(tensor, component_profile) ) if directional_hint is not None: if component_kind == "generic": candidates.insert(0, directional_hint) else: candidates.append(directional_hint) - candidates.extend(self._generic_component_candidates(tensor, component_tensors)) + candidates.extend(self._generic_component_candidates(tensor, component_profile)) candidates.extend(_FREE_INDEX_CARDINAL_DIRECTIONS) candidates.extend(_FREE_INDEX_DIAGONAL_DIRECTIONS) return _deduplicate_directions(candidates) @@ -993,9 +988,9 @@ def _directional_index_hint(self, index: IndexSpec) -> CanvasPosition | None: return _FREE_INDEX_DIRECTION_HINTS.get(normalized_name) def _linear_component_candidates( - self, component_tensors: Sequence[TensorSpec] + self, component_profile: _ComponentGeometryProfile ) -> list[CanvasPosition]: - axis_direction = self._component_primary_axis(component_tensors) + axis_direction = component_profile.primary_axis perpendicular_direction = CanvasPosition( x=-axis_direction.y, y=axis_direction.x ) @@ -1007,15 +1002,12 @@ def _linear_component_candidates( ] def _circular_component_candidates( - self, tensor: TensorSpec, component_tensors: Sequence[TensorSpec] + self, tensor: TensorSpec, component_profile: _ComponentGeometryProfile ) -> list[CanvasPosition]: - center = _average_position( - [component_tensor.position for component_tensor in component_tensors] - ) radial_direction = _normalize_direction( CanvasPosition( - x=tensor.position.x - center.x, - y=tensor.position.y - center.y, + x=tensor.position.x - component_profile.center.x, + y=tensor.position.y - component_profile.center.y, ) ) perpendicular_direction = CanvasPosition( @@ -1029,30 +1021,14 @@ def _circular_component_candidates( ] def _grid_component_candidates( - self, tensor: TensorSpec, component_tensors: Sequence[TensorSpec] + self, tensor: TensorSpec, component_profile: _ComponentGeometryProfile ) -> list[CanvasPosition]: - basis_directions = self._grid_component_basis(component_tensors) + basis_directions = component_profile.grid_basis if basis_directions is None: return [] primary_axis, secondary_axis = basis_directions - primary_projections = sorted( - { - round( - _dot_product(component_tensor.position, primary_axis), - 6, - ) - for component_tensor in component_tensors - } - ) - secondary_projections = sorted( - { - round( - _dot_product(component_tensor.position, secondary_axis), - 6, - ) - for component_tensor in component_tensors - } - ) + primary_projections = component_profile.primary_projections + secondary_projections = component_profile.secondary_projections primary_projection = round(_dot_product(tensor.position, primary_axis), 6) secondary_projection = round(_dot_product(tensor.position, secondary_axis), 6) candidates: list[CanvasPosition] = [] @@ -1087,7 +1063,7 @@ def _grid_component_candidates( return candidates def _generic_component_candidates( - self, tensor: TensorSpec, component_tensors: Sequence[TensorSpec] + self, tensor: TensorSpec, component_profile: _ComponentGeometryProfile ) -> list[CanvasPosition]: neighbor_tensor_ids = self._adjacency_by_tensor_id.get(tensor.id, set()) neighbor_tensors = [ @@ -1108,18 +1084,15 @@ def _generic_component_candidates( ) ) ) - component_center = _average_position( - [component_tensor.position for component_tensor in component_tensors] - ) if ( - abs(component_center.x - tensor.position.x) > 1e-6 - or abs(component_center.y - tensor.position.y) > 1e-6 + abs(component_profile.center.x - tensor.position.x) > 1e-6 + or abs(component_profile.center.y - tensor.position.y) > 1e-6 ): candidates.append( _normalize_direction( CanvasPosition( - x=tensor.position.x - component_center.x, - y=tensor.position.y - component_center.y, + x=tensor.position.x - component_profile.center.x, + y=tensor.position.y - component_profile.center.y, ) ) ) @@ -1128,7 +1101,7 @@ def _generic_component_candidates( [*_FREE_INDEX_CARDINAL_DIRECTIONS, *_FREE_INDEX_DIAGONAL_DIRECTIONS], key=lambda direction: self._generic_direction_sort_key( tensor, - component_tensors, + component_profile.tensors, direction, ), ) @@ -1145,21 +1118,90 @@ def _generic_direction_sort_key( return (penalty, -abs(direction.x) - abs(direction.y)) def _classify_component_shape(self, component_tensors: Sequence[TensorSpec]) -> str: + if not component_tensors: + return "generic" + return self._component_profile_for_tensor(component_tensors[0].id).kind + + def _build_component_profile( + self, + component_tensors: Sequence[TensorSpec], + ) -> _ComponentGeometryProfile: + resolved_component_tensors = tuple(component_tensors) + center = _average_position( + [ + component_tensor.position + for component_tensor in resolved_component_tensors + ] + ) tensor_ids = {tensor.id for tensor in component_tensors} - if len(tensor_ids) >= 2 and self._is_linear_component(tensor_ids): - return "linear" - if len(tensor_ids) >= 3 and self._is_circular_component(tensor_ids): - return "circular" - if len(tensor_ids) >= 4 and self._is_grid_component(component_tensors): - return "grid2d" - return "generic" - - def _is_linear_component(self, tensor_ids: set[str]) -> bool: - component_tensors = [ - self._tensor_by_id[tensor_id] - for tensor_id in tensor_ids - if tensor_id in self._tensor_by_id - ] + primary_axis = self._component_primary_axis(resolved_component_tensors) + kind: _ComponentKind = "generic" + grid_basis: tuple[CanvasPosition, CanvasPosition] | None = None + primary_projections: tuple[float, ...] = () + secondary_projections: tuple[float, ...] = () + + if len(tensor_ids) >= 2 and self._is_linear_component( + tensor_ids, + resolved_component_tensors, + axis_direction=primary_axis, + axis_origin=center, + ): + kind = "linear" + elif len(tensor_ids) >= 3 and self._is_circular_component(tensor_ids): + kind = "circular" + elif len(tensor_ids) >= 4: + grid_basis = self._grid_component_basis(resolved_component_tensors) + if grid_basis is not None and self._is_grid_component( + resolved_component_tensors, + tensor_ids=tensor_ids, + basis_directions=grid_basis, + ): + kind = "grid2d" + primary_axis, secondary_axis = grid_basis + primary_projections = tuple( + sorted( + { + round(_dot_product(tensor.position, primary_axis), 6) + for tensor in resolved_component_tensors + } + ) + ) + secondary_projections = tuple( + sorted( + { + round(_dot_product(tensor.position, secondary_axis), 6) + for tensor in resolved_component_tensors + } + ) + ) + + return _ComponentGeometryProfile( + tensors=resolved_component_tensors, + center=center, + kind=kind, + primary_axis=primary_axis, + grid_basis=grid_basis, + primary_projections=primary_projections, + secondary_projections=secondary_projections, + ) + + def _is_linear_component( + self, + tensor_ids: set[str], + component_tensors: Sequence[TensorSpec] | None = None, + *, + axis_direction: CanvasPosition | None = None, + axis_origin: CanvasPosition | None = None, + ) -> bool: + resolved_component_tensors = tuple( + component_tensors + if component_tensors is not None + else [ + self._tensor_by_id[tensor_id] + for tensor_id in tensor_ids + if tensor_id in self._tensor_by_id + ] + ) degree_by_tensor_id = { tensor_id: len(self._adjacency_by_tensor_id.get(tensor_id, set())) for tensor_id in tensor_ids @@ -1168,33 +1210,44 @@ def _is_linear_component(self, tensor_ids: set[str]) -> bool: degree_one_count = sum( 1 for degree in degree_by_tensor_id.values() if degree == 1 ) - axis_direction = self._component_primary_axis(component_tensors) - axis_origin = _average_position( - [component_tensor.position for component_tensor in component_tensors] + resolved_axis_direction = ( + self._component_primary_axis(resolved_component_tensors) + if axis_direction is None + else axis_direction + ) + resolved_axis_origin = ( + _average_position( + [ + component_tensor.position + for component_tensor in resolved_component_tensors + ] + ) + if axis_origin is None + else axis_origin ) projections = [ _dot_product( CanvasPosition( - x=component_tensor.position.x - axis_origin.x, - y=component_tensor.position.y - axis_origin.y, + x=component_tensor.position.x - resolved_axis_origin.x, + y=component_tensor.position.y - resolved_axis_origin.y, ), - axis_direction, + resolved_axis_direction, ) - for component_tensor in component_tensors + for component_tensor in resolved_component_tensors ] dominant_span = max(projections, default=0.0) - min(projections, default=0.0) minor_span = max( ( abs( _cross_product( - axis_direction, + resolved_axis_direction, CanvasPosition( - x=component_tensor.position.x - axis_origin.x, - y=component_tensor.position.y - axis_origin.y, + x=component_tensor.position.x - resolved_axis_origin.x, + y=component_tensor.position.y - resolved_axis_origin.y, ), ) ) - for component_tensor in component_tensors + for component_tensor in resolved_component_tensors ), default=0.0, ) @@ -1215,11 +1268,26 @@ def _is_circular_component(self, tensor_ids: set[str]) -> bool: degree == 2 for degree in degree_by_tensor_id.values() ) - def _is_grid_component(self, component_tensors: Sequence[TensorSpec]) -> bool: - basis_directions = self._grid_component_basis(component_tensors) - if basis_directions is None: + def _is_grid_component( + self, + component_tensors: Sequence[TensorSpec], + *, + tensor_ids: set[str] | None = None, + basis_directions: tuple[CanvasPosition, CanvasPosition] | None = None, + ) -> bool: + resolved_basis_directions = ( + self._grid_component_basis(component_tensors) + if basis_directions is None + else basis_directions + ) + if resolved_basis_directions is None: return False - primary_axis, secondary_axis = basis_directions + primary_axis, secondary_axis = resolved_basis_directions + resolved_tensor_ids = ( + {tensor.id for tensor in component_tensors} + if tensor_ids is None + else tensor_ids + ) primary_projections = { round(_dot_product(tensor.position, primary_axis), 6) for tensor in component_tensors @@ -1230,11 +1298,10 @@ def _is_grid_component(self, component_tensors: Sequence[TensorSpec]) -> bool: } if len(primary_projections) < 2 or len(secondary_projections) < 2: return False - tensor_ids = {tensor.id for tensor in component_tensors} for edge in self._spec.edges: if ( - edge.left.tensor_id not in tensor_ids - or edge.right.tensor_id not in tensor_ids + edge.left.tensor_id not in resolved_tensor_ids + or edge.right.tensor_id not in resolved_tensor_ids ): continue left_tensor = self._tensor_by_id[edge.left.tensor_id] @@ -1310,6 +1377,41 @@ def _grid_component_basis( return None return basis_directions[0], basis_directions[1] + def _build_connected_index_directions(self) -> dict[str, CanvasPosition]: + connected_index_direction_by_id: dict[str, CanvasPosition] = {} + for edge in self._spec.edges: + left_index = self._index_by_id.get(edge.left.index_id) + right_index = self._index_by_id.get(edge.right.index_id) + if left_index is None or right_index is None: + continue + left_tensor, left_index_spec = left_index + right_tensor, right_index_spec = right_index + left_to_right = _normalize_direction( + CanvasPosition( + x=right_tensor.position.x - left_tensor.position.x, + y=right_tensor.position.y - left_tensor.position.y, + ) + ) + connected_index_direction_by_id[left_index_spec.id] = left_to_right + connected_index_direction_by_id[right_index_spec.id] = CanvasPosition( + x=-left_to_right.x, + y=-left_to_right.y, + ) + for hyperedge in self._spec.hyperedges: + hub = self._hyperedge_hub_position(hyperedge) + for endpoint in hyperedge.endpoints: + tensor_index = self._index_by_id.get(endpoint.index_id) + if tensor_index is None: + continue + tensor, index = tensor_index + connected_index_direction_by_id[index.id] = _normalize_direction( + CanvasPosition( + x=hub.x - tensor.position.x, + y=hub.y - tensor.position.y, + ) + ) + return connected_index_direction_by_id + def _build_tensor_adjacency(self) -> dict[str, set[str]]: adjacency_by_tensor_id: dict[str, set[str]] = { tensor.id: set() for tensor in self._spec.tensors @@ -1342,11 +1444,11 @@ def _build_component_tensor_ids(self) -> dict[str, list[str]]: for tensor in self._spec.tensors: if tensor.id in visited_tensor_ids: continue - queue = [tensor.id] + queue: deque[str] = deque([tensor.id]) component_tensor_ids: list[str] = [] visited_tensor_ids.add(tensor.id) while queue: - current_tensor_id = queue.pop(0) + current_tensor_id = queue.popleft() component_tensor_ids.append(current_tensor_id) for neighbor_tensor_id in self._adjacency_by_tensor_id.get( current_tensor_id, set() @@ -1361,15 +1463,41 @@ def _build_component_tensor_ids(self) -> dict[str, list[str]]: ) return component_tensor_ids_by_tensor_id - def _component_tensors_for_tensor(self, tensor_id: str) -> list[TensorSpec]: - component_tensor_ids = self._component_tensor_ids_by_tensor_id.get( - tensor_id, [tensor_id] - ) - return [ + def _build_component_profiles(self) -> dict[str, _ComponentGeometryProfile]: + component_profile_by_tensor_id: dict[str, _ComponentGeometryProfile] = {} + for tensor in self._spec.tensors: + if tensor.id in component_profile_by_tensor_id: + continue + component_tensors = self._component_tensors_from_ids( + self._component_tensor_ids_by_tensor_id.get(tensor.id, [tensor.id]) + ) + component_profile = self._build_component_profile(component_tensors) + for component_tensor in component_tensors: + component_profile_by_tensor_id[component_tensor.id] = component_profile + return component_profile_by_tensor_id + + def _component_tensors_from_ids( + self, + component_tensor_ids: Sequence[str], + ) -> tuple[TensorSpec, ...]: + return tuple( self._tensor_by_id[component_tensor_id] for component_tensor_id in component_tensor_ids if component_tensor_id in self._tensor_by_id - ] + ) + + def _component_profile_for_tensor( + self, + tensor_id: str, + ) -> _ComponentGeometryProfile: + component_profile = self._component_profile_by_tensor_id.get(tensor_id) + if component_profile is not None: + return component_profile + tensor = self._tensor_by_id[tensor_id] + return self._build_component_profile((tensor,)) + + def _component_tensors_for_tensor(self, tensor_id: str) -> tuple[TensorSpec, ...]: + return self._component_profile_for_tensor(tensor_id).tensors class _TikzRenderer: @@ -2694,6 +2822,7 @@ def _horizontal_alignment(svg_anchor: str) -> str: return "center" +@lru_cache(maxsize=1) def _load_matplotlib_modules() -> tuple[Any, Any, Any, Any]: """Import Matplotlib lazily for academic SVG/PNG/PDF exports.""" try: diff --git a/tests/test_rendering.py b/tests/test_rendering.py index 1db1bd7..ab8c16b 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -826,6 +826,46 @@ def counting_edge_render_infos(self: Any) -> list[Any]: assert edge_render_info_call_count == 1 +def test_render_spec_svg_reuses_component_axis_geometry_within_one_render( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import tensor_network_editor.rendering as rendering_module + + pytest.importorskip("matplotlib") + spec = build_template_spec( + "mps", + TemplateParameters( + graph_size=12, + bond_dimension=3, + physical_dimension=2, + boundary_condition="open", + initial_state="zeros", + ), + ) + component_primary_axis_call_count = 0 + original_component_primary_axis = ( + rendering_module._SvgRenderer._component_primary_axis + ) + + def counting_component_primary_axis( + self: Any, + component_tensors: list[TensorSpec], + ) -> CanvasPosition: + nonlocal component_primary_axis_call_count + component_primary_axis_call_count += 1 + return original_component_primary_axis(self, component_tensors) + + monkeypatch.setattr( + rendering_module._SvgRenderer, + "_component_primary_axis", + counting_component_primary_axis, + ) + + render_spec_svg(spec) + + assert component_primary_axis_call_count == 1 + + def test_render_spec_svg_keeps_labels_as_svg_text_elements() -> None: pytest.importorskip("matplotlib") @@ -1014,6 +1054,38 @@ def reject_matplotlib_modules() -> tuple[object, object, object, object]: rendering_module.render_spec_pdf(build_sample_spec()) +def test_load_matplotlib_modules_memoizes_imports( + monkeypatch: pytest.MonkeyPatch, +) -> None: + import tensor_network_editor.rendering as rendering_module + + if hasattr(rendering_module._load_matplotlib_modules, "cache_clear"): + rendering_module._load_matplotlib_modules.cache_clear() + import_call_counts: dict[str, int] = {} + original_import_module = rendering_module.import_module + + def counting_import_module(name: str) -> Any: + import_call_counts[name] = import_call_counts.get(name, 0) + 1 + return original_import_module(name) + + monkeypatch.setattr( + rendering_module, + "import_module", + counting_import_module, + ) + + first_modules = rendering_module._load_matplotlib_modules() + second_modules = rendering_module._load_matplotlib_modules() + + assert second_modules == first_modules + assert import_call_counts == { + "matplotlib": 1, + "matplotlib.pyplot": 1, + "matplotlib.patches": 1, + "matplotlib.path": 1, + } + + def test_validate_positive_render_scale_normalizes_and_rejects_invalid_values() -> None: import tensor_network_editor.rendering as rendering_module From 03d576566d6531f108844dd6cbb9af7c107116b2 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Fri, 1 May 2026 16:36:13 +0200 Subject: [PATCH 11/23] Add pywebview editor mode and stabilize asset tests --- .gitignore | 1 + CHANGELOG.md | 16 ++ README.md | 27 ++- docs/api.md | 17 +- docs/cli.md | 21 +- docs/extended_guide.md | 15 +- docs/getting-started.md | 13 +- docs/installation.md | 13 +- docs/user-guide.md | 4 +- scripts/clean.bat | 1 + scripts/clean.sh | 1 + src/tensor_network_editor/__init__.py | 4 +- src/tensor_network_editor/app/session.py | 65 ++++- src/tensor_network_editor/cli.py | 12 +- src/tensor_network_editor/editor.py | 67 +++++- .../internal/cli/_cli_handlers.py | 21 +- .../internal/cli/_cli_parser.py | 12 +- tests/app_support.py | 55 ++++- tests/test_api.py | 26 +- tests/test_app_assets.py | 11 +- tests/test_app_support.py | 100 ++++++++ tests/test_cli.py | 39 +++ tests/test_scripts.py | 6 + tests/test_session.py | 226 ++++++++++++++++++ 24 files changed, 724 insertions(+), 49 deletions(-) create mode 100644 tests/test_app_support.py diff --git a/.gitignore b/.gitignore index de9b256..ebc3123 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ __pycache__/ node_modules/ playwright-report/ test-results/ +session.log* *.log .DS_Store Thumbs.db diff --git a/CHANGELOG.md b/CHANGELOG.md index 33d259e..9754f28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,22 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Added + +- The editor now supports an explicit UI launch mode across the CLI and Python + API: browser by default, `pywebview` with the optional `desktop` extra, or a + server-only mode that prints the local URL without opening a window. + ### Changed +- `pywebview` editor launches now open their native window maximized by + default, so the desktop mode starts with the same roomy workspace users + usually expect from the browser flow. +- Test cleanup scripts now remove `session.log*` artifacts, and the repository + ignores those rotating session logs explicitly. +- Shared HTTP test helpers now give bundled editor assets more time to load, + which reduces intermittent timeout failures when the local test server is + under load. - Removed a few unused internal helpers from logging, periodic-mode utilities, rendering, and einsum code generation, and deduplicated built-in template defaults so the catalog now keeps each template's default parameters in one @@ -22,6 +36,8 @@ All notable changes to this project will be documented in this file. - Large static renders now reuse connected-component geometry and connected index direction lookups instead of recomputing the same layout heuristics for every free index, which substantially reduces hot-path SVG/PNG/TikZ latency. +- CLI `edit` now exposes `--ui {browser,pywebview,server}` while keeping + `--no-browser` as a compatibility alias for the server-only mode. ## [0.5.0] - 2026-04-30 diff --git a/README.md b/README.md index 2ebb101..537f3d0 100644 --- a/README.md +++ b/README.md @@ -89,10 +89,10 @@ offline use, and generated code you can inspect. ## Why This Project -- Draw tensor-network diagrams in a local browser session. +- Draw tensor-network diagrams in a local browser or `pywebview` desktop session. - Save and reload backend-independent JSON designs. -- Recover the previous local browser session from a project draft if the tab is - closed before you save. +- Recover the previous local editor session from a project draft if the window + or tab is closed before you save. - Generate code for `tensornetwork`, `quimb`, `tensorkrowch`, `einsum_numpy`, and `einsum_torch`. - Render designs to static SVG, TikZ/LaTeX, Graphviz/DOT, or Mermaid from Python, the @@ -135,9 +135,9 @@ offline use, and generated code you can inspect. - Get structural analysis with FLOP and MAC cost summaries. - Use the package from the CLI or directly from Python. -The editor opens in your browser, but the server runs locally on your own -machine. No Node runtime or cloud service is needed for normal use. A future -desktop wrapper such as `pywebview` may sit on top of this local flow, but the +The editor server runs locally on your own machine. By default it opens in your +browser, and you can also ask for a native `pywebview` window with the optional +`desktop` extra. No Node runtime or cloud service is needed for normal use. The browser-served editor remains the core interface and compatibility target. ## Minimal Installation @@ -176,7 +176,20 @@ tensor-network-editor edit ``` This command starts a local server and waits until you press `Done` or -`Cancel` in the browser session. +`Cancel` in the editor session. + +Open the same local editor in a native `pywebview` window: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + +Start only the local server and open the printed URL yourself: + +```bash +tensor-network-editor edit --ui server +``` Pick a color theme when you launch the editor: diff --git a/docs/api.md b/docs/api.md index 4d30ca1..3d9fc99 100644 --- a/docs/api.md +++ b/docs/api.md @@ -26,6 +26,7 @@ The package exposes the main functions and models at the top level: from tensor_network_editor import ( EngineName, EditorThemeName, + EditorUiMode, NetworkBuilder, NetworkSpec, PythonLoadOptions, @@ -69,7 +70,7 @@ Useful public modules: | Module | Use it for | | --- | --- | -| `tensor_network_editor.editor` | `EditorLaunchOptions`, `EditorThemeName`, and `open_editor(...)` | +| `tensor_network_editor.editor` | `EditorLaunchOptions`, `EditorThemeName`, `EditorUiMode`, and `open_editor(...)` | | `tensor_network_editor.builder` | fluent `NetworkBuilder`, `TensorHandle`, and `IndexHandle` helpers | | `tensor_network_editor.io` | JSON/Python loading, saving, `serialize_spec(...)`, and `SCHEMA_VERSION` | | `tensor_network_editor.models` | data classes, result models, enums, and periodic-mode types | @@ -88,8 +89,7 @@ user-facing API. ## Launch the Editor -Use `open_editor(...)` when you want a local browser editing session from -Python. +Use `open_editor(...)` when you want a local editing session from Python. ```python from tensor_network_editor import EngineName @@ -101,7 +101,7 @@ def main() -> None: options=EditorLaunchOptions( default_engine=EngineName.EINSUM_NUMPY, theme="light", - open_browser=True, + ui_mode="browser", ), ) @@ -122,7 +122,8 @@ Main parameters: - `options.default_collection_format`: initial tensor collection layout - `options.theme`: initial color theme, one of `dark`, `light`, `contrast`, `colorblind`, or `shiny` -- `options.open_browser`: open the browser automatically +- `options.ui_mode`: choose `browser`, `pywebview`, or `server` +- `options.open_browser`: legacy browser/server compatibility flag - `options.host`: local host address, default `127.0.0.1` - `options.port`: local port, default `0` so the OS chooses one - `options.print_code`: print generated code after confirmation @@ -136,6 +137,12 @@ Main parameters: leave it unset to use the project-local default under `.tensor-network-editor/drafts/` +Install the optional desktop extra before using `options.ui_mode="pywebview"`: + +```bash +python -m pip install "tensor-network-editor[desktop]" +``` + Return value: - `None` when the user cancels diff --git a/docs/cli.md b/docs/cli.md index eae10be..d562108 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -26,7 +26,7 @@ reusable-subnetwork catalogs. ## Launch the Editor -Start the local browser editor: +Start the local editor: ```bash tensor-network-editor edit @@ -42,6 +42,8 @@ Useful options: tensor-network-editor edit --load my_network.json tensor-network-editor edit --engine quimb tensor-network-editor edit --theme light +tensor-network-editor edit --ui pywebview +tensor-network-editor edit --ui server tensor-network-editor edit --save-code generated_network.py tensor-network-editor edit --print-code tensor-network-editor edit --no-browser @@ -53,8 +55,19 @@ You can combine them: tensor-network-editor edit --load my_network.json --engine quimb --save-code generated_network.py ``` -Use `--no-browser` when you want to start the local server but open the printed -URL manually. +By default, `edit` opens the local URL in your browser. + +Use `--ui pywebview` when you want the same local editor inside a native +desktop window and you have installed the optional desktop extra: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + +Use `--ui server` when you want to start the local server but open the printed +URL manually. `--no-browser` remains as a compatibility alias for the same +server-only mode. Use `--theme` to choose the editor colors at startup. Available themes are `dark`, `light`, `contrast`, `colorblind`, and `shiny`; `dark` is the default. @@ -185,7 +198,7 @@ from tensor_network_editor.editor import EditorLaunchOptions, open_editor open_editor( options=EditorLaunchOptions( - open_browser=False, + ui_mode="server", log_file_path="tne-editor.log", log_file_max_bytes=10_485_760, log_file_backup_count=5, diff --git a/docs/extended_guide.md b/docs/extended_guide.md index 8d0f6d2..a8e7964 100644 --- a/docs/extended_guide.md +++ b/docs/extended_guide.md @@ -53,9 +53,9 @@ the design and target another backend later. The editor itself runs locally. The package starts a Python HTTP server on your machine, opens a browser tab by default, and waits until you press `Done` or -`Cancel`. Normal use does not require Node.js or a cloud service. A future -desktop wrapper such as `pywebview` can sit on top of the same local server, -but the browser-served editor remains the primary supported surface. +`Cancel`. You can also ask for a native `pywebview` window with the optional +desktop extra. Normal use does not require Node.js or a cloud service, and the +browser-served editor remains the primary supported surface. ## Choosing The Right Tool @@ -178,6 +178,13 @@ Use `--no-browser` when automatic browser opening is blocked, when you work over SSH, or when you want to copy the printed local URL into a browser manually. +Open the same local editor in a native desktop window: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + From Python: ```python @@ -190,7 +197,7 @@ def main() -> None: options=EditorLaunchOptions( default_engine=EngineName.EINSUM_NUMPY, theme="light", - open_browser=True, + ui_mode="browser", ) ) if result is None: diff --git a/docs/getting-started.md b/docs/getting-started.md index c5f2f60..ed96613 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -52,13 +52,22 @@ What happens: - the command waits until you press `Done` or `Cancel` - `Done` returns the final design and generated code for the selected engine +If you prefer a native desktop window instead of the browser, install the +optional desktop extra and run: + +```bash +python -m pip install "tensor-network-editor[desktop]" +tensor-network-editor edit --ui pywebview +``` + If your environment cannot open a browser automatically, use: ```bash -tensor-network-editor edit --no-browser +tensor-network-editor edit --ui server ``` -Then open the local URL printed in the terminal. +Then open the local URL printed in the terminal. `--no-browser` still works as +a compatibility alias for the same server-only mode. You can also choose the editor colors when the session starts: diff --git a/docs/installation.md b/docs/installation.md index 68d4913..dc40b28 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -80,8 +80,10 @@ python -m pip install "tensor-network-editor[desktop]" ``` The `desktop` extra installs `pywebview` for environments that want a desktop -webview dependency available. The standard documented workflow is still the -local browser editor. +webview dependency available. After installing it, you can launch the editor in +its own native window with `tensor-network-editor edit --ui pywebview` or +`EditorLaunchOptions(ui_mode="pywebview")`. The standard default workflow is +still the local browser editor. You can combine extras: @@ -227,6 +229,13 @@ tensor-network-editor edit If the browser does not open, see [troubleshooting.md#the-browser-did-not-open-automatically](troubleshooting.md#the-browser-did-not-open-automatically). +If you prefer a native desktop window instead of the browser, install the +desktop extra and launch: + +```bash +tensor-network-editor edit --ui pywebview +``` + ## Cleanup Scripts The repository includes cleanup scripts for generated local artifacts: diff --git a/docs/user-guide.md b/docs/user-guide.md index 003616e..35c8748 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -158,11 +158,11 @@ Python: ```python from tensor_network_editor.editor import EditorLaunchOptions, open_editor -open_editor(options=EditorLaunchOptions(theme="contrast")) +open_editor(options=EditorLaunchOptions(theme="contrast", ui_mode="browser")) ``` Available themes are `dark`, `light`, `contrast`, `colorblind`, and `shiny`. -The choice only affects the browser editor appearance; saved network JSON and +The choice only affects the editor appearance; saved network JSON and recoverable drafts keep the same model data. ## Templates diff --git a/scripts/clean.bat b/scripts/clean.bat index 31fc285..7b46220 100644 --- a/scripts/clean.bat +++ b/scripts/clean.bat @@ -22,6 +22,7 @@ call :remove_glob_dirs_warn ".\pytest-cache-files-*" call :remove_glob_files ".\.coverage" call :remove_glob_files ".\.coverage.*" call :remove_glob_files ".\coverage.xml" +call :remove_glob_files ".\session.log*" call :remove_dir "__pycache__" call :remove_named_dirs ".\src" "__pycache__" diff --git a/scripts/clean.sh b/scripts/clean.sh index 7d2fffa..0ce26b6 100644 --- a/scripts/clean.sh +++ b/scripts/clean.sh @@ -98,6 +98,7 @@ remove_glob_dirs_warn "./pytest-cache-files-*" remove_file_pattern "./.coverage" remove_file_pattern "./.coverage.*" remove_file_pattern "./coverage.xml" +remove_file_pattern "./session.log*" remove_dir "__pycache__" remove_named_dirs "./src" "__pycache__" diff --git a/src/tensor_network_editor/__init__.py b/src/tensor_network_editor/__init__.py index a06978e..6ec5d09 100644 --- a/src/tensor_network_editor/__init__.py +++ b/src/tensor_network_editor/__init__.py @@ -17,7 +17,7 @@ from .analysis import analyze_contraction, analyze_spec from .builder import IndexHandle, NetworkBuilder, TensorHandle from .canonicalization import canonicalize_spec - from .editor import EditorLaunchOptions, EditorThemeName, open_editor + from .editor import EditorLaunchOptions, EditorThemeName, EditorUiMode, open_editor from .internal.diffing._diffing import diff_specs, semantic_diff_specs from .io import PythonLoadOptions, load_python_spec, load_spec, save_spec from .linting import lint_spec @@ -76,6 +76,7 @@ "EdgeSpec", "EditorLaunchOptions", "EditorThemeName", + "EditorUiMode", "EditorResult", "EngineName", "DotRenderOptions", @@ -130,6 +131,7 @@ "EdgeSpec": ".models", "EditorLaunchOptions": ".editor", "EditorThemeName": ".editor", + "EditorUiMode": ".editor", "EditorResult": ".models", "EngineName": ".models", "DotRenderOptions": ".rendering", diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index 2ffad11..c3828b4 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -7,6 +7,7 @@ import threading import webbrowser from collections.abc import Callable, Mapping, Sequence +from importlib import import_module from pathlib import Path from types import FrameType from typing import Any, Literal @@ -47,6 +48,7 @@ LOGGER = logging.getLogger(__name__) SignalHandler = Callable[[int, FrameType | None], Any] +SessionUiMode = Literal["browser", "pywebview", "server"] def _print_editor_url(base_url: str) -> None: @@ -64,6 +66,55 @@ def _print_browser_open_fallback_message(base_url: str) -> None: _print_editor_url(base_url) +def _import_pywebview() -> Any: + """Import the optional pywebview module on demand.""" + try: + return import_module("webview") + except ModuleNotFoundError as exc: + raise RuntimeError( + "pywebview mode requires the optional desktop extra. Install it with " + 'python -m pip install "tensor-network-editor[desktop]".' + ) from exc + + +def _run_pywebview_session( + session: EditorSession, base_url: str +) -> EditorResult | None: + """Open the local editor in a pywebview window and wait for the result.""" + if threading.current_thread() is not threading.main_thread(): + raise RuntimeError("pywebview mode must be launched from the main thread.") + + try: + pywebview = _import_pywebview() + except ModuleNotFoundError as exc: + raise RuntimeError( + "pywebview mode requires the optional desktop extra. Install it with " + 'python -m pip install "tensor-network-editor[desktop]".' + ) from exc + pywebview_window = pywebview.create_window( + "Tensor Network Editor", + base_url, + maximized=True, + ) + + def _handle_window_closed(*_args: object) -> None: + """Cancel the editor session when the native window is closed.""" + if not session.is_finished(): + session.cancel() + + def _wait_for_session_and_close_window(window: Any) -> None: + """Close the native window after the editor session finishes.""" + wait_for_editor_result(session) + try: + window.destroy() + except Exception: + return None + + pywebview_window.events.closed += _handle_window_closed + pywebview.start(_wait_for_session_and_close_window, pywebview_window) + return wait_for_editor_result(session) + + class EditorSession: """Mutable session state shared between the HTTP server and the caller.""" @@ -394,6 +445,7 @@ def launch_editor_session( default_engine: EngineIdentifier = EngineName.TENSORKROWCH, default_collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST, theme: EditorThemeName = DEFAULT_EDITOR_THEME, + ui_mode: SessionUiMode | None = None, open_browser: bool = True, host: str = "127.0.0.1", port: int = 0, @@ -416,6 +468,7 @@ def launch_editor_session( default_collection_format: Initial tensor collection layout for generated code. theme: Visual theme selected for this editor session. + ui_mode: Explicit UI launch mode for the editor session. open_browser: Whether to ask the system browser to open the local URL. host: Local host interface to bind. port: Local port to bind. Use ``0`` for an ephemeral port. @@ -441,6 +494,7 @@ def launch_editor_session( Raises: KeyboardInterrupt: If the session is interrupted from the main thread. """ + from ..editor import resolve_editor_ui_mode from .server import EditorServer active_logging_runtime = get_active_logging_runtime() @@ -472,6 +526,10 @@ def launch_editor_session( draft_path=draft_path, ) server = EditorServer(session=session, host=host, port=port) + effective_ui_mode = resolve_editor_ui_mode( + ui_mode=ui_mode, + open_browser=open_browser, + ) previous_sigint_handler: SignalHandler | int | None = None server_started = False @@ -485,6 +543,7 @@ def launch_editor_session( "session": session.session_id, "engine": engine_name_to_text(default_engine), "mode": theme, + "ui_mode": effective_ui_mode, }, ): if threading.current_thread() is threading.main_thread(): @@ -501,9 +560,11 @@ def _handle_sigint(_signum: int, _frame: FrameType | None) -> None: server_started = True if _on_server_ready is not None: _on_server_ready(server.base_url) - should_print_editor_url = not open_browser + if effective_ui_mode == "pywebview": + return _run_pywebview_session(session, server.base_url) + should_print_editor_url = effective_ui_mode == "server" should_print_browser_fallback_message = False - if open_browser: + if effective_ui_mode == "browser": try: with log_operation( LOGGER, diff --git a/src/tensor_network_editor/cli.py b/src/tensor_network_editor/cli.py index fabc17b..bb9855b 100644 --- a/src/tensor_network_editor/cli.py +++ b/src/tensor_network_editor/cli.py @@ -120,9 +120,14 @@ def main(argv: Sequence[str] | None = None) -> int: """Run the CLI and return a process-friendly exit code.""" args_list = list(argv) if argv is not None else sys.argv[1:] try: - parsed_args = cast( - _CommandNamespace, build_command_parser().parse_args(args_list) - ) + parser = build_command_parser() + parsed_args = cast(_CommandNamespace, parser.parse_args(args_list)) + if ( + getattr(parsed_args, "command", None) == "edit" + and getattr(parsed_args, "ui", None) is not None + and getattr(parsed_args, "no_browser", False) + ): + parser.error("cannot combine --ui with --no-browser; use only one.") command_context = _build_cli_command_context(parsed_args) with configure_package_logging( parsed_args.log_level, @@ -188,6 +193,7 @@ def _build_cli_debug_context(args: argparse.Namespace) -> dict[str, object]: debug_context: dict[str, object] = {} for field_name in ( "engine", + "ui", "format", "dtype", "python_import_mode", diff --git a/src/tensor_network_editor/editor.py b/src/tensor_network_editor/editor.py index f8db596..381c716 100644 --- a/src/tensor_network_editor/editor.py +++ b/src/tensor_network_editor/editor.py @@ -29,6 +29,12 @@ from .types import StrPath LOGGER = logging.getLogger(__name__) +EditorUiMode = Literal["browser", "pywebview", "server"] +_SUPPORTED_EDITOR_UI_MODES: tuple[EditorUiMode, ...] = ( + "browser", + "pywebview", + "server", +) @dataclass(slots=True, frozen=True) @@ -38,6 +44,7 @@ class EditorLaunchOptions: default_engine: EngineIdentifier = EngineName.TENSORKROWCH default_collection_format: TensorCollectionFormat = TensorCollectionFormat.LIST theme: EditorThemeName = DEFAULT_EDITOR_THEME + ui_mode: EditorUiMode | None = None open_browser: bool = True host: str = "127.0.0.1" port: int = 0 @@ -55,6 +62,12 @@ class EditorLaunchOptions: def __post_init__(self) -> None: """Normalize and validate theme names passed at runtime.""" object.__setattr__(self, "theme", normalize_editor_theme(self.theme)) + validated_ui_mode = _normalize_editor_ui_mode(self.ui_mode) + object.__setattr__(self, "ui_mode", validated_ui_mode) + _validate_editor_ui_mode_compatibility( + ui_mode=validated_ui_mode, + open_browser=self.open_browser, + ) validate_positive_log_setting( self.log_file_max_bytes, name="log_file_max_bytes", @@ -76,10 +89,15 @@ def open_editor( resolved_theme = ( normalize_editor_theme(theme) if theme is not None else resolved_options.theme ) + effective_ui_mode = resolve_editor_ui_mode( + ui_mode=resolved_options.ui_mode, + open_browser=resolved_options.open_browser, + ) active_logging_runtime = get_active_logging_runtime() context: dict[str, object] = { "engine": resolved_options.default_engine, "mode": resolved_theme, + "ui_mode": effective_ui_mode, } if spec is not None: spec_context = summarize_spec_counts(spec) @@ -111,6 +129,7 @@ def open_editor( default_engine=resolved_options.default_engine, default_collection_format=resolved_options.default_collection_format, theme=resolved_theme, + ui_mode=effective_ui_mode, open_browser=resolved_options.open_browser, host=resolved_options.host, port=resolved_options.port, @@ -149,6 +168,46 @@ def _null_logging_scope() -> _NullLoggingScope: return _NullLoggingScope() +def _normalize_editor_ui_mode(ui_mode: EditorUiMode | None) -> EditorUiMode | None: + """Validate one optional editor UI mode.""" + if ui_mode is None: + return None + if ui_mode not in _SUPPORTED_EDITOR_UI_MODES: + supported_modes = ", ".join(_SUPPORTED_EDITOR_UI_MODES) + raise ValueError( + f"Unsupported editor ui_mode {ui_mode!r}. Expected one of: {supported_modes}." + ) + return ui_mode + + +def _validate_editor_ui_mode_compatibility( + *, + ui_mode: EditorUiMode | None, + open_browser: bool, +) -> None: + """Reject combinations that are incompatible with the legacy browser flag.""" + if ui_mode == "browser" and not open_browser: + raise ValueError("ui_mode='browser' requires open_browser=True.") + if ui_mode == "server" and open_browser: + raise ValueError("ui_mode='server' requires open_browser=False.") + + +def resolve_editor_ui_mode( + *, + ui_mode: EditorUiMode | None, + open_browser: bool, +) -> EditorUiMode: + """Resolve the effective UI mode for one editor launch.""" + normalized_ui_mode = _normalize_editor_ui_mode(ui_mode) + _validate_editor_ui_mode_compatibility( + ui_mode=normalized_ui_mode, + open_browser=open_browser, + ) + if normalized_ui_mode is not None: + return normalized_ui_mode + return "browser" if open_browser else "server" + + def _should_open_editor_logging_scope( log_file_path: StrPath | None, active_logging_runtime: object, @@ -164,4 +223,10 @@ def _should_open_editor_logging_scope( return Path(log_file_path).resolve() != Path(runtime_log_file_path).resolve() -__all__ = ["EditorLaunchOptions", "EditorThemeName", "open_editor"] +__all__ = [ + "EditorLaunchOptions", + "EditorThemeName", + "EditorUiMode", + "open_editor", + "resolve_editor_ui_mode", +] diff --git a/src/tensor_network_editor/internal/cli/_cli_handlers.py b/src/tensor_network_editor/internal/cli/_cli_handlers.py index 7e97635..44ddd51 100644 --- a/src/tensor_network_editor/internal/cli/_cli_handlers.py +++ b/src/tensor_network_editor/internal/cli/_cli_handlers.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Protocol, cast -from ...editor import EditorLaunchOptions +from ...editor import EditorLaunchOptions, EditorUiMode from ...errors import SerializationError from ...io import PythonLoadOptions from ...models import EngineName, NetworkSpec, TensorCollectionFormat, ValidationIssue @@ -46,6 +46,11 @@ from ._cli_doctor import build_doctor_report, format_doctor_report_text LOGGER = logging.getLogger(__name__) +_EDITOR_UI_MODE_TO_OPEN_BROWSER: dict[str, bool] = { + "browser": True, + "pywebview": False, + "server": False, +} def handle_edit_command( @@ -55,6 +60,14 @@ def handle_edit_command( open_editor: Callable[..., object], ) -> int: """Launch the browser editor using explicit edit arguments.""" + ui_mode = cast(EditorUiMode | None, getattr(args, "ui", None)) + if ui_mode is not None and args.no_browser: + raise ValueError("cannot combine --ui with --no-browser; use only one.") + effective_open_browser = ( + not args.no_browser + if ui_mode is None + else _EDITOR_UI_MODE_TO_OPEN_BROWSER[ui_mode] + ) loaded_spec_path = Path(args.load).resolve() if args.load else None load_kwargs = _python_load_kwargs(args) initial_spec = load_spec(args.load, **load_kwargs) if args.load else None @@ -68,7 +81,8 @@ def handle_edit_command( "options": EditorLaunchOptions( default_engine=EngineName(args.engine), theme=args.theme, - open_browser=not args.no_browser, + ui_mode=ui_mode, + open_browser=effective_open_browser, print_code=args.print_code, code_path=code_path, log_file_path=args.log_file, @@ -80,7 +94,8 @@ def handle_edit_command( open_editor_kwargs["options"] = EditorLaunchOptions( default_engine=EngineName(args.engine), theme=args.theme, - open_browser=not args.no_browser, + ui_mode=ui_mode, + open_browser=effective_open_browser, print_code=args.print_code, code_path=code_path, log_file_path=args.log_file, diff --git a/src/tensor_network_editor/internal/cli/_cli_parser.py b/src/tensor_network_editor/internal/cli/_cli_parser.py index 8c70141..18af265 100644 --- a/src/tensor_network_editor/internal/cli/_cli_parser.py +++ b/src/tensor_network_editor/internal/cli/_cli_parser.py @@ -90,7 +90,8 @@ def build_command_parser(handlers: CliHandlerBindings) -> argparse.ArgumentParse subparsers = parser.add_subparsers(dest="command", required=True) edit_parser = subparsers.add_parser( - "edit", help="Launch the local editor in the browser." + "edit", + help="Launch the local editor in a browser, desktop window, or server-only mode.", ) _add_edit_arguments(edit_parser) edit_parser.set_defaults(handler=handlers.handle_edit) @@ -393,7 +394,12 @@ def _add_edit_arguments(parser: argparse.ArgumentParser) -> None: "--theme", choices=list(SUPPORTED_EDITOR_THEMES), default=DEFAULT_EDITOR_THEME, - help="Visual theme used by the browser editor.", + help="Visual theme used by the editor UI.", + ) + parser.add_argument( + "--ui", + choices=["browser", "pywebview", "server"], + help="Choose whether to open the editor in the browser, a pywebview window, or server-only mode.", ) parser.add_argument( "--load", @@ -413,7 +419,7 @@ def _add_edit_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--no-browser", action="store_true", - help="Start the local server without opening the browser automatically.", + help="Legacy alias for --ui server: start the local server without opening a UI automatically.", ) diff --git a/tests/app_support.py b/tests/app_support.py index 4e47739..a43b480 100644 --- a/tests/app_support.py +++ b/tests/app_support.py @@ -1,10 +1,15 @@ from __future__ import annotations import json +import time from typing import Any, cast from urllib.error import HTTPError from urllib.request import Request, urlopen +_ASSET_REQUEST_TIMEOUT_SECONDS = 15.0 +_ASSET_REQUEST_RETRY_COUNT = 3 +_ASSET_REQUEST_RETRY_DELAY_SECONDS = 0.1 + def request_json( url: str, @@ -48,13 +53,51 @@ def request_json_with_status( return exc.code, json.loads(exc.read().decode("utf-8")) +def _read_asset_response(url: str) -> tuple[bytes, dict[str, str]]: + """Read one asset request with retries for transient local-server hiccups.""" + last_error: OSError | None = None + for attempt_index in range(_ASSET_REQUEST_RETRY_COUNT): + try: + with urlopen(url, timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response: + body = response.read() + headers = {key: value for key, value in response.headers.items()} + return body, headers + except OSError as exc: + last_error = exc + if attempt_index + 1 >= _ASSET_REQUEST_RETRY_COUNT: + raise + time.sleep(_ASSET_REQUEST_RETRY_DELAY_SECONDS) + if last_error is not None: + raise last_error + raise RuntimeError("Asset request retry loop ended unexpectedly.") + + def request_text(url: str) -> str: - with urlopen(url, timeout=5) as response: - return cast(str, response.read().decode("utf-8")) + body, _headers = _read_asset_response(url) + return cast(str, body.decode("utf-8")) def request_with_headers(url: str) -> tuple[str, dict[str, str]]: - with urlopen(url, timeout=5) as response: - body = response.read().decode("utf-8") - headers = {key: value for key, value in response.headers.items()} - return body, headers + body, headers = _read_asset_response(url) + return body.decode("utf-8"), headers + + +def request_headers(url: str) -> dict[str, str]: + last_error: OSError | None = None + for attempt_index in range(_ASSET_REQUEST_RETRY_COUNT): + try: + with urlopen(url, timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response: + return {key: value for key, value in response.headers.items()} + except OSError as exc: + last_error = exc + if attempt_index + 1 >= _ASSET_REQUEST_RETRY_COUNT: + raise + time.sleep(_ASSET_REQUEST_RETRY_DELAY_SECONDS) + if last_error is not None: + raise last_error + raise RuntimeError("Asset header request retry loop ended unexpectedly.") + + +def request_bytes(url: str) -> bytes: + body, _headers = _read_asset_response(url) + return body diff --git a/tests/test_api.py b/tests/test_api.py index 2c8d150..2532c03 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,7 +10,7 @@ import tensor_network_editor from tensor_network_editor import generate_code as _generate_code -from tensor_network_editor.editor import EditorLaunchOptions, open_editor +from tensor_network_editor.editor import EditorLaunchOptions, EditorUiMode, open_editor from tensor_network_editor.errors import ( CodeGenerationError, PackageIOError, @@ -152,6 +152,7 @@ def test_package_root_exports_supported_public_api() -> None: "EdgeSpec", "EditorLaunchOptions", "EditorThemeName", + "EditorUiMode", "EditorResult", "EngineName", "DotRenderOptions", @@ -233,6 +234,7 @@ def test_editor_launch_options_defaults_match_public_contract() -> None: assert options.default_engine is EngineName.TENSORKROWCH assert options.default_collection_format is TensorCollectionFormat.LIST assert options.theme == "dark" + assert options.ui_mode is None assert options.open_browser is True assert options.host == "127.0.0.1" assert options.port == 0 @@ -248,6 +250,26 @@ def test_editor_launch_options_rejects_unknown_theme() -> None: EditorLaunchOptions(theme="sepia") # type: ignore[arg-type] +def test_editor_ui_mode_type_alias_matches_public_contract() -> None: + assert EditorUiMode == Literal["browser", "pywebview", "server"] + + +@pytest.mark.parametrize( + ("ui_mode", "open_browser", "expected_message"), + [ + ("browser", False, "ui_mode='browser' requires open_browser=True"), + ("server", True, "ui_mode='server' requires open_browser=False"), + ], +) +def test_editor_launch_options_rejects_conflicting_browser_flags( + ui_mode: EditorUiMode, + open_browser: bool, + expected_message: str, +) -> None: + with pytest.raises(ValueError, match=expected_message): + EditorLaunchOptions(ui_mode=ui_mode, open_browser=open_browser) + + def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> None: launch_result = object() @@ -261,6 +283,7 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N default_engine=EngineName.EINSUM_NUMPY, default_collection_format=TensorCollectionFormat.DICT, theme="colorblind", + ui_mode="pywebview", open_browser=False, host="0.0.0.0", port=8123, @@ -281,6 +304,7 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N default_engine=EngineName.EINSUM_NUMPY, default_collection_format=TensorCollectionFormat.DICT, theme="colorblind", + ui_mode="pywebview", open_browser=False, host="0.0.0.0", port=8123, diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py index cafc22b..efcdca7 100644 --- a/tests/test_app_assets.py +++ b/tests/test_app_assets.py @@ -8,7 +8,12 @@ import pytest from tensor_network_editor.app.server import EditorServer -from tests.app_support import request_text, request_with_headers +from tests.app_support import ( + request_bytes, + request_headers, + request_text, + request_with_headers, +) def request_runtime_bundle(editor_server: EditorServer, *relative_paths: str) -> str: @@ -548,8 +553,8 @@ def test_vendor_asset_is_served_locally(editor_server: EditorServer) -> None: def test_favicon_asset_is_served_locally(editor_server: EditorServer) -> None: + body = request_bytes(f"{editor_server.base_url}/favicon.ico") with urlopen(f"{editor_server.base_url}/favicon.ico", timeout=5) as response: - body = response.read() headers = dict(response.info().items()) assert body @@ -3527,7 +3532,7 @@ def test_static_assets_disable_browser_cache( editor_server: EditorServer, path: str, ) -> None: - _, headers = request_with_headers(f"{editor_server.base_url}{path}") + headers = request_headers(f"{editor_server.base_url}{path}") assert "no-store" in headers["Cache-Control"] assert headers["Pragma"] == "no-cache" diff --git a/tests/test_app_support.py b/tests/test_app_support.py new file mode 100644 index 0000000..fd8881b --- /dev/null +++ b/tests/test_app_support.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from unittest.mock import patch + +from tests import app_support + + +class _FakeResponse: + def __init__(self, body: str) -> None: + self._body = body.encode("utf-8") + self.status = 200 + self.headers = {"Cache-Control": "no-store"} + + def __enter__(self) -> _FakeResponse: + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + del exc_type, exc, traceback + return None + + def read(self) -> bytes: + return self._body + + +def test_request_text_uses_shared_asset_timeout() -> None: + recorded_timeout: list[float] = [] + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + recorded_timeout.append(timeout) + assert url == "http://example.test/" + return _FakeResponse("body") + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + body = app_support.request_text("http://example.test/") + + assert body == "body" + assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS] + + +def test_request_with_headers_uses_shared_asset_timeout() -> None: + recorded_timeout: list[float] = [] + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + recorded_timeout.append(timeout) + assert url == "http://example.test/app.css" + return _FakeResponse("css") + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + body, headers = app_support.request_with_headers("http://example.test/app.css") + + assert body == "css" + assert headers == {"Cache-Control": "no-store"} + assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS] + + +def test_request_headers_uses_shared_asset_timeout_without_reading_body() -> None: + recorded_timeout: list[float] = [] + response = _FakeResponse("body") + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + recorded_timeout.append(timeout) + assert url == "http://example.test/vendor.js" + return response + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + headers = app_support.request_headers("http://example.test/vendor.js") + + assert headers == {"Cache-Control": "no-store"} + assert recorded_timeout == [app_support._ASSET_REQUEST_TIMEOUT_SECONDS] + + +def test_read_asset_response_retries_transient_os_errors() -> None: + attempts = 0 + + def fake_urlopen(url: str, timeout: float) -> _FakeResponse: + nonlocal attempts + attempts += 1 + assert url == "http://example.test/retry.js" + assert timeout == app_support._ASSET_REQUEST_TIMEOUT_SECONDS + if attempts < 3: + raise TimeoutError("temporary timeout") + return _FakeResponse("ok") + + with patch("tests.app_support.urlopen", side_effect=fake_urlopen): + body, headers = app_support._read_asset_response("http://example.test/retry.js") + + assert body == b"ok" + assert headers == {"Cache-Control": "no-store"} + assert attempts == 3 + + +def test_request_bytes_uses_shared_asset_fetcher() -> None: + with patch( + "tests.app_support._read_asset_response", + return_value=(b"icon", {"Content-Type": "image/x-icon"}), + ) as read_asset_response_mock: + body = app_support.request_bytes("http://example.test/favicon.ico") + + assert body == b"icon" + read_asset_response_mock.assert_called_once_with("http://example.test/favicon.ico") diff --git a/tests/test_cli.py b/tests/test_cli.py index 6b6276f..f90aa48 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -253,6 +253,45 @@ def test_edit_subcommand_passes_explicit_log_file_path() -> None: ) +def test_edit_subcommand_accepts_explicit_browser_ui_mode() -> None: + with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: + exit_code = main(["edit", "--ui", "pywebview"]) + + assert exit_code == 0 + open_editor_mock.assert_called_once_with( + spec=None, + options=EditorLaunchOptions( + ui_mode="pywebview", + open_browser=False, + ), + ) + + +def test_edit_subcommand_ui_server_matches_no_browser_alias() -> None: + with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: + exit_code = main(["edit", "--ui", "server"]) + + assert exit_code == 0 + open_editor_mock.assert_called_once_with( + spec=None, + options=EditorLaunchOptions( + ui_mode="server", + open_browser=False, + ), + ) + + +def test_edit_subcommand_rejects_ui_and_no_browser_combination( + capsys: pytest.CaptureFixture[str], +) -> None: + with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: + exit_code = main(["edit", "--ui", "browser", "--no-browser"]) + + assert exit_code == 2 + open_editor_mock.assert_not_called() + assert "cannot combine --ui with --no-browser" in capsys.readouterr().err + + def test_edit_subcommand_passes_explicit_log_rotation_settings() -> None: with patch("tensor_network_editor.cli.open_editor") as open_editor_mock: exit_code = main( diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 703dc40..3848faa 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -68,6 +68,9 @@ def seed_generated_artifacts(root: Path) -> None: root / ".coverage", root / ".coverage.unit", root / "coverage.xml", + root / "session.log", + root / "session.log.1", + root / "session.log.7", ] for file_path in files_to_create: file_path.write_text("temporary", encoding="utf-8") @@ -91,6 +94,9 @@ def assert_cleanup_removed_artifacts(root: Path) -> None: root / ".coverage", root / ".coverage.unit", root / "coverage.xml", + root / "session.log", + root / "session.log.1", + root / "session.log.7", ] for path in removed_paths: assert not path.exists() diff --git a/tests/test_session.py b/tests/test_session.py index 0a5ed5c..4ab6be9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -628,6 +628,232 @@ class FakeThread: assert "http://127.0.0.1:43210" in captured +def test_launch_editor_session_pywebview_requires_main_thread( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + class FakeThread: + name = "worker" + + monkeypatch.setattr( + session_module.threading, "current_thread", lambda: FakeThread() + ) + + with pytest.raises(RuntimeError, match="pywebview mode must be launched"): + session_module.launch_editor_session(ui_mode="pywebview") + + +def test_launch_editor_session_pywebview_missing_dependency_raises_clear_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr( + session_module, + "_import_pywebview", + lambda: (_ for _ in ()).throw(ModuleNotFoundError("No module named 'webview'")), + ) + + with pytest.raises(RuntimeError, match="tensor-network-editor\\[desktop\\]"): + session_module.launch_editor_session(ui_mode="pywebview") + + +def test_launch_editor_session_pywebview_closes_window_after_completion( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.closed = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + self.destroy_calls = 0 + + def destroy(self) -> None: + self.destroy_calls += 1 + + class FakePywebview: + def __init__(self) -> None: + self.created_urls: list[str] = [] + self.created_maximized: list[bool] = [] + self.window = FakeWindow() + self.start_calls = 0 + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + ) -> FakeWindow: + assert title == "Tensor Network Editor" + self.created_urls.append(url) + self.created_maximized.append(maximized) + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + self.start_calls += 1 + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + main_thread = FakeMainThread() + pywebview = FakePywebview() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: pywebview) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert pywebview.created_urls == ["http://127.0.0.1:43210"] + assert pywebview.created_maximized == [True] + assert pywebview.start_calls == 1 + assert pywebview.window.destroy_calls == 1 + + +def test_launch_editor_session_pywebview_window_close_cancels_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.closed = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + + def destroy(self) -> None: + return None + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + ) -> FakeWindow: + del title, url, maximized + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + del callback, window + self.window.events.closed.fire() + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr( + session_module, + "_import_pywebview", + lambda: FakePywebview(), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is None + + def test_open_editor_passes_template_catalog_path( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From ace5cccb5ef8c5d6d7161a25df416e7d31ff791a Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Fri, 1 May 2026 17:35:48 +0200 Subject: [PATCH 12/23] Add native pywebview save dialogs for exports --- CHANGELOG.md | 4 + src/tensor_network_editor/app/session.py | 81 ++++++++++ .../static/js/session/sessionEditorFlows.js | 61 +++++++- .../static/js/session/sessionTemplateFlows.js | 32 +++- .../static/js/session/sessionUiAdapters.js | 49 +++++- tests/test_frontend_architecture.py | 99 +++++++++++- tests/test_frontend_runtime.py | 147 +++++++++++++++++- tests/test_session.py | 124 ++++++++++++++- 8 files changed, 575 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9754f28..efccb28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ All notable changes to this project will be documented in this file. - `pywebview` editor launches now open their native window maximized by default, so the desktop mode starts with the same roomy workspace users usually expect from the browser flow. +- `pywebview` exports now open a native `Save As` dialog and write the selected + file from Python, so desktop-mode JSON, Python, and academic exports no + longer disappear into the embedded browser backend's implicit download + folder. - Test cleanup scripts now remove `session.log*` artifacts, and the repository ignores those rotating session logs explicitly. - Shared HTTP test helpers now give bundled editor assets more time to load, diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index c3828b4..f6ae6ec 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -6,6 +6,7 @@ import signal import threading import webbrowser +from base64 import b64decode from collections.abc import Callable, Mapping, Sequence from importlib import import_module from pathlib import Path @@ -77,6 +78,83 @@ def _import_pywebview() -> Any: ) from exc +class _PywebviewExportApi: + """Expose native save-file helpers to the embedded pywebview frontend.""" + + def __init__(self, pywebview_module: Any) -> None: + """Store the imported pywebview module for later dialog calls.""" + self._pywebview_module = pywebview_module + self._window: Any | None = None + + def bind_window(self, window: Any) -> None: + """Attach the created pywebview window once it exists.""" + self._window = window + + def save_text_file( + self, + filename: str, + text: str, + content_type: str = "text/plain;charset=utf-8", + ) -> bool: + """Prompt for a path and write one UTF-8 text file.""" + del content_type + output_path = self._select_output_path(filename) + if output_path is None: + return False + output_path.write_text(text, encoding="utf-8") + return True + + def save_binary_file( + self, + filename: str, + base64_payload: str, + content_type: str = "application/octet-stream", + ) -> bool: + """Prompt for a path and write one decoded binary export file.""" + del content_type + output_path = self._select_output_path(filename) + if output_path is None: + return False + output_path.write_bytes(b64decode(base64_payload)) + return True + + def _select_output_path(self, filename: str) -> Path | None: + """Ask pywebview for a target save path and normalize the response.""" + if self._window is None: + raise RuntimeError("pywebview export API is not bound to a window.") + dialog_result = self._window.create_file_dialog( + self._pywebview_module.SAVE_DIALOG, + save_filename=filename, + file_types=self._build_file_types(filename), + ) + if dialog_result is None: + return None + if isinstance(dialog_result, str): + return Path(dialog_result) + if isinstance(dialog_result, Sequence) and dialog_result: + first_entry = dialog_result[0] + if isinstance(first_entry, str) and first_entry: + return Path(first_entry) + return None + + def _build_file_types(self, filename: str) -> tuple[str, ...]: + """Build a compact pywebview filter tuple from one filename.""" + suffix = Path(filename).suffix.lower() + if not suffix: + return () + label = { + ".dot": "DOT", + ".json": "JSON", + ".mmd": "Mermaid", + ".pdf": "PDF", + ".png": "PNG", + ".py": "Python", + ".svg": "SVG", + ".tex": "LaTeX", + }.get(suffix, suffix.removeprefix(".").upper()) + return (f"{label} (*{suffix})",) + + def _run_pywebview_session( session: EditorSession, base_url: str ) -> EditorResult | None: @@ -91,11 +169,14 @@ def _run_pywebview_session( "pywebview mode requires the optional desktop extra. Install it with " 'python -m pip install "tensor-network-editor[desktop]".' ) from exc + pywebview_export_api = _PywebviewExportApi(pywebview) pywebview_window = pywebview.create_window( "Tensor Network Editor", base_url, maximized=True, + js_api=pywebview_export_api, ) + pywebview_export_api.bind_window(pywebview_window) def _handle_window_closed(*_args: object) -> None: """Cancel the editor session when the native window is closed.""" diff --git a/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js b/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js index 1b790f7..9a3f75a 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js +++ b/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js @@ -57,6 +57,20 @@ export function createSessionEditorFlows({ actions.setStatus(safeMessage, "error"); } + function wasFileSaved(saveResult) { + return saveResult !== false; + } + + async function downloadTextFile(filename, text, contentType) { + return wasFileSaved( + await sessionUi.downloadText(filename, text, contentType) + ); + } + + async function downloadBlobFile(filename, blobLike) { + return wasFileSaved(await sessionUi.downloadBlob(filename, blobLike)); + } + async function requestGeneratedCode() { return sessionService.generateCode({ engine: selectors.getSelectedEngine(), @@ -295,11 +309,11 @@ export function createSessionEditorFlows({ } } - function saveDesign() { + async function saveDesign() { const saveDesignOperation = startFlowOperation("Save design", { operation: "design.save", }); - sessionUi.downloadText( + const didSave = await downloadTextFile( `${actions.sanitizeFilename(state.spec.name || "tensor-network")}.json`, JSON.stringify( actions.serializeCurrentSpec({ persistViewSnapshots: true }), @@ -308,6 +322,13 @@ export function createSessionEditorFlows({ ), "application/json;charset=utf-8" ); + if (!didSave) { + actions.setStatus("Design save cancelled."); + saveDesignOperation?.finish({ + outcome: "cancelled", + }); + return; + } void clearSavedDraft({ silent: true, resumeAutosave: true }); actions.setStatus("Design downloaded as JSON."); saveDesignOperation?.finish({ @@ -436,11 +457,19 @@ export function createSessionEditorFlows({ } store.setGeneratedCode(actions.stripImportLines(payload.code)); syncGeneratedCodePreview(state.generatedCode); - sessionUi.downloadText( + const didSave = await downloadTextFile( `${actions.sanitizeFilename(state.spec.name || "tensor-network")}-${actions.sanitizeFilename(selectors.getSelectedEngine() || "engine")}.py`, payload.code, "text/x-python;charset=utf-8" ); + if (!didSave) { + actions.setStatus("Python export cancelled."); + pythonExportOperation?.finish({ + outcome: "cancelled", + engine: payload.engine, + }); + return; + } actions.setStatus(`Exported ${payload.engine} Python code.`, "success"); pythonExportOperation?.finish({ outcome: "downloaded", @@ -479,7 +508,11 @@ export function createSessionEditorFlows({ sourceContentType: svgPayload.content_type || "image/svg+xml;charset=utf-8", }); - sessionUi.downloadBlob(filename, pngBlob); + const didSave = await downloadBlobFile(filename, pngBlob); + if (!didSave) { + actions.setStatus("PNG export cancelled."); + return false; + } actions.setStatus("Exported a PNG file.", "success"); return true; } @@ -561,18 +594,34 @@ export function createSessionEditorFlows({ return; } if (exportDetails.responseKind === "binary") { - sessionUi.downloadBlob( + const didSave = await downloadBlobFile( filename, new Blob([decodeBase64ToUint8Array(payload.base64 || "")], { type: payload.content_type || exportDetails.contentType, }) ); + if (!didSave) { + actions.setStatus(`${exportDetails.label} export cancelled.`); + exportOperation?.finish({ + outcome: "cancelled", + format, + }); + return; + } } else { - sessionUi.downloadText( + const didSave = await downloadTextFile( filename, payload.text || "", payload.content_type || exportDetails.contentType ); + if (!didSave) { + actions.setStatus(`${exportDetails.label} export cancelled.`); + exportOperation?.finish({ + outcome: "cancelled", + format, + }); + return; + } } actions.setStatus(`Exported a ${exportDetails.label} file.`, "success"); exportOperation?.finish({ diff --git a/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js b/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js index 09b7582..44081c8 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js +++ b/src/tensor_network_editor/app/static/js/session/sessionTemplateFlows.js @@ -42,6 +42,10 @@ export function createSessionTemplateFlows({ : null; } + function wasFileSaved(saveResult) { + return saveResult !== false; + } + function getGroupById(groupId) { return actions.findGroupById(groupId); } @@ -281,11 +285,17 @@ export function createSessionTemplateFlows({ if (payload.spec && payload.spec.network) { payload.spec.network.name = resolvedDisplayName; } - sessionUi.downloadText( - `${actions.sanitizeFilename(resolvedDisplayName || "subnetwork")}.json`, - JSON.stringify(payload.spec, null, 2), - "application/json;charset=utf-8" + const didSave = wasFileSaved( + await sessionUi.downloadText( + `${actions.sanitizeFilename(resolvedDisplayName || "subnetwork")}.json`, + JSON.stringify(payload.spec, null, 2), + "application/json;charset=utf-8" + ) ); + if (!didSave) { + actions.setStatus("Subnetwork export cancelled."); + return; + } actions.setStatus(`Saved ${resolvedDisplayName} as JSON.`, "success"); } catch (error) { actions.setStatus(`Could not export the subnetwork: ${error.message}`, "error"); @@ -411,11 +421,17 @@ export function createSessionTemplateFlows({ serializedSpec, actions.sanitizeFilename ); - sessionUi.downloadText( - `${actions.sanitizeFilename(displayName || "template")}.json`, - JSON.stringify(payload, null, 2), - "application/json;charset=utf-8" + const didSave = wasFileSaved( + await sessionUi.downloadText( + `${actions.sanitizeFilename(displayName || "template")}.json`, + JSON.stringify(payload, null, 2), + "application/json;charset=utf-8" + ) ); + if (!didSave) { + actions.setStatus("Template export cancelled."); + return; + } actions.setStatus(`Exported ${displayName} as a reusable template.`, "success"); } diff --git a/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js b/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js index 952d799..1deab6b 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js +++ b/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js @@ -45,10 +45,46 @@ export function createSessionUiAdapters({ } await clipboard.writeText(text); }; + const resolvedPywebviewApi = + windowRef && + windowRef.pywebview && + windowRef.pywebview.api && + typeof windowRef.pywebview.api.save_text_file === "function" && + typeof windowRef.pywebview.api.save_binary_file === "function" + ? windowRef.pywebview.api + : null; + const encodeBytesToBase64 = (bytes) => { + if ( + typeof Buffer !== "undefined" && + typeof Buffer.from === "function" + ) { + return Buffer.from(bytes).toString("base64"); + } + if (typeof globalThis.btoa !== "function") { + throw new Error("Binary downloads are not available in this browser."); + } + let binaryText = ""; + for (const byte of bytes) { + binaryText += String.fromCharCode(byte); + } + return globalThis.btoa(binaryText); + }; const resolvedDownloadBlob = typeof downloadBlob === "function" ? downloadBlob - : (filename, blobLike) => { + : async (filename, blobLike) => { + if (resolvedPywebviewApi) { + if (!blobLike || typeof blobLike.arrayBuffer !== "function") { + throw new Error("Binary downloads are not available in this browser."); + } + const buffer = await blobLike.arrayBuffer(); + const bytes = new Uint8Array(buffer); + return resolvedPywebviewApi.save_binary_file( + filename, + encodeBytesToBase64(bytes), + blobLike.type || "application/octet-stream" + ); + } if (!documentRef || !urlRef || typeof urlRef.createObjectURL !== "function") { throw new Error("File downloads are not available in this browser."); } @@ -57,15 +93,22 @@ export function createSessionUiAdapters({ anchor.download = filename; anchor.click(); urlRef.revokeObjectURL(anchor.href); + return true; }; const resolvedDownloadText = typeof downloadText === "function" ? downloadText - : (filename, text, contentType = "text/plain;charset=utf-8") => { + : async (filename, text, contentType = "text/plain;charset=utf-8") => { + if (resolvedPywebviewApi) { + return resolvedPywebviewApi.save_text_file(filename, text, contentType); + } if (typeof blobCtor !== "function") { throw new Error("Text downloads are not available in this browser."); } - resolvedDownloadBlob(filename, new blobCtor([text], { type: contentType })); + return resolvedDownloadBlob( + filename, + new blobCtor([text], { type: contentType }) + ); }; const resolvedRasterizeSvgToPng = typeof rasterizeSvgToPng === "function" diff --git a/tests/test_frontend_architecture.py b/tests/test_frontend_architecture.py index 9d77af9..c3e1b2f 100644 --- a/tests/test_frontend_architecture.py +++ b/tests/test_frontend_architecture.py @@ -4758,8 +4758,12 @@ def test_editor_shell_helper_modules_expose_explicit_ui_and_invalidation_adapter throw new Error("Session UI confirm adapter should forward the injected result."); }} await sessionUi.copyText("result = 1"); - sessionUi.downloadText("demo.json", "{{}}", "application/json"); - sessionUi.downloadBlob("demo.py", {{ type: "text/x-python" }}); + await Promise.resolve( + sessionUi.downloadText("demo.json", "{{}}", "application/json") + ); + await Promise.resolve( + sessionUi.downloadBlob("demo.py", {{ type: "text/x-python" }}) + ); sessionUi.closeWindow(); if (!uiEvents.some((event) => event.kind === "copy" && event.text === "result = 1")) {{ throw new Error(`Expected injected copy adapter to run, received ${{JSON.stringify(uiEvents)}}.`); @@ -4929,6 +4933,97 @@ def test_editor_shell_helper_modules_expose_explicit_ui_and_invalidation_adapter ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_use_pywebview_save_api_when_available( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const calls = []; + class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + + async arrayBuffer() {{ + const firstPart = this.parts[0]; + if (!(firstPart instanceof Uint8Array)) {{ + throw new Error("Expected the test blob to receive Uint8Array content."); + }} + return firstPart.buffer.slice( + firstPart.byteOffset, + firstPart.byteOffset + firstPart.byteLength + ); + }} + }} + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef: {{ + pywebview: {{ + api: {{ + async save_text_file(filename, text, contentType) {{ + calls.push({{ type: "text", filename, text, contentType }}); + return true; + }}, + async save_binary_file(filename, base64Payload, contentType) {{ + calls.push({{ type: "binary", filename, base64Payload, contentType }}); + return true; + }}, + }}, + }}, + }}, + blobCtor: FakeBlob, + }}); + + const textSaved = await sessionUi.downloadText( + "demo.json", + "{{\\"ok\\":true}}", + "application/json;charset=utf-8" + ); + const binarySaved = await sessionUi.downloadBlob( + "demo.pdf", + new FakeBlob([Uint8Array.from([0, 1, 2, 255])], {{ type: "application/pdf" }}) + ); + + if (textSaved !== true || binarySaved !== true) {{ + throw new Error(`Expected pywebview saves to resolve true, received ${{JSON.stringify({{ textSaved, binarySaved }})}}.`); + }} + const textCall = calls.find((entry) => entry.type === "text"); + const binaryCall = calls.find((entry) => entry.type === "binary"); + if (!textCall || textCall.filename !== "demo.json") {{ + throw new Error(`Expected text export to use the pywebview API, received ${{JSON.stringify(calls)}}.`); + }} + if ( + !binaryCall || + binaryCall.filename !== "demo.pdf" || + binaryCall.base64Payload !== "AAEC/w==" + ) {{ + throw new Error(`Expected binary export to send base64 bytes through the pywebview API, received ${{JSON.stringify(calls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The pywebview session-ui adapter runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_benchmark_helper_modules_build_comparison_rows_and_history_state( tmp_path: Path, diff --git a/tests/test_frontend_runtime.py b/tests/test_frontend_runtime.py index 8cdab50..77378c5 100644 --- a/tests/test_frontend_runtime.py +++ b/tests/test_frontend_runtime.py @@ -14236,8 +14236,7 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path: throw new Error(`Expected draft-save flow logging, received ${JSON.stringify(flowLog)}.`); } - flows.saveDesign(); - await Promise.resolve(); + await flows.saveDesign(); if (!calls.some((entry) => entry.type === "clearDraft")) { throw new Error(`Expected explicit JSON save to clear the draft, received ${JSON.stringify(calls)}.`); } @@ -14546,6 +14545,130 @@ def _write_session_editor_png_fallback_runtime_script(tmp_path: Path) -> Path: return script_path +def _write_session_editor_save_cancelled_runtime_script(tmp_path: Path) -> Path: + script_path = tmp_path / "session_editor_save_cancelled.mjs" + _copy_js_modules(tmp_path, _SESSION_EDITOR_FLOWS_DEPENDENCY_MODULES) + + script_path.write_text( + textwrap.dedent( + """ + const baseUrl = new URL("./", import.meta.url); + const { createSessionEditorFlows } = await import( + new URL("./session/sessionEditorFlows.js", baseUrl).href + ); + + const calls = []; + const flowLog = []; + const state = { + spec: { name: "draft demo" }, + generatedCode: "", + editorFinished: false, + draftAutosaveReady: true, + draftAutosaveTimer: null, + draftAutosaveDirty: false, + draftAutosaveSaving: false, + }; + + const flows = createSessionEditorFlows({ + dom: { + exportFormatSelect: { value: "json" }, + generatedCode: { value: "" }, + loadInput: { value: "" }, + }, + state, + logger: { + startOperation(name, context = {}) { + flowLog.push({ type: "start", name, context }); + return { + finish(nextContext = {}) { + flowLog.push({ type: "finish", name, context: nextContext }); + }, + fail(error, nextContext = {}) { + flowLog.push({ + type: "fail", + name, + message: error.message, + context: nextContext, + }); + }, + }; + }, + }, + store: { + setGeneratedCode() {}, + setEditorFinished() {}, + }, + selectors: { + getSelectedEngine: () => "quimb", + getSelectedCollectionFormat: () => "dict", + }, + services: { + session: { + async clearDraft() { + calls.push({ type: "clearDraft" }); + return { ok: true }; + }, + }, + }, + commands: { + syncGeneratedCodePreview() {}, + }, + sessionUi: { + async downloadText(filename, text, contentType) { + calls.push({ type: "downloadText", filename, text, contentType }); + return false; + }, + closeWindow() {}, + schedule() { + return 0; + }, + }, + actions: { + serializeCurrentSpec({ persistViewSnapshots }) { + return { + schema_version: 2, + persistViewSnapshots, + network: { id: "network_draft", name: "draft demo" }, + }; + }, + sanitizeFilename: (value) => value.replace(/\\s+/g, "_"), + setStatus(message, level = "info") { + calls.push({ type: "status", message, level }); + }, + }, + }); + + await flows.saveDesign(); + + if (calls.some((entry) => entry.type === "clearDraft")) { + throw new Error(`Cancelling the save dialog should not clear the draft, received ${JSON.stringify(calls)}.`); + } + const cancelStatus = calls.find( + (entry) => + entry.type === "status" && + entry.level === "info" && + entry.message === "Design save cancelled." + ); + if (!cancelStatus) { + throw new Error(`Expected a friendly cancellation status, received ${JSON.stringify(calls)}.`); + } + if ( + !flowLog.some( + (entry) => + entry.type === "finish" && + entry.name === "Save design" && + entry.context.outcome === "cancelled" + ) + ) { + throw new Error(`Expected cancelled save-design flow logging, received ${JSON.stringify(flowLog)}.`); + } + """ + ), + encoding="utf-8", + ) + return script_path + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_session_editor_flows_fall_back_to_svg_when_png_render_fails( tmp_path: Path, @@ -14566,6 +14689,26 @@ def test_session_editor_flows_fall_back_to_svg_when_png_render_fails( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_editor_flows_report_save_cancelled_without_clearing_draft( + tmp_path: Path, +) -> None: + script_path = _write_session_editor_save_cancelled_runtime_script(tmp_path) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The session-editor save-cancelled runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + def _write_tensor_initializer_parsing_runtime_script(tmp_path: Path) -> Path: script_path = tmp_path / "tensor_initializer_parsing.mjs" _copy_js_modules( diff --git a/tests/test_session.py b/tests/test_session.py index 4ab6be9..fdaec0f 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,6 +3,7 @@ import logging import signal import threading +from base64 import b64encode from collections.abc import Iterator from importlib import import_module from pathlib import Path @@ -15,6 +16,7 @@ from tensor_network_editor.app._protocol import JsonDict from tensor_network_editor.app.session import ( EditorSession, + _PywebviewExportApi, build_blank_network_spec, wait_for_editor_result, ) @@ -719,6 +721,7 @@ class FakePywebview: def __init__(self) -> None: self.created_urls: list[str] = [] self.created_maximized: list[bool] = [] + self.created_js_apis: list[object] = [] self.window = FakeWindow() self.start_calls = 0 @@ -728,10 +731,12 @@ def create_window( url: str, *, maximized: bool = False, + js_api: object | None = None, ) -> FakeWindow: assert title == "Tensor Network Editor" self.created_urls.append(url) self.created_maximized.append(maximized) + self.created_js_apis.append(js_api) return self.window def start(self, callback: object, window: FakeWindow) -> None: @@ -772,6 +777,8 @@ class FakeMainThread: assert result is completed_result assert pywebview.created_urls == ["http://127.0.0.1:43210"] assert pywebview.created_maximized == [True] + assert len(pywebview.created_js_apis) == 1 + assert isinstance(pywebview.created_js_apis[0], _PywebviewExportApi) assert pywebview.start_calls == 1 assert pywebview.window.destroy_calls == 1 @@ -814,8 +821,9 @@ def create_window( url: str, *, maximized: bool = False, + js_api: object | None = None, ) -> FakeWindow: - del title, url, maximized + del title, url, maximized, js_api return self.window def start(self, callback: object, window: FakeWindow) -> None: @@ -854,6 +862,120 @@ class FakeMainThread: assert result is None +def test_pywebview_export_api_writes_text_file_to_selected_path( + tmp_path: Path, +) -> None: + output_path = tmp_path / "demo.json" + + class FakePywebview: + SAVE_DIALOG = object() + + class FakeWindow: + def __init__(self) -> None: + self.dialog_calls: list[dict[str, object]] = [] + + def create_file_dialog( + self, + dialog_type: object, + *, + save_filename: str, + file_types: tuple[str, ...], + ) -> tuple[str]: + self.dialog_calls.append( + { + "dialog_type": dialog_type, + "save_filename": save_filename, + "file_types": file_types, + } + ) + return (str(output_path),) + + api = _PywebviewExportApi(FakePywebview()) + window = FakeWindow() + api.bind_window(window) + + saved = api.save_text_file( + "demo.json", + "{\n \"ok\": true\n}\n", + "application/json;charset=utf-8", + ) + + assert saved is True + assert output_path.read_text(encoding="utf-8") == '{\n "ok": true\n}\n' + assert window.dialog_calls == [ + { + "dialog_type": FakePywebview.SAVE_DIALOG, + "save_filename": "demo.json", + "file_types": ("JSON (*.json)",), + } + ] + + +def test_pywebview_export_api_returns_false_when_save_dialog_is_cancelled( + tmp_path: Path, +) -> None: + output_path = tmp_path / "demo.json" + + class FakePywebview: + SAVE_DIALOG = object() + + class FakeWindow: + def create_file_dialog( + self, + dialog_type: object, + *, + save_filename: str, + file_types: tuple[str, ...], + ) -> tuple[str, ...]: + del dialog_type, save_filename, file_types + return () + + api = _PywebviewExportApi(FakePywebview()) + api.bind_window(FakeWindow()) + + saved = api.save_text_file( + "demo.json", + '{"ok": true}', + "application/json;charset=utf-8", + ) + + assert saved is False + assert output_path.exists() is False + + +def test_pywebview_export_api_writes_binary_file_to_selected_path( + tmp_path: Path, +) -> None: + output_path = tmp_path / "demo.pdf" + binary_payload = b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n" + + class FakePywebview: + SAVE_DIALOG = object() + + class FakeWindow: + def create_file_dialog( + self, + dialog_type: object, + *, + save_filename: str, + file_types: tuple[str, ...], + ) -> tuple[str]: + del dialog_type, save_filename, file_types + return (str(output_path),) + + api = _PywebviewExportApi(FakePywebview()) + api.bind_window(FakeWindow()) + + saved = api.save_binary_file( + "demo.pdf", + b64encode(binary_payload).decode("ascii"), + "application/pdf", + ) + + assert saved is True + assert output_path.read_bytes() == binary_payload + + def test_open_editor_passes_template_catalog_path( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, From 17364eb2e9bf806ebd5a717810dcfe10cb759acc Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Fri, 1 May 2026 22:08:21 +0200 Subject: [PATCH 13/23] Icon window --- CHANGELOG.md | 11 + src/tensor_network_editor/app/session.py | 36 ++++ .../app/static/js/bootstrap.js | 11 +- .../static/js/session/sessionUiAdapters.js | 4 +- tests/test_frontend_architecture.py | 201 ++++++++++++++++++ tests/test_session.py | 104 ++++++++- 6 files changed, 363 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index efccb28..0266780 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,17 @@ All notable changes to this project will be documented in this file. file from Python, so desktop-mode JSON, Python, and academic exports no longer disappear into the embedded browser backend's implicit download folder. +- `pywebview` export actions now detect the native save API lazily at export + time instead of only during page startup, so the desktop `Save As` dialog + still appears even when the webview bridge finishes attaching just after the + editor UI initializes. +- The editor bootstrap now starts immediately when the document is already in + `interactive` or `complete`, which fixes `pywebview` windows that could show + the shell markup without wiring toolbar actions, canvas interactions, or the + template bootstrap if `DOMContentLoaded` had already fired. +- Windows `pywebview` launches now reuse the packaged + [`favicon.ico`](src/tensor_network_editor/app/static/favicon.ico) for the + native window icon instead of inheriting the default Python executable icon. - Test cleanup scripts now remove `session.log*` artifacts, and the repository ignores those rotating session logs explicitly. - Shared HTTP test helpers now give bundled editor assets more time to load, diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index f6ae6ec..dc7c681 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -78,6 +78,39 @@ def _import_pywebview() -> Any: ) from exc +def _resolve_pywebview_icon_path() -> Path: + """Return the packaged desktop icon used for the native pywebview window.""" + return Path(__file__).resolve().parent / "static" / "favicon.ico" + + +def _apply_pywebview_native_window_icon(window: Any) -> None: + """Apply the packaged icon to the native pywebview window when supported.""" + icon_path = _resolve_pywebview_icon_path() + if not icon_path.is_file(): + return + native_window = getattr(window, "native", None) + if native_window is None: + return + try: + from System.Drawing import Icon as DrawingIcon # type: ignore[import-not-found] + except Exception: + return + try: + native_window.Icon = DrawingIcon(str(icon_path)) + if hasattr(native_window, "ShowIcon"): + native_window.ShowIcon = True + except Exception as exc: + log_branch( + LOGGER, + "Could not apply the native pywebview window icon", + level=logging.WARNING, + context={ + "icon_path": str(icon_path), + "error": str(exc), + }, + ) + + class _PywebviewExportApi: """Expose native save-file helpers to the embedded pywebview frontend.""" @@ -177,6 +210,9 @@ def _run_pywebview_session( js_api=pywebview_export_api, ) pywebview_export_api.bind_window(pywebview_window) + pywebview_window.events.before_show += lambda: _apply_pywebview_native_window_icon( + pywebview_window + ) def _handle_window_closed(*_args: object) -> None: """Cancel the editor session when the native window is closed.""" diff --git a/src/tensor_network_editor/app/static/js/bootstrap.js b/src/tensor_network_editor/app/static/js/bootstrap.js index b85df79..c94a17b 100644 --- a/src/tensor_network_editor/app/static/js/bootstrap.js +++ b/src/tensor_network_editor/app/static/js/bootstrap.js @@ -37,10 +37,17 @@ export function startEditor(ctx) { redoShortcutLabel: ctx.constants.REDO_SHORTCUT_LABEL, }); - document.addEventListener("DOMContentLoaded", () => { + function initializeEditor() { shellBindings.attachToolbarHandlers(); bootstrapFlow.bootstrap().catch((error) => { actions.setStatus(`Failed to load the editor: ${error.message}`, "error"); }); - }); + } + + if (document.readyState === "loading") { + document.addEventListener("DOMContentLoaded", initializeEditor, { once: true }); + return; + } + + initializeEditor(); } diff --git a/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js b/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js index 1deab6b..9df16b5 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js +++ b/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js @@ -45,7 +45,7 @@ export function createSessionUiAdapters({ } await clipboard.writeText(text); }; - const resolvedPywebviewApi = + const resolvePywebviewApi = () => windowRef && windowRef.pywebview && windowRef.pywebview.api && @@ -73,6 +73,7 @@ export function createSessionUiAdapters({ typeof downloadBlob === "function" ? downloadBlob : async (filename, blobLike) => { + const resolvedPywebviewApi = resolvePywebviewApi(); if (resolvedPywebviewApi) { if (!blobLike || typeof blobLike.arrayBuffer !== "function") { throw new Error("Binary downloads are not available in this browser."); @@ -99,6 +100,7 @@ export function createSessionUiAdapters({ typeof downloadText === "function" ? downloadText : async (filename, text, contentType = "text/plain;charset=utf-8") => { + const resolvedPywebviewApi = resolvePywebviewApi(); if (resolvedPywebviewApi) { return resolvedPywebviewApi.save_text_file(filename, text, contentType); } diff --git a/tests/test_frontend_architecture.py b/tests/test_frontend_architecture.py index c3e1b2f..eba17a4 100644 --- a/tests/test_frontend_architecture.py +++ b/tests/test_frontend_architecture.py @@ -5024,6 +5024,207 @@ class FakeBlob {{ ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_detect_pywebview_save_api_added_after_creation( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_late_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const windowRef = {{}}; + const uiCalls = []; + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef, + documentRef: {{ + createElement() {{ + uiCalls.push({{ type: "web-download" }}); + return {{ + click() {{ + uiCalls.push({{ type: "web-download-click" }}); + }}, + }}; + }}, + }}, + urlRef: {{ + createObjectURL() {{ + return "blob:test"; + }}, + revokeObjectURL() {{ + return undefined; + }}, + }}, + blobCtor: class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + }}, + }}); + + windowRef.pywebview = {{ + api: {{ + async save_text_file(filename, text, contentType) {{ + uiCalls.push({{ type: "pywebview", filename, text, contentType }}); + return true; + }}, + async save_binary_file() {{ + throw new Error("Unexpected binary save in text export test."); + }}, + }}, + }}; + + await sessionUi.downloadText( + "late.json", + "{{\\"late\\": true}}", + "application/json;charset=utf-8" + ); + + if (!uiCalls.some((entry) => entry.type === "pywebview")) {{ + throw new Error(`Expected late pywebview injection to be honored, received ${{JSON.stringify(uiCalls)}}.`); + }} + if (uiCalls.some((entry) => entry.type === "web-download")) {{ + throw new Error(`Expected pywebview save path instead of web download fallback, received ${{JSON.stringify(uiCalls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The late pywebview session-ui adapter runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_start_editor_bootstraps_immediately_when_dom_is_already_ready( + tmp_path: Path, +) -> None: + bootstrap_source = ( + REPO_ROOT + / "src" + / "tensor_network_editor" + / "app" + / "static" + / "js" + / "bootstrap.js" + ).read_text(encoding="utf-8") + (tmp_path / "shell").mkdir(parents=True, exist_ok=True) + (tmp_path / "bootstrap.js").write_text(bootstrap_source, encoding="utf-8") + (tmp_path / "shell" / "editorBootstrapFlow.js").write_text( + """ + export function createEditorBootstrapFlow() { + return { + async bootstrap() { + globalThis.__bootstrapCalls.push("bootstrap"); + return {}; + }, + }; + } + """, + encoding="utf-8", + ) + (tmp_path / "shell" / "shellActions.js").write_text( + """ + export function createShellActions() { + return { + setStatus(message, level = "info") { + globalThis.__bootstrapCalls.push(`status:${level}:${message}`); + }, + }; + } + """, + encoding="utf-8", + ) + (tmp_path / "shell" / "editorShellBindings.js").write_text( + """ + export function createEditorShellBindings() { + return { + attachToolbarHandlers() { + globalThis.__bootstrapCalls.push("attachToolbarHandlers"); + }, + }; + } + """, + encoding="utf-8", + ) + (tmp_path / "shell" / "shortcutTooltip.js").write_text( + """ + export function createShortcutTooltip() { + return { + attachShortcutTooltipHandlers() {}, + }; + } + """, + encoding="utf-8", + ) + script_path = _write_runtime_script( + tmp_path, + "bootstrap_dom_ready.mjs", + """ + globalThis.__bootstrapCalls = []; + const bootstrapUrl = new URL("./bootstrap.js", import.meta.url).href; + const bootstrapModule = await import(bootstrapUrl); + + const documentRef = { + readyState: "complete", + addEventListener(type, handler) { + globalThis.__bootstrapCalls.push(`listener:${type}`); + this.listener = handler; + }, + }; + const ctx = { + state: {}, + store: {}, + window: { + confirm() { + return false; + }, + }, + document: documentRef, + services: { session: {} }, + logger: null, + constants: { REDO_SHORTCUT_LABEL: "Ctrl+Shift+Z" }, + }; + + bootstrapModule.startEditor(ctx); + await Promise.resolve(); + + if (!globalThis.__bootstrapCalls.includes("attachToolbarHandlers")) { + throw new Error(`Expected toolbar handlers to attach immediately, received ${JSON.stringify(globalThis.__bootstrapCalls)}.`); + } + if (!globalThis.__bootstrapCalls.includes("bootstrap")) { + throw new Error(`Expected bootstrap to run immediately for a ready document, received ${JSON.stringify(globalThis.__bootstrapCalls)}.`); + } + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The bootstrap DOM-ready runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_benchmark_helper_modules_build_comparison_rows_and_history_state( tmp_path: Path, diff --git a/tests/test_session.py b/tests/test_session.py index fdaec0f..5ff0aec 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -707,6 +707,7 @@ def fire(self) -> None: class FakeWindowEvents: def __init__(self) -> None: + self.before_show = FakeEventHook() self.closed = FakeEventHook() class FakeWindow: @@ -741,6 +742,7 @@ def create_window( def start(self, callback: object, window: FakeWindow) -> None: self.start_calls += 1 + self.window.events.before_show.fire() cast(Any, callback)(window) class FakeEditorServer: @@ -783,6 +785,105 @@ class FakeMainThread: assert pywebview.window.destroy_calls == 1 +def test_launch_editor_session_pywebview_applies_native_icon_before_show( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.before_show = FakeEventHook() + self.closed = FakeEventHook() + + class FakeNativeWindow: + Icon = None + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + self.native = FakeNativeWindow() + + def destroy(self) -> None: + return None + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + self.window.events.before_show.fire() + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + applied_windows: list[object] = [] + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: FakePywebview()) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + monkeypatch.setattr( + session_module, + "_apply_pywebview_native_window_icon", + lambda window: applied_windows.append(window), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert len(applied_windows) == 1 + assert isinstance(applied_windows[0], FakeWindow) + + def test_launch_editor_session_pywebview_window_close_cancels_session( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -802,6 +903,7 @@ def fire(self) -> None: class FakeWindowEvents: def __init__(self) -> None: + self.before_show = FakeEventHook() self.closed = FakeEventHook() class FakeWindow: @@ -896,7 +998,7 @@ def create_file_dialog( saved = api.save_text_file( "demo.json", - "{\n \"ok\": true\n}\n", + '{\n "ok": true\n}\n', "application/json;charset=utf-8", ) From db8d3356fb169c0dfd763dda5dbdcc2f6e2a79dc Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Sat, 2 May 2026 12:13:52 +0200 Subject: [PATCH 14/23] Stabilize desktop editor flows and lighten benchmark history --- CHANGELOG.md | 22 ++ src/tensor_network_editor/app/server.py | 112 +++++- src/tensor_network_editor/app/session.py | 20 +- .../static/js/session/sessionUiAdapters.js | 9 +- .../app/static/js/state/historySnapshots.js | 88 +++-- tests/test_app_assets.py | 6 + tests/test_app_server.py | 121 +++++++ tests/test_frontend_architecture.py | 324 +++++++++++++++++- tests/test_session.py | 189 ++++++++++ 9 files changed, 845 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0266780..9965028 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ All notable changes to this project will be documented in this file. time instead of only during page startup, so the desktop `Save As` dialog still appears even when the webview bridge finishes attaching just after the editor UI initializes. +- `pywebview` export actions now detect text and binary save capabilities + independently, so desktop exports still use the native `Save As` dialog even + when an embedded backend exposes only one of the two save methods at first. - The editor bootstrap now starts immediately when the document is already in `interactive` or `complete`, which fixes `pywebview` windows that could show the shell markup without wiring toolbar actions, canvas interactions, or the @@ -30,6 +33,25 @@ All notable changes to this project will be documented in this file. - Windows `pywebview` launches now reuse the packaged [`favicon.ico`](src/tensor_network_editor/app/static/favicon.ico) for the native window icon instead of inheriting the default Python executable icon. +- `pywebview` desktop launches now treat the native window-icon hook as best + effort, so backends that do not expose a `before_show` event still open + correctly instead of crashing during startup. +- `pywebview` desktop launches now also tolerate backends with partial window + event hooks, so missing `closed` callbacks no longer crash the editor during + startup. +- Local `EditorServer` startup now waits until a real loopback asset request can + be served before reporting readiness, which stabilizes rapid restart cycles + in tests and makes `_on_server_ready` URLs immediately usable. +- `EditorServer.stop()` is now safe even if it runs before `start()`, so early + cleanup paths no longer risk hanging while waiting for a serve loop that + never began. +- Repeated `EditorServer` startups now reuse the shared static-asset cache + without forcing an immediate full rescan of the asset tree every time, which + trims bursty local startup overhead while still refreshing changed assets + shortly afterward. +- Editor undo/redo snapshots now keep benchmark-mode session history lighter by + stripping inactive scheme view snapshots and ephemeral compare-modal state, + while the active scheme still restores its exact contraction layouts. - Test cleanup scripts now remove `session.log*` artifacts, and the repository ignores those rotating session logs explicitly. - Shared HTTP test helpers now give bundled editor assets more time to load, diff --git a/src/tensor_network_editor/app/server.py b/src/tensor_network_editor/app/server.py index c0018e3..7ef44cc 100644 --- a/src/tensor_network_editor/app/server.py +++ b/src/tensor_network_editor/app/server.py @@ -6,6 +6,7 @@ import logging import mimetypes import threading +import time from collections.abc import Callable from dataclasses import dataclass from http import HTTPStatus @@ -14,6 +15,7 @@ from pathlib import Path from typing import Protocol, TypeAlias, cast from urllib.parse import urlparse +from urllib.request import urlopen from ..internal._logging import ( bind_log_context, @@ -35,9 +37,15 @@ LOGGER = logging.getLogger(__name__) _SERVE_FOREVER_POLL_INTERVAL_SECONDS: float = 0.05 +_STARTUP_READY_TIMEOUT_SECONDS: float = 5.0 +_STARTUP_READY_POLL_INTERVAL_SECONDS: float = 0.01 +_STARTUP_READY_REQUEST_TIMEOUT_SECONDS: float = 0.2 +_RESPONSE_WRITE_CHUNK_SIZE_BYTES: int = 64 * 1024 +_STATIC_ASSET_CACHE_VALIDATION_INTERVAL_SECONDS: float = 0.5 _MAX_REQUEST_BODY_BYTES: int = 1_048_576 _STATIC_ASSET_CACHE_LOCK = threading.Lock() _STATIC_ASSET_CACHE_BY_ROOT: dict[Path, _StaticAssetCache] = {} +_STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT: dict[Path, float] = {} _UNEXPECTED_INTERNAL_ERROR_MESSAGE = "Unexpected internal error." _UNEXPECTED_INTERNAL_ERROR_GUIDANCE = ( "Try again. If the problem continues, check the terminal output for this " @@ -198,11 +206,14 @@ def _build_static_asset_cache( def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: """Return a shared static asset cache for one editor static directory.""" resolved_static_dir = static_dir.resolve() - scanned_files = _scan_static_asset_files(resolved_static_dir) - current_signature = _build_static_asset_source_signature(scanned_files) with _STATIC_ASSET_CACHE_LOCK: + validation_started_at = time.monotonic() cache = _STATIC_ASSET_CACHE_BY_ROOT.get(resolved_static_dir) + last_validated_at = _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.get( + resolved_static_dir + ) if cache is None: + scanned_files = _scan_static_asset_files(resolved_static_dir) with log_operation( LOGGER, "Static asset cache build", @@ -213,9 +224,29 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: scanned_files=scanned_files, ) _STATIC_ASSET_CACHE_BY_ROOT[resolved_static_dir] = cache + _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = ( + validation_started_at + ) success_context["after"] = cache.asset_version success_context["asset_count"] = len(cache.body_by_relative_path) return cache + if ( + last_validated_at is not None + and validation_started_at - last_validated_at + < _STATIC_ASSET_CACHE_VALIDATION_INTERVAL_SECONDS + ): + log_branch( + LOGGER, + "Static asset cache reused", + context={ + "path": resolved_static_dir, + "after": cache.asset_version, + "asset_count": len(cache.body_by_relative_path), + }, + ) + return cache + scanned_files = _scan_static_asset_files(resolved_static_dir) + current_signature = _build_static_asset_source_signature(scanned_files) if cache.source_signature != current_signature: with log_operation( LOGGER, @@ -230,11 +261,17 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: scanned_files=scanned_files, ) _STATIC_ASSET_CACHE_BY_ROOT[resolved_static_dir] = refreshed_cache + _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = ( + validation_started_at + ) success_context["after"] = refreshed_cache.asset_version success_context["asset_count"] = len( refreshed_cache.body_by_relative_path ) return refreshed_cache + _STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT[resolved_static_dir] = ( + validation_started_at + ) log_branch( LOGGER, "Static asset cache reused", @@ -309,6 +346,7 @@ def __init__( ) self._server = ThreadingHTTPServer((host, port), self._build_handler()) self._thread = threading.Thread(target=self._serve_forever, daemon=True) + self._serve_forever_ready = threading.Event() @property def base_url(self) -> str: @@ -322,6 +360,11 @@ def base_url(self) -> str: def start(self) -> None: """Start serving requests in a background thread.""" self._thread.start() + try: + self._wait_until_ready() + except Exception: + self._cleanup_failed_start() + raise log_branch( LOGGER, f"Editor server started at {self.base_url}", @@ -331,9 +374,7 @@ def start(self) -> None: def stop(self) -> None: """Stop the server and wait for the worker thread to exit.""" - self._server.shutdown() - self._server.server_close() - self._thread.join(timeout=5) + self._stop_server_worker() log_branch( LOGGER, "Editor server stopped", @@ -343,8 +384,61 @@ def stop(self) -> None: def _serve_forever(self) -> None: """Serve requests with a short shutdown polling interval.""" + self._serve_forever_ready.set() self._server.serve_forever(poll_interval=_SERVE_FOREVER_POLL_INTERVAL_SECONDS) + def _wait_until_ready(self) -> None: + """Block until loopback requests can read one fully served asset.""" + deadline = time.monotonic() + _STARTUP_READY_TIMEOUT_SECONDS + if not self._serve_forever_ready.wait(timeout=_STARTUP_READY_TIMEOUT_SECONDS): + raise RuntimeError( + "Editor server did not enter the serving loop before the startup timeout elapsed." + ) + + last_error: OSError | None = None + while True: + remaining_seconds = deadline - time.monotonic() + if remaining_seconds <= 0: + break + request_timeout_seconds = min( + _STARTUP_READY_REQUEST_TIMEOUT_SECONDS, + remaining_seconds, + ) + try: + self._probe_loopback_readiness(request_timeout_seconds) + except OSError as exc: + last_error = exc + time.sleep(min(_STARTUP_READY_POLL_INTERVAL_SECONDS, remaining_seconds)) + continue + return + + if last_error is None: + raise RuntimeError( + "Editor server readiness probe timed out before any loopback request succeeded." + ) + raise RuntimeError( + "Editor server did not become ready to serve loopback requests before the startup timeout elapsed." + ) from last_error + + def _probe_loopback_readiness(self, timeout_seconds: float) -> None: + """Read one small static asset to verify the server serves full responses.""" + with urlopen( + f"{self.base_url}/favicon.ico", timeout=timeout_seconds + ) as response: + response.read() + + def _stop_server_worker(self) -> None: + """Best-effort shutdown that is safe before the serve loop starts.""" + if self._thread.is_alive() and self._serve_forever_ready.is_set(): + self._server.shutdown() + self._server.server_close() + if self._thread.ident is not None: + self._thread.join(timeout=5) + + def _cleanup_failed_start(self) -> None: + """Best-effort cleanup when startup fails after allocating the server socket.""" + self._stop_server_worker() + def _build_handler(self) -> type[BaseHTTPRequestHandler]: """Build the request-handler class bound to this server instance.""" session = self.session @@ -598,7 +692,13 @@ def _write_bytes(self, status: int, body: bytes, content_type: str) -> None: self.send_header("Connection", "close") self._write_no_cache_headers() self.end_headers() - self.wfile.write(body) + body_view = memoryview(body) + for offset in range( + 0, len(body_view), _RESPONSE_WRITE_CHUNK_SIZE_BYTES + ): + next_offset = offset + _RESPONSE_WRITE_CHUNK_SIZE_BYTES + self.wfile.write(body_view[offset:next_offset]) + self.wfile.flush() def _write_no_cache_headers(self) -> None: """Emit headers that disable browser and intermediary caching.""" diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index dc7c681..451a1a1 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -210,9 +210,14 @@ def _run_pywebview_session( js_api=pywebview_export_api, ) pywebview_export_api.bind_window(pywebview_window) - pywebview_window.events.before_show += lambda: _apply_pywebview_native_window_icon( - pywebview_window - ) + window_events = getattr(pywebview_window, "events", None) + before_show_event = getattr(window_events, "before_show", None) + if before_show_event is not None: + before_show_event += lambda: _apply_pywebview_native_window_icon( + pywebview_window + ) + else: + _apply_pywebview_native_window_icon(pywebview_window) def _handle_window_closed(*_args: object) -> None: """Cancel the editor session when the native window is closed.""" @@ -227,7 +232,9 @@ def _wait_for_session_and_close_window(window: Any) -> None: except Exception: return None - pywebview_window.events.closed += _handle_window_closed + closed_event = getattr(window_events, "closed", None) + if closed_event is not None: + closed_event += _handle_window_closed pywebview.start(_wait_for_session_and_close_window, pywebview_window) return wait_for_editor_result(session) @@ -647,6 +654,11 @@ def launch_editor_session( ui_mode=ui_mode, open_browser=open_browser, ) + if ( + effective_ui_mode == "pywebview" + and threading.current_thread() is not threading.main_thread() + ): + raise RuntimeError("pywebview mode must be launched from the main thread.") previous_sigint_handler: SignalHandler | int | None = None server_started = False diff --git a/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js b/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js index 9df16b5..bc678fd 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js +++ b/src/tensor_network_editor/app/static/js/session/sessionUiAdapters.js @@ -45,12 +45,11 @@ export function createSessionUiAdapters({ } await clipboard.writeText(text); }; - const resolvePywebviewApi = () => + const resolvePywebviewApiMethod = (methodName) => windowRef && windowRef.pywebview && windowRef.pywebview.api && - typeof windowRef.pywebview.api.save_text_file === "function" && - typeof windowRef.pywebview.api.save_binary_file === "function" + typeof windowRef.pywebview.api[methodName] === "function" ? windowRef.pywebview.api : null; const encodeBytesToBase64 = (bytes) => { @@ -73,7 +72,7 @@ export function createSessionUiAdapters({ typeof downloadBlob === "function" ? downloadBlob : async (filename, blobLike) => { - const resolvedPywebviewApi = resolvePywebviewApi(); + const resolvedPywebviewApi = resolvePywebviewApiMethod("save_binary_file"); if (resolvedPywebviewApi) { if (!blobLike || typeof blobLike.arrayBuffer !== "function") { throw new Error("Binary downloads are not available in this browser."); @@ -100,7 +99,7 @@ export function createSessionUiAdapters({ typeof downloadText === "function" ? downloadText : async (filename, text, contentType = "text/plain;charset=utf-8") => { - const resolvedPywebviewApi = resolvePywebviewApi(); + const resolvedPywebviewApi = resolvePywebviewApiMethod("save_text_file"); if (resolvedPywebviewApi) { return resolvedPywebviewApi.save_text_file(filename, text, contentType); } diff --git a/src/tensor_network_editor/app/static/js/state/historySnapshots.js b/src/tensor_network_editor/app/static/js/state/historySnapshots.js index 2d477c8..b0d45bf 100644 --- a/src/tensor_network_editor/app/static/js/state/historySnapshots.js +++ b/src/tensor_network_editor/app/static/js/state/historySnapshots.js @@ -19,25 +19,48 @@ export function createHistorySnapshotSupport({ refreshContractionAnalysis, setStatus, }) { - function restoreBenchmarkSession(snapshotBenchmarkSession) { - const nextBenchmarkSession = - snapshotBenchmarkSession && typeof snapshotBenchmarkSession === "object" - ? deepClone(snapshotBenchmarkSession) + function createHistorySnapshotBenchmarkPlan(plan) { + if (!plan || typeof plan !== "object") { + return { + id: "", + name: "", + steps: [], + view_snapshots: [], + metadata: {}, + }; + } + return { + id: typeof plan.id === "string" ? plan.id : "", + name: typeof plan.name === "string" ? plan.name : "", + steps: Array.isArray(plan.steps) ? deepClone(plan.steps) : [], + view_snapshots: [], + metadata: + plan.metadata && typeof plan.metadata === "object" + ? deepClone(plan.metadata) + : {}, + }; + } + + function createHistorySnapshotBenchmarkSession(benchmarkSession) { + const sourceSession = + benchmarkSession && typeof benchmarkSession === "object" + ? benchmarkSession : createEmptyBenchmarkSession(); - const compareModal = - nextBenchmarkSession.compareModal && - typeof nextBenchmarkSession.compareModal === "object" - ? nextBenchmarkSession.compareModal - : createEmptyBenchmarkCompareState(); - compareModal.rows = Array.isArray(compareModal.rows) - ? compareModal.rows - : compareModal.tableModel && Array.isArray(compareModal.tableModel.rows) - ? compareModal.tableModel.rows - : []; - nextBenchmarkSession.compareModal = compareModal; - nextBenchmarkSession.schemes = Array.isArray(nextBenchmarkSession.schemes) - ? nextBenchmarkSession.schemes - : []; + const nextBenchmarkSession = { + enabled: Boolean(sourceSession.enabled), + activePosition: Number.isInteger(sourceSession.activePosition) + ? sourceSession.activePosition + : 0, + originalPlan: sourceSession.originalPlan + ? createHistorySnapshotBenchmarkPlan(sourceSession.originalPlan) + : null, + schemes: Array.isArray(sourceSession.schemes) + ? sourceSession.schemes.map((scheme) => + createHistorySnapshotBenchmarkPlan(scheme) + ) + : [], + compareModal: createEmptyBenchmarkCompareState(), + }; nextBenchmarkSession.activePosition = Number.isInteger( nextBenchmarkSession.activePosition ) @@ -51,6 +74,16 @@ export function createHistorySnapshotSupport({ ) { nextBenchmarkSession.activePosition = nextBenchmarkSession.schemes.length; } + return nextBenchmarkSession; + } + + function restoreBenchmarkSession( + snapshotBenchmarkSession, + restoredContractionPlan = null + ) { + const nextBenchmarkSession = createHistorySnapshotBenchmarkSession( + snapshotBenchmarkSession + ); state.benchmarkSession = nextBenchmarkSession; if (!nextBenchmarkSession.enabled) { return; @@ -66,7 +99,13 @@ export function createHistorySnapshotSupport({ state.spec.contraction_plan = null; return; } - state.spec.contraction_plan = deepClone(activeScheme); + const exactActiveScheme = + restoredContractionPlan && + typeof restoredContractionPlan === "object" && + (!activeScheme.id || restoredContractionPlan.id === activeScheme.id) + ? restoredContractionPlan + : null; + state.spec.contraction_plan = exactActiveScheme || deepClone(activeScheme); nextBenchmarkSession.schemes[nextBenchmarkSession.activePosition - 1] = state.spec.contraction_plan; } @@ -84,7 +123,9 @@ export function createHistorySnapshotSupport({ return { spec: snapshotSpec == null ? deepClone(state.spec) : snapshotSpec, tensorOrder: Array.isArray(state.tensorOrder) ? [...state.tensorOrder] : [], - benchmarkSession: deepClone(state.benchmarkSession || createEmptyBenchmarkSession()), + benchmarkSession: createHistorySnapshotBenchmarkSession( + state.benchmarkSession + ), }; } @@ -118,7 +159,12 @@ export function createHistorySnapshotSupport({ state.tensorOrder = Array.isArray(snapshot.tensorOrder) ? [...snapshot.tensorOrder] : []; - restoreBenchmarkSession(snapshot.benchmarkSession); + restoreBenchmarkSession( + snapshot.benchmarkSession, + state.spec && typeof state.spec === "object" + ? state.spec.contraction_plan || null + : null + ); if (typeof bumpSpecRevision === "function") { bumpSpecRevision(); } diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py index efcdca7..c61420b 100644 --- a/tests/test_app_assets.py +++ b/tests/test_app_assets.py @@ -3231,6 +3231,12 @@ def test_editor_assets_use_lookup_caches_and_lighter_history_paths( assert "function createHistorySnapshot(" not in history_body assert "function getSelectedEntries(" not in history_body assert "function createHistorySnapshotSupport(" in history_snapshots_body + assert "function createHistorySnapshotBenchmarkSession(" in history_snapshots_body + assert "function createHistorySnapshotBenchmarkPlan(" in history_snapshots_body + assert ( + "benchmarkSession: deepClone(state.benchmarkSession || createEmptyBenchmarkSession())" + not in history_snapshots_body + ) assert "function createSelectionEntrySupport(" in history_selection_body assert "function createDesignMutationPipeline(" in history_pipeline_body assert "structuredClone" in utilities_body diff --git a/tests/test_app_server.py b/tests/test_app_server.py index f2945bb..d0f6fc5 100644 --- a/tests/test_app_server.py +++ b/tests/test_app_server.py @@ -2,9 +2,12 @@ import logging import os +import subprocess +import sys from http import HTTPStatus from pathlib import Path from typing import Protocol, cast +from urllib.request import urlopen import pytest @@ -132,6 +135,61 @@ def test_editor_servers_reuse_static_asset_cache_between_instances() -> None: second_server._server.server_close() +def test_editor_server_stop_returns_cleanly_before_start() -> None: + env = os.environ.copy() + src_path = str((Path.cwd() / "src").resolve()) + existing_python_path = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + src_path + if not existing_python_path + else f"{src_path}{os.pathsep}{existing_python_path}" + ) + script = """ +from tensor_network_editor.app.server import EditorServer +from tensor_network_editor.app.session import EditorSession + +server = EditorServer(EditorSession()) +server.stop() +print("STOPPED", flush=True) +""".strip() + + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + check=False, + cwd=Path.cwd(), + env=env, + text=True, + timeout=2, + ) + + assert result.returncode == 0, result.stdout + result.stderr + assert result.stdout.strip() == "STOPPED" + + +def test_editor_server_start_makes_shell_and_vendor_assets_immediately_readable_across_rapid_restarts() -> ( + None +): + for _ in range(5): + server = EditorServer(EditorSession(initial_spec=build_sample_spec())) + + server.start() + try: + with urlopen(f"{server.base_url}/", timeout=2) as response: + html = response.read().decode("utf-8") + with urlopen( + f"{server.base_url}/vendor/cytoscape.min.js", timeout=2 + ) as response: + vendor_body = response.read().decode("utf-8") + vendor_headers = dict(response.headers.items()) + finally: + server.stop() + + assert "Tensor Network Editor" in html + assert "cytoscape" in vendor_body + assert vendor_headers["Content-Type"].startswith("application/javascript") + + def test_editor_index_response_embeds_session_runtime_config() -> None: first_server = EditorServer(EditorSession(initial_spec=build_sample_spec())) second_server = EditorServer(EditorSession(initial_spec=build_sample_spec())) @@ -170,8 +228,10 @@ def test_editor_index_response_embeds_session_runtime_config() -> None: def test_static_asset_cache_refreshes_when_static_files_change( tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, ) -> None: static_dir = tmp_path / "static" + resolved_static_dir = static_dir.resolve() asset_path = static_dir / "js" / "app.js" asset_path.parent.mkdir(parents=True) (static_dir / "index.html").write_text( @@ -179,6 +239,13 @@ def test_static_asset_cache_refreshes_when_static_files_change( encoding="utf-8", ) asset_path.write_text("console.log('first');", encoding="utf-8") + monotonic_time = 100.0 + + monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time) + app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) first_cache = app_server._get_static_asset_cache(static_dir) @@ -188,6 +255,7 @@ def test_static_asset_cache_refreshes_when_static_files_change( + 1_000_000_000 ) os.utime(asset_path, ns=(future_timestamp_ns, future_timestamp_ns)) + monotonic_time += 1.0 refreshed_cache = app_server._get_static_asset_cache(static_dir) @@ -210,6 +278,7 @@ def test_static_asset_cache_reuses_one_scan_per_build_or_refresh( encoding="utf-8", ) asset_path.write_text("console.log('first');", encoding="utf-8") + monotonic_time = 100.0 scan_calls: list[Path] = [] original_scan = app_server._scan_static_asset_files @@ -218,8 +287,12 @@ def recording_scan(path: Path) -> list[tuple[Path, str, int, int]]: scan_calls.append(path.resolve()) return original_scan(path) + monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time) monkeypatch.setattr(app_server, "_scan_static_asset_files", recording_scan) app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) first_cache = app_server._get_static_asset_cache(static_dir) @@ -232,6 +305,7 @@ def recording_scan(path: Path) -> list[tuple[Path, str, int, int]]: + 1_000_000_000 ) os.utime(asset_path, ns=(future_timestamp_ns, future_timestamp_ns)) + monotonic_time += 1.0 refreshed_cache = app_server._get_static_asset_cache(static_dir) @@ -241,6 +315,42 @@ def recording_scan(path: Path) -> list[tuple[Path, str, int, int]]: assert scan_calls == [resolved_static_dir, resolved_static_dir] +def test_static_asset_cache_skips_immediate_rescan_for_rapid_reuse( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + static_dir = tmp_path / "static" + resolved_static_dir = static_dir.resolve() + asset_path = static_dir / "js" / "app.js" + asset_path.parent.mkdir(parents=True) + (static_dir / "index.html").write_text( + "", + encoding="utf-8", + ) + asset_path.write_text("console.log('first');", encoding="utf-8") + monotonic_time = 100.0 + + scan_calls: list[Path] = [] + original_scan = app_server._scan_static_asset_files + + def recording_scan(path: Path) -> list[tuple[Path, str, int, int]]: + scan_calls.append(path.resolve()) + return original_scan(path) + + monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time) + monkeypatch.setattr(app_server, "_scan_static_asset_files", recording_scan) + app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) + + first_cache = app_server._get_static_asset_cache(static_dir) + second_cache = app_server._get_static_asset_cache(static_dir) + + assert first_cache is second_cache + assert scan_calls == [resolved_static_dir] + + def test_static_asset_cache_logs_build_and_reuse( tmp_path: Path, caplog: pytest.LogCaptureFixture, @@ -255,6 +365,9 @@ def test_static_asset_cache_logs_build_and_reuse( ) asset_path.write_text("console.log('first');", encoding="utf-8") app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) with caplog.at_level(logging.DEBUG, logger="tensor_network_editor"): first_cache = app_server._get_static_asset_cache(static_dir) @@ -270,6 +383,7 @@ def test_static_asset_cache_logs_build_and_reuse( def test_static_asset_cache_logs_refresh_with_version_context( tmp_path: Path, caplog: pytest.LogCaptureFixture, + monkeypatch: pytest.MonkeyPatch, ) -> None: static_dir = tmp_path / "static" asset_path = static_dir / "js" / "app.js" @@ -280,7 +394,13 @@ def test_static_asset_cache_logs_refresh_with_version_context( encoding="utf-8", ) asset_path.write_text("console.log('first');", encoding="utf-8") + monotonic_time = 100.0 + + monkeypatch.setattr(app_server.time, "monotonic", lambda: monotonic_time) app_server._STATIC_ASSET_CACHE_BY_ROOT.pop(resolved_static_dir, None) + app_server._STATIC_ASSET_CACHE_LAST_VALIDATED_AT_BY_ROOT.pop( + resolved_static_dir, None + ) first_cache = app_server._get_static_asset_cache(static_dir) asset_path.write_text("console.log('second');", encoding="utf-8") @@ -289,6 +409,7 @@ def test_static_asset_cache_logs_refresh_with_version_context( + 1_000_000_000 ) os.utime(asset_path, ns=(future_timestamp_ns, future_timestamp_ns)) + monotonic_time += 1.0 with caplog.at_level(logging.DEBUG, logger="tensor_network_editor"): refreshed_cache = app_server._get_static_asset_cache(static_dir) diff --git a/tests/test_frontend_architecture.py b/tests/test_frontend_architecture.py index eba17a4..94a2d0f 100644 --- a/tests/test_frontend_architecture.py +++ b/tests/test_frontend_architecture.py @@ -5108,6 +5108,186 @@ def test_session_ui_adapters_detect_pywebview_save_api_added_after_creation( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_use_partial_pywebview_text_api_when_available( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_partial_text_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const calls = []; + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef: {{ + pywebview: {{ + api: {{ + async save_text_file(filename, text, contentType) {{ + calls.push({{ type: "text", filename, text, contentType }}); + return true; + }}, + }}, + }}, + }}, + documentRef: {{ + createElement() {{ + calls.push({{ type: "web-download" }}); + return {{ + click() {{ + calls.push({{ type: "web-download-click" }}); + }}, + }}; + }}, + }}, + urlRef: {{ + createObjectURL() {{ + calls.push({{ type: "object-url" }}); + return "blob:test"; + }}, + revokeObjectURL() {{ + return undefined; + }}, + }}, + blobCtor: class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + }}, + }}); + + const saved = await sessionUi.downloadText( + "partial.json", + "{{\\"partial\\": true}}", + "application/json;charset=utf-8" + ); + + if (saved !== true) {{ + throw new Error(`Expected the partial pywebview text save to resolve true, received ${{saved}}.`); + }} + if (!calls.some((entry) => entry.type === "text")) {{ + throw new Error(`Expected downloadText() to use save_text_file(), received ${{JSON.stringify(calls)}}.`); + }} + if (calls.some((entry) => entry.type === "web-download" || entry.type === "object-url")) {{ + throw new Error(`Expected no web-download fallback when save_text_file() exists, received ${{JSON.stringify(calls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The partial pywebview text-save runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_session_ui_adapters_use_partial_pywebview_binary_api_when_available( + tmp_path: Path, +) -> None: + script_path = _write_runtime_script( + tmp_path, + "session_ui_pywebview_partial_binary_save.mjs", + f""" + import {{ pathToFileURL }} from "node:url"; + + const sessionUiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "session" / "sessionUiAdapters.js")!r}).href; + const sessionUiModule = await import(sessionUiUrl); + + const calls = []; + class FakeBlob {{ + constructor(parts, options = {{}}) {{ + this.parts = parts; + this.type = options.type || ""; + }} + + async arrayBuffer() {{ + const firstPart = this.parts[0]; + if (!(firstPart instanceof Uint8Array)) {{ + throw new Error("Expected Uint8Array content in the binary export test blob."); + }} + return firstPart.buffer.slice( + firstPart.byteOffset, + firstPart.byteOffset + firstPart.byteLength + ); + }} + }} + const sessionUi = sessionUiModule.createSessionUiAdapters({{ + windowRef: {{ + pywebview: {{ + api: {{ + async save_binary_file(filename, base64Payload, contentType) {{ + calls.push({{ type: "binary", filename, base64Payload, contentType }}); + return true; + }}, + }}, + }}, + }}, + documentRef: {{ + createElement() {{ + calls.push({{ type: "web-download" }}); + return {{ + click() {{ + calls.push({{ type: "web-download-click" }}); + }}, + }}; + }}, + }}, + urlRef: {{ + createObjectURL() {{ + calls.push({{ type: "object-url" }}); + return "blob:test"; + }}, + revokeObjectURL() {{ + return undefined; + }}, + }}, + blobCtor: FakeBlob, + }}); + + const saved = await sessionUi.downloadBlob( + "partial.pdf", + new FakeBlob([Uint8Array.from([0, 1, 2, 255])], {{ type: "application/pdf" }}) + ); + + if (saved !== true) {{ + throw new Error(`Expected the partial pywebview binary save to resolve true, received ${{saved}}.`); + }} + const binaryCall = calls.find((entry) => entry.type === "binary"); + if (!binaryCall || binaryCall.base64Payload !== "AAEC/w==") {{ + throw new Error(`Expected downloadBlob() to use save_binary_file(), received ${{JSON.stringify(calls)}}.`); + }} + if (calls.some((entry) => entry.type === "web-download" || entry.type === "object-url")) {{ + throw new Error(`Expected no web-download fallback when save_binary_file() exists, received ${{JSON.stringify(calls)}}.`); + }} + """, + ) + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The partial pywebview binary-save runtime script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_start_editor_bootstraps_immediately_when_dom_is_already_ready( tmp_path: Path, @@ -5341,7 +5521,21 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( const historyEvents = []; const historyState = {{ - spec: {{ id: "network_demo" }}, + spec: {{ + id: "network_demo", + contraction_plan: {{ + id: "scheme_beta", + name: "Beta", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "tensor_a" }}], + }}, + ], + metadata: {{}}, + }}, + }}, tensorOrder: ["tensor_a"], undoStack: [], redoStack: [], @@ -5363,13 +5557,47 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( benchmarkSession: {{ enabled: true, activePosition: 2, - originalPlan: {{ id: "original_plan", name: "Original", steps: [], metadata: {{}} }}, + originalPlan: {{ + id: "original_plan", + name: "Original", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "original_tensor" }}], + }}, + ], + metadata: {{}}, + }}, schemes: [ - {{ id: "scheme_alpha", name: "Alpha", steps: [], metadata: {{}} }}, - {{ id: "scheme_beta", name: "Beta", steps: [], metadata: {{}} }}, + {{ + id: "scheme_alpha", + name: "Alpha", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "alpha_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + {{ + id: "scheme_beta", + name: "Beta", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "beta_tensor" }}], + }}, + ], + metadata: {{}}, + }}, ], compareModal: {{ open: true, + tableModel: {{ rows: [{{ scheme_id: "scheme_alpha" }}] }}, rows: [{{ scheme_id: "scheme_alpha" }}], activeRequestId: 7, }}, @@ -5399,19 +5627,80 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( if (!snapshot.benchmarkSession || snapshot.benchmarkSession.activePosition !== 2) {{ throw new Error(`Expected history snapshots to capture benchmark session state, received ${{JSON.stringify(snapshot)}}.`); }} + if (snapshot.benchmarkSession.compareModal.open || snapshot.benchmarkSession.compareModal.activeRequestId !== 0) {{ + throw new Error(`Expected history snapshots to reset ephemeral benchmark compare state, received ${{JSON.stringify(snapshot.benchmarkSession.compareModal)}}.`); + }} + if (snapshot.benchmarkSession.compareModal.rows.length !== 0 || snapshot.benchmarkSession.compareModal.tableModel !== null) {{ + throw new Error(`Expected history snapshots to strip compare rows and table models, received ${{JSON.stringify(snapshot.benchmarkSession.compareModal)}}.`); + }} + if (snapshot.benchmarkSession.originalPlan.view_snapshots.length !== 0) {{ + throw new Error(`Expected history snapshots to strip original-plan view snapshots, received ${{JSON.stringify(snapshot.benchmarkSession.originalPlan)}}.`); + }} + if (snapshot.benchmarkSession.schemes.some((scheme) => scheme.view_snapshots.length !== 0)) {{ + throw new Error(`Expected history snapshots to strip inactive benchmark view snapshots, received ${{JSON.stringify(snapshot.benchmarkSession.schemes)}}.`); + }} + if (snapshot.spec.contraction_plan.view_snapshots.length !== 1) {{ + throw new Error(`Expected the active scheme view snapshots to stay in the main spec snapshot, received ${{JSON.stringify(snapshot.spec.contraction_plan)}}.`); + }} historySupport.restoreHistorySnapshot({{ - spec: {{ id: "restored_network" }}, + spec: {{ + id: "restored_network", + contraction_plan: {{ + id: "scheme_restored", + name: "Restored", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 1, + operand_layouts: [{{ operand_id: "restored_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + }}, tensorOrder: ["tensor_b"], benchmarkSession: {{ enabled: true, activePosition: 1, - originalPlan: null, - schemes: [{{ id: "scheme_restored", name: "Restored", steps: [], metadata: {{}} }}], + originalPlan: {{ + id: "restored_original", + name: "Restored Original", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "restored_original_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + schemes: [ + {{ + id: "scheme_restored", + name: "Restored", + steps: [], + view_snapshots: [], + metadata: {{}}, + }}, + {{ + id: "scheme_inactive", + name: "Inactive", + steps: [], + view_snapshots: [ + {{ + applied_step_count: 0, + operand_layouts: [{{ operand_id: "inactive_tensor" }}], + }}, + ], + metadata: {{}}, + }}, + ], compareModal: {{ - open: false, - rows: [], - activeRequestId: 0, + open: true, + tableModel: {{ rows: [{{ scheme_id: "scheme_restored" }}] }}, + rows: [{{ scheme_id: "scheme_restored" }}], + activeRequestId: 9, }}, }}, }}); @@ -5419,6 +5708,21 @@ def test_benchmark_helper_modules_build_comparison_rows_and_history_state( if (!historyState.benchmarkSession || historyState.benchmarkSession.activePosition !== 1) {{ throw new Error(`Expected history restore to recover benchmark session state, received ${{JSON.stringify(historyState.benchmarkSession)}}.`); }} + if (historyState.benchmarkSession.compareModal.open || historyState.benchmarkSession.compareModal.rows.length !== 0 || historyState.benchmarkSession.compareModal.tableModel !== null) {{ + throw new Error(`Expected history restore to keep benchmark compare state ephemeral, received ${{JSON.stringify(historyState.benchmarkSession.compareModal)}}.`); + }} + if (historyState.benchmarkSession.originalPlan.view_snapshots.length !== 0) {{ + throw new Error(`Expected history restore to keep original-plan snapshots lazy, received ${{JSON.stringify(historyState.benchmarkSession.originalPlan)}}.`); + }} + if (historyState.benchmarkSession.schemes[1].view_snapshots.length !== 0) {{ + throw new Error(`Expected inactive benchmark schemes to stay lightweight after restore, received ${{JSON.stringify(historyState.benchmarkSession.schemes)}}.`); + }} + if (historyState.benchmarkSession.schemes[0] !== historyState.spec.contraction_plan) {{ + throw new Error("Expected history restore to re-link the active benchmark scheme to the restored contraction plan."); + }} + if (historyState.spec.contraction_plan.view_snapshots.length !== 1) {{ + throw new Error(`Expected the restored active scheme to keep its exact view snapshots, received ${{JSON.stringify(historyState.spec.contraction_plan)}}.`); + }} """, ) diff --git a/tests/test_session.py b/tests/test_session.py index 5ff0aec..53dc03c 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -884,6 +884,195 @@ class FakeMainThread: assert isinstance(applied_windows[0], FakeWindow) +def test_launch_editor_session_pywebview_applies_native_icon_without_before_show( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.closed = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + + def destroy(self) -> None: + return None + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + applied_windows: list[object] = [] + main_thread = FakeMainThread() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: FakePywebview()) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + monkeypatch.setattr( + session_module, + "_apply_pywebview_native_window_icon", + lambda window: applied_windows.append(window), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert len(applied_windows) == 1 + assert isinstance(applied_windows[0], FakeWindow) + + +def test_launch_editor_session_pywebview_tolerates_missing_closed_event( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tensor_network_editor.app import session as session_module + + completed_result = EditorResult( + spec=build_blank_network_spec(), + engine=EngineName.EINSUM_NUMPY, + confirmed=True, + ) + + class FakeEventHook: + def __init__(self) -> None: + self._callbacks: list[object] = [] + + def __iadd__(self, callback: object) -> FakeEventHook: + self._callbacks.append(callback) + return self + + def fire(self) -> None: + for callback in list(self._callbacks): + cast(Any, callback)() + + class FakeWindowEvents: + def __init__(self) -> None: + self.before_show = FakeEventHook() + + class FakeWindow: + def __init__(self) -> None: + self.events = FakeWindowEvents() + self.destroy_calls = 0 + + def destroy(self) -> None: + self.destroy_calls += 1 + + class FakePywebview: + def __init__(self) -> None: + self.window = FakeWindow() + + def create_window( + self, + title: str, + url: str, + *, + maximized: bool = False, + js_api: object | None = None, + ) -> FakeWindow: + del title, url, maximized, js_api + return self.window + + def start(self, callback: object, window: FakeWindow) -> None: + self.window.events.before_show.fire() + cast(Any, callback)(window) + + class FakeEditorServer: + def __init__(self, *args: object, **kwargs: object) -> None: + del args, kwargs + self.base_url = "http://127.0.0.1:43210" + + def start(self) -> None: + return None + + def stop(self) -> None: + return None + + class FakeMainThread: + name = "MainThread" + + applied_windows: list[object] = [] + main_thread = FakeMainThread() + pywebview = FakePywebview() + monkeypatch.setattr( + "tensor_network_editor.app.server.EditorServer", + FakeEditorServer, + ) + monkeypatch.setattr(session_module.threading, "main_thread", lambda: main_thread) + monkeypatch.setattr(session_module.threading, "current_thread", lambda: main_thread) + monkeypatch.setattr(session_module, "_import_pywebview", lambda: pywebview) + monkeypatch.setattr( + session_module, + "wait_for_editor_result", + lambda _session: completed_result, + ) + monkeypatch.setattr( + session_module, + "_apply_pywebview_native_window_icon", + lambda window: applied_windows.append(window), + ) + + result = session_module.launch_editor_session(ui_mode="pywebview") + + assert result is completed_result + assert applied_windows == [pywebview.window] + assert pywebview.window.destroy_calls == 1 + + def test_launch_editor_session_pywebview_window_close_cancels_session( monkeypatch: pytest.MonkeyPatch, ) -> None: From da4e04c692ba9765bf03ba2d86e718a1102ff33e Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Sat, 2 May 2026 12:17:27 +0200 Subject: [PATCH 15/23] Version update --- CHANGELOG.md | 2 ++ CITATION.cff | 4 ++-- src/tensor_network_editor/_version.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9965028..431c171 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +## [1.0.0] - 2026-05-02 + ### Added - The editor now supports an explicit UI launch mode across the CLI and Python diff --git a/CITATION.cff b/CITATION.cff index 061f699..ceb2485 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,8 +5,8 @@ type: software authors: - family-names: "Mata Ali" given-names: "Alejandro" -version: "0.5.0" -date-released: "2026-04-30" +version: "1.0.0" +date-released: "2026-05-02" repository-code: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" url: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" license: "MIT" diff --git a/src/tensor_network_editor/_version.py b/src/tensor_network_editor/_version.py index bcfe245..3f00b16 100644 --- a/src/tensor_network_editor/_version.py +++ b/src/tensor_network_editor/_version.py @@ -4,4 +4,4 @@ from typing import Final -__version__: Final[str] = "0.5.0" +__version__: Final[str] = "1.0.0" From 9ad4668f9d6e7b6156cec6cd44496229ca7079ae Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Sat, 2 May 2026 12:32:46 +0200 Subject: [PATCH 16/23] Final --- CHANGELOG.md | 3 +++ MANIFEST.in | 2 -- README.md | 3 +-- tests/test_packaging.py | 21 +++++++++++++++++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 431c171..597315d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ All notable changes to this project will be documented in this file. ### Changed +- Publishing polish: the README no longer advertises a removed `png` extra, and + `MANIFEST.in` no longer carries redundant exclusions for non-package + directories, which keeps `python -m build` quieter. - `pywebview` editor launches now open their native window maximized by default, so the desktop mode starts with the same roomy workspace users usually expect from the browser flow. diff --git a/MANIFEST.in b/MANIFEST.in index 426ccfa..dd3e848 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,5 +3,3 @@ include THIRD_PARTY_LICENSES include README.md recursive-include src/tensor_network_editor/app/static *.css *.html *.ico *.js include src/tensor_network_editor/py.typed -recursive-exclude docs/images * -recursive-exclude tests * diff --git a/README.md b/README.md index 537f3d0..9fc861a 100644 --- a/README.md +++ b/README.md @@ -218,8 +218,7 @@ tensor-network-editor doctor my_network.json tensor-network-editor doctor my_network.json --format json ``` -Render one saved design as SVG, PDF, TikZ/LaTeX, Graphviz/DOT, Mermaid, or with the -optional `png` extra, PNG: +Render one saved design as SVG, PDF, TikZ/LaTeX, Graphviz/DOT, Mermaid, or PNG: ```bash tensor-network-editor render my_network.json --format svg --output figure.svg diff --git a/tests/test_packaging.py b/tests/test_packaging.py index db4bc28..aaf610f 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -103,6 +103,27 @@ def test_project_metadata_declares_required_matplotlib_dependency_and_backend_ex assert "png" not in optional_dependencies +def test_docs_do_not_advertise_removed_png_extra() -> None: + readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8") + installation_text = (Path.cwd() / "docs" / "installation.md").read_text( + encoding="utf-8" + ) + + assert "tensor-network-editor[png]" not in readme_text + assert "optional `png` extra" not in readme_text + assert "tensor-network-editor[png]" not in installation_text + + +def test_manifest_omits_redundant_non_package_exclusions() -> None: + manifest_text = (Path.cwd() / "MANIFEST.in").read_text(encoding="utf-8") + + assert "docs/images" not in manifest_text + assert "prune tests" not in manifest_text + assert "tests" not in manifest_text + assert "recursive-exclude docs/images *" not in manifest_text + assert "recursive-exclude tests *" not in manifest_text + + def test_third_party_notices_describe_bundled_asset_scope() -> None: third_party_text = (Path.cwd() / "THIRD_PARTY_LICENSES").read_text(encoding="utf-8") readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8") From b9bbb1560904e155d9348a13ad1c76fb8ee8c4b4 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Sat, 2 May 2026 12:36:19 +0200 Subject: [PATCH 17/23] Production Ready --- CHANGELOG.md | 2 ++ CITATION.cff | 2 +- README.md | 1 + pyproject.toml | 4 ++-- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 597315d..a8c31d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ All notable changes to this project will be documented in this file. ### Changed +- PyPI trove classifiers and README now state **Production/Stable** readiness + (replacing the previous Beta development-status marker). - Publishing polish: the README no longer advertises a removed `png` extra, and `MANIFEST.in` no longer carries redundant exclusions for non-package directories, which keeps `python -m build` quieter. diff --git a/CITATION.cff b/CITATION.cff index ceb2485..a20b2e6 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -10,7 +10,7 @@ date-released: "2026-05-02" repository-code: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" url: "https://github.com/DOKOS-TAYOS/Tensor-Network-Editor" license: "MIT" -abstract: "A local Python package and browser editor for drawing tensor networks, saving versioned JSON designs, and generating readable Python code for tensor-network backends." +abstract: "A production-ready local Python package and browser editor for drawing tensor networks, saving versioned JSON designs, and generating readable Python code for tensor-network backends." keywords: - "tensor networks" - "scientific computing" diff --git a/README.md b/README.md index 9fc861a..1d7de57 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ [![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-blue)](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor) [![Windows%20%7C%20Linux](https://img.shields.io/badge/platform-Windows%20%7C%20Linux-0A7BBB)](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor/actions/workflows/ci.yml) [![MIT License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) +[![Stability](https://img.shields.io/badge/stability-production--ready-brightgreen)](https://pypi.org/project/tensor-network-editor/) `tensor-network-editor` is a local Python package for drawing tensor networks, saving them as versioned JSON, and generating readable Python code for several diff --git a/pyproject.toml b/pyproject.toml index 6ae9a8c..e2b2ac1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "tensor-network-editor" dynamic = ["version"] -description = "Local visual editor for tensor networks: versioned JSON diagrams and Python code for einsum and optional backends." +description = "Production-ready local visual editor for tensor networks: versioned JSON diagrams and Python code for einsum and optional backends." readme = "README.md" requires-python = ">=3.11" license = "MIT" @@ -38,7 +38,7 @@ keywords = [ "visualization", ] classifiers = [ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", "Intended Audience :: Science/Research", "Operating System :: OS Independent", From bea93e13d0c7c615dc34cc2b94763258cef243b9 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 14 May 2026 10:36:43 +0200 Subject: [PATCH 18/23] Theme academic image exports by editor mode --- CHANGELOG.md | 6 + src/tensor_network_editor/app/routes.py | 104 +++++++++++++++++- .../js/services/editorSessionService.js | 20 ++-- .../static/js/session/sessionEditorFlows.js | 1 + src/tensor_network_editor/rendering.py | 44 +++++--- tests/test_app_routes.py | 53 +++++++++ tests/test_frontend_runtime.py | 12 ++ tests/test_rendering.py | 27 +++++ 8 files changed, 239 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a8c31d8..75b71db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Changed + +- Academic SVG/PNG/PDF exports now inherit the active editor theme for figure + text/background colors; light themes use white PDF backgrounds, while SVG and + PNG exports can preserve transparent backgrounds. + ## [1.0.0] - 2026-05-02 ### Added diff --git a/src/tensor_network_editor/app/routes.py b/src/tensor_network_editor/app/routes.py index f6ae51c..122534b 100644 --- a/src/tensor_network_editor/app/routes.py +++ b/src/tensor_network_editor/app/routes.py @@ -9,6 +9,7 @@ from http import HTTPStatus from typing import Literal, cast +from .._themes import DEFAULT_EDITOR_THEME, EditorThemeName, normalize_editor_theme from ..errors import ( CodeGenerationError, PackageIOError, @@ -112,6 +113,58 @@ class _RenderLabelOptions: show_edge_labels: bool +_IMAGE_EXPORT_THEME_OVERRIDES: dict[EditorThemeName, dict[str, str]] = { + "dark": { + "background": "#0b0d12", + "edge_stroke": "#7e8aa3", + "index_fill": "#d7ae68", + "group_stroke": "#8f7cf7", + "note_fill": "#252b34", + "text_fill": "#f2f5f8", + "muted_text_fill": "#c6d3e6", + }, + "light": { + "background": "#ffffff", + "edge_stroke": "#64748b", + "index_fill": "#b45309", + "group_stroke": "#6d28d9", + "note_fill": "#ffffff", + "text_fill": "#172033", + "muted_text_fill": "#475569", + }, + "contrast": { + "background": "#000000", + "edge_stroke": "#ffffff", + "index_fill": "#ffff00", + "hyperedge_stroke": "#ff5f5f", + "group_stroke": "#ff00ff", + "note_fill": "#101010", + "text_fill": "#ffffff", + "muted_text_fill": "#ffffff", + }, + "colorblind": { + "background": "#ffffff", + "edge_stroke": "#5b5b5b", + "index_fill": "#e69f00", + "hyperedge_stroke": "#d55e00", + "group_stroke": "#cc79a7", + "note_fill": "#ffffff", + "text_fill": "#202124", + "muted_text_fill": "#5b5b5b", + }, + "shiny": { + "background": "#070915", + "edge_stroke": "#94a3b8", + "index_fill": "#facc15", + "hyperedge_stroke": "#fb7185", + "group_stroke": "#e879f9", + "note_fill": "#11152c", + "text_fill": "#f8fafc", + "muted_text_fill": "#c4b5fd", + }, +} + + def _route_context( session: EditorSession | None, route: str, @@ -344,12 +397,15 @@ def handle_render(session: EditorSession, payload: JsonDict) -> JsonResponse: serialized_spec = require_serialized_spec(payload) spec = deserialize_spec(serialized_spec, validate=False) label_options = _resolve_render_label_options(payload) + render_theme = _resolve_render_theme(payload) success_context["format"] = render_format + success_context["theme"] = render_theme success_context.update(summarize_spec_counts(spec)) response_payload = _build_render_response( render_format, spec, label_options, + theme=render_theme, ) except ValueError as exc: return bad_request_response(str(exc)) @@ -733,12 +789,29 @@ def _resolve_render_label_options(payload: JsonDict) -> _RenderLabelOptions: ) -def _svg_render_options(label_options: _RenderLabelOptions) -> SvgRenderOptions: +def _resolve_render_theme(payload: JsonDict) -> EditorThemeName: + """Return the editor theme requested for one render payload.""" + raw_theme = payload.get("theme") + if raw_theme is None: + return DEFAULT_EDITOR_THEME + if not isinstance(raw_theme, str): + raise ValueError("'theme' must be a string when provided.") + return normalize_editor_theme(raw_theme) + + +def _svg_render_options( + label_options: _RenderLabelOptions, + *, + render_format: _RenderFormat, + theme: EditorThemeName, +) -> SvgRenderOptions: """Return SVG/PNG/PDF render options derived from shared label flags.""" return SvgRenderOptions( show_tensor_labels=label_options.show_tensor_labels, show_index_labels=label_options.show_index_labels, show_edge_labels=label_options.show_edge_labels, + transparent_background=render_format in {"svg", "png"}, + **_IMAGE_EXPORT_THEME_OVERRIDES[theme], ) @@ -792,6 +865,8 @@ def _build_render_response( render_format: _RenderFormat, spec: NetworkSpec, label_options: _RenderLabelOptions, + *, + theme: EditorThemeName = DEFAULT_EDITOR_THEME, ) -> JsonDict: """Return the serialized academic render payload for one format request.""" if render_format == "tikz": @@ -815,18 +890,39 @@ def _build_render_response( if render_format == "svg": return _build_text_render_response( render_format, - render_spec_svg(spec, options=_svg_render_options(label_options)), + render_spec_svg( + spec, + options=_svg_render_options( + label_options, + render_format=render_format, + theme=theme, + ), + ), content_type="image/svg+xml;charset=utf-8", ) if render_format == "png": return _build_binary_render_response( render_format, - render_spec_png(spec, options=_svg_render_options(label_options)), + render_spec_png( + spec, + options=_svg_render_options( + label_options, + render_format=render_format, + theme=theme, + ), + ), content_type="image/png", ) return _build_binary_render_response( render_format, - render_spec_pdf(spec, options=_svg_render_options(label_options)), + render_spec_pdf( + spec, + options=_svg_render_options( + label_options, + render_format=render_format, + theme=theme, + ), + ), content_type="application/pdf", ) diff --git a/src/tensor_network_editor/app/static/js/services/editorSessionService.js b/src/tensor_network_editor/app/static/js/services/editorSessionService.js index b9747f7..6c930af 100644 --- a/src/tensor_network_editor/app/static/js/services/editorSessionService.js +++ b/src/tensor_network_editor/app/static/js/services/editorSessionService.js @@ -83,20 +83,26 @@ export function createEditorSessionService({ apiGet, apiPost }) { showTensorNames = true, showIndexNames = true, showBondNames = true, + theme = null, }) { + const payload = { + format, + spec, + show_tensor_names: showTensorNames, + show_index_names: showIndexNames, + show_bond_names: showBondNames, + }; + if (typeof theme === "string" && theme.trim()) { + payload.theme = theme; + } return apiPost( "/api/render", - { - format, - spec, - show_tensor_names: showTensorNames, - show_index_names: showIndexNames, - show_bond_names: showBondNames, - }, + payload, { operation: "render", context: { format, + theme: payload.theme || null, ...summarizeSerializedSpec(spec), }, } diff --git a/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js b/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js index 9a3f75a..a182708 100644 --- a/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js +++ b/src/tensor_network_editor/app/static/js/session/sessionEditorFlows.js @@ -491,6 +491,7 @@ export function createSessionEditorFlows({ showTensorNames: state.academicExportLabels.tensor, showIndexNames: state.academicExportLabels.index, showBondNames: state.academicExportLabels.bond, + theme: state.selectedTheme, }; } diff --git a/src/tensor_network_editor/rendering.py b/src/tensor_network_editor/rendering.py index 0b2378c..0e87e63 100644 --- a/src/tensor_network_editor/rendering.py +++ b/src/tensor_network_editor/rendering.py @@ -88,6 +88,7 @@ class SvgRenderOptions: text_fill: str = "#f2f5f8" muted_text_fill: str = "#aeb9c7" font_family: str = "Arial, sans-serif" + transparent_background: bool = False @dataclass(slots=True, frozen=True) @@ -426,12 +427,13 @@ def render(self) -> str: f'{_number(width)} {_number(height)}">' ), f"{_text(self._spec.name)}", - ( + ] + if not self._options.transparent_background: + lines.append( f"" - ), - ] + ) if self._options.include_groups: lines.extend(self._render_groups()) lines.extend(self._render_edges(edge_render_infos)) @@ -2104,10 +2106,10 @@ def _build_figure( canvas_height / self._BASE_DPI, ), dpi=self._BASE_DPI, - facecolor=self._options.background, + facecolor=self._figure_facecolor(), ) axes = figure.add_axes((0.0, 0.0, 1.0, 1.0)) - axes.set_facecolor(self._options.background) + axes.set_facecolor(self._figure_facecolor()) axes.set_xlim(bounds.x1, bounds.x1 + canvas_width) axes.set_ylim(bounds.y1 + canvas_height, bounds.y1) axes.set_aspect("equal", adjustable="box") @@ -2121,14 +2123,21 @@ def _build_figure( canvas_height, ) + def _figure_facecolor(self) -> str: + """Return the Matplotlib figure/axes background color.""" + if self._options.transparent_background: + return "none" + return self._options.background + def _savefig_kwargs(self, *, file_format: str) -> dict[str, Any]: + figure_facecolor = self._figure_facecolor() save_kwargs: dict[str, Any] = { "format": file_format, "bbox_inches": None, "pad_inches": 0.0, - "facecolor": self._options.background, - "edgecolor": self._options.background, - "transparent": False, + "facecolor": figure_facecolor, + "edgecolor": figure_facecolor, + "transparent": self._options.transparent_background, "metadata": self._render_metadata(file_format=file_format), } if file_format == "png": @@ -2160,15 +2169,16 @@ def _draw_scene( canvas_width: int, canvas_height: int, ) -> None: - background = patches_module.Rectangle( - (bounds.x1, bounds.y1), - canvas_width, - canvas_height, - facecolor=self._options.background, - edgecolor="none", - zorder=-100, - ) - axes.add_patch(background) + if not self._options.transparent_background: + background = patches_module.Rectangle( + (bounds.x1, bounds.y1), + canvas_width, + canvas_height, + facecolor=self._options.background, + edgecolor="none", + zorder=-100, + ) + axes.add_patch(background) if self._options.include_groups: self._render_groups(axes, patches_module) self._render_edges(axes, patches_module, path_module, edge_render_infos) diff --git a/tests/test_app_routes.py b/tests/test_app_routes.py index c626672..4d2929b 100644 --- a/tests/test_app_routes.py +++ b/tests/test_app_routes.py @@ -448,6 +448,59 @@ def test_render_route_applies_label_options_to_svg_png_and_pdf( assert options.show_edge_labels is False +def test_render_route_applies_theme_to_image_export_options( + editor_server: EditorServer, +) -> None: + spec = build_sample_spec() + serialized_spec = { + "schema_version": SCHEMA_VERSION, + "network": spec.to_dict(), + } + + with ( + patch( + "tensor_network_editor.app.routes.render_spec_svg", + return_value="", + ) as render_svg_mock, + patch( + "tensor_network_editor.app.routes.render_spec_pdf", + return_value=b"%PDF-1.4", + ) as render_pdf_mock, + ): + svg_payload = request_json( + f"{editor_server.base_url}/api/render", + method="POST", + payload={ + "format": "svg", + "spec": serialized_spec, + "theme": "light", + }, + ) + pdf_payload = request_json( + f"{editor_server.base_url}/api/render", + method="POST", + payload={ + "format": "pdf", + "spec": serialized_spec, + "theme": "light", + }, + ) + + svg_options = render_svg_mock.call_args.kwargs["options"] + pdf_options = render_pdf_mock.call_args.kwargs["options"] + + assert svg_payload["ok"] is True + assert pdf_payload["ok"] is True + assert svg_options.background == "#ffffff" + assert svg_options.text_fill == "#172033" + assert svg_options.muted_text_fill == "#475569" + assert svg_options.transparent_background is True + assert pdf_options.background == "#ffffff" + assert pdf_options.text_fill == "#172033" + assert pdf_options.muted_text_fill == "#475569" + assert pdf_options.transparent_background is False + + def test_resolve_render_label_options_reads_shared_payload_flags() -> None: import tensor_network_editor.app.routes as routes_module diff --git a/tests/test_frontend_runtime.py b/tests/test_frontend_runtime.py index 77378c5..096344e 100644 --- a/tests/test_frontend_runtime.py +++ b/tests/test_frontend_runtime.py @@ -14052,6 +14052,7 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path: spec: { name: "draft demo" }, generatedCode: "", editorFinished: false, + selectedTheme: "light", draftAutosaveReady: true, draftAutosaveTimer: null, draftAutosaveDirty: false, @@ -14309,6 +14310,13 @@ def _write_session_editor_draft_autosave_runtime_script(tmp_path: Path) -> Path: ) { throw new Error(`Academic exports should persist view snapshots, received ${JSON.stringify(calls)}.`); } + if ( + svgRenderCall.payload.theme !== "light" || + pngRenderCall.payload.theme !== "light" || + pdfRenderCall.payload.theme !== "light" + ) { + throw new Error(`SVG/PNG/PDF exports should include the active theme, received ${JSON.stringify(calls)}.`); + } if (!svgDownloadCall || svgDownloadCall.contentType !== "image/svg+xml;charset=utf-8") { throw new Error(`Expected SVG export to download a .svg file, received ${JSON.stringify(calls)}.`); } @@ -15061,6 +15069,7 @@ def _write_editor_session_service_validate_python_runtime_script( await service.renderSpec({ format: "dot", spec: { schema_version: 2, network: { id: "network_draft" } }, + theme: "light", }); await service.clearDraft(); @@ -15109,6 +15118,9 @@ def _write_editor_session_service_validate_python_runtime_script( if (apiCalls[4].payload.format !== "dot" || apiCalls[4].payload.spec.network.id !== "network_draft") { throw new Error(`Expected renderSpec to keep format and spec payloads, received ${JSON.stringify(apiCalls[4])}.`); } + if (apiCalls[4].payload.theme !== "light") { + throw new Error(`Expected renderSpec to include the current theme, received ${JSON.stringify(apiCalls[4])}.`); + } if (apiCalls[5].path !== "/api/draft/clear" || apiCalls[5].method !== "POST") { throw new Error(`Expected clearDraft to POST /api/draft/clear, received ${JSON.stringify(apiCalls[5])}.`); } diff --git a/tests/test_rendering.py b/tests/test_rendering.py index ab8c16b..75001cc 100644 --- a/tests/test_rendering.py +++ b/tests/test_rendering.py @@ -800,6 +800,20 @@ def test_render_spec_svg_writes_output_path(tmp_path: Path) -> None: assert output_path.read_text(encoding="utf-8") == svg +def test_render_spec_svg_omits_solid_background_when_transparent() -> None: + pytest.importorskip("matplotlib") + + svg = render_spec_svg( + build_sample_spec(), + options=SvgRenderOptions( + background="#abcdef", + transparent_background=True, + ), + ) + + assert "#abcdef" not in svg + + def test_render_spec_svg_reuses_edge_geometry_within_one_render( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -1172,6 +1186,19 @@ def test_render_spec_png_returns_png_bytes_and_writes_output_path( assert output_path.read_bytes() == png_bytes +def test_render_spec_png_uses_alpha_channel_when_transparent() -> None: + pytest.importorskip("matplotlib") + from tensor_network_editor.rendering import render_spec_png + + png_bytes = render_spec_png( + build_sample_spec(), + options=SvgRenderOptions(transparent_background=True), + ) + + assert png_bytes[12:16] == b"IHDR" + assert png_bytes[25] == 6 + + def test_render_spec_pdf_returns_pdf_bytes_and_writes_output_path( tmp_path: Path, ) -> None: From 6566e31e01ccb09dbb8036c441b7b15b7d778315 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 14 May 2026 11:33:45 +0200 Subject: [PATCH 19/23] Harden local editor API security --- .github/workflows/ci.yml | 5 + CHANGELOG.md | 12 + THIRD_PARTY_LICENSES | 17 ++ pyproject.toml | 1 + .../app/_analysis_services.py | 2 + src/tensor_network_editor/app/_limits.py | 180 +++++++++++++++ src/tensor_network_editor/app/_protocol.py | 10 + .../app/_session_requests.py | 3 + .../app/_subnetwork_library_services.py | 5 + .../app/_subnetwork_services.py | 5 + .../app/_template_services.py | 6 + src/tensor_network_editor/app/routes.py | 37 ++- src/tensor_network_editor/app/server.py | 211 +++++++++++++++++- src/tensor_network_editor/app/session.py | 9 +- .../app/static/js/core/editorContext.js | 13 ++ .../app/static/js/core/frontendLogger.js | 24 +- .../js/interactions/interactionsShortcuts.js | 10 +- .../app/static/js/services/api.js | 32 ++- src/tensor_network_editor/editor.py | 31 +++ tests/app_support.py | 50 +++++ tests/test_api.py | 17 ++ tests/test_app_routes.py | 202 ++++++++++++++++- tests/test_app_server.py | 53 +++++ tests/test_frontend_runtime.py | 192 ++++++++++++++++ tests/test_packaging.py | 16 ++ 25 files changed, 1127 insertions(+), 16 deletions(-) create mode 100644 src/tensor_network_editor/app/_limits.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ea55c1..8c8426f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,6 +71,11 @@ jobs: run: | & $env:VENV_PYTHON -m pyright + - name: Run dependency security audit + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.12' + run: | + & $env:VENV_PYTHON -m pip_audit --skip-editable + - name: Run tests run: | & $env:VENV_PYTHON -m pytest -q diff --git a/CHANGELOG.md b/CHANGELOG.md index 75b71db..dd74b09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +### Security + +- The local editor server now requires a per-session API token for `/api/*` + requests, rejects non-JSON POST bodies, blocks untrusted `Host`/`Origin` + headers, and refuses non-loopback binds unless `allow_remote=True` is set + explicitly. +- Editor API routes now reject excessive tensor-network payloads and oversized + template parameters before running expensive validation, rendering, + contraction analysis, code generation, or subnetwork operations. +- CI now runs a dependency vulnerability audit with `pip-audit` as part of the + development dependency set. + ### Changed - Academic SVG/PNG/PDF exports now inherit the active editor theme for figure diff --git a/THIRD_PARTY_LICENSES b/THIRD_PARTY_LICENSES index 090d73b..79a84df 100644 --- a/THIRD_PARTY_LICENSES +++ b/THIRD_PARTY_LICENSES @@ -8,6 +8,8 @@ Runtime pip-installed dependencies are not bundled into this source distribution or wheel. Required installs (`matplotlib`, `opt_einsum`) and optional extras such as `numpy`, `torch`, `quimb`, `tensornetwork`, `tensorkrowch`, or `pywebview` remain covered by their own licenses and notices. +Development tools installed through the `dev` extra are also not bundled and +remain covered by their own upstream licenses and notices. Runtime dependency notice ------------------------- @@ -28,6 +30,21 @@ Matplotlib is a required runtime dependency for academic SVG/PNG/PDF rendering. vendored into this repository. Its own distribution carries the authoritative license text and notices. +Development dependency notice +----------------------------- + +pip-audit is a development and CI dependency used to scan Python dependencies +for known vulnerabilities. + +- Package: pip-audit +- Version range used by this project: `>=2.7` +- Project: https://pypi.org/project/pip-audit/ +- License: Apache Software License +- Notice handling: + pip-audit is installed as an external development dependency rather than + being vendored into this repository. Its own distribution carries the + authoritative license text and notices. + 1. Cytoscape.js - Bundled file: `src/tensor_network_editor/app/static/vendor/cytoscape.min.js` diff --git a/pyproject.toml b/pyproject.toml index e2b2ac1..3e5073f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ torch = ["torch>=2.0"] dev = [ "build>=1.2", "mypy>=1.10", + "pip-audit>=2.7", "pyright>=1.1", "pytest>=8.2", "ruff>=0.6", diff --git a/src/tensor_network_editor/app/_analysis_services.py b/src/tensor_network_editor/app/_analysis_services.py index 0822cf4..ad42cf4 100644 --- a/src/tensor_network_editor/app/_analysis_services.py +++ b/src/tensor_network_editor/app/_analysis_services.py @@ -15,6 +15,7 @@ from ..internal.analysis._contraction_analysis import _analyze_validated_contraction from ..internal.analysis._contraction_analysis_types import ContractionAnalysisResult from ..models import NetworkSpec, ValidationIssue +from ._limits import enforce_spec_api_limits LOGGER = logging.getLogger(__name__) @@ -32,6 +33,7 @@ def analyze_serialized_contraction( emit_start=False, ) as success_context: spec = deserialize_spec_fn(serialized_spec) + enforce_spec_api_limits(spec) issues = validate_spec_fn(spec) if issues: log_branch( diff --git a/src/tensor_network_editor/app/_limits.py b/src/tensor_network_editor/app/_limits.py new file mode 100644 index 0000000..208298f --- /dev/null +++ b/src/tensor_network_editor/app/_limits.py @@ -0,0 +1,180 @@ +"""Complexity limits for local editor API payloads.""" + +from __future__ import annotations + +from collections.abc import Iterator +from dataclasses import dataclass + +from ..internal.models._model_periodic import LinearPeriodicCellSpec +from ..internal.templates._template_catalog import TemplateParameters +from ..models import NetworkSpec, TensorSpec + +MAX_API_TENSORS = 512 +MAX_API_INDICES = 4096 +MAX_API_CONNECTIONS = 4096 +MAX_API_TENSOR_RANK = 64 +MAX_API_INDEX_DIMENSION = 1_000_000 +MAX_API_TENSOR_CARDINALITY = 10_000_000 +MAX_API_TEMPLATE_LINEAR_GRAPH_SIZE = 512 +MAX_API_TEMPLATE_GRID_SIDE_LENGTH = 32 +MAX_API_TEMPLATE_TREE_DEPTH = 10 +MAX_API_TEMPLATE_DIMENSION = 4096 +_GRID_TEMPLATE_NAMES = frozenset({"peps_2x2", "pepo"}) +_TREE_TEMPLATE_NAMES = frozenset({"mera", "ttn"}) + + +@dataclass(slots=True) +class _SpecComplexity: + """Accumulated size information for one editor API payload.""" + + tensor_count: int = 0 + index_count: int = 0 + connection_count: int = 0 + + +def enforce_spec_api_limits(spec: NetworkSpec) -> None: + """Reject a network spec that is too expensive for the local HTTP API.""" + complexity = _SpecComplexity() + for tensors, edge_count in _iter_spec_parts(spec): + complexity.tensor_count += len(tensors) + complexity.connection_count += edge_count + for tensor in tensors: + _enforce_tensor_api_limits(tensor) + complexity.index_count += len(tensor.indices) + + complexity.connection_count += sum( + len(hyperedge.endpoints) for hyperedge in spec.hyperedges + ) + _enforce_count_limit( + name="tensors", + count=complexity.tensor_count, + limit=MAX_API_TENSORS, + ) + _enforce_count_limit( + name="indices", + count=complexity.index_count, + limit=MAX_API_INDICES, + ) + _enforce_count_limit( + name="connections", + count=complexity.connection_count, + limit=MAX_API_CONNECTIONS, + ) + + +def enforce_template_api_limits( + template_name: str, + parameters: TemplateParameters | None, +) -> None: + """Reject built-in template parameters that would create huge payloads.""" + if parameters is None: + return + graph_size_limit = _template_graph_size_limit(template_name) + if parameters.graph_size is not None and parameters.graph_size > graph_size_limit: + raise ValueError( + "Template parameter 'graph_size' " + f"is {parameters.graph_size}, above the API limit of {graph_size_limit}." + ) + _enforce_optional_template_dimension( + parameters.bond_dimension, + field_name="bond_dimension", + ) + _enforce_optional_template_dimension( + parameters.physical_dimension, + field_name="physical_dimension", + ) + + +def _iter_spec_parts(spec: NetworkSpec) -> Iterator[tuple[list[TensorSpec], int]]: + """Yield tensor and edge collections stored in a spec payload.""" + yield spec.tensors, len(spec.edges) + if spec.linear_periodic_chain is not None: + for cell in ( + spec.linear_periodic_chain.initial_cell, + spec.linear_periodic_chain.periodic_cell, + spec.linear_periodic_chain.final_cell, + ): + yield from _iter_cell_parts(cell) + if spec.grid_periodic_grid is not None: + for cell in ( + spec.grid_periodic_grid.top_left_cell, + spec.grid_periodic_grid.top_cell, + spec.grid_periodic_grid.top_right_cell, + spec.grid_periodic_grid.left_cell, + spec.grid_periodic_grid.center_cell, + spec.grid_periodic_grid.right_cell, + spec.grid_periodic_grid.bottom_left_cell, + spec.grid_periodic_grid.bottom_cell, + spec.grid_periodic_grid.bottom_right_cell, + ): + yield from _iter_cell_parts(cell) + if spec.tree_periodic_tree is not None: + for cell in ( + spec.tree_periodic_tree.root_cell, + spec.tree_periodic_tree.branch_cell, + spec.tree_periodic_tree.leaf_cell, + ): + yield from _iter_cell_parts(cell) + + +def _iter_cell_parts( + cell: LinearPeriodicCellSpec, +) -> Iterator[tuple[list[TensorSpec], int]]: + """Yield tensor and edge collections stored in one periodic cell.""" + yield cell.tensors, len(cell.edges) + + +def _enforce_tensor_api_limits(tensor: TensorSpec) -> None: + """Reject one tensor whose local shape is too expensive.""" + rank = len(tensor.indices) + if rank > MAX_API_TENSOR_RANK: + raise ValueError( + f"Tensor '{tensor.name}' has rank {rank}, " + f"above the API limit of {MAX_API_TENSOR_RANK}." + ) + cardinality = 1 + for index in tensor.indices: + if index.dimension > MAX_API_INDEX_DIMENSION: + raise ValueError( + f"Index '{index.name}' on tensor '{tensor.name}' has dimension " + f"{index.dimension}, above the API limit of {MAX_API_INDEX_DIMENSION}." + ) + if index.dimension > 0: + cardinality *= index.dimension + if cardinality > MAX_API_TENSOR_CARDINALITY: + raise ValueError( + f"Tensor '{tensor.name}' spans {cardinality} elements, " + f"above the API limit of {MAX_API_TENSOR_CARDINALITY}." + ) + + +def _enforce_count_limit(*, name: str, count: int, limit: int) -> None: + """Reject one aggregate count when it exceeds its API limit.""" + if count <= limit: + return + raise ValueError( + f"Network contains {count} {name}, above the API limit of {limit}." + ) + + +def _enforce_optional_template_dimension( + value: int | None, + *, + field_name: str, +) -> None: + """Reject template dimensions that would produce very large tensors.""" + if value is None or value <= MAX_API_TEMPLATE_DIMENSION: + return + raise ValueError( + f"Template parameter '{field_name}' is {value}, " + f"above the API limit of {MAX_API_TEMPLATE_DIMENSION}." + ) + + +def _template_graph_size_limit(template_name: str) -> int: + """Return the graph-size limit appropriate for one template family.""" + if template_name in _GRID_TEMPLATE_NAMES: + return MAX_API_TEMPLATE_GRID_SIDE_LENGTH + if template_name in _TREE_TEMPLATE_NAMES: + return MAX_API_TEMPLATE_TREE_DEPTH + return MAX_API_TEMPLATE_LINEAR_GRAPH_SIZE diff --git a/src/tensor_network_editor/app/_protocol.py b/src/tensor_network_editor/app/_protocol.py index d623124..8cb4ef6 100644 --- a/src/tensor_network_editor/app/_protocol.py +++ b/src/tensor_network_editor/app/_protocol.py @@ -412,6 +412,16 @@ def bad_request_response(message: str) -> JsonResponse: return HTTPStatus.BAD_REQUEST, {"ok": False, "message": message} +def forbidden_response(message: str) -> JsonResponse: + """Return a standard forbidden JSON response.""" + return HTTPStatus.FORBIDDEN, {"ok": False, "message": message} + + +def unsupported_media_type_response(message: str) -> JsonResponse: + """Return a standard unsupported-media-type JSON response.""" + return HTTPStatus.UNSUPPORTED_MEDIA_TYPE, {"ok": False, "message": message} + + def not_found_response() -> JsonResponse: """Return a standard not-found JSON response.""" return HTTPStatus.NOT_FOUND, {"ok": False, "message": "Not found."} diff --git a/src/tensor_network_editor/app/_session_requests.py b/src/tensor_network_editor/app/_session_requests.py index 01491da..cd20cd0 100644 --- a/src/tensor_network_editor/app/_session_requests.py +++ b/src/tensor_network_editor/app/_session_requests.py @@ -17,6 +17,7 @@ EngineIdentifier, TensorCollectionFormat, ) +from ._limits import enforce_spec_api_limits if TYPE_CHECKING: from .session import EditorSession @@ -39,6 +40,7 @@ def generate_session_request( context={"engine": engine_name_to_text(engine)}, ): spec = deserialize_spec(serialized_spec) + enforce_spec_api_limits(spec) log_branch( LOGGER, "Deserialized preview spec", @@ -69,6 +71,7 @@ def complete_session_request( context={"engine": engine_name_to_text(engine)}, ): spec = deserialize_spec(serialized_spec) + enforce_spec_api_limits(spec) log_branch( LOGGER, "Deserialized completion spec", diff --git a/src/tensor_network_editor/app/_subnetwork_library_services.py b/src/tensor_network_editor/app/_subnetwork_library_services.py index ecc1716..09a3e08 100644 --- a/src/tensor_network_editor/app/_subnetwork_library_services.py +++ b/src/tensor_network_editor/app/_subnetwork_library_services.py @@ -14,6 +14,7 @@ ) from ..models import CanvasPosition, NetworkSpec from ._bootstrap_payloads import build_subnetwork_catalog_payload +from ._limits import enforce_spec_api_limits from ._protocol import JsonDict if TYPE_CHECKING: @@ -44,7 +45,9 @@ def save_serialized_subnetwork_to_library( LOGGER, "Reusable subnetwork save", context=context ) as success_context: spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) saved_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids) + enforce_spec_api_limits(saved_spec) session.save_project_subnetwork( subnetwork_name, saved_spec, @@ -138,9 +141,11 @@ def prepare_saved_subnetwork_for_insertion( context=context, ) as success_context: spec = session.build_saved_subnetwork(subnetwork_name) + enforce_spec_api_limits(spec) prepared_spec = prepare_subnetwork_for_insertion( spec, target_center=target_center, ) + enforce_spec_api_limits(prepared_spec) success_context.update(summarize_spec_counts(prepared_spec)) return prepared_spec diff --git a/src/tensor_network_editor/app/_subnetwork_services.py b/src/tensor_network_editor/app/_subnetwork_services.py index 5893417..5007258 100644 --- a/src/tensor_network_editor/app/_subnetwork_services.py +++ b/src/tensor_network_editor/app/_subnetwork_services.py @@ -12,6 +12,7 @@ prepare_subnetwork_for_insertion, ) from ..models import CanvasPosition, NetworkSpec +from ._limits import enforce_spec_api_limits LOGGER = logging.getLogger(__name__) @@ -28,7 +29,9 @@ def extract_serialized_subnetwork( context={"tensor_id_count": len(tensor_ids)}, ): spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) extracted_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids) + enforce_spec_api_limits(extracted_spec) log_branch( LOGGER, "Extracted transient reusable subnetwork", @@ -45,10 +48,12 @@ def prepare_serialized_subnetwork_for_insertion( """Deserialize one payload and prepare it for editor insertion.""" with log_operation(LOGGER, "Transient subnetwork insertion preparation"): spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) prepared_spec = prepare_subnetwork_for_insertion( spec, target_center=target_center, ) + enforce_spec_api_limits(prepared_spec) log_branch( LOGGER, "Prepared transient subnetwork for insertion", diff --git a/src/tensor_network_editor/app/_template_services.py b/src/tensor_network_editor/app/_template_services.py index 924d27c..56b5187 100644 --- a/src/tensor_network_editor/app/_template_services.py +++ b/src/tensor_network_editor/app/_template_services.py @@ -16,6 +16,7 @@ parse_template_parameters, ) from ._bootstrap_payloads import build_template_catalog_payload +from ._limits import enforce_spec_api_limits, enforce_template_api_limits from ._protocol import JsonDict if TYPE_CHECKING: @@ -36,6 +37,7 @@ def build_template_from_payload( if session.has_project_template(template_name): log_branch(LOGGER, "Loading template from the project catalog") spec = session.build_project_template(template_name) + enforce_spec_api_limits(spec) success_context.update(summarize_spec_counts(spec)) success_context["status"] = "project" return spec @@ -43,7 +45,9 @@ def build_template_from_payload( template_name, raw_parameters, ) + enforce_template_api_limits(template_name, parameters) spec = build_template_spec(template_name, parameters) + enforce_spec_api_limits(spec) success_context.update(summarize_spec_counts(spec)) success_context["status"] = "global" return spec @@ -68,7 +72,9 @@ def promote_serialized_subnetwork_to_template( LOGGER, "Template promotion", context=context ) as success_context: spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) promoted_spec = extract_subnetwork_spec(spec, tensor_ids=tensor_ids) + enforce_spec_api_limits(promoted_spec) promoted_spec.name = session.build_project_template_display_name(template_name) session.save_project_template( template_name, diff --git a/src/tensor_network_editor/app/routes.py b/src/tensor_network_editor/app/routes.py index 122534b..ece031c 100644 --- a/src/tensor_network_editor/app/routes.py +++ b/src/tensor_network_editor/app/routes.py @@ -7,7 +7,7 @@ from collections.abc import Callable from dataclasses import dataclass from http import HTTPStatus -from typing import Literal, cast +from typing import Literal, TypedDict, cast from .._themes import DEFAULT_EDITOR_THEME, EditorThemeName, normalize_editor_theme from ..errors import ( @@ -41,6 +41,7 @@ from ..types import JSONValue from ..validation import validate_spec from ._drafts import clear_project_draft, load_project_draft, save_project_draft +from ._limits import enforce_spec_api_limits from ._protocol import ( JsonDict, JsonResponse, @@ -95,6 +96,19 @@ _RenderFormat = Literal["tikz", "dot", "mermaid", "svg", "png", "pdf"] +class _ImageExportThemeOverride(TypedDict, total=False): + """Theme override fields supported by image render options.""" + + background: str + edge_stroke: str + group_stroke: str + hyperedge_stroke: str + index_fill: str + muted_text_fill: str + note_fill: str + text_fill: str + + @dataclass(slots=True, frozen=True) class _FrontendClientLogEvent: """Validated frontend log event ready for persistence.""" @@ -113,7 +127,7 @@ class _RenderLabelOptions: show_edge_labels: bool -_IMAGE_EXPORT_THEME_OVERRIDES: dict[EditorThemeName, dict[str, str]] = { +_IMAGE_EXPORT_THEME_OVERRIDES: dict[EditorThemeName, _ImageExportThemeOverride] = { "dark": { "background": "#0b0d12", "edge_stroke": "#7e8aa3", @@ -285,6 +299,15 @@ def handle_validate(session: EditorSession, payload: JsonDict) -> JsonResponse: level=logging.WARNING, ) return bad_request_response("Missing 'spec' or 'python_code' payload.") + try: + enforce_spec_api_limits(spec) + except ValueError as exc: + log_branch( + LOGGER, + f"Validation request exceeded API limits: {exc}", + level=logging.WARNING, + ) + return bad_request_response(str(exc)) issues = validate_spec(spec) if issues: log_branch( @@ -396,6 +419,7 @@ def handle_render(session: EditorSession, payload: JsonDict) -> JsonResponse: render_format = _resolve_render_format(payload) serialized_spec = require_serialized_spec(payload) spec = deserialize_spec(serialized_spec, validate=False) + enforce_spec_api_limits(spec) label_options = _resolve_render_label_options(payload) render_theme = _resolve_render_theme(payload) success_context["format"] = render_format @@ -498,6 +522,13 @@ def handle_analyze_contraction( level=logging.WARNING, ) return bad_request_response(str(exc)) + except ValueError as exc: + log_branch( + LOGGER, + f"Contraction analysis request exceeded API limits: {exc}", + level=logging.WARNING, + ) + return bad_request_response(str(exc)) except SpecValidationError as exc: log_branch( LOGGER, @@ -968,6 +999,8 @@ def _handle_session_codegen_request( raise ValueError(f"Unsupported code generation operation '{operation}'.") except SerializationError as exc: return bad_request_response(str(exc)) + except ValueError as exc: + return bad_request_response(str(exc)) except CodeGenerationError as exc: return bad_request_response(str(exc)) except PackageIOError as exc: diff --git a/src/tensor_network_editor/app/server.py b/src/tensor_network_editor/app/server.py index 7ef44cc..0023288 100644 --- a/src/tensor_network_editor/app/server.py +++ b/src/tensor_network_editor/app/server.py @@ -2,9 +2,12 @@ from __future__ import annotations +import hmac +import ipaddress import json import logging import mimetypes +import secrets import threading import time from collections.abc import Callable @@ -29,9 +32,11 @@ JsonDict, JsonResponse, bad_request_response, + forbidden_response, internal_server_error_response, not_found_response, read_json, + unsupported_media_type_response, ) from .session import EditorSession @@ -54,6 +59,8 @@ _QUIET_MISSING_STATIC_ASSET_PATHS: frozenset[str] = frozenset({"/favicon.ico"}) _ScannedStaticAssetFile: TypeAlias = tuple[Path, str, int, int] _RUNTIME_CONFIG_PLACEHOLDER = "__TNE_RUNTIME_CONFIG__" +_API_TOKEN_HEADER = "X-TNE-Session-Token" +_EXPECTED_JSON_CONTENT_TYPE = "application/json" class SupportsReadBytes(Protocol): @@ -94,6 +101,75 @@ def _read_request_body_bytes(reader: SupportsReadBytes, content_length: int) -> return b"".join(chunks) +def _is_loopback_host_name(host_name: str) -> bool: + """Return whether a hostname literal is safe for local-only editor serving.""" + normalized_host = host_name.strip().strip("[]").rstrip(".").lower() + if normalized_host in {"localhost"} or normalized_host.endswith(".localhost"): + return True + if "%" in normalized_host: + normalized_host = normalized_host.split("%", 1)[0] + try: + address = ipaddress.ip_address(normalized_host) + except ValueError: + return False + return address.is_loopback + + +def _validate_bind_host(host: str, *, allow_remote: bool) -> None: + """Reject non-loopback bind hosts unless remote serving is explicit.""" + if allow_remote or _is_loopback_host_name(host): + return + raise ValueError( + "Refusing to bind the editor server to a non-loopback host. " + "Use allow_remote=True only when you intentionally expose this local API." + ) + + +def _host_name_from_header(host_header: str | None) -> str | None: + """Extract the hostname portion from one HTTP Host header.""" + if host_header is None: + return None + value = host_header.strip() + if not value: + return None + if value.startswith("["): + end_index = value.find("]") + if end_index <= 1: + return None + return value[1:end_index] + if value.count(":") == 1: + host_name, port_text = value.rsplit(":", 1) + if port_text.isdigit(): + return host_name + return value + + +def _is_trusted_host_header(host_header: str | None, *, allow_remote: bool) -> bool: + """Return whether one Host header is acceptable for this server.""" + if allow_remote: + return bool(host_header and host_header.strip()) + host_name = _host_name_from_header(host_header) + return host_name is not None and _is_loopback_host_name(host_name) + + +def _is_trusted_origin_header(origin_header: str | None, *, allow_remote: bool) -> bool: + """Return whether one optional Origin header is acceptable for API writes.""" + if origin_header is None: + return True + parsed_origin = urlparse(origin_header) + if parsed_origin.scheme not in {"http", "https"}: + return False + return _is_trusted_host_header(parsed_origin.netloc, allow_remote=allow_remote) + + +def _is_json_content_type(content_type: str | None) -> bool: + """Return whether one Content-Type header identifies a JSON request body.""" + if content_type is None: + return False + media_type = content_type.split(";", 1)[0].strip().lower() + return media_type == _EXPECTED_JSON_CONTENT_TYPE + + @dataclass(slots=True, frozen=True) class _BinaryResponse: """Internal response container for pre-encoded bytes.""" @@ -284,26 +360,35 @@ def _get_static_asset_cache(static_dir: Path) -> _StaticAssetCache: return cache -def _build_frontend_runtime_config_payload(session: EditorSession) -> JsonDict: +def _build_frontend_runtime_config_payload( + session: EditorSession, *, api_token: str +) -> JsonDict: """Return the runtime configuration embedded into the editor HTML page.""" return { "session_id": session.session_id, + "api_token": api_token, "frontend_logging": build_frontend_logging_payload(session), } -def _serialize_frontend_runtime_config(session: EditorSession) -> str: +def _serialize_frontend_runtime_config( + session: EditorSession, *, api_token: str +) -> str: """Serialize one session runtime config safely for an inline JSON script.""" - return json.dumps(_build_frontend_runtime_config_payload(session)).replace( - " bytes: +def _render_session_index_body( + index_body: bytes, session: EditorSession, *, api_token: str +) -> bytes: """Return the per-session editor HTML body with embedded runtime config.""" return index_body.replace( _RUNTIME_CONFIG_PLACEHOLDER.encode("utf-8"), - _serialize_frontend_runtime_config(session).encode("utf-8"), + _serialize_frontend_runtime_config(session, api_token=api_token).encode( + "utf-8" + ), ) @@ -325,7 +410,13 @@ class EditorServer: """Serve the browser app and JSON API for one editor session.""" def __init__( - self, session: EditorSession, host: str = "127.0.0.1", port: int = 0 + self, + session: EditorSession, + host: str = "127.0.0.1", + port: int = 0, + *, + allow_remote: bool = False, + api_token: str | None = None, ) -> None: """Initialize the threaded local editor server. @@ -333,16 +424,24 @@ def __init__( session: Shared editor session state served by this HTTP server. host: Local host interface to bind. port: Local port to bind. Use ``0`` for an ephemeral port. + allow_remote: Whether non-loopback bind hosts are allowed. + api_token: Optional pre-generated API token for tests. """ + _validate_bind_host(host, allow_remote=allow_remote) self.session = session self.session_id = session.session_id self.host = host self.port = port + self.allow_remote = allow_remote + self.api_token = api_token or secrets.token_urlsafe(32) + if not self.api_token.strip(): + raise ValueError("Editor API token cannot be empty.") self._static_dir = Path(__file__).resolve().parent / "static" self._static_asset_cache = _get_static_asset_cache(self._static_dir) self._index_body = _render_session_index_body( self._static_asset_cache.index_body, session, + api_token=self.api_token, ) self._server = ThreadingHTTPServer((host, port), self._build_handler()) self._thread = threading.Thread(target=self._serve_forever, daemon=True) @@ -446,6 +545,8 @@ def _build_handler(self) -> type[BaseHTTPRequestHandler]: static_dir = self._static_dir static_asset_cache = self._static_asset_cache index_body = self._index_body + api_token = self.api_token + allow_remote = self.allow_remote def build_index_response() -> _BinaryResponse: """Return the cached main HTML page for this editor session.""" @@ -536,6 +637,10 @@ def do_GET(self) -> None: """Handle one HTTP GET request for assets or bootstrap data.""" parsed = urlparse(self.path) with bind_log_context(session=session_id, route=parsed.path): + if self._reject_untrusted_host(): + return + if self._reject_invalid_api_token(parsed.path): + return try: with log_operation(LOGGER, "Route request"): response = self._dispatch_get(parsed.path) @@ -552,6 +657,14 @@ def do_POST(self) -> None: """Handle one HTTP POST request for the editor JSON API.""" parsed = urlparse(self.path) with bind_log_context(session=session_id, route=parsed.path): + if self._reject_untrusted_host(): + return + if self._reject_untrusted_origin(): + return + if self._reject_invalid_api_token(parsed.path): + return + if self._reject_unsupported_content_type(): + return try: with log_operation(LOGGER, "Route request"): try: @@ -607,6 +720,85 @@ def _dispatch_post(self, path: str, payload: JsonDict) -> JsonResponse: LOGGER.debug(format_log_message(f"Unknown POST path: {path}")) return not_found_response() + def _reject_untrusted_host(self) -> bool: + """Write a forbidden response when the Host header is not local.""" + if _is_trusted_host_header( + self.headers.get("Host"), + allow_remote=allow_remote, + ): + return False + LOGGER.warning( + format_log_message( + "Rejected request with untrusted Host header", + context={"host": self.headers.get("Host")}, + ), + ) + self._prepare_rejected_request_connection() + self._write_response(forbidden_response("Untrusted Host header.")) + return True + + def _reject_untrusted_origin(self) -> bool: + """Write a forbidden response when the Origin header is not local.""" + if _is_trusted_origin_header( + self.headers.get("Origin"), + allow_remote=allow_remote, + ): + return False + LOGGER.warning( + format_log_message( + "Rejected request with untrusted Origin header", + context={"origin": self.headers.get("Origin")}, + ), + ) + self._prepare_rejected_request_connection() + self._write_response(forbidden_response("Untrusted Origin header.")) + return True + + def _reject_invalid_api_token(self, path: str) -> bool: + """Write a forbidden response when an API request lacks the token.""" + if not path.startswith("/api/"): + return False + header_value = self.headers.get(_API_TOKEN_HEADER) + if header_value is not None and hmac.compare_digest( + header_value, + api_token, + ): + return False + LOGGER.warning( + format_log_message( + "Rejected API request with invalid session token" + ), + ) + self._prepare_rejected_request_connection() + self._write_response( + forbidden_response("Invalid editor session token.") + ) + return True + + def _reject_unsupported_content_type(self) -> bool: + """Write an unsupported-media response for non-JSON API writes.""" + if _is_json_content_type(self.headers.get("Content-Type")): + return False + LOGGER.warning( + format_log_message( + "Rejected API request with unsupported Content-Type", + context={"content_type": self.headers.get("Content-Type")}, + ), + ) + self._prepare_rejected_request_connection() + self._write_response( + unsupported_media_type_response( + "Expected Content-Type 'application/json'." + ) + ) + return True + + def _prepare_rejected_request_connection(self) -> None: + """Drain rejected POST bodies before closing the connection.""" + if self.command == "POST": + self._drain_pending_request_body() + self.close_connection = True + def _static_response( self, request_path: str ) -> JsonResponse | _BinaryResponse: @@ -688,6 +880,9 @@ def _write_bytes(self, status: int, body: bytes, content_type: str) -> None: self.send_response(status) self.send_header("Content-Type", content_type) self.send_header("Content-Length", str(len(body))) + self.send_header("X-Content-Type-Options", "nosniff") + self.send_header("Referrer-Policy", "no-referrer") + self.send_header("X-Frame-Options", "DENY") if self.close_connection: self.send_header("Connection", "close") self._write_no_cache_headers() diff --git a/src/tensor_network_editor/app/session.py b/src/tensor_network_editor/app/session.py index 451a1a1..8159ba3 100644 --- a/src/tensor_network_editor/app/session.py +++ b/src/tensor_network_editor/app/session.py @@ -572,6 +572,7 @@ def launch_editor_session( ui_mode: SessionUiMode | None = None, open_browser: bool = True, host: str = "127.0.0.1", + allow_remote: bool = False, port: int = 0, print_code: bool = False, code_path: StrPath | None = None, @@ -595,6 +596,7 @@ def launch_editor_session( ui_mode: Explicit UI launch mode for the editor session. open_browser: Whether to ask the system browser to open the local URL. host: Local host interface to bind. + allow_remote: Whether non-loopback bind hosts are allowed. port: Local port to bind. Use ``0`` for an ephemeral port. print_code: Whether to print generated code after confirmation. code_path: Optional output path for generated code after confirmation. @@ -649,7 +651,12 @@ def launch_editor_session( shared_subnetwork_catalog_path=shared_subnetwork_catalog_path, draft_path=draft_path, ) - server = EditorServer(session=session, host=host, port=port) + server = EditorServer( + session=session, + host=host, + port=port, + allow_remote=allow_remote, + ) effective_ui_mode = resolve_editor_ui_mode( ui_mode=ui_mode, open_browser=open_browser, diff --git a/src/tensor_network_editor/app/static/js/core/editorContext.js b/src/tensor_network_editor/app/static/js/core/editorContext.js index 0ac4784..596d3a1 100644 --- a/src/tensor_network_editor/app/static/js/core/editorContext.js +++ b/src/tensor_network_editor/app/static/js/core/editorContext.js @@ -8,6 +8,16 @@ import { createInitialState } from "../state/state.js"; import { createEditorSelectors } from "../state/editorSelectors.js"; import { createEditorStore } from "../state/editorStore.js"; +function resolveRuntimeApiToken(runtimeConfig) { + const rawToken = + runtimeConfig && typeof runtimeConfig.apiToken === "string" + ? runtimeConfig.apiToken + : runtimeConfig && typeof runtimeConfig.api_token === "string" + ? runtimeConfig.api_token + : null; + return typeof rawToken === "string" && rawToken.trim() ? rawToken.trim() : null; +} + export function createEditorContext({ window, document, @@ -17,13 +27,16 @@ export function createEditorContext({ }) { const state = createInitialState(); const store = createEditorStore(state); + const apiToken = resolveRuntimeApiToken(runtimeConfig); const requestApiGet = (path, options = {}) => apiGet(path, { + apiToken, logger, ...options, }); const requestApiPost = (path, payload, options = {}) => apiPost(path, payload, { + apiToken, logger, ...options, }); diff --git a/src/tensor_network_editor/app/static/js/core/frontendLogger.js b/src/tensor_network_editor/app/static/js/core/frontendLogger.js index 0d276cc..7837afc 100644 --- a/src/tensor_network_editor/app/static/js/core/frontendLogger.js +++ b/src/tensor_network_editor/app/static/js/core/frontendLogger.js @@ -85,6 +85,16 @@ function normalizeRuntimeConfig(runtimeConfig = {}) { typeof rawSessionId === "string" && rawSessionId.trim() ? rawSessionId.trim() : null; + const rawApiToken = + typeof candidate.apiToken === "string" + ? candidate.apiToken + : typeof candidate.api_token === "string" + ? candidate.api_token + : null; + const apiToken = + typeof rawApiToken === "string" && rawApiToken.trim() + ? rawApiToken.trim() + : null; const rawEnabled = frontendLogging.enabled === true; const level = normalizeLevel(frontendLogging.level, rawEnabled); const enabled = rawEnabled || level !== "off"; @@ -100,6 +110,7 @@ function normalizeRuntimeConfig(runtimeConfig = {}) { : null; const persist = frontendLogging.persist === true && transportEndpoint !== null; return { + apiToken, enabled, level, persist, @@ -269,6 +280,7 @@ export function createFrontendLogger( const body = JSON.stringify({ events }); if ( preferBeacon + && !resolvedRuntimeConfig.apiToken && navigatorRef && typeof navigatorRef.sendBeacon === "function" && navigatorRef.sendBeacon(resolvedRuntimeConfig.transportEndpoint, body) @@ -278,10 +290,14 @@ export function createFrontendLogger( if (typeof fetchRef !== "function") { return false; } + const headers = { "Content-Type": "application/json" }; + if (resolvedRuntimeConfig.apiToken) { + headers["X-TNE-Session-Token"] = resolvedRuntimeConfig.apiToken; + } try { await fetchRef(resolvedRuntimeConfig.transportEndpoint, { method: "POST", - headers: { "Content-Type": "application/json" }, + headers, body, keepalive: true, }); @@ -443,6 +459,12 @@ export function createFrontendLogger( }, refreshRuntimeConfig(nextRuntimeConfig = {}) { resolvedRuntimeConfig = normalizeRuntimeConfig({ + apiToken: + typeof nextRuntimeConfig.apiToken === "string" + ? nextRuntimeConfig.apiToken + : typeof nextRuntimeConfig.api_token === "string" + ? nextRuntimeConfig.api_token + : resolvedRuntimeConfig.apiToken, sessionId: typeof nextRuntimeConfig.sessionId === "string" ? nextRuntimeConfig.sessionId diff --git a/src/tensor_network_editor/app/static/js/interactions/interactionsShortcuts.js b/src/tensor_network_editor/app/static/js/interactions/interactionsShortcuts.js index 7e76f2a..ab5ebe4 100644 --- a/src/tensor_network_editor/app/static/js/interactions/interactionsShortcuts.js +++ b/src/tensor_network_editor/app/static/js/interactions/interactionsShortcuts.js @@ -756,9 +756,17 @@ export function createInteractionShortcutBindings({ if (state.editorFinished || typeof fetch !== "function") { return; } + const headers = { "Content-Type": "application/json" }; + const apiToken = + ctx.runtimeConfig && typeof ctx.runtimeConfig.apiToken === "string" + ? ctx.runtimeConfig.apiToken.trim() + : ""; + if (apiToken) { + headers["X-TNE-Session-Token"] = apiToken; + } void fetch("/api/cancel", { method: "POST", - headers: { "Content-Type": "application/json" }, + headers, body: JSON.stringify({}), keepalive: true, }); diff --git a/src/tensor_network_editor/app/static/js/services/api.js b/src/tensor_network_editor/app/static/js/services/api.js index e5364d3..ef85282 100644 --- a/src/tensor_network_editor/app/static/js/services/api.js +++ b/src/tensor_network_editor/app/static/js/services/api.js @@ -29,6 +29,27 @@ function buildErrorMessage({ text, json }) { return typeof text === "string" && text.trim() ? text.trim() : "Request failed."; } +function resolveApiToken(options = {}) { + const rawToken = + typeof options.apiToken === "string" + ? options.apiToken + : typeof options.api_token === "string" + ? options.api_token + : typeof options.sessionToken === "string" + ? options.sessionToken + : null; + return typeof rawToken === "string" && rawToken.trim() ? rawToken.trim() : null; +} + +function buildRequestHeaders(options = {}, baseHeaders = {}) { + const headers = { ...baseHeaders }; + const apiToken = resolveApiToken(options); + if (apiToken) { + headers["X-TNE-Session-Token"] = apiToken; + } + return headers; +} + function requireJsonBody({ json }) { if (json === null) { throw new Error("Expected a JSON response."); @@ -160,7 +181,14 @@ async function performJsonRequest(method, path, init = {}, options = {}) { } export async function apiGet(path, options = {}) { - return performJsonRequest("GET", path, {}, options); + return performJsonRequest( + "GET", + path, + { + headers: buildRequestHeaders(options), + }, + options + ); } export async function apiPost(path, payload, options = {}) { @@ -169,7 +197,7 @@ export async function apiPost(path, payload, options = {}) { path, { method: "POST", - headers: { "Content-Type": "application/json" }, + headers: buildRequestHeaders(options, { "Content-Type": "application/json" }), body: JSON.stringify(payload), }, options diff --git a/src/tensor_network_editor/editor.py b/src/tensor_network_editor/editor.py index 381c716..8368d11 100644 --- a/src/tensor_network_editor/editor.py +++ b/src/tensor_network_editor/editor.py @@ -2,6 +2,7 @@ from __future__ import annotations +import ipaddress import logging from collections.abc import Callable from dataclasses import dataclass @@ -47,6 +48,7 @@ class EditorLaunchOptions: ui_mode: EditorUiMode | None = None open_browser: bool = True host: str = "127.0.0.1" + allow_remote: bool = False port: int = 0 print_code: bool = False code_path: StrPath | None = None @@ -68,6 +70,10 @@ def __post_init__(self) -> None: ui_mode=validated_ui_mode, open_browser=self.open_browser, ) + _validate_editor_bind_options( + host=self.host, + allow_remote=self.allow_remote, + ) validate_positive_log_setting( self.log_file_max_bytes, name="log_file_max_bytes", @@ -132,6 +138,7 @@ def open_editor( ui_mode=effective_ui_mode, open_browser=resolved_options.open_browser, host=resolved_options.host, + allow_remote=resolved_options.allow_remote, port=resolved_options.port, print_code=resolved_options.print_code, code_path=resolved_options.code_path, @@ -192,6 +199,30 @@ def _validate_editor_ui_mode_compatibility( raise ValueError("ui_mode='server' requires open_browser=False.") +def _is_loopback_host_name(host_name: str) -> bool: + """Return whether a host name is safe for local-only editor serving.""" + normalized_host = host_name.strip().strip("[]").rstrip(".").lower() + if normalized_host in {"localhost"} or normalized_host.endswith(".localhost"): + return True + if "%" in normalized_host: + normalized_host = normalized_host.split("%", 1)[0] + try: + address = ipaddress.ip_address(normalized_host) + except ValueError: + return False + return address.is_loopback + + +def _validate_editor_bind_options(*, host: str, allow_remote: bool) -> None: + """Reject non-loopback hosts unless the caller explicitly opts in.""" + if allow_remote or _is_loopback_host_name(host): + return + raise ValueError( + "Refusing to bind the editor server to a non-loopback host. " + "Use allow_remote=True only when you intentionally expose this local API." + ) + + def resolve_editor_ui_mode( *, ui_mode: EditorUiMode | None, diff --git a/tests/app_support.py b/tests/app_support.py index a43b480..b0e369a 100644 --- a/tests/app_support.py +++ b/tests/app_support.py @@ -1,14 +1,21 @@ from __future__ import annotations import json +import re import time from typing import Any, cast from urllib.error import HTTPError +from urllib.parse import urlsplit from urllib.request import Request, urlopen _ASSET_REQUEST_TIMEOUT_SECONDS = 15.0 _ASSET_REQUEST_RETRY_COUNT = 3 _ASSET_REQUEST_RETRY_DELAY_SECONDS = 0.1 +_RUNTIME_CONFIG_RE = re.compile( + r'', + re.DOTALL, +) +_SESSION_TOKEN_BY_ORIGIN: dict[str, str | None] = {} def request_json( @@ -33,6 +40,8 @@ def request_json_with_status( method: str = "GET", payload: dict[str, Any] | None = None, raw_body: bytes | None = None, + session_token: str | None = None, + include_session_token: bool = True, timeout: float = 5.0, ) -> tuple[int, dict[str, Any]]: data = None @@ -45,6 +54,12 @@ def request_json_with_status( elif raw_body is not None: data = raw_body headers["Content-Type"] = "application/json" + if include_session_token: + resolved_session_token = ( + session_token if session_token is not None else _session_token_for_url(url) + ) + if resolved_session_token: + headers["X-TNE-Session-Token"] = resolved_session_token request = Request(url=url, method=method, data=data, headers=headers) try: with urlopen(request, timeout=timeout) as response: @@ -53,6 +68,41 @@ def request_json_with_status( return exc.code, json.loads(exc.read().decode("utf-8")) +def _session_token_for_url(url: str) -> str | None: + """Read the embedded editor API token for a local test server URL.""" + origin = _origin_for_url(url) + if origin is None: + return None + if origin in _SESSION_TOKEN_BY_ORIGIN: + return _SESSION_TOKEN_BY_ORIGIN[origin] + try: + with urlopen(f"{origin}/", timeout=_ASSET_REQUEST_TIMEOUT_SECONDS) as response: + html = response.read().decode("utf-8") + except OSError: + _SESSION_TOKEN_BY_ORIGIN[origin] = None + return None + match = _RUNTIME_CONFIG_RE.search(html) + if match is None: + _SESSION_TOKEN_BY_ORIGIN[origin] = None + return None + try: + payload = json.loads(match.group(1)) + except json.JSONDecodeError: + _SESSION_TOKEN_BY_ORIGIN[origin] = None + return None + token = payload.get("api_token") if isinstance(payload, dict) else None + _SESSION_TOKEN_BY_ORIGIN[origin] = token if isinstance(token, str) else None + return _SESSION_TOKEN_BY_ORIGIN[origin] + + +def _origin_for_url(url: str) -> str | None: + """Return the scheme/authority origin for an absolute URL.""" + parsed = urlsplit(url) + if not parsed.scheme or not parsed.netloc: + return None + return f"{parsed.scheme}://{parsed.netloc}" + + def _read_asset_response(url: str) -> tuple[bytes, dict[str, str]]: """Read one asset request with retries for transient local-server hiccups.""" last_error: OSError | None = None diff --git a/tests/test_api.py b/tests/test_api.py index 2532c03..baec4f0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -237,6 +237,7 @@ def test_editor_launch_options_defaults_match_public_contract() -> None: assert options.ui_mode is None assert options.open_browser is True assert options.host == "127.0.0.1" + assert options.allow_remote is False assert options.port == 0 assert options.print_code is False assert options.code_path is None @@ -250,6 +251,20 @@ def test_editor_launch_options_rejects_unknown_theme() -> None: EditorLaunchOptions(theme="sepia") # type: ignore[arg-type] +def test_editor_launch_options_rejects_non_loopback_host_without_remote_opt_in() -> ( + None +): + with pytest.raises(ValueError, match="non-loopback"): + EditorLaunchOptions(host="0.0.0.0") + + +def test_editor_launch_options_allows_non_loopback_host_with_remote_opt_in() -> None: + options = EditorLaunchOptions(host="0.0.0.0", allow_remote=True) + + assert options.host == "0.0.0.0" + assert options.allow_remote is True + + def test_editor_ui_mode_type_alias_matches_public_contract() -> None: assert EditorUiMode == Literal["browser", "pywebview", "server"] @@ -286,6 +301,7 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N ui_mode="pywebview", open_browser=False, host="0.0.0.0", + allow_remote=True, port=8123, print_code=True, code_path="generated.py", @@ -307,6 +323,7 @@ def test_open_editor_passes_editor_launch_options(sample_spec: NetworkSpec) -> N ui_mode="pywebview", open_browser=False, host="0.0.0.0", + allow_remote=True, port=8123, print_code=True, code_path="generated.py", diff --git a/tests/test_app_routes.py b/tests/test_app_routes.py index 4d2929b..a944311 100644 --- a/tests/test_app_routes.py +++ b/tests/test_app_routes.py @@ -22,7 +22,14 @@ from tensor_network_editor.internal._logging import package_logging_scope from tensor_network_editor.io import SCHEMA_VERSION from tensor_network_editor.io import deserialize_spec as deserialize_spec_impl -from tensor_network_editor.models import EngineName, NetworkSpec, TensorCollectionFormat +from tensor_network_editor.models import ( + CanvasPosition, + EngineName, + IndexSpec, + NetworkSpec, + TensorCollectionFormat, + TensorSpec, +) from tests.app_support import request_json, request_json_with_status from tests.factories import ( build_linear_periodic_carry_chain_spec, @@ -35,6 +42,10 @@ ) from tests.optional_backends import require_light_optional_modules +MAX_EXPECTED_API_TEMPLATE_LINEAR_GRAPH_SIZE = 512 +MAX_EXPECTED_API_TENSOR_RANK = 64 +MAX_EXPECTED_API_TENSORS = 512 + def test_bootstrap_returns_session_contract( editor_server: EditorServer, @@ -745,6 +756,50 @@ def test_validate_route_rejects_invalid_json_with_400( assert payload == {"ok": False, "message": "Request body contains invalid JSON."} +def test_api_post_rejects_missing_session_token( + editor_server: EditorServer, +) -> None: + status, payload = request_json_with_status( + f"{editor_server.base_url}/api/cancel", + method="POST", + payload={}, + include_session_token=False, + ) + + assert status == 403 + assert payload == {"ok": False, "message": "Invalid editor session token."} + + +def test_api_post_rejects_wrong_session_token( + editor_server: EditorServer, +) -> None: + status, payload = request_json_with_status( + f"{editor_server.base_url}/api/cancel", + method="POST", + payload={}, + session_token="wrong-token", + ) + + assert status == 403 + assert payload == {"ok": False, "message": "Invalid editor session token."} + + +def test_api_post_rejects_non_json_content_type( + editor_server: EditorServer, +) -> None: + status, payload = _post_cancel_with_content_type( + editor_server, + content_type="text/plain", + body=b"{}", + ) + + assert status == 415 + assert payload == { + "ok": False, + "message": "Expected Content-Type 'application/json'.", + } + + def test_generate_route_logs_success_with_session_and_timing( editor_server: EditorServer, serialized_sample_spec: dict[str, object], @@ -829,6 +884,28 @@ def test_validate_route_rejects_non_object_json_payload_with_400( assert payload == {"ok": False, "message": "Expected a JSON object payload."} +def test_validate_route_rejects_specs_above_api_tensor_limit( + editor_server: EditorServer, +) -> None: + oversized_spec = _build_many_tensor_spec(MAX_EXPECTED_API_TENSORS + 1) + + status, payload = request_json_with_status( + f"{editor_server.base_url}/api/validate", + method="POST", + payload={ + "spec": { + "schema_version": SCHEMA_VERSION, + "network": oversized_spec.to_dict(), + } + }, + ) + + assert status == 400 + assert payload["ok"] is False + assert f"contains {MAX_EXPECTED_API_TENSORS + 1} tensors" in payload["message"] + assert f"API limit of {MAX_EXPECTED_API_TENSORS}" in payload["message"] + + @pytest.mark.parametrize("legacy_schema_version", [4, 5, 6]) def test_validate_route_rejects_legacy_schema_versions( editor_server: EditorServer, @@ -862,6 +939,10 @@ def _post_validate_with_raw_content_length( try: connection.putrequest("POST", "/api/validate") connection.putheader("Content-Type", "application/json") + connection.putheader( + "X-TNE-Session-Token", + getattr(editor_server, "api_token", "test-token"), + ) connection.putheader("Content-Length", content_length) connection.endheaders() if body: @@ -873,6 +954,125 @@ def _post_validate_with_raw_content_length( connection.close() +def _post_cancel_with_content_type( + editor_server: EditorServer, *, content_type: str, body: bytes +) -> tuple[int, dict[str, object]]: + parsed = urlparse(editor_server.base_url) + host = parsed.hostname + port = parsed.port + if host is None or port is None: + raise AssertionError(f"Unexpected editor base URL: {editor_server.base_url}") + connection = HTTPConnection(host, port, timeout=5) + try: + connection.request( + "POST", + "/api/cancel", + body=body, + headers={ + "Content-Type": content_type, + "X-TNE-Session-Token": getattr( + editor_server, + "api_token", + "test-token", + ), + }, + ) + response = connection.getresponse() + response_payload = json.loads(response.read().decode("utf-8")) + return response.status, cast(dict[str, object], response_payload) + finally: + connection.close() + + +def _build_many_tensor_spec(tensor_count: int) -> NetworkSpec: + return NetworkSpec( + id="network_many_tensors", + name="many tensors", + tensors=[ + TensorSpec( + id=f"tensor_{index}", + name=f"T{index}", + position=CanvasPosition(x=float(index), y=0.0), + indices=[ + IndexSpec( + id=f"tensor_{index}_open", + name="open", + dimension=2, + ) + ], + ) + for index in range(tensor_count) + ], + ) + + +def test_generate_route_rejects_tensor_rank_above_api_limit( + editor_server: EditorServer, +) -> None: + rank = MAX_EXPECTED_API_TENSOR_RANK + 1 + oversized_spec = NetworkSpec( + id="network_high_rank", + name="high rank", + tensors=[ + TensorSpec( + id="tensor_high_rank", + name="High rank", + position=CanvasPosition(x=0.0, y=0.0), + indices=[ + IndexSpec(id=f"idx_{index}", name=f"i{index}", dimension=2) + for index in range(rank) + ], + ) + ], + ) + + status, payload = request_json_with_status( + f"{editor_server.base_url}/api/generate", + method="POST", + payload={ + "engine": EngineName.EINSUM_NUMPY.value, + "spec": { + "schema_version": SCHEMA_VERSION, + "network": oversized_spec.to_dict(), + }, + }, + ) + + assert status == 400 + assert payload["ok"] is False + assert f"rank {rank}" in payload["message"] + assert f"API limit of {MAX_EXPECTED_API_TENSOR_RANK}" in payload["message"] + + +def test_template_route_rejects_excessive_graph_size_before_building( + editor_server: EditorServer, +) -> None: + with patch( + "tensor_network_editor.app._template_services.build_template_spec", + ) as build_template_mock: + status, payload = request_json_with_status( + f"{editor_server.base_url}/api/template", + method="POST", + payload={ + "template": "mps", + "parameters": { + "graph_size": MAX_EXPECTED_API_TEMPLATE_LINEAR_GRAPH_SIZE + 1, + "bond_dimension": 2, + "physical_dimension": 2, + }, + }, + ) + + assert status == 400 + assert payload["ok"] is False + assert "graph_size" in payload["message"] + assert ( + f"API limit of {MAX_EXPECTED_API_TEMPLATE_LINEAR_GRAPH_SIZE}" + in payload["message"] + ) + build_template_mock.assert_not_called() + + def test_generate_route_uses_default_engine_when_missing( editor_server: EditorServer, serialized_sample_spec: dict[str, object], diff --git a/tests/test_app_server.py b/tests/test_app_server.py index d0f6fc5..7ffd4ec 100644 --- a/tests/test_app_server.py +++ b/tests/test_app_server.py @@ -1,12 +1,15 @@ from __future__ import annotations +import json import logging import os import subprocess import sys from http import HTTPStatus +from http.client import HTTPConnection from pathlib import Path from typing import Protocol, cast +from urllib.parse import urlparse from urllib.request import urlopen import pytest @@ -190,6 +193,52 @@ def test_editor_server_start_makes_shell_and_vendor_assets_immediately_readable_ assert vendor_headers["Content-Type"].startswith("application/javascript") +def test_editor_server_rejects_untrusted_host_header() -> None: + server = EditorServer(EditorSession(initial_spec=build_sample_spec())) + server.start() + try: + parsed = urlparse(server.base_url) + if parsed.hostname is None or parsed.port is None: + raise AssertionError(f"Unexpected editor base URL: {server.base_url}") + connection = HTTPConnection(parsed.hostname, parsed.port, timeout=5) + try: + connection.putrequest("GET", "/", skip_host=True) + connection.putheader("Host", "attacker.example") + connection.endheaders() + response = connection.getresponse() + payload = json.loads(response.read().decode("utf-8")) + finally: + connection.close() + finally: + server.stop() + + assert response.status == HTTPStatus.FORBIDDEN + assert payload == {"ok": False, "message": "Untrusted Host header."} + + +def test_editor_server_rejects_non_loopback_bind_without_remote_opt_in() -> None: + def build_remote_server() -> EditorServer: + server = EditorServer( + EditorSession(initial_spec=build_sample_spec()), + host="0.0.0.0", + ) + server._server.server_close() + return server + + with pytest.raises(ValueError, match="non-loopback"): + build_remote_server() + + +def test_editor_server_allows_remote_bind_with_explicit_opt_in() -> None: + server = EditorServer( + EditorSession(initial_spec=build_sample_spec()), + host="0.0.0.0", + allow_remote=True, + ) + + server._server.server_close() + + def test_editor_index_response_embeds_session_runtime_config() -> None: first_server = EditorServer(EditorSession(initial_spec=build_sample_spec())) second_server = EditorServer(EditorSession(initial_spec=build_sample_spec())) @@ -214,8 +263,12 @@ def test_editor_index_response_embeds_session_runtime_config() -> None: assert 'id="tne-runtime-config"' in first_body assert first_server.session_id in first_body assert second_server.session_id in second_body + assert first_server.api_token in first_body + assert second_server.api_token in second_body assert first_server.session_id != second_server.session_id + assert first_server.api_token != second_server.api_token assert first_body != second_body + assert '"api_token":' in first_body assert '"frontend_logging"' in first_body assert '"enabled": false' in first_body assert '"persist": false' in first_body diff --git a/tests/test_frontend_runtime.py b/tests/test_frontend_runtime.py index 096344e..edb878c 100644 --- a/tests/test_frontend_runtime.py +++ b/tests/test_frontend_runtime.py @@ -986,6 +986,198 @@ def test_api_service_logs_request_lifecycle_with_frontend_logger( ) +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_api_service_sends_session_token_header( + tmp_path: Path, +) -> None: + script_path = tmp_path / "api_session_token_header.mjs" + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const apiUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "services" / "api.js")!r}).href; + const apiModule = await import(apiUrl); + const calls = []; + + function headerValue(headers, name) {{ + if (headers && typeof headers.get === "function") {{ + return headers.get(name); + }} + return headers?.[name] || headers?.[name.toLowerCase()] || null; + }} + + globalThis.fetch = async (path, options = {{}}) => {{ + calls.push({{ path, options }}); + return new Response(JSON.stringify({{ ok: true }}), {{ + status: 200, + headers: {{ "Content-Type": "application/json" }}, + }}); + }}; + + await apiModule.apiGet("/api/bootstrap", {{ + apiToken: "secret-token", + }}); + await apiModule.apiPost("/api/cancel", {{}}, {{ + apiToken: "secret-token", + }}); + + if (calls.length !== 2) {{ + throw new Error(`Expected two calls, received ${{calls.length}}.`); + }} + for (const call of calls) {{ + const token = headerValue(call.options.headers, "X-TNE-Session-Token"); + if (token !== "secret-token") {{ + throw new Error(`Missing session token header: ${{JSON.stringify(call)}}`); + }} + }} + const contentType = headerValue(calls[1].options.headers, "Content-Type"); + if (contentType !== "application/json") {{ + throw new Error(`Missing JSON content type: ${{JSON.stringify(calls[1])}}`); + }} + """ + ), + encoding="utf-8", + ) + + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The api session token header script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_editor_context_passes_runtime_api_token_to_requests( + tmp_path: Path, +) -> None: + script_path = tmp_path / "editor_context_api_token.mjs" + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const contextUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "core" / "editorContext.js")!r}).href; + const contextModule = await import(contextUrl); + const calls = []; + const documentRef = {{ + getElementById() {{ + return null; + }}, + querySelector() {{ + return null; + }}, + }}; + + function headerValue(headers, name) {{ + if (headers && typeof headers.get === "function") {{ + return headers.get(name); + }} + return headers?.[name] || headers?.[name.toLowerCase()] || null; + }} + + globalThis.fetch = async (path, options = {{}}) => {{ + calls.push({{ path, options }}); + return new Response(JSON.stringify({{ ok: true }}), {{ + status: 200, + headers: {{ "Content-Type": "application/json" }}, + }}); + }}; + + const ctx = contextModule.createEditorContext({{ + window: {{}}, + document: documentRef, + cytoscape: null, + runtimeConfig: {{ apiToken: "runtime-secret" }}, + }}); + await ctx.apiPost("/api/cancel", {{}}); + + const token = headerValue(calls[0]?.options?.headers, "X-TNE-Session-Token"); + if (token !== "runtime-secret") {{ + throw new Error(`Missing context token header: ${{JSON.stringify(calls)}}`); + }} + """ + ), + encoding="utf-8", + ) + + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The editor context API token script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + +@pytest.mark.skipif(shutil.which("node") is None, reason="node is required") +def test_runtime_config_reader_normalizes_api_token( + tmp_path: Path, +) -> None: + script_path = tmp_path / "runtime_config_api_token.mjs" + script_path.write_text( + textwrap.dedent( + f""" + import {{ pathToFileURL }} from "node:url"; + + const loggerUrl = pathToFileURL({str(REPO_ROOT / "src" / "tensor_network_editor" / "app" / "static" / "js" / "core" / "frontendLogger.js")!r}).href; + const loggerModule = await import(loggerUrl); + const documentRef = {{ + getElementById(id) {{ + if (id !== "tne-runtime-config") {{ + return null; + }} + return {{ + textContent: JSON.stringify({{ + session_id: "session-1", + api_token: "embedded-token", + frontend_logging: {{ enabled: false }}, + }}), + }}; + }}, + }}; + + const config = loggerModule.readFrontendRuntimeConfig({{ documentRef }}); + if (config.sessionId !== "session-1") {{ + throw new Error(`Unexpected session id: ${{JSON.stringify(config)}}`); + }} + if (config.apiToken !== "embedded-token") {{ + throw new Error(`Unexpected API token: ${{JSON.stringify(config)}}`); + }} + """ + ), + encoding="utf-8", + ) + + completed_process = subprocess.run( + ["node", str(script_path)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert completed_process.returncode == 0, ( + "The runtime config API token script failed.\n" + f"STDOUT:\n{completed_process.stdout}\n" + f"STDERR:\n{completed_process.stderr}" + ) + + @pytest.mark.skipif(shutil.which("node") is None, reason="node is required") def test_frontend_logger_persists_batched_logs_without_api_recursion( tmp_path: Path, diff --git a/tests/test_packaging.py b/tests/test_packaging.py index aaf610f..286509f 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -103,6 +103,19 @@ def test_project_metadata_declares_required_matplotlib_dependency_and_backend_ex assert "png" not in optional_dependencies +def test_project_metadata_and_ci_enable_dependency_audits() -> None: + pyproject_path = Path.cwd() / "pyproject.toml" + ci_path = Path.cwd() / ".github" / "workflows" / "ci.yml" + + payload = tomllib.loads(pyproject_path.read_text(encoding="utf-8")) + dev_dependencies = payload["project"]["optional-dependencies"]["dev"] + ci_text = ci_path.read_text(encoding="utf-8") + + assert "pip-audit>=2.7" in dev_dependencies + assert "Run dependency security audit" in ci_text + assert "-m pip_audit" in ci_text + + def test_docs_do_not_advertise_removed_png_extra() -> None: readme_text = (Path.cwd() / "README.md").read_text(encoding="utf-8") installation_text = (Path.cwd() / "docs" / "installation.md").read_text( @@ -136,6 +149,9 @@ def test_third_party_notices_describe_bundled_asset_scope() -> None: assert "Runtime pip-installed dependencies are not bundled" in third_party_text assert "Package: Matplotlib" in third_party_text assert "License: Matplotlib license" in third_party_text + assert "Development dependency notice" in third_party_text + assert "Package: pip-audit" in third_party_text + assert "License: Apache Software License" in third_party_text assert "THIRD_PARTY_LICENSES" in readme_text From 2134f10c09fb16be0a6964a68476cb373aa43a99 Mon Sep 17 00:00:00 2001 From: Alejandro Mata Ali Date: Thu, 14 May 2026 12:06:53 +0200 Subject: [PATCH 20/23] Harden editor security posture --- .github/dependabot.yml | 11 +++ .github/workflows/ci.yml | 4 + CHANGELOG.md | 11 +++ README.md | 4 + SECURITY.md | 76 +++++++++++++++++++ THIRD_PARTY_LICENSES | 2 +- docs/api.md | 3 + docs/cli.md | 4 + docs/extended_guide.md | 4 + src/tensor_network_editor/app/server.py | 45 ++++++++++- .../app/static/index.html | 4 +- .../static/js/session/sessionEditorFlows.js | 2 +- .../app/static/vendor/prism-core.min.js | 2 +- .../internal/diffing/_diffing.py | 10 ++- .../internal/io/_python_live_import.py | 2 +- .../internal/io/_python_live_import_runner.py | 2 +- src/tensor_network_editor/rendering.py | 3 +- tests/app_support.py | 2 +- tests/test_app_assets.py | 30 ++++++++ tests/test_frontend_runtime.py | 6 ++ tests/test_packaging.py | 60 +++++++++++++++ 21 files changed, 271 insertions(+), 16 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 SECURITY.md diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..4776211 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c8426f..660e901 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -59,6 +59,10 @@ jobs: run: | & $env:VENV_PYTHON -m ruff check . + - name: Run source security lint + run: | + & $env:VENV_PYTHON -m ruff check src --select S + - name: Check Ruff formatting run: | & $env:VENV_PYTHON -m ruff format --check . diff --git a/CHANGELOG.md b/CHANGELOG.md index dd74b09..81d217c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,17 @@ All notable changes to this project will be documented in this file. contraction analysis, code generation, or subnetwork operations. - CI now runs a dependency vulnerability audit with `pip-audit` as part of the development dependency set. +- Bundled PrismJS assets were updated to 1.30.0, and the editor server now + emits a nonce-based Content Security Policy plus additional browser defense + headers. +- Live Python import prompts and docs now state that live import should only be + used with trusted local Python files, because it executes code in a + subprocess with the active Python environment. +- CI now runs Ruff's Bandit security rules against `src`, and Dependabot tracks + Python and GitHub Actions dependency updates. +- Added `SECURITY.md` with private reporting guidance, a maintainer disclosure + checklist, and a PrismJS advisory draft for releases that bundled PrismJS + 1.29.0. ### Changed diff --git a/README.md b/README.md index 1d7de57..fa882f9 100644 --- a/README.md +++ b/README.md @@ -382,6 +382,9 @@ whole load immediately. scripts or imports already resolvable from the active `.venv`. If a Python file depends on sibling modules or path-sensitive imports, prefer the Python API or CLI with the real file path. +- Only use live import with local Python files you trust. Live import executes + the file in a subprocess with the active Python environment, so trusted code + can still read or write local files. - Tensor values in the visual editor support portable built-in initializers, dtype choices, JSON-friendly complex scalars, and external `.npy`, `.npz`, and `.pt` data references. Symbolic expressions are not supported yet. @@ -399,6 +402,7 @@ whole load immediately. - Source code: [github.com/DOKOS-TAYOS/Tensor-Network-Editor](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor) - Changelog: [CHANGELOG.md](CHANGELOG.md) +- Security policy: [SECURITY.md](SECURITY.md) - Example script: [examples/basic_usage.py](examples/basic_usage.py) - Issue tracker: [github.com/DOKOS-TAYOS/Tensor-Network-Editor/issues](https://github.com/DOKOS-TAYOS/Tensor-Network-Editor/issues) - License: [LICENSE](LICENSE) diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..b28987b --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,76 @@ +# Security Policy + +## Reporting a Vulnerability + +Please use GitHub private vulnerability reporting for security issues in this +repository when it is available: + +https://github.com/DOKOS-TAYOS/Tensor-Network-Editor/security/advisories/new + +Do not open a public issue with exploit details, proof-of-concept payloads, or +private environment information. If private reporting is unavailable, open a +public issue asking for a preferred security contact without including the +technical details. + +Useful reports include: + +- affected version or commit +- operating system and Python version +- whether the browser editor, CLI, or Python API is involved +- concise reproduction steps +- expected impact, if known + +## Maintainer Disclosure Checklist + +For a confirmed issue: + +1. Prepare the fix privately or in a normal pull request when the details are + already public. +2. Publish the patched release before publishing the advisory, unless users + need immediate mitigation guidance. +3. Create or update a GitHub Security Advisory with affected versions, patched + versions, severity, impact, workarounds, and references. +4. Mention the fix in `CHANGELOG.md` and the release notes. +5. Consider yanking affected PyPI releases only when discouraging new installs + of those exact releases is safer than leaving them available. Prefer a clear + yank reason that points users to the patched release. + +In short: publish the patched release before publishing the advisory when users +do not need immediate mitigation guidance. + +## PrismJS Advisory Draft + +Use this when publishing the bundled PrismJS update as a repository advisory +for `tensor-network-editor`. This is not a new PrismJS vulnerability; it is a +vendored dependency advisory that points to the upstream issue. + +- Title: Bundled PrismJS before 1.30.0 in the browser-based editor +- Related upstream advisory: CVE-2024-53382 / GHSA-x7hr-w5r2-h6wg +- Affected package: `tensor-network-editor` +- Affected versions: releases that bundle PrismJS 1.29.0 in + `src/tensor_network_editor/app/static/vendor/` +- Patched version: the first release that bundles PrismJS 1.30.0 or later +- Severity: Moderate, matching the upstream PrismJS advisory unless new project + evidence shows a different impact + +Suggested impact text: + +```text +Tensor Network Editor bundled PrismJS 1.29.0 for syntax highlighting in the +browser-based editor. PrismJS versions before 1.30.0 are affected by +CVE-2024-53382 / GHSA-x7hr-w5r2-h6wg. + +Installing or importing the Python package alone does not execute PrismJS. The +affected code path is the browser-based editor. Risk is higher if the local +editor is exposed beyond localhost, or if untrusted HTML-like content can reach +the editor UI. +``` + +Suggested recommendation text: + +```text +Upgrade to the patched Tensor Network Editor release. If you cannot upgrade +immediately, avoid exposing the local editor outside trusted loopback/local +workflows and avoid opening untrusted designs or Python-derived content in the +browser editor. +``` diff --git a/THIRD_PARTY_LICENSES b/THIRD_PARTY_LICENSES index 79a84df..04688bf 100644 --- a/THIRD_PARTY_LICENSES +++ b/THIRD_PARTY_LICENSES @@ -62,7 +62,7 @@ for known vulnerabilities. - Bundled files: `src/tensor_network_editor/app/static/vendor/prism-core.min.js` `src/tensor_network_editor/app/static/vendor/prism-python.min.js` - - Version: 1.29.0 + - Version: 1.30.0 - Project: https://prismjs.com/ - Upstream repository: https://github.com/PrismJS/prism - Copyright: diff --git a/docs/api.md b/docs/api.md index 3d9fc99..04f630e 100644 --- a/docs/api.md +++ b/docs/api.md @@ -273,6 +273,9 @@ Important details: active Python interpreter, supports live `quimb` and `tensornetwork` objects, and accepts `object_name="..."` when several compatible globals exist +- Only use live import with local Python files you trust. Live import executes + the file in a subprocess with the active Python environment, so trusted code + can still read or write local files. - `PythonLoadOptions.reconstruction_level="simple"` rebuilds only the portable network structure: tensors, inferable connections, and portable tensor-data payloads - `PythonLoadOptions.reconstruction_level="best_available"` is currently only supported diff --git a/docs/cli.md b/docs/cli.md index d562108..93f1d6d 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -228,6 +228,10 @@ If `--python-import-mode live` is used on generated source and the live import fails because the backend package is missing, the loader falls back to the static generated-source parser and reports the fallback as a warning. +Only use live import with local Python files you trust. Live import executes +the file in a subprocess with the active Python environment, so trusted code can +still read or write local files. + ## Headless Commands Headless commands work without opening the visual editor: diff --git a/docs/extended_guide.md b/docs/extended_guide.md index a8e7964..0466d33 100644 --- a/docs/extended_guide.md +++ b/docs/extended_guide.md @@ -825,6 +825,10 @@ Generated exports provide the richest round-trip. External static profiles and live imports are intentionally conservative and do not recover editor layout, groups, notes, or manual contraction plans. +Only use live import with local Python files you trust. Live import executes +the file in a subprocess with the active Python environment, so trusted code can +still read or write local files. + If live import is requested for generated source and the backend import fails because the backend package is missing, the loader can fall back to the static generated-source parser and report a warning. diff --git a/src/tensor_network_editor/app/server.py b/src/tensor_network_editor/app/server.py index 0023288..87a5ab7 100644 --- a/src/tensor_network_editor/app/server.py +++ b/src/tensor_network_editor/app/server.py @@ -59,8 +59,13 @@ _QUIET_MISSING_STATIC_ASSET_PATHS: frozenset[str] = frozenset({"/favicon.ico"}) _ScannedStaticAssetFile: TypeAlias = tuple[Path, str, int, int] _RUNTIME_CONFIG_PLACEHOLDER = "__TNE_RUNTIME_CONFIG__" -_API_TOKEN_HEADER = "X-TNE-Session-Token" +_CSP_NONCE_PLACEHOLDER = "__TNE_CSP_NONCE__" +_API_TOKEN_HEADER = "X-TNE-Session-Token" # noqa: S105, RUF100 - header name. _EXPECTED_JSON_CONTENT_TYPE = "application/json" +_PERMISSIONS_POLICY_HEADER = ( + "accelerometer=(), camera=(), geolocation=(), gyroscope=(), " + "magnetometer=(), microphone=(), payment=(), usb=()" +) class SupportsReadBytes(Protocol): @@ -381,7 +386,11 @@ def _serialize_frontend_runtime_config( def _render_session_index_body( - index_body: bytes, session: EditorSession, *, api_token: str + index_body: bytes, + session: EditorSession, + *, + api_token: str, + csp_nonce: str, ) -> bytes: """Return the per-session editor HTML body with embedded runtime config.""" return index_body.replace( @@ -389,9 +398,30 @@ def _render_session_index_body( _serialize_frontend_runtime_config(session, api_token=api_token).encode( "utf-8" ), + ).replace( + _CSP_NONCE_PLACEHOLDER.encode("utf-8"), + csp_nonce.encode("utf-8"), ) +def _build_content_security_policy(*, csp_nonce: str) -> str: + """Return the editor CSP that permits only trusted local assets.""" + directives = [ + "default-src 'self'", + "base-uri 'none'", + "object-src 'none'", + "frame-ancestors 'none'", + "form-action 'none'", + "connect-src 'self'", + "img-src 'self' data: blob:", + f"script-src 'self' 'nonce-{csp_nonce}'", + "style-src 'self' 'unsafe-inline'", + "font-src 'self' data:", + "worker-src 'self' blob:", + ] + return "; ".join(directives) + + def _unexpected_internal_error_response(session_id: str) -> JsonResponse: """Return an actionable but safe error payload for unexpected failures.""" return internal_server_error_response( @@ -436,12 +466,17 @@ def __init__( self.api_token = api_token or secrets.token_urlsafe(32) if not self.api_token.strip(): raise ValueError("Editor API token cannot be empty.") + self._csp_nonce = secrets.token_urlsafe(16) + self._content_security_policy = _build_content_security_policy( + csp_nonce=self._csp_nonce + ) self._static_dir = Path(__file__).resolve().parent / "static" self._static_asset_cache = _get_static_asset_cache(self._static_dir) self._index_body = _render_session_index_body( self._static_asset_cache.index_body, session, api_token=self.api_token, + csp_nonce=self._csp_nonce, ) self._server = ThreadingHTTPServer((host, port), self._build_handler()) self._thread = threading.Thread(target=self._serve_forever, daemon=True) @@ -521,7 +556,7 @@ def _wait_until_ready(self) -> None: def _probe_loopback_readiness(self, timeout_seconds: float) -> None: """Read one small static asset to verify the server serves full responses.""" - with urlopen( + with urlopen( # noqa: S310, RUF100 - probes this loopback server. f"{self.base_url}/favicon.ico", timeout=timeout_seconds ) as response: response.read() @@ -547,6 +582,7 @@ def _build_handler(self) -> type[BaseHTTPRequestHandler]: index_body = self._index_body api_token = self.api_token allow_remote = self.allow_remote + content_security_policy = self._content_security_policy def build_index_response() -> _BinaryResponse: """Return the cached main HTML page for this editor session.""" @@ -883,6 +919,9 @@ def _write_bytes(self, status: int, body: bytes, content_type: str) -> None: self.send_header("X-Content-Type-Options", "nosniff") self.send_header("Referrer-Policy", "no-referrer") self.send_header("X-Frame-Options", "DENY") + self.send_header("Content-Security-Policy", content_security_policy) + self.send_header("Permissions-Policy", _PERMISSIONS_POLICY_HEADER) + self.send_header("Cross-Origin-Resource-Policy", "same-origin") if self.close_connection: self.send_header("Connection", "close") self._write_no_cache_headers() diff --git a/src/tensor_network_editor/app/static/index.html b/src/tensor_network_editor/app/static/index.html index 2e7b50a..7ebb159 100644 --- a/src/tensor_network_editor/app/static/index.html +++ b/src/tensor_network_editor/app/static/index.html @@ -1575,8 +1575,8 @@

Generated code

- - + ', + r']*\bid="tne-runtime-config")[^>]*>(.*?)', re.DOTALL, ) _SESSION_TOKEN_BY_ORIGIN: dict[str, str | None] = {} diff --git a/tests/test_app_assets.py b/tests/test_app_assets.py index c61420b..8d80460 100644 --- a/tests/test_app_assets.py +++ b/tests/test_app_assets.py @@ -114,6 +114,36 @@ def test_root_serves_editor_shell_with_versioned_module_entry( assert headers["Content-Type"].startswith("text/html") +def test_root_serves_editor_shell_with_csp_nonce_and_defensive_headers( + editor_server: EditorServer, +) -> None: + html, headers = request_with_headers(f"{editor_server.base_url}/") + + content_security_policy = headers["Content-Security-Policy"] + nonce_match = re.search( + r"(?:^|;\s*)script-src 'self' 'nonce-([^']+)';", + content_security_policy, + ) + + assert nonce_match is not None + nonce = nonce_match.group(1) + assert nonce + assert "'unsafe-inline'" not in nonce_match.group(0) + assert ( + f'