3232#include  "ompi/mca/part/persist_aggregated/part_persist_aggregated_sendreq.h" 
3333#include  "ompi/mca/part/persist_aggregated/part_persist_aggregated_recvreq.h" 
3434
35+ #include  "ompi/mca/part/persist_aggregated/schemes/part_persist_aggregated_scheme_regular.h" 
36+ 
3537static  int  mca_part_persist_aggregated_progress (void );
3638static  int  mca_part_persist_aggregated_precv_init (void  * , size_t , size_t , ompi_datatype_t  * , int , int , struct  ompi_communicator_t  * , struct  ompi_info_t  * , struct  ompi_request_t  * * );
3739static  int  mca_part_persist_aggregated_psend_init (const  void * , size_t , size_t , ompi_datatype_t * , int , int , ompi_communicator_t * , struct  ompi_info_t  * , ompi_request_t * * );
@@ -49,6 +51,73 @@ ompi_part_persist_aggregated_t ompi_part_persist_aggregated = {
4951    }
5052};
5153
54+ /** 
55+  * @brief selects an internal partitioning based on the user-provided partitioning 
56+  * and the mca parameters for minimal partition size and maximal partition count. 
57+  * 
58+  * More precisely, given a partitioning into p partitions of size s, computes 
59+  * an internal partitioning into p' partitions of size s' (apart from the last one, 
60+  * which has potentially different size r * s): 
61+  *      p * s = (p' - 1) * s' + r * s 
62+  * where 
63+  *      s' >= s 
64+  *      p' <= p 
65+  *      0 < r * s <= s' 
66+  * and 
67+  *      s' <= max_message_count 
68+  *      p' >= min_message_size 
69+  * (given by mca parameters). 
70+  * 
71+  * @param[in]  partitions           number of user-provided partitions 
72+  * @param[in]  count                size of user-provided partitions in elements 
73+  * @param[out] internal_partitions  number of internal partitions 
74+  * @param[out] factor               number of public partitions corresponding to each internal 
75+  * partitions other than the last one 
76+  * @param[out] last_size            number of public partitions corresponding to the last internal 
77+  * partition 
78+  */ 
79+ static  inline  void  
80+ part_persist_aggregated_select_internal_partitioning (size_t  partitions ,
81+                                                      size_t  part_size ,
82+                                                      size_t  * internal_partitions ,
83+                                                      size_t  * factor ,
84+                                                      size_t  * remainder )
85+ {
86+     size_t  buffer_size  =  partitions  *  part_size ;
87+     size_t  min_part_size  =  ompi_part_persist_aggregated .min_message_size ;
88+     size_t  max_part_count  =  ompi_part_persist_aggregated .max_message_count ;
89+ 
90+     // check if max_part_count imposes higher lower bound on partition size 
91+     if  (max_part_count  >  0  &&  (buffer_size  / max_part_count ) >  min_part_size ) {
92+         min_part_size  =  buffer_size  / max_part_count ;
93+     }
94+ 
95+     // cannot have partitions larger than buffer size 
96+     if  (min_part_size  >  buffer_size ) {
97+         min_part_size  =  buffer_size ;
98+     }
99+ 
100+     if  (part_size  <  min_part_size ) { 
101+         // have to use larger partititions 
102+         // solve p = (p' - 1) * a + r for a (factor) and r (remainder) 
103+         * factor  =  min_part_size  / part_size ;
104+         * internal_partitions  =  partitions  / * factor ;
105+         * remainder  =  partitions  % (* internal_partitions );
106+ 
107+         if  (* remainder  ==  0 ) {  // size of last partition must be set 
108+             * remainder  =  * factor ;
109+         } else  {                
110+             // number of partitions was floored, so add 1 for last (smaller) partition 
111+             * internal_partitions  +=  1 ;
112+         }
113+     } else  { 
114+         // can keep original partitioning 
115+         * internal_partitions  =  partitions ;
116+         * factor  =  1 ;
117+         * remainder  =  1 ;
118+     }
119+ }
120+ 
52121/** 
53122 * This is a helper function that frees a request. This requires ompi_part_persist_aggregated.lock be held before calling. 
54123 */ 
@@ -59,6 +128,12 @@ mca_part_persist_aggregated_free_req(struct mca_part_persist_aggregated_request_
59128    size_t  i ;
60129    opal_list_remove_item (ompi_part_persist_aggregated .progress_list , (opal_list_item_t * )req -> progress_elem );
61130    OBJ_RELEASE (req -> progress_elem );
131+  
132+     // if on sender side, free aggregation state 
133+     if  (MCA_PART_PERSIST_AGGREGATED_REQUEST_PSEND  ==  req -> req_type ) {
134+         mca_part_persist_aggregated_psend_request_t  * sendreq  =  (mca_part_persist_aggregated_psend_request_t  * ) req ;
135+         part_persist_aggregate_regular_free (& sendreq -> aggregation_state );
136+     }
62137
63138    for (i  =  0 ; i  <  req -> real_parts ; i ++ ) {
64139        ompi_request_free (& (req -> persist_reqs [i ]));
@@ -187,17 +262,21 @@ mca_part_persist_aggregated_progress(void)
187262
188263                    /* Set up persistent sends */ 
189264                    req -> persist_reqs  =  (ompi_request_t * * ) malloc (sizeof (ompi_request_t * )* (req -> real_parts ));
190-                     for (i  =  0 ; i  <  req -> real_parts ; i ++ ) {
265+                     for (i  =  0 ; i  <  req -> real_parts   -   1 ; i ++ ) {
191266                         void  * buf  =  ((void * ) (((char * )req -> req_addr ) +  (bytes  *  i )));
192267                         err  =  MCA_PML_CALL (isend_init (buf , req -> real_count , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , MCA_PML_BASE_SEND_STANDARD , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
193268                    }
269+                     // last transfer partition can have different size 
270+                     void  * buf  =  ((void * ) (((char * )req -> req_addr ) +  (bytes  *  i )));
271+                     err  =  MCA_PML_CALL (isend_init (buf , req -> real_remainder , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , MCA_PML_BASE_SEND_STANDARD , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
194272                } else  {
195273                    /* parse message */ 
196-                     req -> world_peer    =  req -> setup_info [1 ].world_rank ;
197-                     req -> my_send_tag   =  req -> setup_info [1 ].start_tag ;
198-                     req -> my_recv_tag   =  req -> setup_info [1 ].setup_tag ;
199-                     req -> real_parts    =  req -> setup_info [1 ].num_parts ;
200-                     req -> real_count    =  req -> setup_info [1 ].count ;
274+                     req -> world_peer      =  req -> setup_info [1 ].world_rank ;
275+                     req -> my_send_tag     =  req -> setup_info [1 ].start_tag ;
276+                     req -> my_recv_tag     =  req -> setup_info [1 ].setup_tag ;
277+                     req -> real_parts      =  req -> setup_info [1 ].num_parts ;
278+                     req -> real_count      =  req -> setup_info [1 ].count ;
279+                     req -> real_remainder  =  req -> setup_info [1 ].remainder ;
201280
202281                    err  =  opal_datatype_type_size (& (req -> req_datatype -> super ), & dt_size_ );
203282                    if (OMPI_SUCCESS  !=  err ) return  OMPI_ERROR ;
@@ -207,10 +286,14 @@ mca_part_persist_aggregated_progress(void)
207286                    /* Set up persistent sends */ 
208287                    req -> persist_reqs  =  (ompi_request_t * * ) malloc (sizeof (ompi_request_t * )* (req -> real_parts ));
209288                    req -> flags  =  (int * ) calloc (req -> real_parts ,sizeof (int ));
210-                     for (i  =  0 ; i  <  req -> real_parts ; i ++ ) {
289+                     for (i  =  0 ; i  <  req -> real_parts   -   1 ; i ++ ) {
211290                         void  * buf  =  ((void * ) (((char * )req -> req_addr ) +  (bytes  *  i )));
212291                         err  =  MCA_PML_CALL (irecv_init (buf , req -> real_count , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
213292                    }
293+                     // last transfer partition can have different size 
294+                     void  * buf  =  ((void * ) (((char * )req -> req_addr ) +  (bytes  *  i )));
295+                     err  =  MCA_PML_CALL (irecv_init (buf , req -> real_remainder , req -> req_datatype , req -> world_peer , req -> my_send_tag + i , ompi_part_persist_aggregated .part_comm , & (req -> persist_reqs [i ])));
296+ 
214297                    err  =  req -> persist_reqs [0 ]-> req_start (req -> real_parts , (& (req -> persist_reqs [0 ])));
215298
216299                    /* Send back a message */ 
@@ -373,19 +456,26 @@ mca_part_persist_aggregated_psend_init(const void* buf,
373456    dt_size  =  (dt_size_  >  (size_t ) INT_MAX ) ? MPI_UNDEFINED  : (int ) dt_size_ ;
374457    req -> req_bytes  =  parts  *  count  *  dt_size ;
375458
459+     // select internal partitioning (i.e. real_parts) here 
460+     size_t  factor , remaining_partitions ;
461+     part_persist_aggregated_select_internal_partitioning (parts , count , & req -> real_parts , & factor , & remaining_partitions );
462+ 
463+     req -> real_remainder  =  remaining_partitions  *  count ;     // convert to number of elements 
464+     req -> real_count  =  factor  *  count ;
465+     req -> setup_info [0 ].num_parts  =  req -> real_parts ;         // setup info has to contain internal partitioning 
466+     req -> setup_info [0 ].count      =  req -> real_count ;
467+     req -> setup_info [0 ].remainder  =  req -> real_remainder ;
468+     opal_output_verbose (5 , ompi_part_base_framework .framework_output , "mapped given %lu*%lu partitioning to internal partitioning of %lu*%lu + %lu\n" , parts , count , req -> real_parts  -  1 , req -> real_count , req -> real_remainder );
376469
470+     // init aggregation state 
471+     part_persist_aggregate_regular_init (& sendreq -> aggregation_state , req -> real_parts , factor , remaining_partitions );
377472
378473    /* non-blocking send set-up data */ 
379474    req -> setup_info [0 ].world_rank  =  ompi_comm_rank (& ompi_mpi_comm_world .comm );
380475    req -> setup_info [0 ].start_tag  =  ompi_part_persist_aggregated .next_send_tag ; ompi_part_persist_aggregated .next_send_tag  +=  parts ;
381476    req -> my_send_tag  =  req -> setup_info [0 ].start_tag ;
382477    req -> setup_info [0 ].setup_tag  =  ompi_part_persist_aggregated .next_recv_tag ; ompi_part_persist_aggregated .next_recv_tag ++ ;
383478    req -> my_recv_tag  =  req -> setup_info [0 ].setup_tag ;
384-     req -> setup_info [0 ].num_parts  =  parts ;
385-     req -> real_parts  =  parts ;
386-     req -> setup_info [0 ].count  =  count ;
387-     req -> real_count  =  count ;
388- 
389479
390480    req -> flags  =  (int * ) calloc (req -> real_parts , sizeof (int ));
391481
@@ -428,6 +518,13 @@ mca_part_persist_aggregated_start(size_t count, ompi_request_t** requests)
428518
429519    for (size_t  i  =  0 ; i  <  _count  &&  OMPI_SUCCESS  ==  err ; i ++ ) {
430520        mca_part_persist_aggregated_request_t  * req  =  (mca_part_persist_aggregated_request_t  * )(requests [i ]);
521+ 
522+         // reset aggregation state here 
523+         if  (MCA_PART_PERSIST_AGGREGATED_REQUEST_PSEND  ==  req -> req_type ) {
524+             mca_part_persist_aggregated_psend_request_t  * sendreq  =  (mca_part_persist_aggregated_psend_request_t  * )(req );
525+             part_persist_aggregate_regular_reset (& sendreq -> aggregation_state );
526+         }
527+ 
431528        /* First use is a special case, to support lazy initialization */ 
432529        if (false ==  req -> first_send )
433530        {
@@ -470,19 +567,31 @@ mca_part_persist_aggregated_pready(size_t min_part,
470567    size_t  i ;
471568
472569    mca_part_persist_aggregated_request_t  * req  =  (mca_part_persist_aggregated_request_t  * )(request );
570+     int  flag_value ;
473571    if (true ==  req -> initialized )
474572    {
475-         err  =  req -> persist_reqs [min_part ]-> req_start (max_part - min_part + 1 , (& (req -> persist_reqs [min_part ])));
476-         for (i  =  min_part ; i  <= max_part  &&  OMPI_SUCCESS  ==  err ; i ++ ) {
477-             req -> flags [i ] =  0 ; /* Mark partition as ready for testing */ 
478-         }
573+         flag_value  =  0 ;     /* Mark partition as ready for testing */ 
479574    }
480575    else 
481576    {
482-         for (i  =  min_part ; i  <= max_part  &&  OMPI_SUCCESS  ==  err ; i ++ ) {
483-             req -> flags [i ] =  -2 ; /* Mark partition as queued */ 
577+         flag_value  =  -2 ;    /* Mark partition as queued */ 
578+     }
579+ 
580+     mca_part_persist_aggregated_psend_request_t  * sendreq  =  (mca_part_persist_aggregated_psend_request_t  * )(request );
581+     int  internal_part_ready ;
582+     for (i  =  min_part ; i  <= max_part  &&  OMPI_SUCCESS  ==  err ; i ++ ) {
583+         part_persist_aggregate_regular_pready (& sendreq -> aggregation_state , i , & internal_part_ready );
584+ 
585+         if  (-1  !=  internal_part_ready ) {
586+             // transfer partition is ready 
587+             if (true ==  req -> initialized ) {
588+                 err  =  req -> persist_reqs [internal_part_ready ]-> req_start (1 , (& (req -> persist_reqs [internal_part_ready ])));
589+             }
590+     
591+             req -> flags [internal_part_ready ] =  flag_value ;
484592        }
485593    }
594+ 
486595    return  err ;
487596}
488597
0 commit comments