This commit is contained in:
ІО-23 Шмуляр Олег 2025-12-06 13:50:19 +02:00
parent bddef39f9c
commit 9c3e92b4f8
3 changed files with 111 additions and 3 deletions

98
main.py Normal file
View File

@ -0,0 +1,98 @@
from tensorflow.keras import layers as kl
from tensorflow.keras import models as km
from tensorflow.keras import losses as ks
from tensorflow.keras import optimizers as ko
from tensorflow.keras import callbacks as kc
from tensorflow.keras.preprocessing.text import Tokenizer as kT
from tensorflow.keras.utils import pad_sequences as kps
import re
import numpy as np
import pandas as pd
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from tqdm import tqdm
tqdm.pandas()
print("I")
t = pd.read_csv("yelp_review_polarity_csv/train.csv",
header = None,
names = ['c', 'r'])
print("R")
y = t['c'] - 1
r = t['r']
#nltk.download("stopwords")
#nltk.download("punkt_tab")
#nltk.download("wordnet")
def fr(r):
r = r.lower()
r = " ".join(tuple(re.findall(r'\w+', r)))
for i in ['\n', '\r', ',', '.', '-', ';', ':', '\'', '"']:
r = r.replace(i, "")
sw = set(stopwords.words("english"))
l = WordNetLemmatizer()
return " ".join([l.lemmatize(i.strip(), pos = 'v') for i in word_tokenize(r) if i.strip() not in sw])
r = r.progress_apply(fr)
#print(r)
print("A")
tk = kT(num_words = 6000)
tk.fit_on_texts(r)
print("F")
#print(tk.word_index)
s = tk.texts_to_sequences(r)
#print(s)
print("T")
ts = kps(s, maxlen = 100)
print("P")
m = km.Sequential([
kl.Input(shape = (None, ), dtype = 'int32'),
kl.Embedding(6000, 96),
kl.Dropout(0.2),
kl.Conv1D(128, 5, activation = 'relu'),
kl.LSTM(128, return_sequences = True),
kl.LSTM(64),
kl.Dense(64),
kl.Dropout(0.5),
kl.Dense(1, activation = 'sigmoid')
])
m.compile(optimizer = ko.Lion(learning_rate = 0.0005),
loss = 'binary_crossentropy',
metrics = ['accuracy'])
#m.summary()
ckpt = kc.ModelCheckpoint('model1.keras',
monitor = 'val_accuracy',
save_best_only = True,
verbose = 1)
history = m.fit(ts,
y,
epochs = 3,
batch_size = 256,
validation_split = 0.1,
callbacks = [ckpt])

View File

@ -19,7 +19,7 @@ t = pd.read_csv("yelp_review_polarity_csv/train.csv",
print("R") print("R")
y = t['c'] - 1 y = (t['c'] - 1)
r = t['r'] r = t['r']
r = r.progress_apply(fr) r = r.progress_apply(fr)

View File

@ -3,7 +3,7 @@ import nltk
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer from nltk.stem import WordNetLemmatizer
from spellchecker import SpellChecker as sc #from spellchecker import SpellChecker as sc
nltk.download("stopwords") nltk.download("stopwords")
nltk.download("punkt_tab") nltk.download("punkt_tab")
@ -20,4 +20,14 @@ def fr(r):
sw = set(stopwords.words("english")) sw = set(stopwords.words("english"))
l = WordNetLemmatizer() l = WordNetLemmatizer()
return " ".join([l.lemmatize(i.strip(), pos = 'v') for i in word_tokenize(r) if i.strip() not in sw]) #c = sc()
r = [i.strip() for i in word_tokenize(r) if i.strip() not in sw]
# spellcheck
#for k, i in enumerate(r):
# w = c.correction(i)
# if w:
# r[k] = w
return " ".join([l.lemmatize(i, pos = 'v') for i in r])