이미지 임베딩
TensorFlow Similarity
설치
pip install tensorflow_similarity
준비
from tensorflow_similarity.layers import MetricEmbedding
from tensorflow_similarity.models import SimilarityModel
MNIST 예제 데이터
from tensorflow_similarity.samplers import \
TFDatasetMultiShotMemorySampler
sampler = TFDatasetMultiShotMemorySampler(
dataset_name='mnist', classes_per_batch=10)
모형 정의
base_model = tf.keras.Sequential([
Rescaling(1/255, input_shape=(28, 28, 1)),
Conv2D(64, 3, activation='relu'),
Flatten(),
Dense(64, activation='relu'),
MetricEmbedding(64)])
model = SimilarityModel(base_model.inputs, base_model.outputs)
학습
from tensorflow_similarity.losses import MultiSimilarityLoss
model.compile('adam', loss=MultiSimilarityLoss())
model.fit(sampler, epochs=5)
임베딩 보기
sx, sy = sampler.get_slice(0,1)
model(sx)
인덱싱
sx, sy = sampler.get_slice(0,100)
model.index(x=sx, y=sy, data=sx)
검색할 이미지
qx, qy = sampler.get_slice(3713, 1)
검색
nns = model.single_lookup(qx[0])
CLIP
모델 로딩
OpenAI의 CLIP 모델의 전처리를 위한 프로세서를 생성
from transformers import CLIPProcessor, CLIPModel
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
이미지 로딩
import glob
from PIL import Image
images = []
for path in glob.glob('coco/*.jpg'):
images.append(Image.open(path))
images[0]
zero-shot 분류
전처리
inputs = processor(
text=["a photo of sleeping cats", "a photo of dogs"],
images=images[0],
return_tensors="pt",
padding=True)
모델에 입력
outputs = model(**inputs)
모델 출력에서 로짓 값을 추출
outputs.logits_per_image
tensor([[29.6581, 21.5314]], grad_fn=<TBackward0>)
확률
import torch
torch.softmax(outputs.logits_per_image, dim=-1)
tensor([[9.9970e-01, 2.9545e-04]], grad_fn=<SoftmaxBackward0>)
이미지 유사도
이미지들을 전처리
inputs = processor(images=images, return_tensors="pt")
입력 이미지들의 임베딩을 추출하고 numpy 배열로 변환
embs = model.get_image_features(**inputs)
embs = embs.detach().numpy()
코사인 유사도를 계산
from sklearn.metrics.pairwise import cosine_similarity
sims = cosine_similarity(embs)
이미지 인덱스를 정렬
import numpy as np
np.argsort(sims[0])
array([4, 5, 3, 7, 9, 6, 2, 1, 8, 0], dtype=int64)
images[8]