Skip to content

Commit

Permalink
sklearn > 0.23 compatible transform
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Mar 29, 2022
1 parent c937292 commit 53bbbd3
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 72 deletions.
178 changes: 107 additions & 71 deletions ch10/ch10.ipynb

Large diffs are not rendered by default.

69 changes: 68 additions & 1 deletion ch10/ch10.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
from mlxtend.plotting import heatmap
from sklearn.preprocessing import StandardScaler
from distutils.version import LooseVersion
import sklearn
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RANSACRegressor
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -181,6 +183,10 @@







class LinearRegressionGD(object):

def __init__(self, eta=0.001, n_iter=20):
Expand Down Expand Up @@ -265,9 +271,15 @@ def lin_regplot(X, y, model):





num_rooms_std = sc_x.transform(np.array([[5.0]]))
price_std = lr.predict(num_rooms_std)
print("Price in $1000s: %.3f" % sc_y.inverse_transform(price_std))

if LooseVersion(sklearn.__version__) >= LooseVersion('0.23.0'):
print("Price in $1000s: %.3f" % sc_y.inverse_transform(price_std[:, np.newaxis]).flatten())
else:
print("Price in $1000s: %.3f" % sc_y.inverse_transform(price_std))



Expand Down Expand Up @@ -725,3 +737,58 @@ def lin_regplot(X, y, model):


























































0 comments on commit 53bbbd3

Please sign in to comment.