Skip to content

Commit

Permalink
[WP] Add script omp.py + test_omp for SMV
Browse files Browse the repository at this point in the history
  • Loading branch information
GauthierGar committed Aug 3, 2021
1 parent e00b9b2 commit 5266b91
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 1 deletion.
75 changes: 75 additions & 0 deletions SciDataTool/Functions/omp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from logging import getLogger

import numpy as np
from numpy import ndarray, concatenate, identity, arange
from sklearn.linear_model import orthogonal_mp
from scipy.fft import idct, idst


def comp_dictionary(n: int, M: ndarray) -> ndarray:
"""
Construct the dictionary on which the signal is decomposed
Parameter
---------
M: index of the grid corresponding to the observations of the signal
n: length of the grid on which the signal is undersampled
Returns
dictionary: concatenation of the DST and DCT's matrix
"""

DCT = idct(identity(n), type=2, norm='ortho', axis=0)
DCT = DCT[M]
DST = idst(identity(n), type=2, norm='ortho', axis=0)
DST = DST[M]

dictionary = concatenate([DCT,DST],axis=1)

return dictionary


def omp(Y: ndarray, M: ndarray, n: int, n_coefs: int=None) -> ndarray:
"""
Given Y of shape (len(M),n_targets), recover n_targets signals with joint sparsity of
length len(M).
Each signal - column of Y is the signal's observation on the support M
Parameter
---------
Y: ndarray (len(M),n_targets) matrix of the n_targets joint sparse signals
M: index of the grid corresponding to the observations of the signals
n: length of the grid on which the signal is undersampled
n_coefs: passed to n_nonzero_coefs, a parameter of orthogonal_mp.
If None set to 10% of n.
Returns:
Y_full: ndarray (n,n_targets) matrix of the recovered signals
"""

dictionary = comp_dictionary(n,M)

sparse_decomposition = orthogonal_mp(X=dictionary,y=Y,n_nonzero_coefs=n_coefs)

dictionary = comp_dictionary(n,arange(n))

Y_full = dictionary @ sparse_decomposition

return Y_full













80 changes: 80 additions & 0 deletions Tests/Validation/test_omp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest

from math import floor
from SciDataTool import Data1D, DataTime, DataFreq, DataND

from SciDataTool.Functions.omp import omp

@pytest.mark.validation
def test_omp_SMV():
"""
Test the recovery of a sparse undersampled signal in the SMV situation
"""

def f_1d(x: np.ndarray) -> np.ndarray:
"""
Create a 1D function with the following Fourier transform coefficients:
- 2 at 0 Hz
- 3 at 5 Hz
- 4 at 12 Hz
- 1 at 20 Hz
"""

return (
2
+ 3 * np.sin(5 * 2 * np.pi * x)
+ 4 * np.sin(12 * 2 * np.pi * x)
+ 1 * np.sin(20 * 2 * np.pi * x)
)

# Define a time vector
n = 1000
time = Data1D(name="time", unit="s", values=np.linspace(0, 1, n))

# Compute the signal the signal
signal = f_1d(time.values)
field = DataTime(
name="field",
symbol="X",
axes=[time],
values=signal,
unit="m"
)

# fix seed to avoid problem due to random non uniform sampling
np.random.seed(90)

# Randomly choose observations of the signal
# a subset M of the time-grid
K = 0.90
m = floor(K*n)
M = np.random.choice(n,m, replace=False)
M.sort()
M = np.asarray(M)

# Undersample the signal
Y = signal[M]

# recover the signal with the OMP
Y_full = omp(Y,M,n,n_coefs=8*2)

# Check that the result match the signal
np.testing.assert_allclose(
Y_full,
signal,
rtol=1e-1,
atol=1.5*1e-1,
)











2 changes: 1 addition & 1 deletion Tutorials/tuto_4_Omp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"source": [
"# cvxpy has an Apache license but the ECOS solver (GPL 3.0) is one of it's requirement.\r\n",
"# import cvxpy as cp\r\n",
Expand Down

0 comments on commit 5266b91

Please sign in to comment.