Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
SebGue committed Jan 6, 2021
2 parents 3785391 + eb09655 commit cb28ce5
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 1 deletion.
16 changes: 16 additions & 0 deletions SciDataTool/Classes/DataND.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@
except ImportError as error:
get_axes = error

try:
from ..Methods.DataND.get_data_along import get_data_along
except ImportError as error:
get_data_along = error


from numpy import array, array_equal
from ._check import InitUnKnowClassError
Expand Down Expand Up @@ -317,6 +322,17 @@ class DataND(Data):
)
else:
get_axes = get_axes
# cf Methods.DataND.get_data_along
if isinstance(get_data_along, ImportError):
get_data_along = property(
fget=lambda x: raise_(
ImportError(
"Can't use DataND method get_data_along: " + str(get_data_along)
)
)
)
else:
get_data_along = get_data_along
# save and copy methods are available in all object
save = save
copy = copy
Expand Down
1 change: 1 addition & 0 deletions SciDataTool/Generator/ClassesRef/DataND.csv
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ is_real,,To indicate if the signal is real (use only positive frequencies),,bool
,,,,,,,,,,,_set_values,,,,
,,,,,,,,,,,has_period,,,,
,,,,,,,,,,,get_axes,,,,
,,,,,,,,,,,get_data_along,,,,
69 changes: 69 additions & 0 deletions SciDataTool/Methods/DataND/get_data_along.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
from SciDataTool import Data1D
from SciDataTool.Functions import AxisError, axes_dict, rev_axes_dict


def get_data_along(self, *args, unit="SI", is_norm=False, axis_data=[]):
"""Returns the sliced or interpolated version of the data, using conversions and symmetries if needed.
Parameters
----------
self: Data
a Data object
*args: list of strings
List of axes requested by the user, their units and values (optional)
unit: str
Unit requested by the user ("SI" by default)
is_norm: bool
Boolean indicating if the field must be normalized (False by default)
axis_data: list
list of ndarray corresponding to user-input data
Returns
-------
a DataND object
"""

# Dynamic import to avoid loop
module = __import__("SciDataTool.Classes.DataND", fromlist=["DataND"])
DataND = getattr(module, "DataND")

results = self.get_along(*args)
values = results.pop(self.symbol)
del results["axes_dict_other"]
del results["axes_list"]
Axes = []
for axis_name in results.keys():
if len(results[axis_name]) > 1:
for axis in self.axes:
if axis.name == axis_name:
name = axis.name
is_components = axis.is_components
axis_values = results[axis_name]
unit = axis.unit
elif axis_name in axes_dict:
if axes_dict[axis_name][0] == axis.name:
name = axis_name
is_components = axis.is_components
axis_values = results[axis_name]
unit = axes_dict[axis_name][2]
elif axis_name in rev_axes_dict:
if rev_axes_dict[axis_name][0] == axis.name:
name = axis_name
is_components = axis.is_components
axis_values = results[axis_name]
unit = rev_axes_dict[axis_name][2]
Axes.append(
Data1D(
name=name,
unit=unit,
values=axis_values,
is_components=is_components,
)
)
return DataND(
name=self.name,
unit=self.unit,
symbol=self.symbol,
axes=Axes,
values=values,
is_real=self.is_real,
)
84 changes: 83 additions & 1 deletion Tests/Validation/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,88 @@
from numpy.fft import rfftn, irfftn, fftshift, fftn, ifftshift, ifftn


@pytest.mark.validation
def test_fft2_remove_periodicity():
f = 50
Nt_tot = 16
Na_tot = 20
time = np.linspace(0, 1 / (2 * f), Nt_tot, endpoint=False)
Time = Data1D(name="time", unit="s", values=time, symmetries={"antiperiod": 4})
angle = np.linspace(0, 2 * np.pi, Na_tot, endpoint=False)
Angle = Data1D(name="angle", unit="rad", values=angle, symmetries={"period": 4})

field = np.random.random((Nt_tot, Na_tot))

Field = DataTime(
name="field",
symbol="X",
axes=[Time, Angle],
values=field,
unit="m",
)

angle_new = Angle.get_values(is_oneperiod=True, is_antiperiod=False)

time_new = Time.get_values(
is_oneperiod=False,
is_antiperiod=False,
)

# Load magnetic flux
field_new = Field.get_along(
"time=axis_data",
"angle=axis_data",
axis_data={"time": time_new, "angle": angle_new},
)["X"]

Time2 = Data1D(name="time", unit="s", values=time_new, symmetries={"period": 2})
Angle2 = Data1D(name="angle", unit="rad", values=angle_new)

Field_new = DataTime(
name="field",
symbol="X",
axes=[Time2, Angle2],
values=field_new,
unit="m",
)

result_fft = Field_new.get_along("freqs", "wavenumber")
X_test = result_fft["X"]
freqs = result_fft["freqs"]
wavenumber = result_fft["wavenumber"]

Nr = len(wavenumber)
Nf = len(freqs)

# Check the FFT2 reconstruction of the new object
field_ift = np.zeros((len(time_new), len(angle_new)))
Xangle, Xtime = np.meshgrid(angle_new, time_new)

for ir in range(Nr):
r = wavenumber[ir]
for ifrq in range(Nf):
fit = freqs[ifrq]
field_ift = field_ift + abs(X_test[ifrq, ir]) * np.cos(
(2 * np.pi * fit * Xtime + r * Xangle + np.angle(X_test[ifrq, ir]))
)

assert_array_almost_equal(field_ift, field_new)

# Compare with the initial field
field_ift = np.zeros((len(time), len(angle)))
Xangle, Xtime = np.meshgrid(angle, time)

for ir in range(Nr):
r = wavenumber[ir]
for ifrq in range(Nf):
fit = freqs[ifrq]
field_ift = field_ift + abs(X_test[ifrq, ir]) * np.cos(
(2 * np.pi * fit * Xtime + r * Xangle + np.angle(X_test[ifrq, ir]))
)

assert_array_almost_equal(field_ift, field)


@pytest.mark.validation
def test_compare_rfft_fft_irfft_ifft():

Expand Down Expand Up @@ -731,4 +813,4 @@ def test_fft2_anti_period_random():
Field_FT = Field_FT.freq_to_time()
assert_array_almost_equal(
Field_FT.get_along("angle", "time")["X"], Field.get_along("angle", "time")["X"]
)
)

0 comments on commit cb28ce5

Please sign in to comment.