Skip to content

Commit

Permalink
Data generation for AE and NDSI unsupervised pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
n9Mtq4 committed Sep 13, 2020
1 parent 42ebb97 commit acf00e9
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 7 deletions.
91 changes: 91 additions & 0 deletions TileGen/src/main/kotlin/com/n9mtq4/gss/tilegen/AETileGen.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package com.n9mtq4.gss.tilegen

import java.awt.image.BufferedImage
import java.io.File
import javax.imageio.ImageIO

/**
* Created by will on 9/8/20 at 9:36 PM.
*
* @author Will "n9Mtq4" Bresnahan
*/

val DATA_AE_INPUT_DIR = File("../data/ls78_ae")
val TILE_AE_OUTPUT_DIR = File("../data/tiles/ls78_ae")

fun main() {

TILE_AE_OUTPUT_DIR.mkdirs()

// val tileAllImgs = listOf("LC08_L1TP_044006_20150711_20170227_01_T1")
// val tileAllImgs = FULL_IMGS
val tileAllImgs = emptyArray<String>()

val imageDirs = DATA_AE_INPUT_DIR.listFiles { dir, name -> File(dir, name).isDirectory }!!

imageDirs.forEach { processAEImageDir(it, it.name in tileAllImgs) }

}

/**
* Processes a directory of a landsat image.
* Finds tiles from the truth image. Saves the bands and truth tiles in the output directory.
* */
fun processAEImageDir(imageDir: File, tileAll: Boolean) {

val imgName = imageDir.name
val otherBands = imageDir.listFiles()!!

val firstBand = ImageIO.read(otherBands.first())
val allTiles = findAllTileLocs(firstBand, tileAll)
.toList()

val tiles = allTiles
.shuffled()
.take(allTiles.size / 4)

// extract other bands
for (bandFile in otherBands) {

val bandNum = getBandFromName(bandFile.name)

if (bandNum !in INCLUDE_BANDS) continue

val bandImage = ImageIO.read(bandFile)

// if (bandImage.width != groundTruth.width || bandImage.height != groundTruth.height) {
// println("${imageDir.name} B$bandNum doesn't match width and height. Band width = ${bandImage.width}, Band height = ${bandImage.height}, truth width = ${groundTruth.width}, truth height = ${groundTruth.height}")
// continue
// }

val bandDir = File(TILE_AE_OUTPUT_DIR, "B$bandNum")
bandDir.mkdirs()

// try mean and std match from http://www.fmwconcepts.com/imagemagick/index.php matchimage
val adjBandImage = stretchToMinMax(bandImage)
// val adjBandImage = bandImage

tiles
.map { tile -> tile to extractTile(adjBandImage, tile) }
.map { (tile, img) -> tile to listOf(img) }
.forEach { (tile, imgs) -> writeTileGroup(imgName, "", bandDir, tile, imgs) }

}

}

fun findAllTileLocs(firstBand: BufferedImage, tileAll: Boolean) = sequence {

for (y in 0 until (firstBand.height - 2 * TILE_Y_STRIDE) step TILE_Y_STRIDE) {
for (x in 0 until (firstBand.width - 2 * TILE_X_STRIDE) step TILE_X_STRIDE) {

if (tileNotBlack(firstBand, x, y, TILE_WIDTH, TILE_HEIGHT)) {
yield(Tile(x, y, TILE_WIDTH, TILE_HEIGHT))
}

}
}

return@sequence

}
55 changes: 55 additions & 0 deletions TileGen/src/main/kotlin/com/n9mtq4/gss/tilegen/ImgMath.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.n9mtq4.gss.tilegen

import java.awt.image.BufferedImage
import kotlin.math.roundToInt

/**
* Created by will on 9/10/20 at 5:38 PM.
*
* @author Will "n9Mtq4" Bresnahan
*/

inline fun forPixels(width: Int, height: Int, body: (Int, Int) -> Unit) {
for (y in 0 until height) {
for (x in 0 until width) {
body(x, y)
}
}
}

