Skip to content

Commit 00ceffc

Browse files
committed
added remove.py
1 parent c83acb1 commit 00ceffc

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed

tests/remove.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright 2024-2025 Open Quantum Design
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from oqd_compiler_infrastructure.interface import TypeReflectBaseModel
17+
from oqd_compiler_infrastructure.rule import RewriteRule
18+
from oqd_compiler_infrastructure.walk import Post
19+
20+
21+
# AST data structures (same as before)
22+
class Expression(TypeReflectBaseModel):
23+
"""Base class for arithmetic expressions.
24+
25+
This class serves as the foundation for all expression types in the AST.
26+
"""
27+
28+
pass
29+
30+
31+
class Number(Expression):
32+
"""Represents a numeric literal.
33+
34+
Attributes:
35+
value (float): The numeric value of the literal.
36+
"""
37+
38+
value: float
39+
40+
41+
class Variable(Expression):
42+
"""Represents a variable in an expression.
43+
44+
Attributes:
45+
name (str): The name of the variable.
46+
"""
47+
48+
name: str
49+
50+
51+
class BinaryOp(Expression):
52+
"""Represents a binary operation.
53+
54+
Attributes:
55+
op (str): The operator (e.g., '+', '-', '*', '/').
56+
left (Expression): The left operand.
57+
right (Expression): The right operand.
58+
"""
59+
60+
op: str # '+', '-', '*', '/'
61+
left: Expression
62+
right: Expression
63+
64+
65+
class AdvancedAlgebraicSimplifier(RewriteRule):
66+
"""Applies advanced algebraic simplification rules.
67+
68+
Rules implemented:
69+
- x - x = 0
70+
- x + (-x) = 0
71+
- x * x = x^2
72+
- (x + y) - y = x
73+
- Distributive law: a * (b + c) = (a * b) + (a * c)
74+
"""
75+
76+
# Implements additional algebraic identities like subtraction and distribution
77+
78+
def map_BinaryOp(self, model):
79+
"""Apply advanced simplification to binary operations.
80+
81+
Args:
82+
model (BinaryOp): The binary operation to simplify.
83+
84+
Returns:
85+
Expression: The simplified expression or original if no rule applies.
86+
"""
87+
# x - x = 0
88+
# Rule: subtracting identical terms yields zero
89+
if model.op == "-" and self._expressions_equal(model.left, model.right):
90+
return Number(value=0)
91+
92+
# x + (-x) = 0
93+
# Rule: x + (-1 * x) => 0
94+
if (
95+
model.op == "+"
96+
and isinstance(model.right, BinaryOp)
97+
and model.right.op == "*"
98+
and isinstance(model.right.left, Number)
99+
and model.right.left.value == -1
100+
and self._expressions_equal(model.left, model.right.right)
101+
):
102+
return Number(value=0)
103+
104+
# Distributive law: a * (b + c) = (a * b) + (a * c)
105+
# Rule: a * (b + c) => (a*b) + (a*c)
106+
if (
107+
model.op == "*"
108+
and isinstance(model.right, BinaryOp)
109+
and model.right.op in ["+", "-"]
110+
):
111+
# a * (b + c) -> (a * b) + (a * c)
112+
return BinaryOp(
113+
op=model.right.op,
114+
left=BinaryOp(op="*", left=model.left, right=model.right.left),
115+
right=BinaryOp(op="*", left=model.left, right=model.right.right),
116+
)
117+
118+
# (x + y) - y = x
119+
# Rule: (x + y) - y => x
120+
if (
121+
model.op == "-"
122+
and isinstance(model.left, BinaryOp)
123+
and model.left.op == "+"
124+
and self._expressions_equal(model.left.right, model.right)
125+
):
126+
return model.left.left
127+
128+
return model
129+
130+
def _expressions_equal(self, expr1, expr2):
131+
"""Check if two expressions are structurally equal.
132+
133+
Args:
134+
expr1 (Expression): The first expression to compare.
135+
expr2 (Expression): The second expression to compare.
136+
137+
Returns:
138+
bool: True if structurally equal, False otherwise.
139+
"""
140+
# Compare types and recursively compare sub-expressions
141+
if not isinstance(expr1, type(expr2)):
142+
return False
143+
144+
if isinstance(expr1, Number):
145+
return expr1.value == expr2.value
146+
147+
if isinstance(expr1, Variable):
148+
return expr1.name == expr2.name
149+
150+
if isinstance(expr1, BinaryOp):
151+
return (
152+
expr1.op == expr2.op
153+
and self._expressions_equal(expr1.left, expr2.left)
154+
and self._expressions_equal(expr1.right, expr2.right)
155+
)
156+
157+
return False
158+
159+
160+
def print_expr(expr):
161+
"""Convert an expression into a readable string.
162+
163+
Args:
164+
expr (Expression): The expression to format.
165+
166+
Returns:
167+
str: A string representation of the expression.
168+
"""
169+
# Convert AST nodes into parenthesized infix notation
170+
if isinstance(expr, Number):
171+
return str(expr.value)
172+
elif isinstance(expr, Variable):
173+
return expr.name
174+
elif isinstance(expr, BinaryOp):
175+
return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})"
176+
return str(expr)
177+
178+
179+
def main():
180+
"""Main function to demonstrate advanced algebraic simplification."""
181+
# Prepare test cases and run the AdvancedAlgebraicSimplifier
182+
# Create test expressions
183+
test_cases = [
184+
# x - x = 0
185+
BinaryOp(op="-", left=Variable(name="x"), right=Variable(name="x")),
186+
# x + (-1 * x) = 0
187+
BinaryOp(
188+
op="+",
189+
left=Variable(name="x"),
190+
right=BinaryOp(op="*", left=Number(value=-1), right=Variable(name="x")),
191+
),
192+
# a * (b + c) -> (a * b) + (a * c)
193+
BinaryOp(
194+
op="*",
195+
left=Variable(name="a"),
196+
right=BinaryOp(op="+", left=Variable(name="b"), right=Variable(name="c")),
197+
),
198+
# (x + y) - y = x
199+
BinaryOp(
200+
op="-",
201+
left=BinaryOp(op="+", left=Variable(name="x"), right=Variable(name="y")),
202+
right=Variable(name="y"),
203+
),
204+
]
205+
206+
# Create simplifier with Post traversal
207+
simplifier = Post(AdvancedAlgebraicSimplifier())
208+
209+
# Run simplifications
210+
print("Advanced Algebraic Simplifications:")
211+
for i, expr in enumerate(test_cases, 1):
212+
print(f"\nTest Case {i}:")
213+
print(f"Original: {print_expr(expr)}")
214+
result = simplifier(expr)
215+
print(f"Simplified: {print_expr(result)}")
216+
217+
218+
if __name__ == "__main__":
219+
main()

0 commit comments

Comments
 (0)