Skip to content

Commit

Permalink
Merge pull request #31 from NevesLucas/documentation_dml
Browse files Browse the repository at this point in the history
Update documentation, add Conda environments
  • Loading branch information
Nikita Misiura authored Sep 12, 2022
2 parents 7079110 + d6ba2f2 commit a6ba2c3
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea/
70 changes: 45 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
<h1>Update 2</h1>
Updated documentation for using Tensorflow-directml on windows for broad support on any modern gpu with sufficient memory.

<hr>
<h1>Update</h1>

Pushed a new implementation os starnet in TF2.x. The whole implementation is in one file *starnet_v1_TF2.py*.
Expand Down Expand Up @@ -127,20 +131,55 @@ I left one training image to show organization of folders my code expects. Insid
Throughout the code all input and output images I use are 8 bits per channel **tif** images.
This code should read some other image formats (like jpeg, 16bit tiff, etc), but I did not check all of them.

<center><h1>Prerequisites</h1></center>
<center><h1>Prerequisites and installation Guide</h1></center>

for all environments, using conda is strongly encouraged, installation instructions assume a conda install of either Anaconda python or miniconda:
- https://docs.conda.io/en/latest/miniconda.html

## Windows (New!)

On windows we can now run starnet on GPU on any modern graphics card! (yes AMD and Intel included)

### Prerequisites

Windows 10 Version 1709, 64-bit (Build 16299 or higher) or Windows 11 Version 21H2, 64-bit (Build 22000 or higher)

### Installation

Python and Tensorflow, preferably Tensorflow-GPU if you have an NVidia GPU. In this case you will also need CUDA and CuDNN libraries.
Once anaconda is installed, you can open an "anaconda powershell prompt" to proceed.

I tested it in Python 3.6.3 (Anaconda) + TensorFlow-GPU 1.4.0
We use the environment config file provided to configure and install all the dependencies:

Environment: I used Win 10 + Cygwin
#### With GPU support (Windows):
```
conda env create -f environment-windows.yml
```
#### With CUDA support (linux or windows):
```
conda env create -f environment-lnx-cuda.yml
```
#### CPU only(Mac, Linux, Windows):
```
conda env create -f environment-cpu.yml
```
### Post installation
Initialize the environment with:
```
conda activate starnet
```
And you're ready to go!

GPU was NVidia GeForce 840M 2Gb, compute capability 5.0, CUDA version 9.1

Originally tested on:
- Win 10 + Cygwin
- NVidia GeForce 840M 2Gb, compute capability 5.0, CUDA version 9.1

Windows general GPU support tested on:
- Win 10 12H1
- AMD RX 6800-XT 16GB

<center><h1>Usage</h1></center>

Modes of use:
python.exe -u starnet.py transform <input_image> - The most probable use. This command will transform
input image (namely will remove stars) and will
Expand Down Expand Up @@ -168,9 +207,6 @@ GPU was NVidia GeForce 840M 2Gb, compute capability 5.0, CUDA version 9.1
By default output will be in './test' sub-folder.





<center><h1>Couple more examples</h1></center>

More examples can be found <a href="https://www.astrobin.com/339099/0/">here</a>.
Expand Down Expand Up @@ -200,22 +236,6 @@ This whole thing works as command line program, which means that there is no gra
run it in a console using some text commands (like ones you see above) and it outputs text (and writes image files
of course!).

The way you get on board with all this will depend on your OS and is hard to squeeze instructions into few words but
very briefly:

1. For MacOS/Linux it is pretty simple. You should already have console up with python installed. The only thing you will need
is to use pip to install tensorflow and probably few other packages via:

pip install tensorflow
pip install numpy
pip install <whatever is missing>

2. For Windows it is a bit trickier because console is unusable to say the least. You can use it, but I prefer Cygwin. Next, you
need to install python. I think installing Anaconda for that is by far the best option. After you got up the console of your choice
and installed anaconda you use its native pip to install tensorflow (see above) and you should be ready to go.

<b>The whole installation should not need more than installing few software packages and typing few command lines!</b>

2. Where exactly do I put weights of the network?

All the files you download should be in one folder: all the files with extension .py (starnet.py, train.py, transform.py, etc.) should
Expand Down
13 changes: 13 additions & 0 deletions environment-cpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# To run: conda env create -f environment-lnx-cpu.yml
name: starnet
channels:
- conda-forge
dependencies:
- python=3.9
- pip
- Pillow
- ipython
- matplotlib
- numpy
- tifffile
- tensorflow-cpu
13 changes: 13 additions & 0 deletions environment-lnx-cuda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# To run: conda env create -f environment-lnx-cuda.yml
name: starnet
channels:
- conda-forge
dependencies:
- python=3.9
- pip
- Pillow
- ipython
- matplotlib
- numpy
- tifffile
- tensorflow
15 changes: 15 additions & 0 deletions environment-windows.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# To run: conda env create -f environment-windows.yml
name: starnet
channels:
- conda-forge
dependencies:
- python=3.9
- pip
- Pillow
- matplotlib
- numpy
- ipython
- tifffile
- pip:
- tensorflow-cpu
- tensorflow-directml-plugin
6 changes: 3 additions & 3 deletions starnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
learning_rates = [0.000002, 0.0000005] # Learning rates: the first is for generator, the second is for discriminator. Usually they are the same,
# but who knows. In the beginning of training suitable values are about 0.0002 and then can be made smaller
# as the model gets better.
stride = 128 # Stride value for image transformation. The smaller it gets, the less artefacts you get in the final image,
stride = 64 # Stride value for image transformation. The smaller it gets, the less artefacts you get in the final image,
# but the more time it takes to transform an image. 100 looks about optimal for now.