inline fun imgMath(i1: BufferedImage, i2: BufferedImage, band: Int = 0, func: (Int, Int) -> Double): Array<DoubleArray> {

val outputRaster = Array(i1.width) { DoubleArray(i1.height) { 0.0 } }

forPixels(i1.width, i1.height) { x, y ->

val i1Val = i1.raster.getSample(x, y, band)
val i2Val = i2.raster.getSample(x, y, band)

outputRaster[x][y] = func(i1Val, i2Val)

}

return outputRaster

}

fun scaleRasterTo255(raster: Array<DoubleArray>, min: Double? = null, max: Double? = null) {

val minVal = min ?: d2RasterMin(raster)
val maxVal = max ?: d2RasterMax(raster)
val scale = 255.0 / (maxVal - minVal)

forPixels(raster.size, raster[0].size) { x, y ->
raster[x][y] = ((raster[x][y] - minVal) * scale).coerceIn(0.0, 255.0)
}

}

fun d2RasterMax(raster: Array<DoubleArray>): Double {
return raster.map { it.max()!! }.max()!!
}

fun d2RasterMin(raster: Array<DoubleArray>): Double {
return raster.map { it.min()!! }.min()!!
}
51 changes: 51 additions & 0 deletions TileGen/src/main/kotlin/com/n9mtq4/gss/tilegen/NDSITileGen.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.n9mtq4.gss.tilegen

import java.io.File

/**
* Created by will on 9/10/20 at 9:04 PM.
*
* @author Will "n9Mtq4" Bresnahan
*/
fun main() {

val inputDir = File("../data/ls8_pretrain")
val outputDir = File("../data/tiles/ls8_pretrain1")


val bandDirs = arrayOf(1, 2, 3, 4, 5, 6, 7)
.map { File(outputDir, "B$it") }
bandDirs.forEach { it.mkdirs() }


val imgDirs = inputDir.listFiles()!!

imgDirs.forEach { ndsiProcImgDir(it, outputDir) }

}

fun ndsiProcImgDir(imgDir: File, outputDir: File) {

val imgFiles = imgDir.listFiles()!!.toList()

val mbImg = SatImg(imgFiles)

val tiles = findAllTileLocs(mbImg.firstBand(), true)
.toList()

mbImg.bands.forEach { band, img ->
tiles.map { tile -> tile to extractTile(img, tile) }
.map { (tile, img) -> tile to listOf(img) }
.forEach { (tile, imgs) -> writeTileGroup(imgDir.name, "", File(outputDir, "B$band"), tile, imgs) }
}

val ndsi = mbImg.ndsi()

val truthDir = File(outputDir, "truth")
truthDir.mkdirs()
tiles
.map { tile -> tile to extractTile(ndsi, tile) }
.map { (tile, img) -> tile to listOf(img) }
.forEach { (tile, imgs) -> writeTileGroup(imgDir.name, "", truthDir, tile, imgs) }

}
101 changes: 101 additions & 0 deletions TileGen/src/main/kotlin/com/n9mtq4/gss/tilegen/SatImg.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package com.n9mtq4.gss.tilegen

import java.awt.image.BufferedImage
import java.io.File
import javax.imageio.ImageIO

