Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ class CapabilitiesConstants(BaseModel):

MAGNITUDE_PATTERN_VALUE_MIN: Decimal
MAGNITUDE_PATTERN_VALUE_MAX: Decimal
MAX_NET_DETUNING: Decimal
MAX_NET_DETUNING: Decimal | None
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def validate_net_detuning_with_warning(
program (Program): The given program
time_points (np.ndarray): The time points for both global and local detunings
global_detuning_coefs (np.ndarray): The values of global detuning
local_detuning_patterns (List): The pattern of local detuning
local_detuning_patterns (list): The pattern of local detuning
local_detuning_coefs (np.ndarray): The values of local detuning
capabilities (CapabilitiesConstants): The capability constants

Expand Down Expand Up @@ -102,3 +102,99 @@ def validate_net_detuning_with_warning(
# Return immediately if there is an atom has net detuning
# exceeding MAX_NET_DETUNING at a time point
return program


Copy link
Contributor

@maolinml maolinml Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to remind myself, are these four functions the only things that needs to be moved from the service side to the client side? Also it maybe good to comment on the top of each of these functions that they are only for device emulator and not for AHS local simulator [so that people won't be confused that why they are not used in this repo].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, don't we also need "validate_pattern_precision"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate_pattern_precision is never used by the Device validators so I chose not to bring it over to the Default Simulator repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the additional comments makes sense for explaining the use of these helpers!

def validate_time_separation(times: list[Decimal], min_time_separation: Decimal, name: str) -> None:
"""
Used in Device Emulation; Validate that the time points in a time series are separated by at
least min_time_separation.

Args:
times (list[Decimal]): A list of time points in a time series.
min_time_separation (Decimal): The minimal amount of time any two time points should be
separated by.
name (str): The name of the time series, used for logging.

Raises:
ValueError: If any two subsequent time points (assuming the time points are sorted
in ascending order) are separated by less than min_time_separation.
"""
for i in range(len(times) - 1):
time_diff = times[i + 1] - times[i]
if time_diff < min_time_separation:
raise ValueError(
f"Time points of {name} time_series, {i} ({times[i]}) and "
f"{i + 1} ({times[i + 1]}), are too close; they are separated "
f"by {time_diff} seconds. It must be at least {min_time_separation} seconds"
)


def validate_value_precision(values: list[Decimal], max_precision: Decimal, name: str) -> None:
"""
Used in Device Emulation; Validate that the precision of a set of values do not
exceed max_precision.

Args:
times (list[Decimal]): A list of values from a time series to validate.
max_precision (Decimal): The maximum allowed precision.
name (str): The name of the time series, used for logging.

Raises:
ValueError: If any of the given values is defined with precision exceeding max_precision.
"""
for idx, v in enumerate(values):
if v % max_precision != 0:
raise ValueError(
f"Value {idx} ({v}) in {name} time_series is defined with too many digits; "
f"it must be an integer multiple of {max_precision}"
)


def validate_max_absolute_slope(
times: list[Decimal], values: list[Decimal], max_slope: Decimal, name: str
):
"""
Used in Device Emulation; Validate that the magnitude of the slope between any
two subsequent points in a time series (time points provided in ascending order) does not
exceed max_slope.

Args:
times (list[Decimal]): A list of time points in a time series.
max_slope (Decimal): The maximum allowed rate of change between points in the time series.
name (str): The name of the time series, used for logging.

Raises:
ValueError: if at any time the time series (times, values)
rises/falls faster than allowed.
"""
for idx in range(len(values) - 1):
slope = (values[idx + 1] - values[idx]) / (times[idx + 1] - times[idx])
if abs(slope) > max_slope:
raise ValueError(
f"For the {name} field, rate of change of values "
f"(between the {idx}-th and the {idx + 1}-th times) "
f"is {abs(slope)}, more than {max_slope}"
)


def validate_time_precision(times: list[Decimal], time_precision: Decimal, name: str):
"""
Used in Device Emulation; Validate that the precision of a set of time points do not
exceed max_precision.

Args:
times (list[Decimal]): A list of time points to validate.
max_precision (Decimal): The maximum allowed precision.
name (str): The name of the time series, used for logging.

Raises:
ValueError: If any of the given time points is defined with
precision exceeding max_precision.
"""
for idx, t in enumerate(times):
if t % time_precision != 0:
raise ValueError(
f"time point {idx} ({t}) of {name} time_series is "
f"defined with too many digits; it must be an "
f"integer multiple of {time_precision}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def net_detuning_must_not_exceed_max_net_detuning(cls, values):
# If no local detuning, we simply return the values
# because there are separate validators to validate
# the global driving fields in the program
if not len(local_detuning):
if not len(local_detuning) or not capabilities.MAX_NET_DETUNING:
return values

detuning_times = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from decimal import Decimal

import pytest

from braket.analog_hamiltonian_simulator.rydberg.validators.field_validator_util import (
validate_max_absolute_slope,
validate_time_precision,
validate_time_separation,
validate_value_precision,
)


@pytest.mark.parametrize(
"times, min_time_separation, fail",
[
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1e-3"),
True,
),
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1e-6"),
False,
),
(
[Decimal("0.0"), Decimal("1"), Decimal("2"), Decimal("3"), Decimal("4")],
Decimal("1e-3"),
False,
),
],
)
def test_validate_time_separation(times, min_time_separation, fail):
if fail:
with pytest.raises(ValueError):
validate_time_separation(times, min_time_separation, "test")
else:
try:
validate_time_separation(times, min_time_separation, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_min_time_separation: {str(e)}")


@pytest.mark.parametrize(
"values, max_precision, fail",
[
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1e-3"),
True,
),
(
[Decimal("0.0"), Decimal("1e-9"), Decimal("2e-5"), Decimal("3e-4"), Decimal("5.0")],
Decimal("1e-6"),
True,
),
(
[
Decimal("0.0"),
Decimal("0.00089"),
Decimal("2e-4"),
Decimal("0.003"),
Decimal("0.21"),
Decimal("1"),
],
Decimal("1e-5"),
False,
),
],
)
def test_validate_value_precision(values, max_precision, fail):
if fail:
with pytest.raises(ValueError):
validate_value_precision(values, max_precision, "test")
else:
try:
validate_value_precision(values, max_precision, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_value_precision: {str(e)}")


@pytest.mark.parametrize(
"times, values, max_slope, fail",
[
(
[Decimal("0.0"), Decimal("1.0"), Decimal("2.0"), Decimal("3.0")],
[Decimal("0.0"), Decimal("2.1"), Decimal("3.2"), Decimal("3.9")],
Decimal("2.0"),
True,
),
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("3")],
[Decimal("0.0"), Decimal("1.2"), Decimal("2.34"), Decimal("2.39")],
Decimal("1.5e5"),
False,
),
(
[Decimal("0.0"), Decimal("1.0"), Decimal("2e-5"), Decimal("3")],
[Decimal("0.0"), Decimal("1.2"), Decimal("2.34"), Decimal("2.39")],
Decimal("1e4"),
False,
),
],
)
def test_validate_max_absolute_slope(times, values, max_slope, fail):
if fail:
with pytest.raises(ValueError):
validate_max_absolute_slope(times, values, max_slope, "test")
else:
try:
validate_max_absolute_slope(times, values, max_slope, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_max_absolute_slope: {str(e)}")


@pytest.mark.parametrize(
"times, max_precision, fail",
[
(
[Decimal("0.0"), Decimal("1e-5"), Decimal("2e-5"), Decimal("2.5"), Decimal("4")],
Decimal("1.3"),
True,
),
(
[Decimal("0.0"), Decimal("1e-9"), Decimal("2e-5"), Decimal("3e-4"), Decimal("5.0")],
Decimal("1e-6"),
True,
),
(
[Decimal("0"), Decimal("1e-07"), Decimal("3.9e-06"), Decimal("4e-06")],
Decimal("1e-09"),
False,
),
],
)
def test_validate_time_precision(times, max_precision, fail):
if fail:
with pytest.raises(ValueError):
validate_time_precision(times, max_precision, "test")
else:
try:
validate_time_precision(times, max_precision, "test")
except ValueError as e:
pytest.fail(f"Failed valid validate_min_time_precision: {str(e)}")