Skip to content

Commit 08f25e1

Browse files
committed
update tests to run in paralell. do gymnastics for warnings
1 parent 6326702 commit 08f25e1

File tree

1 file changed

+44
-21
lines changed

1 file changed

+44
-21
lines changed

pyuvsim/tests/test_uvsim.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,7 @@ def test_run_mpierr():
918918
pyuvsim.run_uvdata_uvsim(UVData(), ['beamlist'])
919919

920920

921+
@pytest.mark.parallel(2)
921922
@pytest.mark.parametrize("order", [("bda",), ("baseline", "time"), ("ant2", "time")])
922923
def test_ordering(uvdata_two_redundant_bls_triangle_sources, order):
923924
pytest.importorskip('mpi4py')
@@ -931,41 +932,62 @@ def test_ordering(uvdata_two_redundant_bls_triangle_sources, order):
931932
beam_dict=beam_dict,
932933
catalog=sky_model,
933934
)
934-
assert out_uv.blt_order == order
935-
assert out_uv.blt_order == uvdata_linear.blt_order
935+
rank = pyuvsim.mpi.get_rank()
936+
if rank == 0:
937+
print(rank, out_uv)
938+
assert out_uv.blt_order == order
939+
assert out_uv.blt_order == uvdata_linear.blt_order
936940

937-
uvdata_linear.data_array = out_uv.data_array
941+
uvdata_linear.data_array = out_uv.data_array
938942

939-
uvdata_linear.reorder_blts(order="time", minor_order="baseline", conj_convention="ant1<ant2")
943+
uvdata_linear.reorder_blts(
944+
order="time", minor_order="baseline", conj_convention="ant1<ant2"
945+
)
940946

941-
assert np.allclose(
942-
uvdata_linear.get_data((0, 1)), uvdata_linear.get_data((1, 2))
943-
)
944-
assert not np.allclose(
945-
uvdata_linear.get_data((0, 1)), uvdata_linear.get_data((0, 2))
946-
)
947+
assert np.allclose(
948+
uvdata_linear.get_data((0, 1)), uvdata_linear.get_data((1, 2))
949+
)
950+
assert not np.allclose(
951+
uvdata_linear.get_data((0, 1)), uvdata_linear.get_data((0, 2))
952+
)
947953

948954

955+
@pytest.mark.parallel(2)
949956
@pytest.mark.parametrize("order", [("bda",), ("baseline", "time"), ("ant2", "time")])
950957
def test_order_warning(uvdata_two_redundant_bls_triangle_sources, order):
951958
pytest.importorskip('mpi4py')
959+
# need to get the mpi initialized
960+
# now that simulations require at least 2 PUs
961+
pyuvsim.mpi.start_mpi()
962+
rank = pyuvsim.mpi.get_rank()
952963
uvdata_linear, beam_list, beam_dict, sky_model = uvdata_two_redundant_bls_triangle_sources
953964

954965
uvdata_linear.reorder_blts(*order)
955966
# delete the order like we forgot to set it
956967
uvdata_linear.blt_order = None
957-
with uvtest.check_warnings(
958-
UserWarning, match="The parameter `blt_order` could not be identified."
959-
):
968+
if rank == 0:
969+
with uvtest.check_warnings(
970+
UserWarning, match="The parameter `blt_order` could not be identified."
971+
):
972+
973+
out_uv = pyuvsim.uvsim.run_uvdata_uvsim(
974+
input_uv=uvdata_linear.copy(),
975+
beam_list=beam_list,
976+
beam_dict=beam_dict,
977+
catalog=sky_model,
978+
)
979+
980+
assert out_uv.blt_order == ("time", "baseline")
981+
else:
960982
out_uv = pyuvsim.uvsim.run_uvdata_uvsim(
961983
input_uv=uvdata_linear.copy(),
962984
beam_list=beam_list,
963985
beam_dict=beam_dict,
964986
catalog=sky_model,
965987
)
966-
assert out_uv.blt_order == ("time", "baseline")
967988

968989

990+
@pytest.mark.parallel(2)
969991
def test_nblts_not_square(uvdata_two_redundant_bls_triangle_sources):
970992
pytest.importorskip('mpi4py')
971993
uvdata_linear, beam_list, beam_dict, sky_model = uvdata_two_redundant_bls_triangle_sources
@@ -977,7 +999,7 @@ def test_nblts_not_square(uvdata_two_redundant_bls_triangle_sources):
977999
indices = np.nonzero(
9781000
uvdata_linear.baseline_array == uvdata_linear.antnums_to_baseline(0, 2)
9791001
)[0]
980-
print(uvdata_linear.antnums_to_baseline(0, 2), indices)
1002+
9811003
# discard half of them
9821004
indices = indices[::2]
9831005
blt_inds = np.delete(np.arange(uvdata_linear.Nblts), indices)
@@ -991,12 +1013,13 @@ def test_nblts_not_square(uvdata_two_redundant_bls_triangle_sources):
9911013
beam_dict=beam_dict,
9921014
catalog=sky_model,
9931015
)
994-
995-
assert np.allclose(
996-
out_uv.get_data((0, 1)), out_uv.get_data((1, 2))
997-
)
998-
# make sure (0, 2) has fewer times
999-
assert out_uv.get_data((0, 2)).shape == (out_uv.Ntimes // 2, out_uv.Nfreqs, out_uv.Npols)
1016+
rank = pyuvsim.mpi.get_rank()
1017+
if rank == 0 :
1018+
assert np.allclose(
1019+
out_uv.get_data((0, 1)), out_uv.get_data((1, 2))
1020+
)
1021+
# make sure (0, 2) has fewer times
1022+
assert out_uv.get_data((0, 2)).shape == (out_uv.Ntimes // 2, out_uv.Nfreqs, out_uv.Npols)
10001023

10011024

10021025
def test_tqdm_import_error():

0 commit comments

Comments
 (0)