Skip to content

Commit 19834b1

Browse files
committed
Add mutable union and difference operations on sets
1 parent 5bfcb1a commit 19834b1

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

docs/new_sets_doc.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,29 @@ Returns the union of several sets.
246246
The set union of all sets in `*args`.
247247

248248

249+
<a id="sets.mutable_union"></a>
250+
251+
## sets.mutable_union
252+
253+
<pre>
254+
sets.mutable_union(<a href="#sets.mutable_union-a">a</a>, <a href="#sets.mutable_union-b">b</a>)
255+
</pre>
256+
257+
Modify set `a` adding elements from `b` to it.
258+
259+
**PARAMETERS**
260+
261+
262+
| Name | Description | Default Value |
263+
| :------------- | :------------- | :------------- |
264+
| <a id="sets.mutable_union-a"></a>a | A set, as returned by <code>sets.make()</code>. | none |
265+
| <a id="sets.mutable_union-b"></a>b | A set, as returned by <code>sets.make()</code>. | none |
266+
267+
**RETURNS**
268+
269+
The set `a` with all elements appearing in `b` added to it.
270+
271+
249272
<a id="sets.difference"></a>
250273

251274
## sets.difference
@@ -269,6 +292,29 @@ Returns the elements in `a` that are not in `b`.
269292
A set containing the elements that are in `a` but not in `b`.
270293

271294

295+
<a id="sets.mutable_difference"></a>
296+
297+
## sets.mutable_difference
298+
299+
<pre>
300+
sets.mutable_difference(<a href="#sets.mutable_difference-a">a</a>, <a href="#sets.mutable_difference-b">b</a>)
301+
</pre>
302+
303+
Modify set `a` removing elements from `b` from it.
304+
305+
**PARAMETERS**
306+
307+
308+
| Name | Description | Default Value |
309+
| :------------- | :------------- | :------------- |
310+
| <a id="sets.mutable_difference-a"></a>a | A set, as returned by <code>sets.make()</code>. | none |
311+
| <a id="sets.mutable_difference-b"></a>b | A set, as returned by <code>sets.make()</code>. | none |
312+
313+
**RETURNS**
314+
315+
The set `a` with all elements appearing in `b` removed from it.
316+
317+
272318
<a id="sets.length"></a>
273319

274320
## sets.length

