update training method
This commit is contained in:
parent
347f6a2483
commit
2b6d093c0d
25
train_l_v2.py
Normal file
25
train_l_v2.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
ds_train, ds_valid = tf.keras.preprocessing.image_dataset_from_directory(
|
||||||
|
'/mnt/tmpfs1/ds-mini-1',
|
||||||
|
labels = 'inferred',
|
||||||
|
label_mode = 'categorical',
|
||||||
|
color_mode = 'rgb',
|
||||||
|
batch_size = 16,
|
||||||
|
image_size = (300, 300),
|
||||||
|
shuffle = False,
|
||||||
|
validation_split = 0.05,
|
||||||
|
subset = 'both',
|
||||||
|
verbose = True
|
||||||
|
)
|
||||||
|
|
||||||
|
from m import *
|
||||||
|
|
||||||
|
ckpt = kc.ModelCheckpoint("model3.model.keras",
|
||||||
|
monitor = 'val_accuracy',
|
||||||
|
save_best_only = True)
|
||||||
|
|
||||||
|
h = mod.fit(ds_train,
|
||||||
|
epochs = 9,
|
||||||
|
validation_data = ds_valid,
|
||||||
|
callbacks = [ckpt])
|
||||||
Loading…
x
Reference in New Issue
Block a user