/**
* Created by will on 9/10/20 at 5:18 PM.
*
* @author Will "n9Mtq4" Bresnahan
*/
class SatImg(val bandsFiles: List<File>, val landSat: Int = 8) {

val bands = HashMap<Int, BufferedImage>()

init {

bandsFiles
.map { getBandFromName(it.name) to ImageIO.read(it) }
.forEach { (bandNum, img) -> bands[bandNum] = img }

}

fun firstBand() = bands.values.first()

operator fun get(bandNum: Int): BufferedImage = bands[bandNum]!!

fun ndsi(): BufferedImage {

val greenBand = when (landSat) {
8 -> 3
else -> -1
}
val swirBand = when(landSat) {
8 -> 6
else -> -1
}

val greenImg = this[greenBand]
val swirImg = this[swirBand]

// calculate ndsi
val ndsiRaster = imgMath(greenImg, swirImg) { green, swir ->
(green - swir).toDouble() / (green + swir + 1).toDouble()
}

scaleRasterTo255(ndsiRaster, min = -0.4, max = 1.0)

// set ndwi pixel values to that of the raster
val ndsi = BufferedImage(greenImg.width, greenImg.height, greenImg.type)
forPixels(ndsi.width, ndsi.height) { x, y ->
ndsi.raster.setSample(x, y, 0, ndsiRaster[x][y].toInt())
// ndsi.raster.setSample(x, y, 1, ndsiRaster[x][y].toInt())
// ndsi.raster.setSample(x, y, 2, ndsiRaster[x][y].toInt())
}

return ndsi

}

fun ndwi(): BufferedImage {

val greenBand = when (landSat) {
8 -> 3
else -> -1
}
val nirBand = when (landSat) {
8 -> 5
7 -> 4
else -> -1
}
val swirBand = when(landSat) {
8 -> 6
7 -> 5
else -> -1
}

val greenImg = this[greenBand]
val nirImg = this[nirBand]
val swirImg = this[swirBand]

// calculate ndwi
val ndwiRaster = imgMath(greenImg, nirImg) { green, nir ->
(green - nir).toDouble() / (green + nir + 1).toDouble()
}

scaleRasterTo255(ndwiRaster, min = -0.3, max = 0.2)

// set ndwi pixel values to that of the raster
val ndwi = BufferedImage(nirImg.width, nirImg.height, nirImg.type)
forPixels(ndwi.width, ndwi.height) { x, y ->
ndwi.raster.setSample(x, y, 0, ndwiRaster[x][y].toInt())
// ndwi.raster.setSample(x, y, 1, ndwiRaster[x][y].toInt())
// ndwi.raster.setSample(x, y, 2, ndwiRaster[x][y].toInt())
}

return ndwi

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import kotlin.math.roundToInt
*
* @author Will "n9Mtq4" Bresnahan
*/
fun stretchToMinMax(img: BufferedImage): BufferedImage {
fun stretchToMinMax(img: BufferedImage, minThresh: Int = 3): BufferedImage {

val newImg = BufferedImage(img.width, img.height, BufferedImage.TYPE_BYTE_GRAY)

Expand All @@ -20,7 +20,7 @@ fun stretchToMinMax(img: BufferedImage): BufferedImage {

val grey = img.raster.getSample(x, y, 0)
@Suppress("ConvertTwoComparisonsToRangeCheck")
if (grey < min && grey >= 3) min = grey // min will always be 0 with black rotation border, so make it at least 3
if (grey < min && grey >= minThresh) min = grey // min will always be 0 with black rotation border, so make it at least 3
if (grey > max) max = grey

}
Expand Down
10 changes: 5 additions & 5 deletions TileGen/src/main/kotlin/com/n9mtq4/gss/tilegen/TileGen.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ const val TILE_Y_STRIDE = TILE_HEIGHT / 2

const val NO_TRANSFORMS = false

val INCLUDE_BANDS = intArrayOf(1, 2, 3, 4, 5, 7)
val INCLUDE_BANDS = intArrayOf(1, 2, 3, 4, 5, 6, 7)

val DATA_INPUT_DIR = File("../data/ls78")
val TILE_OUTPUT_DIR = File("../data/tiles/ls78")
val DATA_INPUT_DIR = File("../data/ls8")
val TILE_OUTPUT_DIR = File("../data/tiles/ls8")

val FULL_IMGS = arrayOf(
"LC08_L1TP_045005_20190814_20190820_01_T1",
Expand Down Expand Up @@ -92,8 +92,8 @@ fun processImageDir(imageDir: File, tileAll: Boolean) {
bandDir.mkdirs()

// try mean and std match from http://www.fmwconcepts.com/imagemagick/index.php matchimage
val adjBandImage = stretchToMinMax(bandImage)
// val adjBandImage = bandImage
// val adjBandImage = stretchToMinMax(bandImage)
val adjBandImage = bandImage

tiles
.map { tile -> tile to extractTile(adjBandImage, tile) }
Expand Down

0 comments on commit acf00e9

Please sign in to comment.