Skip to content

Commit

Permalink
updating character based model
Browse files Browse the repository at this point in the history
  • Loading branch information
nasavish committed Oct 24, 2018
1 parent 28b4bca commit 55aa248
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
5 changes: 5 additions & 0 deletions character_based/info_{now.strftime('%Y-%m-%d_%H-%M')}.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{'seq_length': 100}model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
32 changes: 28 additions & 4 deletions seuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,50 @@


# In[11]:

import datetime
now = datetime.datetime.now()

# define the LSTM model
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(LSTM(400, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(200))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')

with open("character_based/info_{now.strftime('%Y-%m-%d_%H-%M')}.txt", 'w+') as f:
pstr = "{'seq_length': " + str(seq_length) + '}'
modelstr =
"""
model = Sequential()
model.add(LSTM(400, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(200))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
"""
f.write(pstr)
f.write(modelstr)



print("model compiled")
# In[12]:


# define the checkpoint
filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"

filepath=f"character_based/wi-{{epoch:02d}}-{{loss:.4f}}_{now.strftime('%Y-%m-%d_%H-%M')}.h5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]


# In[13]:

print("nuke launching")
model.fit(X, y, epochs=1000, batch_size=128, callbacks=callbacks_list, verbose=1)
history = model.fit(X, y, epochs=500, batch_size=128, callbacks=callbacks_list, verbose=1)
loss_history = history.history
with open("character_based/loss_history_{now.strftime('%Y-%m-%d_%H-%M')}.txt", 'w+') as f:
f.write(str(loss_history))

0 comments on commit 55aa248

Please sign in to comment.