Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit support for Regression, performed major refactoring of tests, removed unused code and updated notebooks to work (again). #248

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
dc327f9
simplify mkdocs config
Mar 7, 2024
0935989
update to version 2.1.2
Mar 7, 2024
6c1755f
update precommit
Mar 7, 2024
1a4cecf
remove functionality which is done better in other libraries (Skoreca…
Mar 14, 2024
35a9401
update version & nb
Mar 14, 2024
2ae1b7f
update ruff config
Mar 14, 2024
e785c6a
downgrade ruff
Mar 14, 2024
fa4d975
downgrade ruff p2
Mar 14, 2024
7770889
revert back downgrade - its an image issue
Mar 14, 2024
71c4925
add no-cache option and update readme
Mar 14, 2024
0f474f0
add no-cache option to other as well
Mar 14, 2024
d779c61
remove shap inspector
Mar 15, 2024
54ae4ba
update documentation
Mar 17, 2024
ea72e62
remove image
Mar 17, 2024
b688a60
allow for python version 3.12 and fix the bug for upgrading to shap 0…
Mar 17, 2024
0a79441
Merge branch 'main' into add_compatibility_p312
Mar 17, 2024
3b930fc
Update pre-commit
Mar 17, 2024
92ee361
remove import
Mar 17, 2024
a37a8d8
fix dependency of shap
Mar 17, 2024
1b453b8
fix file
Mar 17, 2024
ce185be
fix for python v 3.8
Mar 17, 2024
aec0e80
removal of leftover references
Mar 17, 2024
90f2794
add explicit state setting
Mar 18, 2024
75d3fd3
another random state found to be added
Mar 18, 2024
4271cf7
Merge branch 'add_compatibility_p312' into set_random_state_explicit
Mar 18, 2024
1286eed
fix tests
Mar 18, 2024
6b71074
fix tests to more consistent standard.
Mar 18, 2024
7d9d466
Merge branch 'main' into set_random_state_explicit
Mar 18, 2024
22e17c2
major test refactor
Mar 19, 2024
1371d5d
Merge remote-tracking branch 'origin/set_random_state_explicit' into …
Mar 19, 2024
fb3a33e
fix many things
Mar 20, 2024
cdc6b88
update readme
Mar 20, 2024
108db45
update cronjob
Mar 20, 2024
351e4f9
update copyright
Mar 20, 2024
cb1b3ef
change version from 3.0.1 to 3.1.0 since the changes are a bit more t…
Mar 20, 2024
b6bf310
change cronjob
Mar 20, 2024
0be6f3f
fix nb run flag
Mar 20, 2024
4a0c9b3
remove debug file
Mar 20, 2024
fe97bf2
Merge branch 'main' into fixes_and_add_explicit_multi_and_regression
Mar 26, 2024
c1285c3
Add explicit state setting (#242)
Mar 28, 2024
77f303f
Update catboost requirement (#254)
dependabot[bot] Mar 28, 2024
bccc06e
rebase master
Mar 28, 2024
18db9d9
update version & nb
Mar 14, 2024
1e9988e
rebase
Mar 28, 2024
b576151
fix tests to more consistent standard.
Mar 18, 2024
0f2b816
major test refactor
Mar 19, 2024
036afa2
fix many thigns
Mar 28, 2024
3ad20b9
update readme
Mar 20, 2024
085c316
update cronjob
Mar 20, 2024
42577ac
update copyright
Mar 20, 2024
5475447
change version from 3.0.1 to 3.1.0 since the changes are a bit more t…
Mar 20, 2024
38edbc7
change cronjob
Mar 20, 2024
ecf4647
fix nb run flag
Mar 20, 2024
cdda30c
remove debug file
Mar 20, 2024
6d2c39a
rebase master
Mar 28, 2024
37808de
Merge branch 'fixes_and_add_explicit_multi_and_regression' of https:/…
Mar 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add explicit state setting (#242)
Set the random states explicitly. 

Tasks:

- [x] Adjust the code where "sample" does not use random_state
- [x] Adjust the test code for it
- [x] Make sure the tests use it consistently.
- [x] Look if we can remove some unnecessary checks as mentioned in this
issue: #221
  • Loading branch information
Reinier Koops authored Mar 28, 2024
commit c1285c3f9be10d8a326b3f46605fdb924ac94ff8
30 changes: 15 additions & 15 deletions probatus/feature_elimination/feature_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.base import clone, is_classifier, is_regressor
from sklearn.model_selection import check_cv
from sklearn.model_selection._search import BaseSearchCV
from loguru import logger

from probatus.utils import (
BaseFitComputePlotClass,
Expand Down Expand Up @@ -156,9 +157,8 @@ def __init__(
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set at each round of feature elimination. If it is None, the results will not be
Expand Down Expand Up @@ -395,7 +395,7 @@ def _get_feature_shap_values_per_fold(
score_val = self.scorer.scorer(clf, X_val, y_val)

# Compute SHAP values
shap_values = shap_calc(clf, X_val, verbose=self.verbose, **shap_kwargs)
shap_values = shap_calc(clf, X_val, verbose=self.verbose, random_state=self.random_state, **shap_kwargs)
return shap_values, score_train, score_val

def fit(
Expand Down Expand Up @@ -537,7 +537,7 @@ def fit(
self.min_features_to_select = 0
# This ensures that, if columns_to_keep is provided ,
# the last features remaining are only the columns_to_keep.
if self.verbose > 50:
if self.verbose > 1:
warnings.warn(f"Minimum features to select : {stopping_criteria}")

while len(current_features_set) > stopping_criteria:
Expand Down Expand Up @@ -615,8 +615,8 @@ def fit(
val_metric_mean=np.mean(scores_val),
val_metric_std=np.std(scores_val),
)
if self.verbose > 50:
print(
if self.verbose > 1:
logger.info(
f"Round: {round_number}, Current number of features: {len(current_features_set)}, "
f'Current performance: Train {self.report_df.loc[round_number]["train_metric_mean"]} '
f'+/- {self.report_df.loc[round_number]["train_metric_std"]}, CV Validation '
Expand Down Expand Up @@ -841,8 +841,8 @@ def _get_best_num_features(self, best_method, standard_error_threshold=1.0):
)

# Log shap_report for users who want to inspect / debug
if self.verbose > 50:
print(shap_report)
if self.verbose > 1:
logger.info(shap_report)

return best_num_features

Expand Down Expand Up @@ -1110,10 +1110,9 @@ def __init__(
verbose (int, optional):
Controls verbosity of the output:

- 0 - nether prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 0 - neither prints nor warnings are shown
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set at each round of feature elimination. If it is None, the results will not be
Expand Down Expand Up @@ -1210,7 +1209,8 @@ def _get_fit_params_lightGBM(
"eval_set": [(X_val, y_val)],
"callbacks": [early_stopping(self.early_stopping_rounds, first_metric_only=True)],
}
if self.verbose >= 100:

if self.verbose >= 2:
fit_params["callbacks"].append(log_evaluation(1))
else:
fit_params["callbacks"].append(log_evaluation(0))
Expand Down Expand Up @@ -1505,5 +1505,5 @@ def _get_feature_shap_values_per_fold(
score_val = self.scorer.scorer(clf, X_val, y_val)

# Compute SHAP values
shap_values = shap_calc(clf, X_val, verbose=self.verbose, **shap_kwargs)
shap_values = shap_calc(clf, X_val, verbose=self.verbose, random_state=self.random_state, **shap_kwargs)
return shap_values, score_train, score_val
16 changes: 12 additions & 4 deletions probatus/interpret/model_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ShapModelInterpreter(BaseFitComputePlotClass):
<img src="../img/model_interpret_sample.png" width="320" />
"""

def __init__(self, clf, scoring="roc_auc", verbose=0):
def __init__(self, clf, scoring="roc_auc", verbose=0, random_state=None):
"""
Initializes the class.

Expand All @@ -98,13 +98,17 @@ def __init__(self, clf, scoring="roc_auc", verbose=0):
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set for the nr of samples. If it is None, the results will not be reproducible. For
reproducible results set it to an integer.
"""
self.clf = clf
self.scorer = get_single_scorer(scoring)
self.verbose = verbose
self.random_state = random_state

def fit(
self,
Expand Down Expand Up @@ -186,6 +190,7 @@ def fit(
column_names=self.column_names,
class_names=self.class_names,
verbose=self.verbose,
random_state=self.random_state,
**shap_kwargs,
)

Expand All @@ -200,6 +205,7 @@ def fit(
column_names=self.column_names,
class_names=self.class_names,
verbose=self.verbose,
random_state=self.random_state,
**shap_kwargs,
)

Expand All @@ -212,6 +218,7 @@ def _prep_shap_related_variables(
y,
approximate=False,
verbose=0,
random_state=None,
column_names=None,
class_names=None,
**shap_kwargs,
Expand All @@ -228,6 +235,7 @@ def _prep_shap_related_variables(
X,
approximate=approximate,
verbose=verbose,
random_state=random_state,
return_explainer=True,
**shap_kwargs,
)
Expand Down
21 changes: 16 additions & 5 deletions probatus/interpret/shap_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DependencePlotter(BaseFitComputePlotClass):
<img src="../img/model_interpret_dep.png"/>
"""

def __init__(self, clf, verbose=0):
def __init__(self, clf, verbose=0, random_state=None):
"""
Initializes the class.

Expand All @@ -64,12 +64,16 @@ def __init__(self, clf, verbose=0):
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings regarding data properties are shown (excluding SHAP warnings)
- 51 - 100 - shows most important warnings, prints of the feature removal process
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set for the nr of samples. If it is None, the results will not be reproducible. For
reproducible results set it to an integer.
"""
self.clf = clf
self.verbose = verbose
self.random_state = random_state

def __repr__(self):
"""
Expand Down Expand Up @@ -113,7 +117,14 @@ def fit(self, X, y, column_names=None, class_names=None, precalc_shap=None, **sh
if self.class_names is None:
self.class_names = ["Negative Class", "Positive Class"]

self.shap_vals_df = shap_to_df(self.clf, self.X, precalc_shap=precalc_shap, verbose=self.verbose, **shap_kwargs)
self.shap_vals_df = shap_to_df(
self.clf,
self.X,
precalc_shap=precalc_shap,
verbose=self.verbose,
random_state=self.random_state,
**shap_kwargs,
)

self.fitted = True
return self
Expand Down
24 changes: 12 additions & 12 deletions probatus/sample_similarity/resemblance_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings

import matplotlib.pyplot as plt
from loguru import logger
import numpy as np
import pandas as pd
from shap import summary_plot
Expand Down Expand Up @@ -76,9 +77,8 @@ class is 'roc_auc'.
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set at each round of feature elimination. If it is None, the results will not be
Expand Down Expand Up @@ -178,8 +178,8 @@ def fit(self, X1, X2, column_names=None, class_names=None):
f"Train {self.scorer.metric_name}: {np.round(self.train_score, 3)},\n"
f"Test {self.scorer.metric_name}: {np.round(self.test_score, 3)}."
)
if self.verbose > 50:
print(f"Finished model training: \n{self.results_text}")
if self.verbose > 1:
logger.info(f"Finished model training: \n{self.results_text}")

if self.verbose > 0:
if self.train_score > self.test_score:
Expand Down Expand Up @@ -343,9 +343,8 @@ class is 'roc_auc'.
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set at each round of feature elimination. If it is None, the results will not be
Expand Down Expand Up @@ -572,9 +571,8 @@ class is 'roc_auc'.
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set at each round of feature elimination. If it is None, the results will not be
Expand Down Expand Up @@ -630,7 +628,9 @@ def fit(self, X1, X2, column_names=None, class_names=None, **shap_kwargs):
"""
super().fit(X1=X1, X2=X2, column_names=column_names, class_names=class_names)

self.shap_values_test = shap_calc(self.clf, self.X_test, verbose=self.verbose, **shap_kwargs)
self.shap_values_test = shap_calc(
self.clf, self.X_test, verbose=self.verbose, random_state=self.random_state, **shap_kwargs
)
self.report = calculate_shap_importance(self.shap_values_test, self.column_names)
return self

Expand Down
2 changes: 1 addition & 1 deletion probatus/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
assure_list_values_allowed,
)
from .plots import plot_distributions_of_feature
from .interface import BaseFitComputeClass, BaseFitComputePlotClass
from .base_class_interface import BaseFitComputeClass, BaseFitComputePlotClass

__all__ = [
"NotFittedError",
Expand Down
11 changes: 5 additions & 6 deletions probatus/utils/arrayfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def preprocess_data(X, X_name=None, column_names=None, verbose=0):
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings regarding data properties are shown (excluding SHAP warnings)
- 51 - 100 - shows most important warnings, prints of the feature removal process
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.


Returns:
(pd.DataFrame):
Expand Down Expand Up @@ -255,9 +255,8 @@ def preprocess_labels(y, y_name=None, index=None, verbose=0):
Controls verbosity of the output:

- 0 - neither prints nor warnings are shown
- 1 - 50 - only most important warnings regarding data properties are shown (excluding SHAP warnings)
- 51 - 100 - shows most important warnings, prints of the feature removal process
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

Returns:
(pd.Series):
Expand Down
File renamed without changes.
16 changes: 10 additions & 6 deletions probatus/utils/shap_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def shap_calc(
X,
return_explainer=False,
verbose=0,
random_state=None,
sample_size=100,
approximate=False,
check_additivity=True,
Expand All @@ -54,10 +55,13 @@ def shap_calc(
verbose (int, optional):
Controls verbosity of the output:

- 0 - nether prints nor warnings are shown
- 1 - 50 - only most important warnings
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
- 0 - neither prints nor warnings are shown
- 1 - only most important warnings
- 2 - shows all prints and all warnings.

random_state (int, optional):
Random state set for the nr of samples. If it is None, the results will not be reproducible. For
reproducible results set it to an integer.

approximate (boolean):
if True uses shap approximations - less accurate, but very fast. It applies to tree-based explainers only.
Expand All @@ -82,7 +86,7 @@ def shap_calc(
)
# Suppress warnings regarding XGboost and Lightgbm models.
with warnings.catch_warnings():
if verbose <= 100:
if verbose <= 1:
warnings.simplefilter("ignore")

# For tree explainers, do not pass masker when feature_perturbation is
Expand All @@ -100,7 +104,7 @@ def shap_calc(
sample_size = int(np.ceil(X.shape[0] * 0.2))
else:
pass
mask = sample(X, sample_size)
mask = sample(X, sample_size, random_state=random_state)
explainer = Explainer(model, masker=mask, **shap_kwargs)

# For tree-explainers allow for using check_additivity and approximate arguments
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"shap>=0.43.0 ; python_version != '3.8'",
"numpy>=1.23.2",
"numba>=0.57.0",
"loguru>=0.7.2",
]

[project.urls]
Expand Down
Loading
Loading