@@ -52,7 +52,8 @@ def __init__(
5252 :param reduction: The reduction to be used to reduce the aggregated
5353 result of the modules in ``networks`` to the desired output
5454 dimension. Available reductions include: sum: ``+``, product: ``*``,
55- mean: ``mean``, min: ``min``, max: ``max``. Default is ``+``.
55+ mean: ``mean``, min: ``min``, max: ``max``, identity: "id".
56+ Default is ``+``.
5657 :type reduction: str or Callable
5758 :param bool scale: If ``True``, the final output is scaled before being
5859 returned in the forward pass. Default is ``True``.
@@ -122,19 +123,6 @@ def __init__(
122123 check_consistency (scale , bool )
123124 check_consistency (translation , bool )
124125
125- # check trunk branch nets consistency
126- shapes = []
127- for key , value in networks .items ():
128- check_consistency (value , (str , int ))
129- check_consistency (key , torch .nn .Module )
130- input_ = torch .rand (10 , len (value ))
131- shapes .append (key (input_ ).shape [- 1 ])
132-
133- if not all (map (lambda x : x == shapes [0 ], shapes )):
134- raise ValueError (
135- "The passed networks have not the same output dimension."
136- )
137-
138126 # assign trunk and branch net with their input indeces
139127 self .models = torch .nn .ModuleList (networks .keys ())
140128 self ._indeces = networks .values ()
@@ -171,6 +159,7 @@ def _symbol_functions(**kwargs):
171159 "mean" : partial (torch .mean , ** kwargs ),
172160 "min" : lambda x : torch .min (x , ** kwargs ).values ,
173161 "max" : lambda x : torch .max (x , ** kwargs ).values ,
162+ "id" : lambda x : x ,
174163 }
175164
176165 def _init_aggregator (self , aggregator ):
@@ -181,7 +170,7 @@ def _init_aggregator(self, aggregator):
181170 :type aggregator: str or Callable
182171 :raises ValueError: If the aggregator is not supported.
183172 """
184- aggregator_funcs = self ._symbol_functions (dim = 2 )
173+ aggregator_funcs = self ._symbol_functions (dim = - 1 )
185174 if aggregator in aggregator_funcs :
186175 aggregator_func = aggregator_funcs [aggregator ]
187176 elif isinstance (aggregator , nn .Module ) or is_function (aggregator ):
@@ -264,13 +253,9 @@ def forward(self, x):
264253 # reduce
265254 output_ = self ._reduction (aggregated )
266255 if self ._reduction_type in self ._symbol_functions (dim = - 1 ):
267- output_ = output_ .reshape (- 1 , 1 )
268-
269- # scale and translate
270- output_ *= self ._scale
271- output_ += self ._trasl
256+ output_ = output_ .reshape (* output_ .shape , 1 )
272257
273- return output_
258+ return self . _scale * output_ + self . _trasl
274259
275260 @property
276261 def aggregator (self ):
0 commit comments