Skip to content

Commit

Permalink
update the readme
Browse files Browse the repository at this point in the history
  • Loading branch information
ayumiymk committed Jul 13, 2019
1 parent 3d27396 commit 28f3730
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 2 deletions.
35 changes: 33 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,33 @@
# aster.pytorch
ASTER in Pytorch
# ASTER: Attentional Scene Text Recognizer with Flexible Rectification

This repository implements the ASTER in pytorch. Origin software could be found in [here](https://github.com/bgshih/aster).

## Train
```
bash scripts/stn_att_rec.sh
```

## Test
```
bash scripts/main_test_all.sh
```

## Reproduced results

| | IIIT5k | SVT | IC03 | IC13 | SVTP | CUTE |
|:-------------:|:------:|:----:|:-----:|:-----:|:-----:|:-----:|
| ASTER (L2R) | 92.67 | - | 93.72 | 90.74 | 78.76 | 76.39 |
| ASTER.Pytorch | 93.2 | 89.2 | 92.2 | 91 | 81.2 | 81.9 |
|

At present, the bidirectional attention decoder proposed in ASTER is not included in my implementation.

You can use the codes to bootstrap for your next text recognition research project.


## Data preparation

We give an example to construct your own datasets. Details please refer to `tools/create_svtp_lmdb.py`.


IMPORTANT NOTICE: Although this software is licensed under MIT, our intention is to make it free for academic research purposes. If you are going to use it in a product, we suggest you [contact us]([email protected]) regarding possible patent issues.
98 changes: 98 additions & 0 deletions lib/tools/create_svtp_lmdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re

def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True


def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
txn.put(k.encode(), v)


def _is_difficult(word):
assert isinstance(word, str)
return not re.match('^[\w]+$', word)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
for i in range(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if len(label) == 0:
continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue

imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
if lexiconList:
lexiconKey = 'lexicon-%09d' % cnt
cache[lexiconKey] = ' '.join(lexiconList[i])
if cnt % 1000 == 0:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples).encode()
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)

if __name__ == "__main__":
data_dir = '/data/mkyang/datasets/English/benchmark/svtp/'
lmdb_output_path = '/data/mkyang/datasets/English/benchmark_lmdbs_new/svt_p_645'
gt_file = os.path.join(data_dir, 'gt.txt')
image_dir = data_dir
with open(gt_file, 'r') as f:
lines = [line.strip('\n') for line in f.readlines()]

imagePathList, labelList = [], []
for i, line in enumerate(lines):
splits = line.split(' ')
image_name = splits[0]
gt_text = splits[1]
print(image_name, gt_text)
imagePathList.append(os.path.join(image_dir, image_name))
labelList.append(gt_text)

createDataset(lmdb_output_path, imagePathList, labelList)

0 comments on commit 28f3730

Please sign in to comment.