Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions alembic/versions/c8d3e5f1a204_add_allowed_subnets_to_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""add allowed_subnets to devices

Revision ID: c8d3e5f1a204
Revises: b7e2f4a1c903
Create Date: 2026-06-05 10:30:00.000000

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql


# revision identifiers, used by Alembic.
revision: str = 'c8d3e5f1a204'
down_revision: Union[str, None] = 'b7e2f4a1c903'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column('devices', sa.Column('allowed_subnets', postgresql.JSON(astext_type=sa.Text()), nullable=False, server_default='[]'))


def downgrade() -> None:
op.drop_column('devices', 'allowed_subnets')
35 changes: 34 additions & 1 deletion tests/e2e/test_admin_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,37 @@ async def test_config_dialog_shows_wg_config(page: Page, test_user):
await expect(page.get_by_role("button", name="Download .conf")).to_be_visible()

# QR code should be rendered
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)
await expect(page.locator(".q-dialog img")).to_be_visible(timeout=5_000)


async def test_create_device_with_relay_subnets(page: Page, test_user):
"""Admin creates a device with relay subnets for site-to-site VPN."""
await _go_to_admin_devices(page)
await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)

await page.locator("input[aria-label='Device Name']").fill("site-gateway")
await page.locator("input[aria-label='Description (optional)']").fill("Site-to-site gateway")

# Scroll to relay configuration section
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")

# Fill in relay subnets
await page.locator(".q-dialog input[aria-label='Routed Subnets (optional)']").fill("192.168.1.0/24, 10.20.0.0/16")

await page.get_by_role("button", name="Create").click()

# Should see config dialog
await expect(page.get_by_text("Config for site-gateway")).to_be_visible(timeout=10_000)
await page.get_by_role("button", name="Close").click()
await page.wait_for_timeout(500)

# Verify device was created with relay subnets in DB
async with async_session() as session:
result = await session.execute(
select(Device).where(Device.name == "site-gateway")
.order_by(Device.inserted_at.desc()).limit(1)
)
device = result.scalar_one()
assert device.allowed_subnets == ["192.168.1.0/24", "10.20.0.0/16"]
assert device.description == "Site-to-site gateway"
22 changes: 22 additions & 0 deletions tests/e2e/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,25 @@ async def test_add_device_requires_name(page: Page, test_user: UserModel):
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)
await page.get_by_role("button", name="Create").click()
await expect(page.get_by_text("Device name is required")).to_be_visible(timeout=5_000)


async def test_add_device_with_relay_subnets(page: Page, test_user: UserModel):
"""Test creating a device with relay subnets for site-to-site VPN."""
await login(page)
await expect(page.get_by_text("My Devices")).to_be_visible(timeout=10_000)

await page.get_by_role("button", name="Add Device").click()
await expect(page.get_by_text("New Device")).to_be_visible(timeout=5_000)

await page.locator("input[aria-label='Device Name']").fill("Gateway Router")

# Scroll to relay configuration section
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")

# Fill in relay subnets
await page.locator("input[aria-label='Routed Subnets (optional)']").fill("192.168.1.0/24, 10.20.0.0/16")

await page.get_by_role("button", name="Create").click()

# Should see config dialog with the device name
await expect(page.get_by_text("Config for Gateway Router")).to_be_visible(timeout=10_000)
87 changes: 87 additions & 0 deletions tests/test_firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,90 @@ def test_build_rule_expr_single_port():
def test_build_rule_expr_no_port():
expr = _build_rule_expr("0.0.0.0/0", "accept", port_type=None, port_range=None)
assert expr == "ip daddr 0.0.0.0/0 accept"


# --- Tests for relay subnet support ---


async def test_add_device_jump_rule_with_allowed_subnets():
"""Test that add_device_jump_rule creates rules for tunnel IPs and relay subnets."""
from unittest.mock import AsyncMock, patch
from wiregui.services.firewall import add_device_jump_rule

with patch("wiregui.services.firewall._nft_batch") as mock_nft:
mock_nft.return_value = None

await add_device_jump_rule(
user_id="test-user-id",
device_ipv4="10.3.2.5",
device_ipv6="fd00::3:2:5",
allowed_subnets=["192.168.1.0/24", "10.20.0.0/16", "fd00:1::/64"]
)

# Verify nft_batch was called with correct commands
mock_nft.assert_called_once()
commands = mock_nft.call_args[0][0]

# Should have 5 rules: 2 tunnel IPs + 3 subnets
assert len(commands) == 5
assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands)
assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in commands)
assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands)
assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in commands)
assert any("ip6 saddr fd00:1::/64 jump" in cmd for cmd in commands)


async def test_add_device_jump_rule_ipv4_subnet_only():
"""Test add_device_jump_rule with only IPv4 relay subnet."""
from unittest.mock import AsyncMock, patch
from wiregui.services.firewall import add_device_jump_rule

