diff --git a/src/braket/default_simulator/openqasm/_helpers/functions.py b/src/braket/default_simulator/openqasm/_helpers/functions.py index 002540b9..311b51d1 100644 --- a/src/braket/default_simulator/openqasm/_helpers/functions.py +++ b/src/braket/default_simulator/openqasm/_helpers/functions.py @@ -99,10 +99,16 @@ [BooleanLiteral(xv.value ^ yv.value) for xv, yv in zip(x.values, y.values)] ), getattr(BinaryOperator, "<<"): lambda x, y: ArrayLiteral( - x.values[y.value :] + [BooleanLiteral(False) for _ in range(y.value)] + x.values[len(y.values) :] + [BooleanLiteral(False) for _ in range(len(y.values))] + if isinstance(y, ArrayLiteral) + else x.values[y.value :] + [BooleanLiteral(False) for _ in range(y.value)] ), getattr(BinaryOperator, ">>"): lambda x, y: ArrayLiteral( - [BooleanLiteral(False) for _ in range(y.value)] + x.values[: len(x.values) - y.value] + [BooleanLiteral(False) for _ in range(len(y.values))] + + x.values[: len(x.values) - len(y.values)] + if isinstance(y, ArrayLiteral) + else [BooleanLiteral(False) for _ in range(y.value)] + + x.values[: len(x.values) - y.value] ), getattr(UnaryOperator, "~"): lambda x: ArrayLiteral( [BooleanLiteral(not v.value) for v in x.values] diff --git a/src/braket/default_simulator/openqasm/interpreter.py b/src/braket/default_simulator/openqasm/interpreter.py index 5503b11a..dc6a106b 100644 --- a/src/braket/default_simulator/openqasm/interpreter.py +++ b/src/braket/default_simulator/openqasm/interpreter.py @@ -248,8 +248,13 @@ def _(self, node: Identifier) -> LiteralType: @visit.register def _(self, node: QubitDeclaration) -> None: - size = self.visit(node.size).value if node.size else 1 - self.context.add_qubits(node.qubit.name, size) + size_arg = self.visit(node.size) + if isinstance(size_arg, ArrayLiteral) and size_arg: + size = "".join(str(cast_to(IntegerLiteral, qubit).value) for qubit in size_arg.values) + self.context.add_qubits(node.qubit.name, int(size, 2)) + else: + size = size_arg.value if size_arg else 1 + self.context.add_qubits(node.qubit.name, size) @visit.register def _(self, node: QuantumReset) -> None: diff --git a/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py b/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py index 20b9e2f4..28d49f5f 100644 --- a/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py +++ b/test/unit_tests/braket/default_simulator/openqasm/test_interpreter.py @@ -2245,6 +2245,80 @@ def test_measure_qubit_out_of_range(qasm, expected): Interpreter().build_circuit(qasm) +@pytest.mark.parametrize( + "qasm, expected", + [ + ( + """ + bit[2] b; + qubit["10"] r1; + b = measure r1; + """, + [0, 1], + ), + ( + """ + bit[3] b; + qubit["11"] r1; + b = measure r1; + """, + [0, 1, 2], + ), + ( + """ + bit[1] b; + qubit[!"1"] r1; + b = measure r1; + """, + [], + ), + ( + """ + qubit["1" ^ "0"] r1; + """, + [], + ), + ( + """ + bit[1] b; + qubit["1" != "0"] r1; + b = measure r1; + """, + [0], + ), + ( + """ + bit[1] b; + qubit["1" == "0"] r1; + b = measure r1; + """, + [], + ), + ( + """ + bit[1] b; + qubit[1] r1; + h r1["0" << "1"]; + b = measure r1; + """, + [0], + ), + ( + """ + bit[2] b; + qubit[1] r1; + h r1["0" >> "1"]; + b = measure r1; + """, + [0], + ), + ], +) +def test_circuit_from_string_literal(qasm, expected): + circ = Interpreter().build_circuit(source=qasm) + assert expected == circ.measured_qubits + + @pytest.mark.parametrize( "qasm,error_message", [