incremental backup
This commit is contained in:
parent
d998a04d47
commit
03712e07d8
8
loss.py
8
loss.py
@ -9,5 +9,13 @@ def CTCLoss(y_true, y_pred):
|
||||
input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
||||
label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
||||
|
||||
#y_pred = tf.math.log_softmax(y_pred)
|
||||
|
||||
#y_true = tf.cast(y_true, dtype = "int64")
|
||||
#y_true_sparse = tf.keras.backend.ctc_label_dense_to_sparse(y_true, label_length)
|
||||
|
||||
#print(y_true_sparse)
|
||||
|
||||
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
||||
#loss = tf.nn.ctc_loss(y_true_sparse, y_pred, input_length, label_length)
|
||||
return loss
|
||||
|
||||
6
model.py
6
model.py
@ -1,3 +1,5 @@
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.keras import layers as kl
|
||||
from tensorflow.keras import models as km
|
||||
from tensorflow.keras import losses as ks
|
||||
@ -8,6 +10,10 @@ from tensorflow import keras
|
||||
|
||||
from loss import CTCLoss
|
||||
|
||||
#g = tf.config.experimental.list_physical_devices('GPU')
|
||||
#for i in g:
|
||||
# tf.config.experimental.set_memory_growth(i, True)
|
||||
|
||||
|
||||
def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72):
|
||||
li = kl.Input((None, input_dim))
|
||||
|
||||
12
train.py
12
train.py
@ -12,9 +12,9 @@ data = pd.read_csv("/mnt/tmpfs1/LJSpeech-1.1/metadata.csv",
|
||||
header = None,
|
||||
quoting = 3,
|
||||
names = ['file_name', 'i', 'normalized_transcription'])
|
||||
s = int(len(data) // 2 * 0.90)
|
||||
s = int(len(data) * 0.90)
|
||||
train_data = data[:s]
|
||||
valid_data = data[s:len(data) // 2]
|
||||
valid_data = data[s:]
|
||||
|
||||
train_ds = to_dataset(train_data, batch_size = bs)
|
||||
valid_ds = to_dataset(valid_data, batch_size = bs)
|
||||
@ -22,13 +22,13 @@ valid_ds = to_dataset(valid_data, batch_size = bs)
|
||||
m = model(input_dim = fft_length // 2 + 1,
|
||||
output_dim = char_to_num.vocabulary_size())
|
||||
|
||||
m.load_weights('model20-latest.keras')
|
||||
ckpt1 = kc.ModelCheckpoint('model21-latest.keras',
|
||||
m.load_weights('model23-latest.keras')
|
||||
ckpt1 = kc.ModelCheckpoint('model24-latest.keras',
|
||||
monitor = 'val_loss',
|
||||
save_best_only = False,
|
||||
verbose = 1)
|
||||
|
||||
ckpt2 = kc.ModelCheckpoint('model21-best.keras',
|
||||
ckpt2 = kc.ModelCheckpoint('model24-best.keras',
|
||||
monitor = 'val_loss',
|
||||
save_best_only = True,
|
||||
verbose = 1)
|
||||
@ -36,6 +36,6 @@ ckpt2 = kc.ModelCheckpoint('model21-best.keras',
|
||||
ce1 = ce(valid_ds, m)
|
||||
|
||||
m.fit(train_ds,
|
||||
epochs = 8,
|
||||
epochs = 80,
|
||||
validation_data = valid_ds,
|
||||
callbacks = [ckpt1, ckpt2, ce1])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user