@@ -121,19 +121,65 @@ def test_set_paramfile_changeparams_scalar_errors_given_list(self):
121121 with self .assertRaisesRegex (RuntimeError , "Incorrect N dims" ):
122122 sp .main ()
123123
124- def test_set_paramfile_changeparam_1d_errors_given_scalar (self ):
125- """Test that set_paramfile errors if given a scalar for a 1-d parameter"""
124+ def test_set_paramfile_changeparam_1d_given_scalar (self ):
125+ """
126+ Test that set_paramfile works correctly if given a scalar for a 1-d parameter. We want it
127+ to set all members of the 1d array to the given scalar.
128+ """
126129 output_path = os .path .join (self .tempdir , "output.nc" )
127130 sys .argv = [
128131 "set_paramfile" ,
129132 "-i" ,
130133 PARAMFILE ,
131134 "-o" ,
132135 output_path ,
133- "xl=0.724 " ,
136+ "mxmat=1987 " ,
134137 ]
135- with self .assertRaisesRegex (RuntimeError , "Incorrect N dims" ):
136- sp .main ()
138+ sp .main ()
139+ self .assertTrue (os .path .exists (output_path ))
140+ ds_in = open_paramfile (PARAMFILE )
141+ ds_out = open_paramfile (output_path )
142+
143+ for var in ds_in .variables :
144+ # Check that all variables/coords are equal except the ones we changed, which should be
145+ # set to what we asked
146+ if var == "mxmat" :
147+ self .assertTrue (np .all (ds_out [var ].values == 1987 ))
148+ else :
149+ self .assertTrue (are_paramfile_dataarrays_identical (ds_in [var ], ds_out [var ]))
150+
151+ def test_set_paramfile_changeparam_1d_given_scalar_and_pftlist (self ):
152+ """
153+ Test that set_paramfile works correctly if given a scalar for a 1-d parameter. We want it
154+ to set all members of the 1d array to the given scalar. As
155+ test_set_paramfile_changeparam_1d_given_scalar, but here we give a pft list.
156+ """
157+ output_path = os .path .join (self .tempdir , "output.nc" )
158+ sys .argv = [
159+ "set_paramfile" ,
160+ "-i" ,
161+ PARAMFILE ,
162+ "-o" ,
163+ output_path ,
164+ "-p" ,
165+ "temperate_corn,irrigated_temperate_corn" ,
166+ "mxmat=1987" ,
167+ ]
168+ sp .main ()
169+ self .assertTrue (os .path .exists (output_path ))
170+ ds_in = open_paramfile (PARAMFILE )
171+ ds_out = open_paramfile (output_path )
172+
173+ for var in ds_in .variables :
174+ # Check that all variables/coords are equal except the ones we changed, which should be
175+ # set to what we asked
176+ if var == "mxmat" :
177+ # First, check that they weren't 1987 before
178+ self .assertFalse (np .any (ds_in [var ].values [17 :18 ] == 1987 ))
179+ # Now check that they are 1987
180+ self .assertTrue (np .all (ds_out [var ].values [17 :18 ] == 1987 ))
181+ else :
182+ self .assertTrue (are_paramfile_dataarrays_identical (ds_in [var ], ds_out [var ]))
137183
138184 def test_set_paramfile_changeparams_scalar_double (self ):
139185 """Test that set_paramfile can copy to a new file with some scalar double params changed"""
0 commit comments