code reviewed
This commit is contained in:
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import faiss
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel
|
||||
from torchvision import transforms
|
||||
|
||||
# Device setup
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Global variable for lazy loading
|
||||
_model = None
|
||||
_preprocess = None
|
||||
|
||||
def get_model():
|
||||
global _model
|
||||
if _model is None:
|
||||
_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch16').to(device)
|
||||
_model.eval()
|
||||
return _model
|
||||
|
||||
def get_preprocess():
|
||||
global _preprocess
|
||||
if _preprocess is None:
|
||||
_preprocess = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
return _preprocess
|
||||
|
||||
def transform_image(image):
|
||||
model = get_model()
|
||||
preprocess = get_preprocess()
|
||||
image = preprocess(image).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
features = model.get_image_features(pixel_values=image)
|
||||
return features.squeeze().cpu().numpy().astype(np.float32)
|
||||
|
||||
def load_index(index_file):
|
||||
return faiss.read_index(index_file)
|
||||
|
||||
def search_similar_images(query_image, index, top_k=5):
|
||||
query_embedding = transform_image(query_image)
|
||||
query_embedding_normalized = query_embedding / np.linalg.norm(query_embedding)
|
||||
distances, indices = index.search(np.array([query_embedding_normalized]), top_k)
|
||||
return distances[0][:top_k], indices[0][:top_k]
|
||||
@@ -0,0 +1,16 @@
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
def download_image(image_url: str, save_path: str) -> None:
|
||||
"""
|
||||
Download an image from a URL and save it to a local path.
|
||||
Args:
|
||||
image_url (str): The URL of the image to download.
|
||||
save_path (str): The local file path to save the image.
|
||||
"""
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status()
|
||||
image = Image.open(BytesIO(response.content))
|
||||
image.save(save_path)
|
||||
@@ -0,0 +1,58 @@
|
||||
import pandas as pd
|
||||
|
||||
def aggregate_results(matched_names, expert_data, community_data, max_results=30):
|
||||
"""
|
||||
Aggregate similar images, appraisal values, years, and status for matched tag names.
|
||||
Args:
|
||||
matched_names (list): List of tag names to match.
|
||||
expert_data (pd.DataFrame): Expert dataset.
|
||||
community_data (pd.DataFrame): Community dataset.
|
||||
max_results (int): Maximum number of results to return.
|
||||
Returns:
|
||||
dict: Aggregated results with images, appraisal values, years, and status.
|
||||
"""
|
||||
similar_data = []
|
||||
|
||||
for title in matched_names:
|
||||
# Handle expert data (no 'year' column)
|
||||
community_items = community_data[community_data['brand_name'] == title]
|
||||
expert_items = expert_data[expert_data['brand_name'] == title]
|
||||
|
||||
# Process community data (has 'year' column)
|
||||
if not community_items.empty:
|
||||
# Use 'year' column if it exists, otherwise use 'year_start'
|
||||
year_col = 'year' if 'year' in community_items.columns else 'year_start'
|
||||
community_records = community_items[['front_tag', 'appraisal_value', 'key', 'status', year_col]].to_dict('records')
|
||||
# Rename year column to 'year' for consistency
|
||||
for record in community_records:
|
||||
record['year'] = record.pop(year_col) if year_col in record else None
|
||||
similar_data.extend(community_records)
|
||||
|
||||
# Process expert data (no 'year' column)
|
||||
if not expert_items.empty:
|
||||
expert_records = expert_items[['front_tag', 'appraisal_value', 'key', 'status']].to_dict('records')
|
||||
# Add None for year since expert data doesn't have it
|
||||
for record in expert_records:
|
||||
record['year'] = None
|
||||
similar_data.extend(expert_records)
|
||||
|
||||
# Remove duplicates by key, preserving order
|
||||
seen_keys = set()
|
||||
unique_data = []
|
||||
for item in similar_data:
|
||||
if item['key'] not in seen_keys:
|
||||
seen_keys.add(item['key'])
|
||||
unique_data.append(item)
|
||||
|
||||
# Prepare results
|
||||
similar_images = [item['front_tag'] for item in unique_data][:max_results]
|
||||
appraisal_values = [item['appraisal_value'] for item in unique_data][:max_results]
|
||||
years = [item.get('year') for item in unique_data][:max_results]
|
||||
statuses = [item['status'] for item in unique_data][:max_results]
|
||||
|
||||
return {
|
||||
'similar_images': similar_images,
|
||||
'appraisal_values': appraisal_values,
|
||||
'years': years,
|
||||
'statuses': statuses
|
||||
}
|
||||
Reference in New Issue
Block a user