Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
U-DESKTOP-A6FGL28\nekit committed Aug 17, 2020
1 parent 8a9e3d5 commit a8af5d0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 36 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
<h1>Update</h1>

Pushed a new implementation os starnet in TF2.x. The whole implementation is in one file starnet_v1_TF2.py.

I also created a few Jupyter notebooks for ease of use:

1. starnet_v1_TF2_transform.ipynb - loads and transforms an image.
2. starnet_v1_TF2.ipynb - more detailed example that loads a model and shows how to train it (really simple as well I think).


**StarNet** is a neural network that can remove stars from images in one simple step leaving only background.

More technically it is a convolutional residual net with encoder-decoder architecture and with L1, Adversarial and Perceptual losses.
Expand Down
32 changes: 0 additions & 32 deletions starnet_v1_TF2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,6 @@ def _ramp(self, x):
return tf.clip_by_value(x, 0, 1)

def _augmentator(self, o, s):
# rotate
if np.random.rand() < 0.33:
r = np.random.randint(360)
pad = int(self.window_size / 2)
bc = np.random.rand()

if np.random.rand() < 0.50:
resample = img.BICUBIC
elif np.random.rand() < 0.75:
resample = img.BILINEAR
else:
resample = img.NEAREST

Xtmp = copy.copy(o)
Ytmp = copy.copy(s)

if np.random.rand() < 0.90:
Xtmp = np.pad(Xtmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'reflect')
Ytmp = np.pad(Ytmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'reflect')
else:
Xtmp = np.pad(Xtmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'constant', constant_values = bc)
Ytmp = np.pad(Ytmp, ((pad, pad), (pad, pad), (0, 0)), mode = 'constant', constant_values = bc)

Xtmp = img.fromarray(np.uint8(Xtmp * 255))
Ytmp = img.fromarray(np.uint8(Ytmp * 255))

Xtmp = Xtmp.rotate(r, resample = resample)
Ytmp = Ytmp.rotate(r, resample = resample)

o = np.array(Xtmp)[pad:-pad, pad:-pad] / 255
s = np.array(Ytmp)[pad:-pad, pad:-pad] / 255

# flip horizontally
if np.random.rand() < 0.50:
o = np.flip(o, axis = 1)
Expand Down
10 changes: 6 additions & 4 deletions wherearemyweights.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
To download pre-trained weights, follow this link (~500Mb)
To download pre-trained weights, follow the links below:

https://www.dropbox.com/s/atcs42ox4n99w96/starnet_weights.zip?dl=0
For TF1.x implementation:

And extract all the content into root folder (folder with this file).
https://www.dropbox.com/s/atcs42ox4n99w96/starnet_weights.zip?dl=0

For TF2.x implementation:

https://www.dropbox.com/s/lcgn5gvnxpo27s5/starnet_weights2.zip?dl=0

Or you might also use lfs (if you know what that means, otherwise ignore this).
And extract all the content into root folder (folder with this file).

0 comments on commit a8af5d0

Please sign in to comment.