with patch("wiregui.services.firewall._nft_batch") as mock_nft:
mock_nft.return_value = None

await add_device_jump_rule(
user_id="test-user-id",
device_ipv4="10.3.2.5",
device_ipv6=None,
allowed_subnets=["192.168.1.0/24"]
)

commands = mock_nft.call_args[0][0]
assert len(commands) == 2
assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in commands)
assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in commands)


async def test_rebuild_all_rules_with_allowed_subnets():
"""Test that rebuild_all_rules includes relay subnets in jump rules."""
from unittest.mock import patch
from wiregui.services.firewall import rebuild_all_rules

with patch("wiregui.services.firewall._nft_batch") as mock_nft, \
patch("wiregui.services.firewall._list_user_chains") as mock_list:
mock_nft.return_value = None
mock_list.return_value = set()

await rebuild_all_rules([{
"user_id": "user-123",
"devices": [
{
"ipv4": "10.3.2.5",
"ipv6": "fd00::3:2:5",
"allowed_subnets": ["192.168.1.0/24", "10.20.0.0/16"]
}
],
"rules": []
}])

# Verify nft_batch was called
mock_nft.assert_called_once()
commands = mock_nft.call_args[0][0]

# Check that jump rules include both tunnel IPs and relay subnets
forward_rules = [cmd for cmd in commands if "forward" in cmd and "jump" in cmd]
assert len(forward_rules) == 4 # 2 tunnel IPs + 2 subnets
assert any("ip saddr 10.3.2.5 jump" in cmd for cmd in forward_rules)
assert any("ip6 saddr fd00::3:2:5 jump" in cmd for cmd in forward_rules)
assert any("ip saddr 192.168.1.0/24 jump" in cmd for cmd in forward_rules)
assert any("ip saddr 10.20.0.0/16 jump" in cmd for cmd in forward_rules)
80 changes: 79 additions & 1 deletion tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from wiregui.models.device import Device
from wiregui.models.rule import Rule
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created
from wiregui.services.events import on_device_created, on_device_deleted, on_device_updated, on_rule_created, _device_allowed_ips


def _make_device(**kwargs) -> Device:
Expand All @@ -20,6 +20,51 @@ def _make_device(**kwargs) -> Device:
return Device(**defaults)


# --- _device_allowed_ips tests ---


def test_device_allowed_ips_basic():
"""Test _device_allowed_ips returns tunnel IPs with /32 and /128."""
device = _make_device()
ips = _device_allowed_ips(device)
assert ips == ["10.3.2.5/32", "fd00::3:2:5/128"]


def test_device_allowed_ips_with_relay_subnets():
"""Test _device_allowed_ips includes relay subnets."""
device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"])
ips = _device_allowed_ips(device)
assert ips == ["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"]


def test_device_allowed_ips_ipv4_only():
"""Test _device_allowed_ips with only IPv4."""
device = _make_device(ipv6=None)
ips = _device_allowed_ips(device)
assert ips == ["10.3.2.5/32"]


def test_device_allowed_ips_ipv6_only():
"""Test _device_allowed_ips with only IPv6."""
device = _make_device(ipv4=None)
ips = _device_allowed_ips(device)
assert ips == ["fd00::3:2:5/128"]


def test_device_allowed_ips_relay_only():
"""Test _device_allowed_ips with only relay subnets (no tunnel IPs)."""
device = _make_device(ipv4=None, ipv6=None, allowed_subnets=["192.168.1.0/24"])
ips = _device_allowed_ips(device)
assert ips == ["192.168.1.0/24"]


def test_device_allowed_ips_empty():
"""Test _device_allowed_ips with no IPs or subnets."""
device = _make_device(ipv4=None, ipv6=None, allowed_subnets=[])
ips = _device_allowed_ips(device)
assert ips == []


# --- Events (WG disabled) ---


