Skip to content

Commit

Permalink
rename tot_batch_length to mini_batch_length for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Jun 7, 2018
1 parent 3e02dfd commit 7497726
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
26 changes: 19 additions & 7 deletions code/ch16/ch16.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1631,19 +1631,19 @@
],
"source": [
"def reshape_data(sequence, batch_size, num_steps):\n",
" tot_batch_length = batch_size * num_steps\n",
" num_batches = int(len(sequence) / tot_batch_length)\n",
" if num_batches*tot_batch_length + 1 > len(sequence):\n",
" mini_batch_length = batch_size * num_steps\n",
" num_batches = int(len(sequence) / mini_batch_length)\n",
" if num_batches*minitot_batch_length + 1 > len(sequence):\n",
" num_batches = num_batches - 1\n",
" ## Truncate the sequence at the end to get rid of \n",
" ## remaining charcaters that do not make a full batch\n",
" x = sequence[0 : num_batches*tot_batch_length]\n",
" y = sequence[1 : num_batches*tot_batch_length + 1]\n",
" x = sequence[0 : num_batches*mini_batch_length]\n",
" y = sequence[1 : num_batches*mini_batch_length + 1]\n",
" ## Split x & y into a list batches of sequences: \n",
" x_batch_splits = np.split(x, batch_size)\n",
" y_batch_splits = np.split(y, batch_size)\n",
" ## Stack the batches together\n",
" ## batch_size x tot_batch_length\n",
" ## batch_size x mini_batch_length\n",
" x = np.stack(x_batch_splits)\n",
" y = np.stack(y_batch_splits)\n",
" \n",
Expand Down Expand Up @@ -2866,7 +2866,19 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
"version": "3.6.4"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
Expand Down
12 changes: 6 additions & 6 deletions code/ch16/ch16.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,19 +495,19 @@ def predict(self, X_data, return_proba=False):


def reshape_data(sequence, batch_size, num_steps):
tot_batch_length = batch_size * num_steps
num_batches = int(len(sequence) / tot_batch_length)
if num_batches*tot_batch_length + 1 > len(sequence):
mini_batch_length = batch_size * num_steps
num_batches = int(len(sequence) / mini_batch_length)
if num_batches*mini_batch_length + 1 > len(sequence):
num_batches = num_batches - 1
## Truncate the sequence at the end to get rid of
## remaining charcaters that do not make a full batch
x = sequence[0 : num_batches*tot_batch_length]
y = sequence[1 : num_batches*tot_batch_length + 1]
x = sequence[0 : num_batches*mini_batch_length]
y = sequence[1 : num_batches*mini_batch_length + 1]
## Split x & y into a list batches of sequences:
x_batch_splits = np.split(x, batch_size)
y_batch_splits = np.split(y, batch_size)
## Stack the batches together
## batch_size x tot_batch_length
## batch_size x mini_batch_length
x = np.stack(x_batch_splits)
y = np.stack(y_batch_splits)

Expand Down

0 comments on commit 7497726

Please sign in to comment.