Skip to content

Commit

Permalink
[BC] fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
helene-t committed Jun 1, 2022
1 parent 70d4ca7 commit 340a96a
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 32 deletions.
50 changes: 29 additions & 21 deletions SciDataTool/Functions/fft_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,32 +383,40 @@ def comp_ifftn(values, axes_requ_list, is_real=True, axes_list=[]):
axis_obj = axes_list[
[axis.name for axis in axes_list].index(axis.name)
]
operation = axis.name + "_to_" + fft_dict[axis.name]
else:
axis_obj = axes_list[
[axis.name for axis in axes_list].index(fft_dict[axis.name])
]
operation = None
freqs = axis_obj.get_values(
is_smallestperiod=True,
operation=operation,
freqs = comp_fft_freqs(
axis.input_data, axis.name == "time", is_real
)
if len(axis.corr_values) == 1 or (
len(axis.corr_values) > 1
and (
not is_uniform(axis.corr_values)
or (
len(freqs) != len(axis.corr_values)
and not isin(freqs, axis.corr_values).all()
)
or (
len(freqs) == len(axis.corr_values)
and not allclose(
freqs,
axis.corr_values,
rtol=1e-5,
atol=1e-8,
equal_nan=False,
if "period" in axis_obj.symmetries:
if axis.name != "time":
freqs = freqs * axis_obj.symmetries["period"]
elif "antiperiod" in axis_obj.symmetries:
if axis.name != "time":
freqs = freqs * axis_obj.symmetries["antiperiod"] / 2
# If already one non uniform axis, use NUDFT
if (
axes_dict_non_uniform
or len(axis.corr_values) == 1
or (
len(axis.corr_values) > 1
and (
not is_uniform(axis.corr_values)
or (
len(freqs) != len(axis.corr_values)
and not isin(freqs, axis.corr_values).all()
)
or (
len(freqs) == len(axis.corr_values)
and not allclose(
freqs,
axis.corr_values,
rtol=1e-5,
atol=1e-8,
equal_nan=False,
)
)
)
)
Expand Down
7 changes: 4 additions & 3 deletions SciDataTool/Methods/DataND/export_along.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ def export_along(

# Rest of file: first axis + matrix
if len(Ydatas) == 1:
# Transpose if 1D array
field = np.array(Ydatas[0]).T
field = np.array(Ydatas[0])
else:
field = np.array(Ydatas).T
field = np.array(Ydatas)
if field.shape[0] != len(Xdata[0]):
field = field.T
matrix = format_matrix(
np.column_stack((np.array(Xdata[0]).T, field)).astype("str"),
CHAR_LIST,
Expand Down
12 changes: 9 additions & 3 deletions Tests/Validation/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_export_2D():
Field.export_along("time", "angle{°}", save_path=save_validation_path)
assert isfile(join(save_validation_path, "B_r_Data.csv"))
Field.export_along(
"time=1",
"angle{°}",
"time=1",
save_path=save_validation_path,
file_name="B_r_Data_sliced",
)
Expand All @@ -46,14 +46,19 @@ def test_export_3D():
)

Field.export_along(
"time=1",
"angle{°}",
"time=1",
save_path=save_validation_path,
file_name="B_r_Data3D_sliced",
)
assert isfile(join(save_validation_path, "B_r_Data3D_sliced.csv"))
Field.export_along(
"time", "angle{°}", "z", save_path=save_validation_path, is_multiple_files=True
"time",
"angle{°}",
"z",
save_path=save_validation_path,
is_multiple_files=True,
is_2D=False,
)
assert isfile(join(save_validation_path, "B_r_Data_z0.0.csv"))
assert isfile(join(save_validation_path, "B_r_Data_z1.0.csv"))
Expand All @@ -62,6 +67,7 @@ def test_export_3D():
save_path=save_validation_path,
is_multiple_files=True,
file_name="B_r_withoutargs",
is_2D=False,
)
assert isfile(join(save_validation_path, "B_r_withoutargs_z0.0.csv"))
assert isfile(join(save_validation_path, "B_r_withoutargs_z1.0.csv"))
Expand Down
6 changes: 3 additions & 3 deletions Tests/Validation/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ def test_fft2_interp():


@pytest.mark.validation
def test_fft1d_non_uniform(per_a=2, is_apera=True, is_add_zero_freq=True):
def test_fft1d_non_uniform(per_a=2, is_apera=True, is_add_zero_freq=False):
"""check non uniform fft1d
TODO: solve bug for a single frequency vector"""
# %%
Expand Down Expand Up @@ -962,6 +962,6 @@ def test_fft1d_non_uniform(per_a=2, is_apera=True, is_add_zero_freq=True):

if __name__ == "__main__":
# test_ifft2d_period()
# test_fft1d_non_uniform(is_add_zero_freq=True)
test_fft1d_non_uniform(is_add_zero_freq=False)
# test_fft1d_non_uniform(is_add_zero_freq=False)
test_fft2_interp()
# test_fft2_interp()
4 changes: 2 additions & 2 deletions Tests/Validation/test_filter_spectral_leakage.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,6 @@ def test_filter_spectral_leakage_vectorfield():


if __name__ == "__main__":
test_filter_spectral_leakage_1d()
test_filter_spectral_leakage_2d()
# test_filter_spectral_leakage_1d()
# test_filter_spectral_leakage_2d()
test_filter_spectral_leakage_vectorfield()
4 changes: 4 additions & 0 deletions Tests/Validation/test_ndft.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,7 @@ def f_1d(x: np.ndarray) -> np.ndarray:
result_inudft["X"].real, f_1d(time_vect_non_unif), decimal=0
)
assert np.allclose(result_inudft["X"].real, f_1d(time_vect_non_unif), rtol=1e-1)


if __name__ == "__main__":
test_nudft_2d()

0 comments on commit 340a96a

Please sign in to comment.