neuro-lab7/main.py

75 lines
1.6 KiB
Python
Raw Normal View History

2025-12-06 13:50:19 +02:00
from tensorflow.keras import layers as kl
from tensorflow.keras import models as km
from tensorflow.keras import losses as ks
from tensorflow.keras import optimizers as ko
from tensorflow.keras import callbacks as kc
from tensorflow.keras.preprocessing.text import Tokenizer as kT
from tensorflow.keras.utils import pad_sequences as kps
import numpy as np
import pandas as pd
2025-12-06 16:36:55 +02:00
import pickle
2025-12-06 13:50:19 +02:00
print("I")
2025-12-06 15:56:00 +02:00
t = pd.read_csv("prepped_train.csv",
2025-12-06 13:50:19 +02:00
header = None,
2025-12-06 15:56:00 +02:00
names = ['i', 'c', 'r'])
2025-12-06 13:50:19 +02:00
print("R")
2025-12-06 15:56:00 +02:00
y = t['c']
r = t['r'].astype(str)
2025-12-06 13:50:19 +02:00
tk = kT(num_words = 6000)
tk.fit_on_texts(r)
2025-12-06 16:36:55 +02:00
with open('tokenizer.pickle', 'wb') as f:
pickle.dump(tk, f, protocol = pickle.HIGHEST_PROTOCOL)
2025-12-06 13:50:19 +02:00
print("F")
s = tk.texts_to_sequences(r)
print("T")
ts = kps(s, maxlen = 100)
print("P")
2025-12-06 15:56:00 +02:00
'''
2025-12-06 13:50:19 +02:00
m = km.Sequential([
kl.Input(shape = (None, ), dtype = 'int32'),
kl.Embedding(6000, 96),
kl.Dropout(0.2),
kl.Conv1D(128, 5, activation = 'relu'),
kl.LSTM(128, return_sequences = True),
kl.LSTM(64),
kl.Dense(64),
kl.Dropout(0.5),
kl.Dense(1, activation = 'sigmoid')
])
m.compile(optimizer = ko.Lion(learning_rate = 0.0005),
loss = 'binary_crossentropy',
metrics = ['accuracy'])
2025-12-06 15:56:00 +02:00
'''
2025-12-06 13:50:19 +02:00
2025-12-06 16:36:55 +02:00
'''
2025-12-06 15:56:00 +02:00
from model import m
2025-12-06 13:50:19 +02:00
2025-12-06 15:56:00 +02:00
ckpt = kc.ModelCheckpoint('model2.keras',
2025-12-06 13:50:19 +02:00
monitor = 'val_accuracy',
save_best_only = True,
verbose = 1)
2025-12-06 15:56:00 +02:00
m.load_weights("model1.keras")
2025-12-06 13:50:19 +02:00
history = m.fit(ts,
y,
2025-12-06 15:56:00 +02:00
epochs = 15,
batch_size = 1024,
2025-12-06 13:50:19 +02:00
validation_split = 0.1,
callbacks = [ckpt])
2025-12-06 16:36:55 +02:00
'''