103 lines
2.2 KiB
Python
103 lines
2.2 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from PIL import Image, ImageTk
|
|
import tkinter as tk
|
|
import math
|
|
import time
|
|
|
|
__put_active = 0
|
|
__take_active = 0
|
|
|
|
__img = None
|
|
__cw = None
|
|
|
|
def classify_live(m, label):
|
|
global __img, __cw
|
|
|
|
m.compile(optimizer = "adam",
|
|
loss = "categorical_crossentropy",
|
|
metrics = ["accuracy"])
|
|
|
|
m.load_weights(f"save-{label}.weights.h5")
|
|
|
|
r = tk.Tk()
|
|
r.title("Draw!")
|
|
|
|
canvas = np.zeros([28, 28])
|
|
|
|
__img = Image.fromarray(np.uint8(canvas * 255), "L")
|
|
__cw = ImageTk.PhotoImage(
|
|
__img.resize(size = (504, 504),
|
|
resample = Image.NEAREST))
|
|
|
|
l = tk.Label(r, image = __cw)
|
|
lt = tk.Label(r, text = "", font = ("Liberation Sans", 48))
|
|
|
|
lt.pack()
|
|
l.pack()
|
|
|
|
def clear_array():
|
|
canvas[:][:] = np.zeros([28, 28])
|
|
|
|
def mouse_down(ev):
|
|
global __put_active, __take_active
|
|
|
|
if ev.num == 1:
|
|
__put_active = 1
|
|
elif ev.num == 2:
|
|
__take_active = 1
|
|
|
|
|
|
def mouse_up(ev):
|
|
global __put_active, __take_active
|
|
|
|
if ev.num == 1:
|
|
__put_active = 0
|
|
elif ev.num == 2:
|
|
__take_active = 0
|
|
|
|
|
|
def update_img():
|
|
global __img, __cw
|
|
|
|
__img = Image.fromarray(np.uint8(canvas * 255), "L")
|
|
__cw = ImageTk.PhotoImage(
|
|
__img.resize(size = (504, 504),
|
|
resample = Image.NEAREST))
|
|
|
|
l.configure(image = __cw)
|
|
|
|
r.after(50, update_img)
|
|
|
|
|
|
def update_pred():
|
|
pred = m.predict(canvas.reshape(-1, 784),
|
|
verbose = 0)
|
|
|
|
lt.configure(text = np.argmax(pred))
|
|
|
|
r.after(300, update_pred)
|
|
|
|
|
|
def change_pix(ev):
|
|
x = math.floor(ev.x / 18)
|
|
y = math.floor(ev.y / 18)
|
|
|
|
if __put_active:
|
|
canvas[y, x] = 1
|
|
elif __take_active:
|
|
canvas[x, y] = 0
|
|
|
|
l.bind("<Motion>", change_pix)
|
|
l.bind("<ButtonPress>", mouse_down)
|
|
l.bind("<ButtonRelease>", mouse_up)
|
|
|
|
tk.Button(text = "Clear",
|
|
command = clear_array).pack()
|
|
|
|
r.after(300, update_pred)
|
|
r.after(50, update_img)
|
|
|
|
r.mainloop()
|