Skip to content

Commit

Permalink
Stride error corrected
Browse files Browse the repository at this point in the history
  • Loading branch information
U-DESKTOP-A6FGL28\nekit committed Aug 23, 2020
1 parent c170b3e commit 867a279
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
10 changes: 5 additions & 5 deletions starnet_v1_TF2.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,15 @@ def transform(self, name):

output = copy.deepcopy(image)

h, w, _ = image.shape
ith = int(h / self.stride)
itw = int(w / self.stride)
for i in range(ith):
for j in range(itw):
tile = np.expand_dims(image[self.stride*i:self.stride*i+self.window_size, self.stride*j:self.stride*j+self.window_size, :], axis = 0)
x = self.stride * i
y = self.stride * j

tile = np.expand_dims(image[x:x+self.window_size, y:y+self.window_size, :], axis = 0)
tile = (self.G(tile)[0] + 1) / 2
tile = tile[offset:offset+self.stride, offset:offset+self.stride, :]
output[self.stride*i+offset:self.stride*(i+1)+offset, self.stride*j+offset:self.stride*(j+1)+offset, :] = tile
output[x+offset:self.stride*(i+1)+offset, y+offset:self.stride*(j+1)+offset, :] = tile

output = np.clip(output, 0, 1)

Expand Down
13 changes: 8 additions & 5 deletions starnet_v1_TF2_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from PIL import Image as img\n",
"import logging\n",
"tf.get_logger().setLevel(logging.ERROR)\n",
"from starnet_v1_TF2 import StarNet"
"from starnet_v1_TF2 import StarNet\n",
"\n",
"%reload_ext autoreload\n",
"%autoreload 2"
]
},
{
Expand All @@ -22,11 +25,11 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"starnet = StarNet(mode = 'RGB', window_size = 512)\n",
"starnet = StarNet(mode = 'RGB', window_size = 512, stride = 128)\n",
"starnet.load_model('./weights', './history')"
]
},
Expand All @@ -39,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down

0 comments on commit 867a279

Please sign in to comment.