From 9c3ad64d9e417b0530311ca317527410e9abb3c2 Mon Sep 17 00:00:00 2001 From: Francois Herbert Date: Fri, 5 Jun 2026 11:00:02 +1200 Subject: [PATCH 1/3] feat: add allowed_subnets for VPN relay configuration --- ...e5f1a204_add_allowed_subnets_to_devices.py | 27 ++++++++ tests/e2e/test_admin_devices.py | 35 +++++++++- tests/e2e/test_devices.py | 22 ++++++ tests/test_services.py | 68 ++++++++++++++++++- wiregui/models/device.py | 3 + wiregui/pages/admin/devices.py | 16 +++++ wiregui/pages/devices.py | 18 +++++ wiregui/services/events.py | 4 +- 8 files changed, 190 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py diff --git a/alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py b/alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py new file mode 100644 index 0000000..b7bfd0c --- /dev/null +++ b/alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py @@ -0,0 +1,27 @@ +"""add allowed_subnets to devices + +Revision ID: c8d3e5f1a204 +Revises: b7e2f4a1c903 +Create Date: 2026-06-05 10:30:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = 'c8d3e5f1a204' +down_revision: Union[str, None] = 'b7e2f4a1c903' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('devices', sa.Column('allowed_subnets', postgresql.JSON(astext_type=sa.Text()), nullable=False, server_default='[]')) + + +def downgrade() -> None: + op.drop_column('devices', 'allowed_subnets') diff --git a/tests/e2e/test_admin_devices.py b/tests/e2e/test_admin_devices.py index b44a262..e80eda0 100644 --- a/tests/e2e/test_admin_devices.py +++ b/tests/e2e/test_admin_devices.py @@ -236,4 +236,37 @@ async def test_config_dialog_shows_wg_config(page: Page, test_user): await expect(page.get_by_role("button", name="Download .conf")).to_be_visible() # QR code should be rendered - await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000) \ No newline at end of file + await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000) + + +async def test_create_device_with_relay_subnets(page: Page, test_user): + """Admin creates a device with relay subnets for site-to-site VPN.""" + await _go_to_admin_devices(page) + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("site-gateway") + await page.locator("input[aria-label='Description (optional)']").fill("Site-to-site gateway") + + # Scroll to relay configuration section + await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") + + # Fill in relay subnets + await page.locator(".q-dialog input[aria-label='Routed Subnets (optional)']").fill("192.168.1.0/24, 10.20.0.0/16") + + await page.get_by_role("button", name="Create").click() + + # Should see config dialog + await expect(page.get_by_text("Config for site-gateway")).to_be_visible(timeout=10_000) + await page.get_by_role("button", name="Close").click() + await page.wait_for_timeout(500) + + # Verify device was created with relay subnets in DB + async with async_session() as session: + result = await session.execute( + select(Device).where(Device.name == "site-gateway") + .order_by(Device.inserted_at.desc()).limit(1) + ) + device = result.scalar_one() + assert device.allowed_subnets == ["192.168.1.0/24", "10.20.0.0/16"] + assert device.description == "Site-to-site gateway" \ No newline at end of file diff --git a/tests/e2e/test_devices.py b/tests/e2e/test_devices.py index 805910a..53be9b1 100644 --- a/tests/e2e/test_devices.py +++ b/tests/e2e/test_devices.py @@ -30,3 +30,25 @@ async def test_add_device_requires_name(page: Page, test_user: UserModel): await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) await page.get_by_role("button", name="Create").click() await expect(page.get_by_text("Device name is required")).to_be_visible(timeout=5_000) + + +async def test_add_device_with_relay_subnets(page: Page, test_user: UserModel): + """Test creating a device with relay subnets for site-to-site VPN.""" + await login(page) + await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000) + + await page.get_by_role("button", name="Add Device").click() + await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000) + + await page.locator("input[aria-label='Device Name']").fill("Gateway Router") + + # Scroll to relay configuration section + await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") + + # Fill in relay subnets + await page.locator("input[aria-label='Routed Subnets (optional)']").fill("192.168.1.0/24, 10.20.0.0/16") + + await page.get_by_role("button", name="Create").click() + + # Should see config dialog with the device name + await expect(page.get_by_text("Config for Gateway Router")).to_be_visible(timeout=10_000) diff --git a/tests/test_services.py b/tests/test_services.py index f74a4e9..69bd9dd 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -4,7 +4,7 @@ from wiregui.models.device import Device from wiregui.models.rule import Rule -from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created +from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created, _device_allowed_ips def _make_device(**kwargs) -> Device: @@ -20,6 +20,51 @@ def _make_device(**kwargs) -> Device: return Device(**defaults) +# --- _device_allowed_ips tests --- + + +def test_device_allowed_ips_basic(): + """Test _device_allowed_ips returns tunnel IPs with /32 and /128.""" + device = _make_device() + ips = _device_allowed_ips(device) + assert ips == ["10.3.2.5/32", "fd00::3:2:5/128"] + + +def test_device_allowed_ips_with_relay_subnets(): + """Test _device_allowed_ips includes relay subnets.""" + device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"]) + ips = _device_allowed_ips(device) + assert ips == ["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"] + + +def test_device_allowed_ips_ipv4_only(): + """Test _device_allowed_ips with only IPv4.""" + device = _make_device(ipv6=None) + ips = _device_allowed_ips(device) + assert ips == ["10.3.2.5/32"] + + +def test_device_allowed_ips_ipv6_only(): + """Test _device_allowed_ips with only IPv6.""" + device = _make_device(ipv4=None) + ips = _device_allowed_ips(device) + assert ips == ["fd00::3:2:5/128"] + + +def test_device_allowed_ips_relay_only(): + """Test _device_allowed_ips with only relay subnets (no tunnel IPs).""" + device = _make_device(ipv4=None, ipv6=None, allowed_subnets=["192.168.1.0/24"]) + ips = _device_allowed_ips(device) + assert ips == ["192.168.1.0/24"] + + +def test_device_allowed_ips_empty(): + """Test _device_allowed_ips with no IPs or subnets.""" + device = _make_device(ipv4=None, ipv6=None, allowed_subnets=[]) + ips = _device_allowed_ips(device) + assert ips == [] + + # --- Events (WG disabled) --- @@ -55,6 +100,27 @@ async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_setting await on_device_created(device) +@patch("wiregui.services.events.get_settings") +@patch("wiregui.services.events.firewall") +@patch("wiregui.services.events.wireguard") +async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_settings): + """Test that device creation with relay subnets passes correct allowed_ips to WireGuard.""" + mock_settings.return_value.wg_enabled = True + mock_wg.add_peer = AsyncMock() + mock_fw.add_user_chain = AsyncMock() + mock_fw.add_device_jump_rule = AsyncMock() + + device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"]) + await on_device_created(device) + + # Verify WireGuard peer was added with tunnel IPs + relay subnets + mock_wg.add_peer.assert_awaited_once_with( + public_key="pk-test", + allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"], + preshared_key="psk-test", + ) + + # --- Rule events --- diff --git a/wiregui/models/device.py b/wiregui/models/device.py index 31fb865..a27b228 100644 --- a/wiregui/models/device.py +++ b/wiregui/models/device.py @@ -33,6 +33,9 @@ class Device(SQLModel, table=True): # Assigned tunnel addresses ipv4: str | None = Field(default=None, unique=True) ipv6: str | None = Field(default=None, unique=True) + + # Additional subnets this peer routes (for site-to-site / relay configuration) + allowed_subnets: list[str] = Field(default_factory=list, sa_column=Column(JSON, default=[])) # Peer stats (updated periodically from WireGuard) remote_ip: str | None = None diff --git a/wiregui/pages/admin/devices.py b/wiregui/pages/admin/devices.py index 34a1d1c..8c40df5 100644 --- a/wiregui/pages/admin/devices.py +++ b/wiregui/pages/admin/devices.py @@ -132,6 +132,8 @@ async def create_device(): if not create_use_default_keepalive.value and create_keepalive.value else None), allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()] if not create_use_default_ips.value and create_allowed_ips.value else []), + allowed_subnets=([s.strip() for s in create_allowed_subnets.value.split(",") if s.strip()] + if create_allowed_subnets.value else []), ) session.add(device) await session.commit() @@ -175,6 +177,7 @@ def _reset_create_form(): create_endpoint.value = _defaults["endpoint"] create_mtu.value = _defaults["mtu"] create_keepalive.value = _defaults["keepalive"] + create_allowed_subnets.value = "" # --- Edit device --- edit_device_id = {"value": None} @@ -198,6 +201,7 @@ async def open_edit(device_id: str): edit_mtu.value = str(device.mtu) if device.mtu else "" edit_keepalive.value = str(device.persistent_keepalive) if device.persistent_keepalive else "" edit_allowed_ips.value = ", ".join(device.allowed_ips) if device.allowed_ips else "" + edit_allowed_subnets.value = ", ".join(device.allowed_subnets) if device.allowed_subnets else "" edit_dialog.open() async def save_edit(): @@ -228,6 +232,8 @@ async def save_edit(): device.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None if not device.use_default_allowed_ips: device.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()] + + device.allowed_subnets = [s.strip() for s in edit_allowed_subnets.value.split(",") if s.strip()] session.add(device) await session.commit() @@ -337,6 +343,11 @@ def on_admin_row_click(e): create_use_default_keepalive = ui.switch("Use default Keepalive", value=True) create_keepalive = ui.input("Persistent Keepalive", value=_defaults["keepalive"]).props("outlined dense").classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v) + ui.separator().classes("q-my-sm") + ui.label("Relay / Site-to-Site Configuration").classes("text-subtitle2") + ui.label("Additional subnets this device routes (comma-separated CIDRs, e.g., 192.168.1.0/24)").classes("text-caption text-grey") + create_allowed_subnets = ui.input("Routed Subnets (optional)").props("outlined dense").classes("w-full") + with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=create_dialog.close).props("flat") ui.button("Create", on_click=create_device).props("color=primary") @@ -368,6 +379,11 @@ def on_admin_row_click(e): edit_use_default_keepalive = ui.switch("Use default Keepalive", value=True) edit_keepalive = ui.input("Persistent Keepalive").props("outlined dense").classes("w-full").bind_enabled_from(edit_use_default_keepalive, "value", backward=lambda v: not v) + ui.separator().classes("q-my-sm") + ui.label("Relay / Site-to-Site Configuration").classes("text-subtitle2") + ui.label("Additional subnets this device routes (comma-separated CIDRs, e.g., 192.168.1.0/24)").classes("text-caption text-grey") + edit_allowed_subnets = ui.input("Routed Subnets (optional)").props("outlined dense").classes("w-full") + with ui.row().classes("w-full justify-end q-mt-sm"): ui.button("Cancel", on_click=edit_dialog.close).props("flat") ui.button("Save", on_click=save_edit).props("color=primary") diff --git a/wiregui/pages/devices.py b/wiregui/pages/devices.py index f9d7d93..db48aec 100644 --- a/wiregui/pages/devices.py +++ b/wiregui/pages/devices.py @@ -119,6 +119,8 @@ async def create_device(): if not create_use_default_keepalive.value and create_keepalive.value else None), allowed_ips=([s.strip() for s in create_allowed_ips.value.split(",") if s.strip()] if not create_use_default_ips.value and create_allowed_ips.value else []), + allowed_subnets=([s.strip() for s in create_allowed_subnets.value.split(",") if s.strip()] + if create_allowed_subnets.value else []), ) session.add(device) await session.commit() @@ -162,6 +164,7 @@ def _reset_create_form(): create_endpoint.value = _defaults["endpoint"] create_mtu.value = _defaults["mtu"] create_keepalive.value = _defaults["keepalive"] + create_allowed_subnets.value = "" # --- Delete device --- async def delete_device(device_id: str): @@ -256,6 +259,11 @@ def on_row_click(e): "outlined dense" ).classes("w-full").bind_enabled_from(create_use_default_keepalive, "value", backward=lambda v: not v) + ui.separator().classes("q-my-sm") + ui.label("Relay / Site-to-Site Configuration").classes("text-subtitle2") + ui.label("Additional subnets this device routes (comma-separated CIDRs, e.g., 192.168.1.0/24)").classes("text-caption text-grey") + create_allowed_subnets = ui.input("Routed Subnets (optional)").props("outlined dense").classes("w-full") + with ui.row().classes("w-full justify-end q-mt-md"): ui.button("Cancel", on_click=create_dialog.close).props("flat") ui.button("Create", on_click=create_device).props("color=primary") @@ -305,6 +313,8 @@ async def save_edit(): d.persistent_keepalive = int(edit_keepalive.value) if edit_keepalive.value else None if not d.use_default_allowed_ips: d.allowed_ips = [s.strip() for s in edit_allowed_ips.value.split(",") if s.strip()] + + d.allowed_subnets = [s.strip() for s in edit_allowed_subnets.value.split(",") if s.strip()] session.add(d) await session.commit() @@ -555,6 +565,14 @@ async def refresh_stats(): edit_use_default_keepalive, "value", backward=lambda v: not v ) + ui.separator().classes("q-my-sm") + ui.label("Relay / Site-to-Site Configuration").classes("text-subtitle2") + ui.label("Additional subnets this device routes (comma-separated CIDRs, e.g., 192.168.1.0/24)").classes("text-caption text-grey") + edit_allowed_subnets = ui.input( + "Routed Subnets (optional)", + value=", ".join(device.allowed_subnets) if device.allowed_subnets else "" + ).props("outlined dense").classes("w-full") + ui.button("Save Changes", on_click=save_edit).props("color=primary").classes("q-mt-md") # Danger zone diff --git a/wiregui/services/events.py b/wiregui/services/events.py index 2136b67..f445b7e 100644 --- a/wiregui/services/events.py +++ b/wiregui/services/events.py @@ -12,12 +12,14 @@ def _device_allowed_ips(device: Device) -> list[str]: - """Build the allowed-ips list for a device peer (its tunnel addresses).""" + """Build the allowed-ips list for a device peer (its tunnel addresses + relay subnets).""" ips = [] if device.ipv4: ips.append(f"{device.ipv4}/32") if device.ipv6: ips.append(f"{device.ipv6}/128") + if device.allowed_subnets: + ips.extend(device.allowed_subnets) return ips From 48e4c1828445447705a5b19ddb80b964fb3ef6ae Mon Sep 17 00:00:00 2001 From: Francois Herbert Date: Fri, 5 Jun 2026 12:35:35 +1200 Subject: [PATCH 2/3] feat: Ensure user firewall chain is jumped to when src address is from users device allowed subnets --- tests/test_firewall.py | 87 ++++++++++++++++++++++++++++++++++++ tests/test_services.py | 10 ++++- wiregui/services/events.py | 18 ++++++-- wiregui/services/firewall.py | 48 ++++++++++++++++++-- wiregui/tasks/reconcile.py | 7 ++- 5 files changed, 161 insertions(+), 9 deletions(-) diff --git a/tests/test_firewall.py b/tests/test_firewall.py index 24b04ce..5a193d4 100644 --- a/tests/test_firewall.py +++ b/tests/test_firewall.py @@ -38,3 +38,90 @@ def test_build_rule_expr_single_port(): def test_build_rule_expr_no_port(): expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None) assert expr == "ip daddr 0.0.0.0/0 accept" + + +# --- Tests for relay subnet support --- + + +async def test_add_device_jump_rule_with_allowed_subnets(): + """Test that add_device_jump_rule creates rules for tunnel IPs and relay subnets.""" + from unittest.mock import AsyncMock, patch + from wiregui.services.firewall import add_device_jump_rule + + with patch("wiregui.services.firewall._nft_batch") as mock_nft: + mock_nft.return_value = None + + await add_device_jump_rule( + user_id="test-user-id", + device_ipv4="10.3.2.5", + device_ipv6="fd00::3:2:5", + allowed_subnets=["192.168.1.0/24", "10.20.0.0/16", "fd00:1::/64"] + ) + + # Verify nft_batch was called with correct commands + mock_nft.assert_called_once() + commands = mock_nft.call_args[0][0] + + # Should have 5 rules: 2 tunnel IPs + 3 subnets + assert len(commands) == 5 + assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands) + assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in commands) + assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands) + assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in commands) + assert any("ip6 saddr fd00:1::/64 jump" in cmd for cmd in commands) + + +async def test_add_device_jump_rule_ipv4_subnet_only(): + """Test add_device_jump_rule with only IPv4 relay subnet.""" + from unittest.mock import AsyncMock, patch + from wiregui.services.firewall import add_device_jump_rule + + with patch("wiregui.services.firewall._nft_batch") as mock_nft: + mock_nft.return_value = None + + await add_device_jump_rule( + user_id="test-user-id", + device_ipv4="10.3.2.5", + device_ipv6=None, + allowed_subnets=["192.168.1.0/24"] + ) + + commands = mock_nft.call_args[0][0] + assert len(commands) == 2 + assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands) + assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands) + + +async def test_rebuild_all_rules_with_allowed_subnets(): + """Test that rebuild_all_rules includes relay subnets in jump rules.""" + from unittest.mock import patch + from wiregui.services.firewall import rebuild_all_rules + + with patch("wiregui.services.firewall._nft_batch") as mock_nft, \ + patch("wiregui.services.firewall._list_user_chains") as mock_list: + mock_nft.return_value = None + mock_list.return_value = set() + + await rebuild_all_rules([{ + "user_id": "user-123", + "devices": [ + { + "ipv4": "10.3.2.5", + "ipv6": "fd00::3:2:5", + "allowed_subnets": ["192.168.1.0/24", "10.20.0.0/16"] + } + ], + "rules": [] + }]) + + # Verify nft_batch was called + mock_nft.assert_called_once() + commands = mock_nft.call_args[0][0] + + # Check that jump rules include both tunnel IPs and relay subnets + forward_rules = [cmd for cmd in commands if "forward" in cmd and "jump" in cmd] + assert len(forward_rules) == 4 # 2 tunnel IPs + 2 subnets + assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in forward_rules) + assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in forward_rules) + assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in forward_rules) + assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in forward_rules) diff --git a/tests/test_services.py b/tests/test_services.py index 69bd9dd..199f811 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -104,7 +104,7 @@ async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_setting @patch("wiregui.services.events.firewall") @patch("wiregui.services.events.wireguard") async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_settings): - """Test that device creation with relay subnets passes correct allowed_ips to WireGuard.""" + """Test that device creation with relay subnets passes correct allowed_ips to WireGuard and firewall.""" mock_settings.return_value.wg_enabled = True mock_wg.add_peer = AsyncMock() mock_fw.add_user_chain = AsyncMock() @@ -119,6 +119,14 @@ async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_setti allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"], preshared_key="psk-test", ) + + # Verify firewall jump rule was added with relay subnets + mock_fw.add_device_jump_rule.assert_awaited_once_with( + "00000000-0000-0000-0000-000000000000", + "10.3.2.5", + "fd00::3:2:5", + ["192.168.1.0/24", "10.20.0.0/16"], + ) # --- Rule events --- diff --git a/wiregui/services/events.py b/wiregui/services/events.py index f445b7e..c7715ce 100644 --- a/wiregui/services/events.py +++ b/wiregui/services/events.py @@ -44,7 +44,10 @@ async def on_device_created(device: Device) -> None: # Ensure user chain exists before adding jump rules await firewall.add_user_chain(str(device.user_id)) await firewall.add_device_jump_rule( - str(device.user_id), device.ipv4, device.ipv6, + str(device.user_id), + device.ipv4, + device.ipv6, + device.allowed_subnets, ) except Exception as e: logger.error("Failed to add firewall jump rule for device {}: {}", device.name, e) @@ -62,7 +65,7 @@ async def on_device_deleted(device: Device) -> None: async def on_device_updated(device: Device) -> None: - """Update WireGuard peer after a device is modified.""" + """Update WireGuard peer and firewall after a device is modified.""" if not get_settings().wg_enabled: return try: @@ -73,6 +76,12 @@ async def on_device_updated(device: Device) -> None: ) except Exception as e: logger.error("Failed to update WG peer for device {}: {}", device.name, e) + + # Rebuild firewall rules for this user to update allowed_subnets + try: + await _rebuild_user_chain(str(device.user_id)) + except Exception as e: + logger.error("Failed to rebuild firewall rules for device {}: {}", device.name, e) # --- Rule events --- @@ -127,7 +136,10 @@ async def _rebuild_user_chain(user_id: str) -> None: await firewall.rebuild_all_rules([{ "user_id": user_id, - "devices": [{"ipv4": d.ipv4, "ipv6": d.ipv6} for d in devices], + "devices": [ + {"ipv4": d.ipv4, "ipv6": d.ipv6, "allowed_subnets": d.allowed_subnets} + for d in devices + ], "rules": [ {"destination": r.destination, "action": r.action, "port_type": r.port_type, "port_range": r.port_range} diff --git a/wiregui/services/firewall.py b/wiregui/services/firewall.py index ea7c0d0..a875e6d 100644 --- a/wiregui/services/firewall.py +++ b/wiregui/services/firewall.py @@ -101,10 +101,24 @@ async def remove_user_chain(user_id: str) -> None: logger.debug("Remove user chain {}: {}", chain, e) -async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv6: str | None) -> None: - """Add jump rules in the forward chain to route device traffic to the user chain.""" +async def add_device_jump_rule( + user_id: str, + device_ipv4: str | None, + device_ipv6: str | None, + allowed_subnets: list[str] | None = None, +) -> None: + """Add jump rules in the forward chain to route device traffic to the user chain. + + Args: + user_id: User ID for the chain + device_ipv4: Device tunnel IPv4 address + device_ipv6: Device tunnel IPv6 address + allowed_subnets: Additional relay subnets this device routes + """ chain = _user_chain_name(user_id) commands = [] + + # Add jump rules for tunnel IPs if device_ipv4: commands.append( f"add rule inet {TABLE_NAME} forward ip saddr {device_ipv4} jump {chain}" @@ -113,9 +127,25 @@ async def add_device_jump_rule(user_id: str, device_ipv4: str | None, device_ipv commands.append( f"add rule inet {TABLE_NAME} forward ip6 saddr {device_ipv6} jump {chain}" ) + + # Add jump rules for relay subnets + if allowed_subnets: + for subnet in allowed_subnets: + if ":" in subnet: + # IPv6 + commands.append( + f"add rule inet {TABLE_NAME} forward ip6 saddr {subnet} jump {chain}" + ) + else: + # IPv4 + commands.append( + f"add rule inet {TABLE_NAME} forward ip saddr {subnet} jump {chain}" + ) + if commands: await _nft_batch(commands) - logger.debug("Jump rules added for device {}/{} -> {}", device_ipv4, device_ipv6, chain) + logger.debug("Jump rules added for device {}/{} + {} subnets -> {}", + device_ipv4, device_ipv6, len(allowed_subnets or []), chain) async def apply_rule(user_id: str, destination: str, action: str, port_type: str | None = None, port_range: str | None = None) -> None: @@ -133,7 +163,7 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: Args: users_devices_rules: list of dicts with keys: - user_id, devices (list of {ipv4, ipv6}), rules (list of {destination, action, port_type, port_range}) + user_id, devices (list of {ipv4, ipv6, allowed_subnets}), rules (list of {destination, action, port_type, port_range}) """ # Discover existing user_ chains so we can remove orphans existing_user_chains = await _list_user_chains() @@ -164,10 +194,20 @@ async def rebuild_all_rules(users_devices_rules: list[dict]) -> None: user_id = entry["user_id"] chain = _user_chain_name(user_id) for dev in entry.get("devices", []): + # Add jump rules for tunnel IPs if dev.get("ipv4"): commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {dev['ipv4']} jump {chain}") if dev.get("ipv6"): commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {dev['ipv6']} jump {chain}") + + # Add jump rules for relay subnets + for subnet in dev.get("allowed_subnets", []): + if ":" in subnet: + # IPv6 + commands.append(f"add rule inet {TABLE_NAME} forward ip6 saddr {subnet} jump {chain}") + else: + # IPv4 + commands.append(f"add rule inet {TABLE_NAME} forward ip saddr {subnet} jump {chain}") # Remove orphaned user chains (must happen after forward chain is flushed # so there are no remaining jump references to these chains) diff --git a/wiregui/tasks/reconcile.py b/wiregui/tasks/reconcile.py index 3c5634c..2846893 100644 --- a/wiregui/tasks/reconcile.py +++ b/wiregui/tasks/reconcile.py @@ -34,6 +34,8 @@ async def reconcile() -> None: ips.append(f"{device.ipv4}/32") if device.ipv6: ips.append(f"{device.ipv6}/128") + if device.allowed_subnets: + ips.extend(device.allowed_subnets) try: await wireguard.add_peer( public_key=device.public_key, @@ -75,7 +77,10 @@ async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: entries.append({ "user_id": uid, - "devices": [{"ipv4": d.ipv4, "ipv6": d.ipv6} for d in user_devices], + "devices": [ + {"ipv4": d.ipv4, "ipv6": d.ipv6, "allowed_subnets": d.allowed_subnets} + for d in user_devices + ], "rules": [ {"destination": r.destination, "action": r.action, "port_type": r.port_type, "port_range": r.port_range} From 0dd25f1336218862bd289bb0e40caa513f103eaa Mon Sep 17 00:00:00 2001 From: Francois Herbert Date: Fri, 5 Jun 2026 12:54:05 +1200 Subject: [PATCH 3/3] feat: Add routes for peer allowed ip list --- tests/test_services.py | 6 ++- tests/test_wireguard_extended.py | 92 +++++++++++++++++++++++++++++++- wiregui/services/events.py | 29 ++++++++-- wiregui/services/wireguard.py | 55 +++++++++++++++++++ wiregui/tasks/reconcile.py | 23 ++++++++ 5 files changed, 200 insertions(+), 5 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index 199f811..8572144 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -104,9 +104,10 @@ async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_setting @patch("wiregui.services.events.firewall") @patch("wiregui.services.events.wireguard") async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_settings): - """Test that device creation with relay subnets passes correct allowed_ips to WireGuard and firewall.""" + """Test that device creation with relay subnets passes correct allowed_ips to WireGuard, adds routes, and configures firewall.""" mock_settings.return_value.wg_enabled = True mock_wg.add_peer = AsyncMock() + mock_wg.add_routes = AsyncMock() mock_fw.add_user_chain = AsyncMock() mock_fw.add_device_jump_rule = AsyncMock() @@ -120,6 +121,9 @@ async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_setti preshared_key="psk-test", ) + # Verify routes were added for relay subnets + mock_wg.add_routes.assert_awaited_once_with(["192.168.1.0/24", "10.20.0.0/16"]) + # Verify firewall jump rule was added with relay subnets mock_fw.add_device_jump_rule.assert_awaited_once_with( "00000000-0000-0000-0000-000000000000", diff --git a/tests/test_wireguard_extended.py b/tests/test_wireguard_extended.py index ab848df..a5d5444 100644 --- a/tests/test_wireguard_extended.py +++ b/tests/test_wireguard_extended.py @@ -7,6 +7,8 @@ set_private_key, set_listen_port, configure_interface, + add_routes, + remove_routes, ) @@ -111,4 +113,92 @@ async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run) args = mock_run.call_args[0][0] assert args[0:3] == ["wg", "set", "wg-test"] assert "private-key" in args - assert "listen-port" in args \ No newline at end of file + assert "listen-port" in args + + +# ========== add_routes ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_ipv4(mock_run): + """add_routes should call ip route add for IPv4 subnets.""" + mock_run.return_value = "" + await add_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-4", "route", "add", "192.168.1.0/24", "dev", "wg-test"] + assert calls[1] == ["ip", "-4", "route", "add", "10.20.0.0/16", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_ipv6(mock_run): + """add_routes should call ip -6 route add for IPv6 subnets.""" + mock_run.return_value = "" + await add_routes(["fd00:1::/64", "fd00:2::/48"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-6", "route", "add", "fd00:1::/64", "dev", "wg-test"] + assert calls[1] == ["ip", "-6", "route", "add", "fd00:2::/48", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_mixed(mock_run): + """add_routes should handle mixed IPv4 and IPv6.""" + mock_run.return_value = "" + await add_routes(["192.168.1.0/24", "fd00:1::/64"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-4", "route", "add", "192.168.1.0/24", "dev", "wg-test"] + assert calls[1] == ["ip", "-6", "route", "add", "fd00:1::/64", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_empty_list(mock_run): + """add_routes with empty list should not call ip route.""" + await add_routes([], iface="wg-test") + mock_run.assert_not_awaited() + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_add_routes_already_exists(mock_run): + """add_routes should not fail if route already exists.""" + mock_run.side_effect = RuntimeError("RTNETLINK answers: File exists") + # Should not raise + await add_routes(["192.168.1.0/24"], iface="wg-test") + mock_run.assert_awaited_once() + + +# ========== remove_routes ========== + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_remove_routes_ipv4(mock_run): + """remove_routes should call ip route del for IPv4 subnets.""" + mock_run.return_value = "" + await remove_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test") + + assert mock_run.await_count == 2 + calls = [c[0][0] for c in mock_run.call_args_list] + assert calls[0] == ["ip", "-4", "route", "del", "192.168.1.0/24", "dev", "wg-test"] + assert calls[1] == ["ip", "-4", "route", "del", "10.20.0.0/16", "dev", "wg-test"] + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_remove_routes_ipv6(mock_run): + """remove_routes should call ip -6 route del for IPv6 subnets.""" + mock_run.return_value = "" + await remove_routes(["fd00:1::/64"], iface="wg-test") + + mock_run.assert_awaited_once_with(["ip", "-6", "route", "del", "fd00:1::/64", "dev", "wg-test"]) + + +@patch("wiregui.services.wireguard._run", new_callable=AsyncMock) +async def test_remove_routes_not_found(mock_run): + """remove_routes should not fail if route doesn't exist.""" + mock_run.side_effect = RuntimeError("RTNETLINK answers: No such process") + # Should not raise + await remove_routes(["192.168.1.0/24"], iface="wg-test") + mock_run.assert_awaited_once() \ No newline at end of file diff --git a/wiregui/services/events.py b/wiregui/services/events.py index c7715ce..618a7f1 100644 --- a/wiregui/services/events.py +++ b/wiregui/services/events.py @@ -27,7 +27,7 @@ def _device_allowed_ips(device: Device) -> list[str]: async def on_device_created(device: Device) -> None: - """Configure WireGuard peer and firewall after a new device is created.""" + """Configure WireGuard peer, routes, and firewall after a new device is created.""" settings = get_settings() if not settings.wg_enabled: return @@ -39,6 +39,13 @@ async def on_device_created(device: Device) -> None: ) except Exception as e: logger.error("Failed to add WG peer for device {}: {}", device.name, e) + + # Add routes for relay subnets + if device.allowed_subnets: + try: + await wireguard.add_routes(device.allowed_subnets) + except Exception as e: + logger.error("Failed to add routes for device {}: {}", device.name, e) try: # Ensure user chain exists before adding jump rules @@ -54,18 +61,25 @@ async def on_device_created(device: Device) -> None: async def on_device_deleted(device: Device) -> None: - """Remove WireGuard peer after a device is deleted.""" + """Remove WireGuard peer and routes after a device is deleted.""" if not get_settings().wg_enabled: return try: await wireguard.remove_peer(public_key=device.public_key) except Exception as e: logger.error("Failed to remove WG peer for device {}: {}", device.name, e) + + # Remove routes for relay subnets + if device.allowed_subnets: + try: + await wireguard.remove_routes(device.allowed_subnets) + except Exception as e: + logger.error("Failed to remove routes for device {}: {}", device.name, e) # Firewall jump rules are cleaned up on next rebuild async def on_device_updated(device: Device) -> None: - """Update WireGuard peer and firewall after a device is modified.""" + """Update WireGuard peer, routes, and firewall after a device is modified.""" if not get_settings().wg_enabled: return try: @@ -77,6 +91,15 @@ async def on_device_updated(device: Device) -> None: except Exception as e: logger.error("Failed to update WG peer for device {}: {}", device.name, e) + # Note: We can't easily diff old vs new allowed_subnets here without fetching old state. + # The reconcile task will clean up orphaned routes periodically. + # For now, just ensure current routes exist. + if device.allowed_subnets: + try: + await wireguard.add_routes(device.allowed_subnets) + except Exception as e: + logger.error("Failed to add routes for device {}: {}", device.name, e) + # Rebuild firewall rules for this user to update allowed_subnets try: await _rebuild_user_chain(str(device.user_id)) diff --git a/wiregui/services/wireguard.py b/wiregui/services/wireguard.py index 6d29160..35a2943 100644 --- a/wiregui/services/wireguard.py +++ b/wiregui/services/wireguard.py @@ -186,3 +186,58 @@ async def get_peers(iface: str | None = None) -> list[PeerInfo]: tx_bytes=tx_bytes, )) return peers + + +async def add_routes(subnets: list[str], iface: str | None = None) -> None: + """Add IP routes for relay subnets through the WireGuard interface. + + Args: + subnets: List of CIDR subnets to route through WireGuard + iface: WireGuard interface name (defaults to config value) + """ + if not subnets: + return + + settings = get_settings() + iface = iface or settings.wg_interface + + for subnet in subnets: + try: + if ":" in subnet: + # IPv6 + await _run(["ip", "-6", "route", "add", subnet, "dev", iface]) + else: + # IPv4 + await _run(["ip", "-4", "route", "add", subnet, "dev", iface]) + logger.debug("Route added: {} via {}", subnet, iface) + except RuntimeError as e: + # Route might already exist, log but don't fail + if "File exists" not in str(e): + logger.warning("Failed to add route for {}: {}", subnet, e) + + +async def remove_routes(subnets: list[str], iface: str | None = None) -> None: + """Remove IP routes for relay subnets. + + Args: + subnets: List of CIDR subnets to remove routes for + iface: WireGuard interface name (defaults to config value) + """ + if not subnets: + return + + settings = get_settings() + iface = iface or settings.wg_interface + + for subnet in subnets: + try: + if ":" in subnet: + # IPv6 + await _run(["ip", "-6", "route", "del", subnet, "dev", iface]) + else: + # IPv4 + await _run(["ip", "-4", "route", "del", subnet, "dev", iface]) + logger.debug("Route removed: {} via {}", subnet, iface) + except RuntimeError as e: + # Route might not exist, log but don't fail + logger.debug("Failed to remove route for {}: {}", subnet, e) diff --git a/wiregui/tasks/reconcile.py b/wiregui/tasks/reconcile.py index 2846893..b3d621e 100644 --- a/wiregui/tasks/reconcile.py +++ b/wiregui/tasks/reconcile.py @@ -60,6 +60,9 @@ async def reconcile() -> None: # Rebuild all firewall rules from DB await _reconcile_firewall(devices, rules) + + # Reconcile routes for relay subnets + await _reconcile_routes(devices) async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: @@ -92,3 +95,23 @@ async def _reconcile_firewall(devices: list[Device], rules: list[Rule]) -> None: await firewall.rebuild_all_rules(entries) except Exception as e: logger.error("Reconcile: firewall rebuild failed: {}", e) + + +async def _reconcile_routes(devices: list[Device]) -> None: + """Ensure all relay subnet routes exist for devices in the database.""" + all_subnets = [] + for device in devices: + if device.allowed_subnets: + all_subnets.extend(device.allowed_subnets) + + if not all_subnets: + logger.debug("No relay subnets configured, skipping route reconciliation") + return + + # Add routes for all relay subnets + # Note: add_routes is idempotent (ignores "File exists" errors) + try: + await wireguard.add_routes(all_subnets) + logger.info("Reconciled {} relay subnet route(s)", len(all_subnets)) + except Exception as e: + logger.error("Reconcile: route sync failed: {}", e)