Skip to content

Commit

Permalink
[Course] Draw data plot
Browse files Browse the repository at this point in the history
Type : feat.

Abstract:
* Draw data plot
  • Loading branch information
JackFunfia committed May 24, 2023
1 parent b9ec514 commit eca7da5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ch02/perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@

class Perceptron:
def __init__(self, eta=0.01, n_iter=50, random_state=1):
self.eta = eta
self.error_ = None
self.w_ = None
self.eta = max(0.0, min(eta, 1.0))
self.n_iter = n_iter
self.random_state = random_state

def fit(self, X, y):
rgen = np.random.RandomState(self.random_state)
self.w_ = rgen.normal(loc=0.0, scale=0.01, size=1 + X.shape[1])
self.w_ = rgen.normal(loc=0.0, scale=0.01, size=1 + X.Shape[1])
self.error_ = []

for _ in range(self.n_iter):
errors = 0
for xi, target in zip(X, y):
zip_data = zip(X, y)
for xi, target in zip_data:
print(f"{xi}, {target}")
update = self.eta * (target - self.predict(xi))
print(f"update {update}")
self.w_[1:] += update * xi
self.w_[0] += update
errors += int(update != 0)
Expand All @@ -30,5 +35,5 @@ def net_input(self, X):
return res

def predict(self, X):
res = np.Where(self.net_input(X) >= 0.0, 1, -1)
return res
res = np.where(self.net_input(X) >= 0.0, 1, -1)
return res
49 changes: 49 additions & 0 deletions ch02/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import matplotlib.pyplot as plt
import pandas as pd
import requests
import io
import os
from perceptron import Perceptron
PATH = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"

with requests.get(PATH) as response:
raw_data = response.text

# s = os.path.join("https://archive.ics.uci.edu", "ml",
# "machine-learning-databases",
# "iris", "iris.data")
s = os.path.join(PATH)

df = pd.read_csv(s)
# print(df.tail())
# string_data = io.StringIO(raw_data)
#
# df = pd.read_csv(string_data, header=None, encoding="utf-8")
#
# print(df.tail())

import matplotlib.pyplot as plot
import numpy as np

y = df.iloc[:, 4].values
y = np.where(y == "Iris-setosa", -1, 1)

X = df.iloc[:, [0, 2, 4]].values
setosa_array = X[(X[:, 2] == "Iris-setosa")]
virginica_array = X[(X[:, 2] == "Iris-virginica")]
versicolor_array = X[(X[:, 2] == "Iris-versicolor")]

plt.scatter(setosa_array[:, 0], setosa_array[:, 1],
color="red", marker="o", label="setosa")
plt.scatter(virginica_array[:, 0], virginica_array[:, 1],
color="blue", marker="x", label="virginica")
plt.scatter(versicolor_array[:, 0], versicolor_array[:, 1],
color="green", marker="^", label="versicolor")

plt.xlabel("sepal length [cm]")
plt.ylabel("petal length [cm]")
plt.legend(loc="upper left")
plt.show()

ppn = Perceptron(eta=0.1, n_iter=10)
# ppn.fit(X, y)

0 comments on commit eca7da5

Please sign in to comment.