26 lines
637 B
Python
26 lines
637 B
Python
import tensorflow as tf
|
|
|
|
ds_train, ds_valid = tf.keras.preprocessing.image_dataset_from_directory(
|
|
'/mnt/tmpfs1/ds-mini-1',
|
|
labels = 'inferred',
|
|
label_mode = 'categorical',
|
|
color_mode = 'rgb',
|
|
batch_size = 16,
|
|
image_size = (300, 300),
|
|
shuffle = False,
|
|
validation_split = 0.05,
|
|
subset = 'both',
|
|
verbose = True
|
|
)
|
|
|
|
from m import *
|
|
|
|
ckpt = kc.ModelCheckpoint("model3.model.keras",
|
|
monitor = 'val_accuracy',
|
|
save_best_only = True)
|
|
|
|
h = mod.fit(ds_train,
|
|
epochs = 9,
|
|
validation_data = ds_valid,
|
|
callbacks = [ckpt])
|