44 lines
1.5 KiB
Python
Raw Permalink Normal View History

2025-12-08 17:10:47 +02:00
import tensorflow as tf
from tensorflow import keras
from jiwer import wer
import numpy as np
from preprocessing import *
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
results = keras.backend.ctc_decode(pred,
input_length = input_len,
greedy = True)[0][0]
output_text = []
for result in results:
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
output_text.append(result)
return output_text
class ce(keras.callbacks.Callback):
def __init__(self, dataset, model):
super().__init__()
self.dataset = dataset
self.____model = model
def on_epoch_end(self, epoch, logs = None):
predictions = []
targets = []
for batch in self.dataset:
X, y = batch
batch_predictions = self.____model.predict(X, verbose = 0)
batch_predictions = decode_batch_predictions(batch_predictions)
predictions.extend(batch_predictions)
for label in y:
label = (
tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
)
targets.append(label)
wer_score = wer(targets, predictions)
print(f"Word Error Rate: {wer_score:.4f}")
for i in np.random.randint(0, len(predictions), 10):
print(f"Target : {targets[i]}")
print(f"Prediction: {predictions[i]}")