Trained model을 어떻게 사용할까? Tensorflow, Keras 버전
How to use your trained model - Deep Learning basics with Python, TensorFlow and Keras
Save Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import TensorBoard
import pickle
import time
pickle_in = open("X.pickle","rb")
X = pickle.load(pickle_in)
pickle_in = open("y.pickle","rb")
y = pickle.load(pickle_in)
X = X/255.0
dense_layers = [0]
layer_sizes = [64]
conv_layers = [3]
for dense_layer in dense_layers:
for layer_size in layer_sizes:
for conv_layer in conv_layers:
NAME = "{}-conv-{}-nodes-{}-dense-{}".format(conv_layer, layer_size, dense_layer, int(time.time()))
print(NAME)
model = Sequential()
model.add(Conv2D(layer_size, (3, 3), input_shape=X.shape[1:]))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
for l in range(conv_layer-1):
model.add(Conv2D(layer_size, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
for _ in range(dense_layer):
model.add(Dense(layer_size))
model.add(Activation('relu'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'],
)
model.fit(X, y,
batch_size=32,
epochs=10,
validation_split=0.3,
callbacks=[tensorboard])
model.save('64x3-CNN.model') #SAVE MODEL
Load Model
import cv2
import tensorflow as tf
CATEGORIES = ["Dog", "Cat"]
def prepare(filepath):
IMG_SIZE = 70
img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE)
return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)
model = tf.keras.models.load_model("64x3-CNN.model") #LOAD MODEL
prediction = model.predict([prepare('dog.jpg')])
print(prediction) #[[0.]]
print(CATEGORIES[int(prediction[0][0])]) #Dog
출처
'python' 카테고리의 다른 글
VScode python interpreter 버전 변경 (0) | 2022.01.14 |
---|---|
[python] Iterator slicing (0) | 2021.12.24 |
[torch] trained model 저장 및 사용 - (2)PyTorch (0) | 2021.12.21 |
[python] datetime, replace (0) | 2021.10.27 |
map 함수 (0) | 2021.10.26 |