Skip to content

Commit

Permalink
ENH: use temporary "/vsimem" file handling, simplify _create_dst_data…
Browse files Browse the repository at this point in the history
…source
  • Loading branch information
kmuehlbauer committed Jun 11, 2019
1 parent f5af37e commit a81e454
Showing 1 changed file with 42 additions and 37 deletions.
79 changes: 42 additions & 37 deletions wradlib/zonalstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import matplotlib.patches as patches
from osgeo import gdal, ogr
import warnings
import tempfile
import os

from .io import open_vector, gdal_create_dataset, write_raster_dataset
from .georef import (numpy_to_ogr, ogr_add_feature, ogr_copy_layer,
Expand Down Expand Up @@ -204,7 +206,9 @@ def _check_src(self, src):
- transforming source grid points/polygons to ogr.geometries
on ogr.layer
"""
ogr_src = gdal_create_dataset('Memory', 'out',
tmpfile = tempfile.NamedTemporaryFile(mode='w+b').name
ogr_src = gdal_create_dataset('ESRI Shapefile',
os.path.join('/vsimem', tmpfile),
gdal_type=gdal.OF_VECTOR)

src = np.array(src)
Expand Down Expand Up @@ -254,10 +258,10 @@ def load_vector(self, filename, source=0, driver='ESRI Shapefile'):
driver : string
driver string
"""

self.ds = gdal_create_dataset('Memory', self._name,
tmpfile = tempfile.NamedTemporaryFile(mode='w+b').name
self.ds = gdal_create_dataset('ESRI Shapefile',
os.path.join('/vsimem', tmpfile),
gdal_type=gdal.OF_VECTOR)

# get input file handles
ds_in, tmp_lyr = open_vector(filename, driver=driver, layer=source)

Expand Down Expand Up @@ -310,7 +314,6 @@ def dump_raster(self, filename, driver='GTiff', attr=None,

band = ds_out.GetRasterBand(1)
band.FlushCache()
print("Rasterize layers")
if attr is not None:
gdal.RasterizeLayer(ds_out, [1], layer, burn_values=[0],
options=["ATTRIBUTE={0}".format(attr),
Expand Down Expand Up @@ -528,8 +531,15 @@ def _create_dst_datasource(self):
ds_mem : object
gdal.Dataset object
"""

# create mem-mapped temp file dataset
tmpfile = tempfile.NamedTemporaryFile(mode='w+b').name
ds_out = gdal_create_dataset('ESRI Shapefile',
os.path.join('/vsimem', tmpfile),
gdal_type=gdal.OF_VECTOR)

# create intermediate mem dataset
ds_mem = gdal_create_dataset('Memory', 'dst',
ds_mem = gdal_create_dataset('Memory', 'out',
gdal_type=gdal.OF_VECTOR)

# get src geometry layer
Expand All @@ -538,43 +548,38 @@ def _create_dst_datasource(self):
src_lyr.SetSpatialFilter(None)
geom_type = src_lyr.GetGeomType()

# create temp Buffer layer (time consuming)
ds_tmp = gdal_create_dataset('Memory', 'tmp',
gdal_type=gdal.OF_VECTOR)
ogr_copy_layer(self.trg.ds, 0, ds_tmp)
tmp_trg_lyr = ds_tmp.GetLayer()
# get trg geometry layer
trg_lyr = self.trg.ds.GetLayerByName('trg')
trg_lyr.ResetReading()
trg_lyr.SetSpatialFilter(None)

for i in range(tmp_trg_lyr.GetFeatureCount()):
feat = tmp_trg_lyr.GetFeature(i)
feat.SetGeometryDirectly(feat.GetGeometryRef().
Buffer(self._buffer))
tmp_trg_lyr.SetFeature(feat)
# buffer handling (time consuming)
if self._buffer > 0:
for i in range(trg_lyr.GetFeatureCount()):
feat = trg_lyr.GetFeature(i)
feat.SetGeometryDirectly(feat.GetGeometryRef().
Buffer(self._buffer))
trg_lyr.SetFeature(feat)

# get target layer, iterate over polygons and calculate intersections
tmp_trg_lyr.ResetReading()
# reset target layer
trg_lyr.ResetReading()

# create tmp dest layer
self.tmp_lyr = ogr_create_layer(ds_mem, 'dst', srs=self._srs,
geom_type=geom_type)

try:
tmp_trg_lyr.Intersection(src_lyr, self.tmp_lyr,
options=['SKIP_FAILURES=YES',
'INPUT_PREFIX=trg_',
'METHOD_PREFIX=src_',
'PROMOTE_TO_MULTI=YES',
'PRETEST_CONTAINMENT=YES'],
callback=gdal.TermProgress)
except RuntimeError:
# Catch RuntimeError that was reported on gdal 1.11.1
# on Windows systems
tmp_trg_lyr.Intersection(src_lyr, self.tmp_lyr,
options=['SKIP_FAILURES=YES',
'INPUT_PREFIX=trg_',
'METHOD_PREFIX=src_',
'PROMOTE_TO_MULTI=YES',
'PRETEST_CONTAINMENT=YES'])

return ds_mem
trg_lyr.Intersection(src_lyr, self.tmp_lyr,
options=['SKIP_FAILURES=YES',
'INPUT_PREFIX=trg_',
'METHOD_PREFIX=src_',
'PROMOTE_TO_MULTI=YES',
'USE_PREPARED_GEOMETRIES=YES',
'PRETEST_CONTAINMENT=YES'],
callback=gdal.TermProgress)

ogr_copy_layer(ds_mem, 0, ds_out)

return ds_out

def dump_vector(self, filename, driver='ESRI Shapefile', remove=True):
"""Output source/target grid points/polygons to ESRI_Shapefile
Expand Down

0 comments on commit a81e454

Please sign in to comment.