if len(sys.argv) > 1:
Expand Down Expand Up @@ -103,8 +103,8 @@
else:
start = time.time()
import transform
transform.transform(image = sys.argv[2],
stride = stride)
transform.transform(imageName = sys.argv[2],
stride = stride)
stop = time.time()
t = float((stop - start) / 60)
if t > 60.0:
Expand Down
77 changes: 40 additions & 37 deletions transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,17 @@
import tensorflow as tf
from PIL import Image as img
import matplotlib.pyplot as plt
from scipy.misc import toimage
import matplotlib
import sys
import time
import model
import starnet_utils
import tifffile as tiff

WINDOW_SIZE = 256 # Size of the image fed to net. Do not change until you know what you are doing! Default is 256
# and changing this will force you to train the net anew.
def transform(imageName, stride):

def transform(image, stride):

# placeholders for tensorflow
X = tf.placeholder(tf.float32, shape = [None, WINDOW_SIZE, WINDOW_SIZE, 3], name = "X")
Y = tf.placeholder(tf.float32, shape = [None, WINDOW_SIZE, WINDOW_SIZE, 3], name = "Y")
Expand All @@ -52,14 +51,23 @@ def transform(image, stride):

# read input image
print("Opening input image...")
input = np.array(img.open(image), dtype = np.float32)
print("Done!")

# rescale to [-1, 1]
input /= 255
# backup to use for mask
backup = np.copy(input)
input = input * 2 - 1
data = tiff.imread(imageName)
if len(data.shape) > 3:
layer = input("Tiff has %d layers, please enter layer to process: " % data.shape[0])
layer = int(layer)
data = data[layer]

input_dtype = data.dtype
if input_dtype == 'uint16':
image = (data / 255.0 / 255.0).astype('float32')
elif input_dtype == 'uint8':
image = (data / 255.0).astype('float32')
else:
raise ValueError('Unknown image dtype:', data.dtype)

if image.shape[2] == 4:
print("Input image has 4 channels. Removing Alpha-Channel")
image = image[:, :, [0, 1, 2]]


# now some tricky magic
Expand All @@ -69,7 +77,7 @@ def transform(image, stride):

# get size of the image and calculate numbers of iterations needed to transform it
# given stride and taking into account that we will pad it a bit later (+1 comes from that)
h, w, _ = input.shape
h, w, _ = image.shape
ith = int(h / stride) + 1
itw = int(w / stride) + 1

Expand All @@ -78,18 +86,18 @@ def transform(image, stride):
dw = itw * stride - w

# pad image using parts of the image itself and values calculated above
input = np.concatenate((input, input[(h - dh) :, :, :]), axis = 0)
input = np.concatenate((input, input[:, (w - dw) :, :]), axis = 1)
image = np.concatenate((image, image[(h - dh) :, :, :]), axis = 0)
image = np.concatenate((image, image[:, (w - dw) :, :]), axis = 1)

# get image size again and pad to allow offsets on all four sides of the image
h, w, _ = input.shape
input = np.concatenate((input, input[(h - offset) :, :, :]), axis = 0)
input = np.concatenate((input[: offset, :, :], input), axis = 0)
input = np.concatenate((input, input[:, (w - offset) :, :]), axis = 1)
input = np.concatenate((input[:, : offset, :], input), axis = 1)
h, w, _ = image.shape
image = np.concatenate((image, image[(h - offset) :, :, :]), axis = 0)
image = np.concatenate((image[: offset, :, :], image), axis = 0)
image = np.concatenate((image, image[:, (w - offset) :, :]), axis = 1)
image = np.concatenate((image[:, : offset, :], image), axis = 1)

# copy input image to output
output = np.copy(input)
output = np.copy(image)

# helper array just to add fourth dimension to net input
tmp = np.zeros((1, WINDOW_SIZE, WINDOW_SIZE, 3), dtype = np.float)
Expand All @@ -103,30 +111,25 @@ def transform(image, stride):
y = stride * j

# write piece of input image to tmp array
tmp[0] = input[x : x + WINDOW_SIZE, y : y + WINDOW_SIZE, :]
tmp[0] = image[x : x + WINDOW_SIZE, y : y + WINDOW_SIZE, :]

# transform
result = sess.run(outputs, feed_dict = {X:tmp})

# write transformed array to output
output[x + offset : x + stride + offset, y + offset: y + stride + offset, :] = result[0, offset : stride + offset, offset : stride + offset, :]
print("Transforming input image... Done!")

# rescale back to [0, 1]
output = (output + 1) / 2
output = np.clip(output, 0, 1)

# leave only necessary part, without pads added earlier
output = output[offset : - (offset + dh), offset : - (offset + dw), :]
output = output[offset:-(offset + dh), offset:-(offset + dw), :]

print("Saving output image...")
toimage(output * 255, cmin = 0, cmax = 255).save('./' + image + '_starless.tif')
print("Done!")

print("Saving mask...")
# mask showing areas that were changed significantly
mask = (((backup * 255).astype(np.int_) - (output * 255).astype(np.int_)) > 25).astype(np.int_)
mask = mask.max(axis = 2, keepdims = True)
mask = np.concatenate((mask, mask, mask), axis = 2)
toimage(mask * 255, cmin = 0, cmax = 255).save('./' + image + '_mask.tif')
print("Done!")

if input_dtype == 'uint8':
tiff.imsave('./' + imageName + '_starless.tif', (output * 255).astype('uint8'))
else:
tiff.imsave('./' + imageName + '_starless.tif', (output * 255 * 255).astype('uint16'))

print("Done!")

0 comments on commit a6ba2c3

Please sign in to comment.