TensorFlow Hub :: 컴퓨터 비전 - mindscale
Skip to content

TensorFlow Hub

텐서플로 허브

사전 학습된 모델 저장소 https://tfhub.dev

설치

pip install tensorflow_hub

불러오기

import tensorflow_hub as hub

MobileNet 불러오기

레이어 가져오기

model_url = (
    "https://tfhub.dev/google/tf2-preview/"                
    "mobilenet_v2/classification/4")

모델 만들기

import tensorflow as tf

model = tf.keras.Sequential([   
    tf.keras.layers.Rescaling(1/255),    
    hub.KerasLayer(model_url)
])

레이블 다운로드

다운로드

label_file = tf.keras.utils.get_file(origin=    
  'https://storage.googleapis.com/download.tensorflow.org/' +        
  'data/ImageNetLabels.txt')

불러오기

labels = open(label_file).read().splitlines()

예측

size = 224, 224
validation_dataset = tf.keras.utils.image_dataset_from_directory(
    'cats_and_dogs_filtered/validation',
    shuffle=False,
    image_size=size)
Found 1000 files belonging to 2 classes.
predicted = model.predict(validation_dataset)
32/32 [==============================] - 8s 226ms/step

확률이 가장 높은 클래스 출력

for i in predicted.argmax(axis=1):
      print(labels[i])
Egyptian cat
carton
paper towel
tabby
Siamese cat

...

출력층을 교체하여 전이학습

  • 앞의 모형은 기존에 학습된 1001가지 카테고리로 분류
  • 수행하려는 과제에 맞춰 새로운 출력층으로 교체
  • 모형이 classification → feature_vector로 달라지는데 유의
from tensorflow.keras.layers import *

model_url = ('https://tfhub.dev/google/imagenet/'             
             'mobilenet_v2_100_224/feature_vector/5')
model = tf.keras.Sequential([
    Rescaling(1/255),
    hub.KerasLayer(model_url, trainable=False),      
    Dense(1, activation='sigmoid')
])

미세조정

설정

model.compile(  
    optimizer=tf.keras.optimizers.Adam(),      
    loss=tf.keras.losses.binary_crossentropy,  
    metrics=['accuracy'])
train_dataset = tf.keras.utils.image_dataset_from_directory(
    'cats_and_dogs_filtered/train',
    shuffle=True,
    image_size=size)
Found 2000 files belonging to 2 classes.

학습

model.fit(train_dataset, validation_data=validation_dataset)
63/63 [==============================] - 25s 359ms/step - loss: 0.2028 - accuracy: 0.9340 - val_loss: 0.0706 - val_accuracy: 0.9830
<keras.callbacks.History at 0x1a324e45d80>