2025-12-08 17:10:47 +02:00
|
|
|
import tensorflow as tf
|
|
|
|
|
from tensorflow import keras
|
|
|
|
|
|
|
|
|
|
def CTCLoss(y_true, y_pred):
|
|
|
|
|
batch_len = tf.cast(tf.shape(y_true)[0], dtype = "int64")
|
|
|
|
|
input_length = tf.cast(tf.shape(y_pred)[1], dtype = "int64")
|
|
|
|
|
label_length = tf.cast(tf.shape(y_true)[1], dtype = "int64")
|
|
|
|
|
|
|
|
|
|
input_length = input_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
|
|
|
|
label_length = label_length * tf.ones(shape = (batch_len, 1), dtype = "int64")
|
|
|
|
|
|
2025-12-09 08:09:49 +02:00
|
|
|
#y_pred = tf.math.log_softmax(y_pred)
|
|
|
|
|
|
|
|
|
|
#y_true = tf.cast(y_true, dtype = "int64")
|
|
|
|
|
#y_true_sparse = tf.keras.backend.ctc_label_dense_to_sparse(y_true, label_length)
|
|
|
|
|
|
|
|
|
|
#print(y_true_sparse)
|
|
|
|
|
|
2025-12-08 17:10:47 +02:00
|
|
|
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
|
2025-12-09 08:09:49 +02:00
|
|
|
#loss = tf.nn.ctc_loss(y_true_sparse, y_pred, input_length, label_length)
|
2025-12-08 17:10:47 +02:00
|
|
|
return loss
|