incremental backup
This commit is contained in:
parent
d998a04d47
commit
03712e07d8
8
loss.py
8
loss.py
@ -9,5 +9,13 @@ def CTCLoss(y_true, y_pred):
|
|||||||
input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
||||||
label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
||||||
|
|
||||||
|
#y_pred = tf.math.log_softmax(y_pred)
|
||||||
|
|
||||||
|
#y_true = tf.cast(y_true, dtype = "int64")
|
||||||
|
#y_true_sparse = tf.keras.backend.ctc_label_dense_to_sparse(y_true, label_length)
|
||||||
|
|
||||||
|
#print(y_true_sparse)
|
||||||
|
|
||||||
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
||||||
|
#loss = tf.nn.ctc_loss(y_true_sparse, y_pred, input_length, label_length)
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
6
model.py
6
model.py
@ -1,3 +1,5 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
from tensorflow.keras import layers as kl
|
from tensorflow.keras import layers as kl
|
||||||
from tensorflow.keras import models as km
|
from tensorflow.keras import models as km
|
||||||
from tensorflow.keras import losses as ks
|
from tensorflow.keras import losses as ks
|
||||||
@ -8,6 +10,10 @@ from tensorflow import keras
|
|||||||
|
|
||||||
from loss import CTCLoss
|
from loss import CTCLoss
|
||||||
|
|
||||||
|
#g = tf.config.experimental.list_physical_devices('GPU')
|
||||||
|
#for i in g:
|
||||||
|
# tf.config.experimental.set_memory_growth(i, True)
|
||||||
|
|
||||||
|
|
||||||
def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72):
|
def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72):
|
||||||
li = kl.Input((None, input_dim))
|
li = kl.Input((None, input_dim))
|
||||||
|
|||||||
@ -38,7 +38,7 @@ def encode_single_sample(wav, label):
|
|||||||
def to_dataset(df, batch_size = 32):
|
def to_dataset(df, batch_size = 32):
|
||||||
ds = tf.data.Dataset.from_tensor_slices((list(df["file_name"]),
|
ds = tf.data.Dataset.from_tensor_slices((list(df["file_name"]),
|
||||||
list(df["normalized_transcription"])))
|
list(df["normalized_transcription"])))
|
||||||
ds = ds.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) \
|
ds = ds.map(encode_single_sample, num_parallel_calls = tf.data.AUTOTUNE) \
|
||||||
.padded_batch(batch_size) \
|
.padded_batch(batch_size) \
|
||||||
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
.prefetch(buffer_size=tf.data.AUTOTUNE)
|
||||||
|
|
||||||
|
|||||||
12
train.py
12
train.py
@ -12,9 +12,9 @@ data = pd.read_csv("/mnt/tmpfs1/LJSpeech-1.1/metadata.csv",
|
|||||||
header = None,
|
header = None,
|
||||||
quoting = 3,
|
quoting = 3,
|
||||||
names = ['file_name', 'i', 'normalized_transcription'])
|
names = ['file_name', 'i', 'normalized_transcription'])
|
||||||
s = int(len(data) // 2 * 0.90)
|
s = int(len(data) * 0.90)
|
||||||
train_data = data[:s]
|
train_data = data[:s]
|
||||||
valid_data = data[s:len(data) // 2]
|
valid_data = data[s:]
|
||||||
|
|
||||||
train_ds = to_dataset(train_data, batch_size = bs)
|
train_ds = to_dataset(train_data, batch_size = bs)
|
||||||
valid_ds = to_dataset(valid_data, batch_size = bs)
|
valid_ds = to_dataset(valid_data, batch_size = bs)
|
||||||
@ -22,13 +22,13 @@ valid_ds = to_dataset(valid_data, batch_size = bs)
|
|||||||
m = model(input_dim = fft_length // 2 + 1,
|
m = model(input_dim = fft_length // 2 + 1,
|
||||||
output_dim = char_to_num.vocabulary_size())
|
output_dim = char_to_num.vocabulary_size())
|
||||||
|
|
||||||
m.load_weights('model20-latest.keras')
|
m.load_weights('model23-latest.keras')
|
||||||
ckpt1 = kc.ModelCheckpoint('model21-latest.keras',
|
ckpt1 = kc.ModelCheckpoint('model24-latest.keras',
|
||||||
monitor = 'val_loss',
|
monitor = 'val_loss',
|
||||||
save_best_only = False,
|
save_best_only = False,
|
||||||
verbose = 1)
|
verbose = 1)
|
||||||
|
|
||||||
ckpt2 = kc.ModelCheckpoint('model21-best.keras',
|
ckpt2 = kc.ModelCheckpoint('model24-best.keras',
|
||||||
monitor = 'val_loss',
|
monitor = 'val_loss',
|
||||||
save_best_only = True,
|
save_best_only = True,
|
||||||
verbose = 1)
|
verbose = 1)
|
||||||
@ -36,6 +36,6 @@ ckpt2 = kc.ModelCheckpoint('model21-best.keras',
|
|||||||
ce1 = ce(valid_ds, m)
|
ce1 = ce(valid_ds, m)
|
||||||
|
|
||||||
m.fit(train_ds,
|
m.fit(train_ds,
|
||||||
epochs = 8,
|
epochs = 80,
|
||||||
validation_data = valid_ds,
|
validation_data = valid_ds,
|
||||||
callbacks = [ckpt1, ckpt2, ce1])
|
callbacks = [ckpt1, ckpt2, ce1])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user