incremental backup
This commit is contained in:
parent
03712e07d8
commit
e0ea8c5387
4
model.py
4
model.py
@ -15,7 +15,7 @@ from loss import CTCLoss
|
|||||||
# tf.config.experimental.set_memory_growth(i, True)
|
# tf.config.experimental.set_memory_growth(i, True)
|
||||||
|
|
||||||
|
|
||||||
def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72):
|
def model(input_dim, output_dim, rnn_layers = 5, rnn_units = 128):
|
||||||
li = kl.Input((None, input_dim))
|
li = kl.Input((None, input_dim))
|
||||||
l1 = kl.Reshape((-1, input_dim, 1))(li)
|
l1 = kl.Reshape((-1, input_dim, 1))(li)
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72):
|
|||||||
lo = kl.Dense(output_dim + 1, activation = 'softmax')(lc2)
|
lo = kl.Dense(output_dim + 1, activation = 'softmax')(lc2)
|
||||||
|
|
||||||
m = keras.Model(li, lo)
|
m = keras.Model(li, lo)
|
||||||
m.compile(optimizer = ko.Lion(0.0004),
|
m.compile(optimizer = ko.Adam(0.0001),
|
||||||
loss = CTCLoss)
|
loss = CTCLoss)
|
||||||
|
|
||||||
return m
|
return m
|
||||||
|
|||||||
8
train.py
8
train.py
@ -22,13 +22,13 @@ valid_ds = to_dataset(valid_data, batch_size = bs)
|
|||||||
m = model(input_dim = fft_length // 2 + 1,
|
m = model(input_dim = fft_length // 2 + 1,
|
||||||
output_dim = char_to_num.vocabulary_size())
|
output_dim = char_to_num.vocabulary_size())
|
||||||
|
|
||||||
m.load_weights('model23-latest.keras')
|
m.load_weights('model40-latest.keras')
|
||||||
ckpt1 = kc.ModelCheckpoint('model24-latest.keras',
|
ckpt1 = kc.ModelCheckpoint('model41-latest.keras',
|
||||||
monitor = 'val_loss',
|
monitor = 'val_loss',
|
||||||
save_best_only = False,
|
save_best_only = False,
|
||||||
verbose = 1)
|
verbose = 1)
|
||||||
|
|
||||||
ckpt2 = kc.ModelCheckpoint('model24-best.keras',
|
ckpt2 = kc.ModelCheckpoint('model41-best.keras',
|
||||||
monitor = 'val_loss',
|
monitor = 'val_loss',
|
||||||
save_best_only = True,
|
save_best_only = True,
|
||||||
verbose = 1)
|
verbose = 1)
|
||||||
@ -36,6 +36,6 @@ ckpt2 = kc.ModelCheckpoint('model24-best.keras',
|
|||||||
ce1 = ce(valid_ds, m)
|
ce1 = ce(valid_ds, m)
|
||||||
|
|
||||||
m.fit(train_ds,
|
m.fit(train_ds,
|
||||||
epochs = 80,
|
epochs = 40,
|
||||||
validation_data = valid_ds,
|
validation_data = valid_ds,
|
||||||
callbacks = [ckpt1, ckpt2, ce1])
|
callbacks = [ckpt1, ckpt2, ce1])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user