forked from Eomys/SciDataTool
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WP] Add script omp.py + test_omp for SMV
- Loading branch information
1 parent
e00b9b2
commit 5266b91
Showing
3 changed files
with
156 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from logging import getLogger | ||
|
||
import numpy as np | ||
from numpy import ndarray, concatenate, identity, arange | ||
from sklearn.linear_model import orthogonal_mp | ||
from scipy.fft import idct, idst | ||
|
||
|
||
def comp_dictionary(n: int, M: ndarray) -> ndarray: | ||
""" | ||
Construct the dictionary on which the signal is decomposed | ||
Parameter | ||
--------- | ||
M: index of the grid corresponding to the observations of the signal | ||
n: length of the grid on which the signal is undersampled | ||
Returns | ||
dictionary: concatenation of the DST and DCT's matrix | ||
""" | ||
|
||
DCT = idct(identity(n), type=2, norm='ortho', axis=0) | ||
DCT = DCT[M] | ||
DST = idst(identity(n), type=2, norm='ortho', axis=0) | ||
DST = DST[M] | ||
|
||
dictionary = concatenate([DCT,DST],axis=1) | ||
|
||
return dictionary | ||
|
||
|
||
def omp(Y: ndarray, M: ndarray, n: int, n_coefs: int=None) -> ndarray: | ||
""" | ||
Given Y of shape (len(M),n_targets), recover n_targets signals with joint sparsity of | ||
length len(M). | ||
Each signal - column of Y is the signal's observation on the support M | ||
Parameter | ||
--------- | ||
Y: ndarray (len(M),n_targets) matrix of the n_targets joint sparse signals | ||
M: index of the grid corresponding to the observations of the signals | ||
n: length of the grid on which the signal is undersampled | ||
n_coefs: passed to n_nonzero_coefs, a parameter of orthogonal_mp. | ||
If None set to 10% of n. | ||
Returns: | ||
Y_full: ndarray (n,n_targets) matrix of the recovered signals | ||
""" | ||
|
||
dictionary = comp_dictionary(n,M) | ||
|
||
sparse_decomposition = orthogonal_mp(X=dictionary,y=Y,n_nonzero_coefs=n_coefs) | ||
|
||
dictionary = comp_dictionary(n,arange(n)) | ||
|
||
Y_full = dictionary @ sparse_decomposition | ||
|
||
return Y_full | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pytest | ||
|
||
from math import floor | ||
from SciDataTool import Data1D, DataTime, DataFreq, DataND | ||
|
||
from SciDataTool.Functions.omp import omp | ||
|
||
@pytest.mark.validation | ||
def test_omp_SMV(): | ||
""" | ||
Test the recovery of a sparse undersampled signal in the SMV situation | ||
""" | ||
|
||
def f_1d(x: np.ndarray) -> np.ndarray: | ||
""" | ||
Create a 1D function with the following Fourier transform coefficients: | ||
- 2 at 0 Hz | ||
- 3 at 5 Hz | ||
- 4 at 12 Hz | ||
- 1 at 20 Hz | ||
""" | ||
|
||
return ( | ||
2 | ||
+ 3 * np.sin(5 * 2 * np.pi * x) | ||
+ 4 * np.sin(12 * 2 * np.pi * x) | ||
+ 1 * np.sin(20 * 2 * np.pi * x) | ||
) | ||
|
||
# Define a time vector | ||
n = 1000 | ||
time = Data1D(name="time", unit="s", values=np.linspace(0, 1, n)) | ||
|
||
# Compute the signal the signal | ||
signal = f_1d(time.values) | ||
field = DataTime( | ||
name="field", | ||
symbol="X", | ||
axes=[time], | ||
values=signal, | ||
unit="m" | ||
) | ||
|
||
# fix seed to avoid problem due to random non uniform sampling | ||
np.random.seed(90) | ||
|
||
# Randomly choose observations of the signal | ||
# a subset M of the time-grid | ||
K = 0.90 | ||
m = floor(K*n) | ||
M = np.random.choice(n,m, replace=False) | ||
M.sort() | ||
M = np.asarray(M) | ||
|
||
# Undersample the signal | ||
Y = signal[M] | ||
|
||
# recover the signal with the OMP | ||
Y_full = omp(Y,M,n,n_coefs=8*2) | ||
|
||
# Check that the result match the signal | ||
np.testing.assert_allclose( | ||
Y_full, | ||
signal, | ||
rtol=1e-1, | ||
atol=1.5*1e-1, | ||
) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters