neuro-lab8/train.py

42 lines
1.1 KiB
Python
Raw Normal View History

2025-12-08 17:10:47 +02:00
from model import *
from preprocessing import *
import pandas as pd
from cc import ce
bs = 32
data = pd.read_csv("/mnt/tmpfs1/LJSpeech-1.1/metadata.csv",
sep = '|',
header = None,
quoting = 3,
names = ['file_name', 'i', 'normalized_transcription'])
2025-12-09 08:09:49 +02:00
s = int(len(data) * 0.90)
2025-12-08 17:10:47 +02:00
train_data = data[:s]
2025-12-09 08:09:49 +02:00
valid_data = data[s:]
2025-12-08 17:10:47 +02:00
train_ds = to_dataset(train_data, batch_size = bs)
valid_ds = to_dataset(valid_data, batch_size = bs)
m = model(input_dim = fft_length // 2 + 1,
output_dim = char_to_num.vocabulary_size())
2025-12-11 09:32:21 +02:00
m.load_weights('model40-latest.keras')
ckpt1 = kc.ModelCheckpoint('model41-latest.keras',
2025-12-08 17:10:47 +02:00
monitor = 'val_loss',
save_best_only = False,
verbose = 1)
2025-12-11 09:32:21 +02:00
ckpt2 = kc.ModelCheckpoint('model41-best.keras',
2025-12-08 17:10:47 +02:00
monitor = 'val_loss',
save_best_only = True,
verbose = 1)
ce1 = ce(valid_ds, m)
m.fit(train_ds,
2025-12-11 09:32:21 +02:00
epochs = 40,
2025-12-08 17:10:47 +02:00
validation_data = valid_ds,
callbacks = [ckpt1, ckpt2, ce1])