Skip to content

Commit 4aa3d34

Browse files
committed
deeponet tutorial
1 parent fca3db7 commit 4aa3d34

File tree

9 files changed

+497
-45
lines changed

9 files changed

+497
-45
lines changed

pina/model/deeponet.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(
5252
:param reduction: The reduction to be used to reduce the aggregated
5353
result of the modules in ``networks`` to the desired output
5454
dimension. Available reductions include: sum: ``+``, product: ``*``,
55-
mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
55+
mean: ``mean``, min: ``min``, max: ``max``, identity: "id".
56+
Default is ``+``.
5657
:type reduction: str or Callable
5758
:param bool scale: If ``True``, the final output is scaled before being
5859
returned in the forward pass. Default is ``True``.
@@ -122,19 +123,6 @@ def __init__(
122123
check_consistency(scale, bool)
123124
check_consistency(translation, bool)
124125

125-
# check trunk branch nets consistency
126-
shapes = []
127-
for key, value in networks.items():
128-
check_consistency(value, (str, int))
129-
check_consistency(key, torch.nn.Module)
130-
input_ = torch.rand(10, len(value))
131-
shapes.append(key(input_).shape[-1])
132-
133-
if not all(map(lambda x: x == shapes[0], shapes)):
134-
raise ValueError(
135-
"The passed networks have not the same output dimension."
136-
)
137-
138126
# assign trunk and branch net with their input indeces
139127
self.models = torch.nn.ModuleList(networks.keys())
140128
self._indeces = networks.values()
@@ -171,6 +159,7 @@ def _symbol_functions(**kwargs):
171159
"mean": partial(torch.mean, **kwargs),
172160
"min": lambda x: torch.min(x, **kwargs).values,
173161
"max": lambda x: torch.max(x, **kwargs).values,
162+
"id": lambda x: x,
174163
}
175164

176165
def _init_aggregator(self, aggregator):
@@ -181,7 +170,7 @@ def _init_aggregator(self, aggregator):
181170
:type aggregator: str or Callable
182171
:raises ValueError: If the aggregator is not supported.
183172
"""
184-
aggregator_funcs = self._symbol_functions(dim=2)
173+
aggregator_funcs = self._symbol_functions(dim=-1)
185174
if aggregator in aggregator_funcs:
186175
aggregator_func = aggregator_funcs[aggregator]
187176
elif isinstance(aggregator, nn.Module) or is_function(aggregator):
@@ -264,13 +253,9 @@ def forward(self, x):
264253
# reduce
265254
output_ = self._reduction(aggregated)
266255
if self._reduction_type in self._symbol_functions(dim=-1):
267-
output_ = output_.reshape(-1, 1)
268-
269-
# scale and translate
270-
output_ *= self._scale
271-
output_ += self._trasl
256+
output_ = output_.reshape(*output_.shape, 1)
272257

273-
return output_
258+
return self._scale * output_ + self._trasl
274259

275260
@property
276261
def aggregator(self):

tests/test_model/test_deeponet.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
data = torch.rand((20, 3))
1010
input_vars = ["a", "b", "c"]
1111
input_ = LabelTensor(data, input_vars)
12-
symbol_funcs_red = DeepONet._symbol_functions(dim=-1)
12+
symbol_funcs_red = DeepONet._symbol_functions()
1313
output_dims = [1, 5, 10, 20]
1414

1515

@@ -26,20 +26,6 @@ def test_constructor():
2626
)
2727

2828

29-
def test_constructor_fails_when_invalid_inner_layer_size():
30-
branch_net = FeedForward(input_dimensions=1, output_dimensions=10)
31-
trunk_net = FeedForward(input_dimensions=2, output_dimensions=8)
32-
with pytest.raises(ValueError):
33-
DeepONet(
34-
branch_net=branch_net,
35-
trunk_net=trunk_net,
36-
input_indeces_branch_net=["a"],
37-
input_indeces_trunk_net=["b", "c"],
38-
reduction="+",
39-
aggregator="*",
40-
)
41-
42-
4329
def test_forward_extract_str():
4430
branch_net = FeedForward(input_dimensions=1, output_dimensions=10)
4531
trunk_net = FeedForward(input_dimensions=2, output_dimensions=10)

tests/test_model/test_mionet.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,6 @@ def test_constructor():
1818
MIONet(networks=networks, reduction="+", aggregator="*")
1919

2020

21-
def test_constructor_fails_when_invalid_inner_layer_size():
22-
branch_net1 = FeedForward(input_dimensions=1, output_dimensions=10)
23-
branch_net2 = FeedForward(input_dimensions=2, output_dimensions=10)
24-
trunk_net = FeedForward(input_dimensions=1, output_dimensions=12)
25-
networks = {branch_net1: ["x"], branch_net2: ["x", "y"], trunk_net: ["z"]}
26-
with pytest.raises(ValueError):
27-
MIONet(networks=networks, reduction="+", aggregator="*")
28-
29-
3021
def test_forward_extract_str():
3122
branch_net1 = FeedForward(input_dimensions=1, output_dimensions=10)
3223
branch_net2 = FeedForward(input_dimensions=1, output_dimensions=10)

tutorials/static/deeponet.png

351 KB
Loading
392 KB
Binary file not shown.
40.8 KB
Binary file not shown.
392 KB
Binary file not shown.
40.8 KB
Binary file not shown.

tutorials/tutorial24/tutorial.ipynb

Lines changed: 490 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)