Expand Down Expand Up @@ -55,6 +100,39 @@ async def test_on_device_created_handles_wg_error(mock_wg, mock_fw, mock_setting
await on_device_created(device)


@patch("wiregui.services.events.get_settings")
@patch("wiregui.services.events.firewall")
@patch("wiregui.services.events.wireguard")
async def test_on_device_created_with_relay_subnets(mock_wg, mock_fw, mock_settings):
"""Test that device creation with relay subnets passes correct allowed_ips to WireGuard, adds routes, and configures firewall."""
mock_settings.return_value.wg_enabled = True
mock_wg.add_peer = AsyncMock()
mock_wg.add_routes = AsyncMock()
mock_fw.add_user_chain = AsyncMock()
mock_fw.add_device_jump_rule = AsyncMock()

device = _make_device(allowed_subnets=["192.168.1.0/24", "10.20.0.0/16"])
await on_device_created(device)

# Verify WireGuard peer was added with tunnel IPs + relay subnets
mock_wg.add_peer.assert_awaited_once_with(
public_key="pk-test",
allowed_ips=["10.3.2.5/32", "fd00::3:2:5/128", "192.168.1.0/24", "10.20.0.0/16"],
preshared_key="psk-test",
)

# Verify routes were added for relay subnets
mock_wg.add_routes.assert_awaited_once_with(["192.168.1.0/24", "10.20.0.0/16"])

# Verify firewall jump rule was added with relay subnets
mock_fw.add_device_jump_rule.assert_awaited_once_with(
"00000000-0000-0000-0000-000000000000",
"10.3.2.5",
"fd00::3:2:5",
["192.168.1.0/24", "10.20.0.0/16"],
)


# --- Rule events ---


Expand Down
92 changes: 91 additions & 1 deletion tests/test_wireguard_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
set_private_key,
set_listen_port,
configure_interface,
add_routes,
remove_routes,
)


Expand Down Expand Up @@ -111,4 +113,92 @@ async def test_configure_interface_sets_key_and_port(mock_session_cls, mock_run)
args = mock_run.call_args[0][0]
assert args[0:3] == ["wg", "set", "wg-test"]
assert "private-key" in args
assert "listen-port" in args
assert "listen-port" in args


# ========== add_routes ==========


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_routes_ipv4(mock_run):
"""add_routes should call ip route add for IPv4 subnets."""
mock_run.return_value = ""
await add_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test")

assert mock_run.await_count == 2
calls = [c[0][0] for c in mock_run.call_args_list]
assert calls[0] == ["ip", "-4", "route", "add", "192.168.1.0/24", "dev", "wg-test"]
assert calls[1] == ["ip", "-4", "route", "add", "10.20.0.0/16", "dev", "wg-test"]


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_routes_ipv6(mock_run):
"""add_routes should call ip -6 route add for IPv6 subnets."""
mock_run.return_value = ""
await add_routes(["fd00:1::/64", "fd00:2::/48"], iface="wg-test")

assert mock_run.await_count == 2
calls = [c[0][0] for c in mock_run.call_args_list]
assert calls[0] == ["ip", "-6", "route", "add", "fd00:1::/64", "dev", "wg-test"]
assert calls[1] == ["ip", "-6", "route", "add", "fd00:2::/48", "dev", "wg-test"]


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_routes_mixed(mock_run):
"""add_routes should handle mixed IPv4 and IPv6."""
mock_run.return_value = ""
await add_routes(["192.168.1.0/24", "fd00:1::/64"], iface="wg-test")

assert mock_run.await_count == 2
calls = [c[0][0] for c in mock_run.call_args_list]
assert calls[0] == ["ip", "-4", "route", "add", "192.168.1.0/24", "dev", "wg-test"]
assert calls[1] == ["ip", "-6", "route", "add", "fd00:1::/64", "dev", "wg-test"]


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_routes_empty_list(mock_run):
"""add_routes with empty list should not call ip route."""
await add_routes([], iface="wg-test")
mock_run.assert_not_awaited()


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_add_routes_already_exists(mock_run):
"""add_routes should not fail if route already exists."""
mock_run.side_effect = RuntimeError("RTNETLINK answers: File exists")
# Should not raise
await add_routes(["192.168.1.0/24"], iface="wg-test")
mock_run.assert_awaited_once()


# ========== remove_routes ==========


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_remove_routes_ipv4(mock_run):
"""remove_routes should call ip route del for IPv4 subnets."""
mock_run.return_value = ""
await remove_routes(["192.168.1.0/24", "10.20.0.0/16"], iface="wg-test")

assert mock_run.await_count == 2
calls = [c[0][0] for c in mock_run.call_args_list]
assert calls[0] == ["ip", "-4", "route", "del", "192.168.1.0/24", "dev", "wg-test"]
assert calls[1] == ["ip", "-4", "route", "del", "10.20.0.0/16", "dev", "wg-test"]


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_remove_routes_ipv6(mock_run):
"""remove_routes should call ip -6 route del for IPv6 subnets."""
mock_run.return_value = ""
await remove_routes(["fd00:1::/64"], iface="wg-test")

mock_run.assert_awaited_once_with(["ip", "-6", "route", "del", "fd00:1::/64", "dev", "wg-test"])


@patch("wiregui.services.wireguard._run", new_callable=AsyncMock)
async def test_remove_routes_not_found(mock_run):
"""remove_routes should not fail if route doesn't exist."""
mock_run.side_effect = RuntimeError("RTNETLINK answers: No such process")
# Should not raise
await remove_routes(["192.168.1.0/24"], iface="wg-test")
mock_run.assert_awaited_once()
Loading