Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions sky/skypilot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def __init__(self,

SKYPILOT_CONFIG_LOCK_PATH = '~/.sky/locks/.skypilot_config.lock'

_WARNED_DISALLOWED_KEYS_CACHE: typing.Set[Tuple[str, Tuple[str, ...]]] = set()
_WARNED_DISALLOWED_KEYS_CACHE_LOCK = threading.Lock()
_WARNED_DISALLOWED_KEYS_CACHE_SIZE = 1000


def get_skypilot_config_lock_path() -> str:
"""Get the path for the SkyPilot config lock file."""
Expand Down Expand Up @@ -731,12 +735,25 @@ def override_skypilot_config(
# Only warn if there is a diff in disallowed override keys, as the client
# use the same config file when connecting to a local server.
if disallowed_diff_keys:
logger.warning(
f'The following keys ({json.dumps(disallowed_diff_keys)}) have '
'different values in the client SkyPilot config with the server '
'and will be ignored. Remove these keys to disable this warning. '
'If you want to specify it, please modify it on server side or '
'contact your administrator.')
run_id = common_utils.get_usage_run_id()
should_warn = True
if run_id:
cache_key = (run_id, tuple(sorted(disallowed_diff_keys)))
with _WARNED_DISALLOWED_KEYS_CACHE_LOCK:
should_warn = cache_key not in _WARNED_DISALLOWED_KEYS_CACHE
if should_warn:
if (len(_WARNED_DISALLOWED_KEYS_CACHE) >
_WARNED_DISALLOWED_KEYS_CACHE_SIZE):
_WARNED_DISALLOWED_KEYS_CACHE.clear()
_WARNED_DISALLOWED_KEYS_CACHE.add(cache_key)

if should_warn:
logger.warning(
f'The following keys ({json.dumps(disallowed_diff_keys)}) have '
'different values in the client SkyPilot config with the '
'server and will be ignored. Remove these keys to disable '
'this warning. If you want to specify it, please modify it '
'on server side or contact your administrator.')
config = original_config.get_nested(
keys=tuple(),
default_value=None,
Expand Down
153 changes: 153 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,159 @@ def test_override_skypilot_config_with_disallowed_keys(monkeypatch, tmp_path):
'your administrator.')


@mock.patch('sky.skylet.constants.SKIPPED_CLIENT_OVERRIDE_KEYS',
[('aws', 'vpc_name'), ('aws', 'ssh_proxy_command')])
def test_override_skypilot_config_warning_deduplication(monkeypatch, tmp_path):
"""Test that warning messages are deduplicated per run_id."""
from sky.usage import constants as usage_constants

# Clear the cache before testing
skypilot_config._WARNED_DISALLOWED_KEYS_CACHE.clear()

with mock.patch('sky.skypilot_config.logger') as mock_logger:
mock_logger.getEffectiveLevel.return_value = INFO
os.environ.pop(skypilot_config.ENV_VAR_SKYPILOT_CONFIG, None)
# Create original config file
config_path = tmp_path / 'config.yaml'
_create_config_file(config_path)
monkeypatch.setattr(skypilot_config, '_GLOBAL_CONFIG_PATH', config_path)
skypilot_config.safe_reload_config()

# Verify config is loaded and has expected value
assert skypilot_config.get_nested(('aws', 'vpc_name'), None) == VPC_NAME

run_id_1 = 'test-run-id-1'
run_id_2 = 'test-run-id-2'

# Test 1: Warning appears once per run_id
annotations.clear_request_level_cache()
monkeypatch.setenv(usage_constants.USAGE_RUN_ID_ENV_VAR, run_id_1)
override_configs_1 = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(override_configs_1):
pass
# First call should log warning
assert mock_logger.warning.call_count == 1
mock_logger.warning.reset_mock()

# Second call with same run_id should NOT log warning (deduplicated)
override_configs_2 = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(override_configs_2):
pass
mock_logger.warning.assert_not_called()

# Test 2: Warning appears again with different run_id
# Verify cache has the first entry
assert len(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE) == 1
# Verify the cache entry is for run_id_1
cache_entry_run_id_1 = (run_id_1, ('aws.vpc_name',))
assert cache_entry_run_id_1 in skypilot_config._WARNED_DISALLOWED_KEYS_CACHE

# Change run_id and verify it's different
annotations.clear_request_level_cache()
monkeypatch.setenv(usage_constants.USAGE_RUN_ID_ENV_VAR, run_id_2)
assert os.environ.get(usage_constants.USAGE_RUN_ID_ENV_VAR) == run_id_2

# The cache entry for run_id_2 should not exist yet
cache_entry_run_id_2 = (run_id_2, ('aws.vpc_name',))
assert cache_entry_run_id_2 not in skypilot_config._WARNED_DISALLOWED_KEYS_CACHE

# Call with new run_id - should trigger warning
override_configs_3 = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(override_configs_3):
pass
# Should log warning again for new run_id (different cache key)
assert mock_logger.warning.call_count == 1, (
f'Expected 1 warning call, got {mock_logger.warning.call_count}. '
f'Cache size: {len(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE)}. '
f'Cache contents: {list(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE)}'
)
# Cache should now have 2 entries (one per run_id)
assert len(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE) == 2
assert cache_entry_run_id_2 in skypilot_config._WARNED_DISALLOWED_KEYS_CACHE
mock_logger.warning.reset_mock()

# Test 3: Backward compatibility - warning appears when run_id is not set
annotations.clear_request_level_cache()
monkeypatch.delenv(usage_constants.USAGE_RUN_ID_ENV_VAR, raising=False)
override_configs_no_run_id = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(
override_configs_no_run_id):
pass
# Should log warning when run_id is not set
mock_logger.warning.assert_called_once()
mock_logger.warning.reset_mock()

# Test 4: Different disallowed keys get separate warnings
annotations.clear_request_level_cache()
monkeypatch.setenv(usage_constants.USAGE_RUN_ID_ENV_VAR, run_id_1)
override_configs_different_key = {
'aws': {
'ssh_proxy_command': 'override-command',
}
}
with skypilot_config.override_skypilot_config(
override_configs_different_key):
pass
# Should log warning for different keys even with same run_id
mock_logger.warning.assert_called_once()


@mock.patch('sky.skylet.constants.SKIPPED_CLIENT_OVERRIDE_KEYS',
[('aws', 'vpc_name')])
def test_override_skypilot_config_warning_cache_clearing(monkeypatch, tmp_path):
"""Test that warning cache is cleared when it exceeds 1000 entries."""
from sky.usage import constants as usage_constants

# Clear the cache before testing
skypilot_config._WARNED_DISALLOWED_KEYS_CACHE.clear()

with mock.patch('sky.skypilot_config.logger') as mock_logger:
mock_logger.getEffectiveLevel.return_value = INFO
os.environ.pop(skypilot_config.ENV_VAR_SKYPILOT_CONFIG, None)
# Create original config file
config_path = tmp_path / 'config.yaml'
_create_config_file(config_path)
monkeypatch.setattr(skypilot_config, '_GLOBAL_CONFIG_PATH', config_path)
skypilot_config.safe_reload_config()

# Fill cache to just under the limit using different run_ids
# Each run_id creates a unique cache entry
for i in range(1000):
run_id = f'test-run-id-{i}'
annotations.clear_request_level_cache()
monkeypatch.setenv(usage_constants.USAGE_RUN_ID_ENV_VAR, run_id)
# Create fresh dict for each iteration since pop_nested modifies it
override_configs = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(override_configs):
pass

# Verify cache has 1000 entries
assert len(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE) == 1000

# Add one more entry (1001st), which should NOT trigger cache clearing yet
# (clearing happens when cache > 1000, i.e., at 1001)
run_id_1000 = 'test-run-id-1000'
annotations.clear_request_level_cache()
monkeypatch.setenv(usage_constants.USAGE_RUN_ID_ENV_VAR, run_id_1000)
override_configs_1000 = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(override_configs_1000):
pass
# Cache should now have 1001 entries (not cleared yet)
assert len(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE) == 1001

# Add one more entry (1002nd), which should trigger cache clearing
run_id_1001 = 'test-run-id-1001'
annotations.clear_request_level_cache()
monkeypatch.setenv(usage_constants.USAGE_RUN_ID_ENV_VAR, run_id_1001)
override_configs_1001 = {'aws': {'vpc_name': 'override-vpc',}}
with skypilot_config.override_skypilot_config(override_configs_1001):
pass
# Cache should be cleared and only have 1 entry now (the 1002nd entry)
assert len(skypilot_config._WARNED_DISALLOWED_KEYS_CACHE) == 1
assert (run_id_1001, (
'aws.vpc_name',)) in skypilot_config._WARNED_DISALLOWED_KEYS_CACHE


def test_hierarchical_server_config(monkeypatch, tmp_path):
"""Test that hierarchical server config is loaded correctly."""
# prepare a clean test environment
Expand Down
Loading