commit d998a04d47bb4094afa27d7d69809522607e1735 Author: hasslesstech Date: Mon Dec 8 17:10:47 2025 +0200 initial commit (pre ctc) diff --git a/cc.py b/cc.py new file mode 100644 index 0000000..b4edecf --- /dev/null +++ b/cc.py @@ -0,0 +1,43 @@ +import tensorflow as tf +from tensorflow import keras +from jiwer import wer + +import numpy as np + +from preprocessing import * + +def decode_batch_predictions(pred): + input_len = np.ones(pred.shape[0]) * pred.shape[1] + results = keras.backend.ctc_decode(pred, + input_length = input_len, + greedy = True)[0][0] + output_text = [] + for result in results: + result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8") + output_text.append(result) + return output_text + +class ce(keras.callbacks.Callback): + def __init__(self, dataset, model): + super().__init__() + self.dataset = dataset + self.____model = model + + def on_epoch_end(self, epoch, logs = None): + predictions = [] + targets = [] + for batch in self.dataset: + X, y = batch + batch_predictions = self.____model.predict(X, verbose = 0) + batch_predictions = decode_batch_predictions(batch_predictions) + predictions.extend(batch_predictions) + for label in y: + label = ( + tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") + ) + targets.append(label) + wer_score = wer(targets, predictions) + print(f"Word Error Rate: {wer_score:.4f}") + for i in np.random.randint(0, len(predictions), 10): + print(f"Target : {targets[i]}") + print(f"Prediction: {predictions[i]}") diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..682b166 --- /dev/null +++ b/loss.py @@ -0,0 +1,13 @@ +import tensorflow as tf +from tensorflow import keras + +def CTCLoss(y_true, y_pred): + batch_len = tf.cast(tf.shape(y_true)[0], dtype = "int64") + input_length = tf.cast(tf.shape(y_pred)[1], dtype = "int64") + label_length = tf.cast(tf.shape(y_true)[1], dtype = "int64") + + input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64") + label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64") + + loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length) + return loss diff --git a/model.py b/model.py new file mode 100644 index 0000000..8324f4f --- /dev/null +++ b/model.py @@ -0,0 +1,58 @@ +from tensorflow.keras import layers as kl +from tensorflow.keras import models as km +from tensorflow.keras import losses as ks +from tensorflow.keras import optimizers as ko +from tensorflow.keras import callbacks as kc + +from tensorflow import keras + +from loss import CTCLoss + + +def model(input_dim, output_dim, rnn_layers = 3, rnn_units = 72): + li = kl.Input((None, input_dim)) + l1 = kl.Reshape((-1, input_dim, 1))(li) + + l2 = kl.Conv2D(32, + kernel_size = [11, 41], + strides = [2, 2], + padding = 'same', + use_bias = False)(l1) + l3 = kl.BatchNormalization()(l2) + l4 = kl.ReLU()(l3) + + l5 = kl.Conv2D(32, + kernel_size = [11, 21], + strides = [1, 2], + padding = 'same', + use_bias = False)(l4) + l6 = kl.BatchNormalization()(l5) + l7 = kl.ReLU()(l6) + + lb = kl.Reshape((-1, l7.shape[-2] * l7.shape[-1]))(l7) + + for i in range(rnn_layers): + i += 1 + + r = kl.GRU(rnn_units, + activation = 'tanh', + recurrent_activation = 'sigmoid', + use_bias = True, + return_sequences = True, + reset_after = True) + + lb = kl.Bidirectional(r, + merge_mode = 'concat')(lb) + + if i < rnn_layers: + lb = kl.Dropout(rate=0.5)(lb) + + lc1 = kl.Dense(rnn_units * 2, activation = 'relu')(lb) + lc2 = kl.Dropout(0.5)(lc1) + lo = kl.Dense(output_dim + 1, activation = 'softmax')(lc2) + + m = keras.Model(li, lo) + m.compile(optimizer = ko.Lion(0.0004), + loss = CTCLoss) + + return m diff --git a/preprocessing.py b/preprocessing.py new file mode 100644 index 0000000..7f892f7 --- /dev/null +++ b/preprocessing.py @@ -0,0 +1,45 @@ +from tensorflow import keras +import tensorflow as tf + +characters = list("abcdefghijklmnopqrstuvwxyz'?! ") + +char_to_num = keras.layers.StringLookup(vocabulary = characters, + oov_token = "") + +num_to_char = keras.layers.StringLookup(vocabulary = char_to_num.get_vocabulary(), + oov_token = "", + invert = True) + +frame_length = 256 +frame_step = 160 +fft_length = 384 + +wavs = '/mnt/tmpfs1/LJSpeech-1.1/wavs/' + +def encode_single_sample(wav, label): + file = tf.io.read_file(wavs + wav + ".wav") + audio, _ = tf.audio.decode_wav(file) + audio = tf.squeeze(audio, axis = -1) + audio = tf.cast(audio, tf.float32) + spectrogram = tf.signal.stft(audio, + frame_length = frame_length, + frame_step = frame_step, + fft_length = fft_length) + spectrogram = tf.abs(spectrogram) + spectrogram = tf.math.pow(spectrogram, 0.5) + means = tf.math.reduce_mean(spectrogram, 1, keepdims=True) + stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True) + spectrogram = (spectrogram - means) / (stddevs + 1e-10) + label = tf.strings.lower(label) + label = tf.strings.unicode_split(label, input_encoding="UTF-8") + label = char_to_num(label) + return spectrogram, label + +def to_dataset(df, batch_size = 32): + ds = tf.data.Dataset.from_tensor_slices((list(df["file_name"]), + list(df["normalized_transcription"]))) + ds = ds.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) \ + .padded_batch(batch_size) \ + .prefetch(buffer_size=tf.data.AUTOTUNE) + + return ds diff --git a/train.py b/train.py new file mode 100644 index 0000000..295ce61 --- /dev/null +++ b/train.py @@ -0,0 +1,41 @@ +from model import * +from preprocessing import * + +import pandas as pd + +from cc import ce + +bs = 32 + +data = pd.read_csv("/mnt/tmpfs1/LJSpeech-1.1/metadata.csv", + sep = '|', + header = None, + quoting = 3, + names = ['file_name', 'i', 'normalized_transcription']) +s = int(len(data) // 2 * 0.90) +train_data = data[:s] +valid_data = data[s:len(data) // 2] + +train_ds = to_dataset(train_data, batch_size = bs) +valid_ds = to_dataset(valid_data, batch_size = bs) + +m = model(input_dim = fft_length // 2 + 1, + output_dim = char_to_num.vocabulary_size()) + +m.load_weights('model20-latest.keras') +ckpt1 = kc.ModelCheckpoint('model21-latest.keras', + monitor = 'val_loss', + save_best_only = False, + verbose = 1) + +ckpt2 = kc.ModelCheckpoint('model21-best.keras', + monitor = 'val_loss', + save_best_only = True, + verbose = 1) + +ce1 = ce(valid_ds, m) + +m.fit(train_ds, + epochs = 8, + validation_data = valid_ds, + callbacks = [ckpt1, ckpt2, ce1])