Files
ds_task_scan/backend/services/image_similarity.py
T
Aherobo Ovie Victor 5e07248594 code reviewed
2025-07-22 09:46:32 +01:00

48 lines
1.4 KiB
Python

import os
import numpy as np
import faiss
import torch
from PIL import Image
from transformers import CLIPModel
from torchvision import transforms
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Global variable for lazy loading
_model = None
_preprocess = None
def get_model():
global _model
if _model is None:
_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch16').to(device)
_model.eval()
return _model
def get_preprocess():
global _preprocess
if _preprocess is None:
_preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return _preprocess
def transform_image(image):
model = get_model()
preprocess = get_preprocess()
image = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
features = model.get_image_features(pixel_values=image)
return features.squeeze().cpu().numpy().astype(np.float32)
def load_index(index_file):
return faiss.read_index(index_file)
def search_similar_images(query_image, index, top_k=5):
query_embedding = transform_image(query_image)
query_embedding_normalized = query_embedding / np.linalg.norm(query_embedding)
distances, indices = index.search(np.array([query_embedding_normalized]), top_k)
return distances[0][:top_k], indices[0][:top_k]