lib/new_sets.bzl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,19 @@ def _union(*args):
189189
"""
190190
return struct(_values = dicts.add(*[s._values for s in args]))
191191

192+
def _mutable_union(a, b):
193+
"""Modify set `a` adding elements from `b` to it.
194+
195+
Args:
196+
a: A set, as returned by `sets.make()`.
197+
b: A set, as returned by `sets.make()`.
198+
199+
Returns:
200+
The set `a` with all elements appearing in `b` added to it.
201+
"""
202+
a._values.update(b._values)
203+
return a
204+
192205
def _difference(a, b):
193206
"""Returns the elements in `a` that are not in `b`.
194207
@@ -201,6 +214,21 @@ def _difference(a, b):
201214
"""
202215
return struct(_values = {e: None for e in a._values.keys() if e not in b._values})
203216

217+
def _mutable_difference(a, b):
218+
"""Modify set `a` removing elements from `b` from it.
219+
220+
Args:
221+
a: A set, as returned by `sets.make()`.
222+
b: A set, as returned by `sets.make()`.
223+
224+
Returns:
225+
The set `a` with all elements appearing in `b` removed from it.
226+
"""
227+
for item in b._values.keys():
228+
if item in a._values:
229+
a._values.pop(item)
230+
return a
231+
204232
def _length(s):
205233
"""Returns the number of elements in a set.
206234
@@ -234,7 +262,9 @@ sets = struct(
234262
disjoint = _disjoint,
235263
intersection = _intersection,
236264
union = _union,
265+
mutable_union = _mutable_union,
237266
difference = _difference,
267+
mutable_difference = _mutable_difference,
238268
length = _length,
239269
remove = _remove,
240270
repr = _repr,

tests/new_sets_tests.bzl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,38 @@ def _union_test(ctx):
114114

115115
union_test = unittest.make(_union_test)
116116

117+
def _mutable_union_test(ctx):
118+
"""Unit tests for sets.union."""
119+
env = unittest.begin(ctx)
120+
121+
s = sets.make()
122+
s = sets.mutable_union(s, sets.make())
123+
asserts.new_set_equals(env, sets.make(), s)
124+
s = sets.make()
125+
s = sets.mutable_union(s, sets.make([1]))
126+
asserts.new_set_equals(env, sets.make([1]), s)
127+
s = sets.make([1])
128+
s = sets.mutable_union(s, sets.make())
129+
asserts.new_set_equals(env, sets.make([1]), s)
130+
s = sets.make([1])
131+
s = sets.mutable_union(s, sets.make([1]))
132+
asserts.new_set_equals(env, sets.make([1]), s)
133+
s = sets.make([1])
134+
s = sets.mutable_union(s, sets.make([1, 2]))
135+
asserts.new_set_equals(env, sets.make([1, 2]), s)
136+
s = sets.make([1])
137+
s = sets.mutable_union(s, sets.make([2]))
138+
asserts.new_set_equals(env, sets.make([1, 2]), s)
139+
140+
# If passing a list, verify that duplicate elements are ignored.
141+
s = sets.make([1, 1])
142+
s = sets.mutable_union(s, sets.make([1, 2]))
143+
asserts.new_set_equals(env, sets.make([1, 2]), s)
144+
145+
return unittest.end(env)
146+
147+
mutable_union_test = unittest.make(_mutable_union_test)
148+
117149
def _difference_test(ctx):
118150
"""Unit tests for sets.difference."""
119151
env = unittest.begin(ctx)
@@ -132,6 +164,38 @@ def _difference_test(ctx):
132164

133165
difference_test = unittest.make(_difference_test)
134166

167+
def _mutable_difference_test(ctx):
168+
"""Unit tests for sets.difference."""
169+
env = unittest.begin(ctx)
170+
171+
s = sets.make()
172+
s = sets.mutable_difference(s, sets.make())
173+
asserts.new_set_equals(env, sets.make(), s)
174+
s = sets.make()
175+
s = sets.mutable_difference(s, sets.make([1]))
176+
asserts.new_set_equals(env, sets.make(), s)
177+
s = sets.make([1])
178+
s = sets.mutable_difference(s, sets.make())
179+
asserts.new_set_equals(env, sets.make([1]), s)
180+
s = sets.make([1])
181+
s = sets.mutable_difference(s, sets.make([1]))
182+
asserts.new_set_equals(env, sets.make(), s)
183+
s = sets.make([1])
184+
s = sets.mutable_difference(s, sets.make([1, 2]))
185+
asserts.new_set_equals(env, sets.make(), s)
186+
s = sets.make([1])
187+
s = sets.mutable_difference(s, sets.make([2]))
188+
asserts.new_set_equals(env, sets.make([1]), s)
189+
190+
# If passing a list, verify that duplicate elements are ignored.
191+
s = sets.make([1, 2])
192+
s = sets.mutable_difference(s, sets.make([1, 1]))
193+
asserts.new_set_equals(env, sets.make([2]), s)
194+
195+
return unittest.end(env)
196+
197+
mutable_difference_test = unittest.make(_mutable_difference_test)
198+
135199
def _to_list_test(ctx):
136200
"""Unit tests for sets.to_list."""
137201
env = unittest.begin(ctx)
@@ -257,7 +321,9 @@ def new_sets_test_suite():
257321
is_equal_test,
258322
is_subset_test,
259323
difference_test,
324+
mutable_difference_test,
260325
union_test,
326+
mutable_union_test,
261327
to_list_test,
262328
make_test,
263329
copy_test,

0 commit comments

Comments
 (0)