code reviewed
This commit is contained in:
@@ -0,0 +1,124 @@
|
||||
import openai
|
||||
import json
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import requests
|
||||
from typing import List, Dict
|
||||
import os
|
||||
|
||||
class LLMTagSimilarity:
|
||||
def __init__(self, api_key: str = None, model: str = "gpt-4o"):
|
||||
"""
|
||||
Initialize LLM-based tag similarity analyzer
|
||||
Args:
|
||||
api_key: OpenAI API key (if None, will use OPENAI_API_KEY env var)
|
||||
model: Model to use for analysis
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.client = openai.OpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
|
||||
def encode_image_to_base64(self, image_url: str) -> str:
|
||||
"""Convert image URL to base64 for API"""
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format='JPEG')
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode()
|
||||
return img_str
|
||||
except Exception as e:
|
||||
print(f"Error encoding image {image_url}: {e}")
|
||||
return None
|
||||
|
||||
def analyze_tag_similarity(self, query_image_url: str, candidate_images: List[str], max_candidates: int = 30) -> List[Dict]:
|
||||
"""
|
||||
Use LLM to analyze similarity between query tag and candidate tags
|
||||
Args:
|
||||
query_image_url: URL of the query tag image
|
||||
candidate_images: List of candidate tag image URLs
|
||||
max_candidates: Maximum number of candidates to analyze
|
||||
Returns:
|
||||
List of candidates with similarity scores and explanations
|
||||
"""
|
||||
candidates = candidate_images[:max_candidates]
|
||||
query_base64 = self.encode_image_to_base64(query_image_url)
|
||||
if not query_base64:
|
||||
return []
|
||||
candidate_base64s = []
|
||||
for img_url in candidates:
|
||||
base64_img = self.encode_image_to_base64(img_url)
|
||||
if base64_img:
|
||||
candidate_base64s.append(base64_img)
|
||||
prompt = """
|
||||
You are an expert in t-shirt tag authentication. Your job is to strictly assess visual authenticity between a query tag and several candidate tags.
|
||||
For each candidate, return:
|
||||
1. similarity_score (0-100): Based on **visual design fidelity** (not just text). How visually similar the tag looks. not just a color combination but color is very important. for example, a tag could have a whilte background a blue text and a red stripe is not similar at all to the same kind of tag with same white background and same text style and everything but has a red text with blue stripe. Also text colors are very important, two tags can look alike but text color are different... they are not similar. color is simple... for example, sky blue, navy blue is all blue... red, wine, is all the same. then focus on the overall style of the tag.. could be the same brand or name but diffrent style isnt similar. background color and text color is very important. so for example, if a tag background is black, similar tags are tags that are visiually similar and have black background as well. if a tag background is white, similar tags are tags that are visiually similar and have white background and so on.
|
||||
Output a JSON array of objects like:
|
||||
{
|
||||
"candidate_index": 0,
|
||||
"similarity_score": 42,
|
||||
}
|
||||
"""
|
||||
content = [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{query_base64}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
for i, base64_img in enumerate(candidate_base64s):
|
||||
content.append({"type": "text", "text": f"Candidate {i+1}:"})
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_img}",
|
||||
"detail": "high"
|
||||
}
|
||||
})
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
max_tokens=2000,
|
||||
temperature=0.1
|
||||
)
|
||||
analysis_text = response.choices[0].message.content
|
||||
try:
|
||||
start_idx = analysis_text.find('[')
|
||||
end_idx = analysis_text.rfind(']') + 1
|
||||
json_str = analysis_text[start_idx:end_idx]
|
||||
results = json.loads(json_str)
|
||||
for i, result in enumerate(results):
|
||||
result['original_url'] = candidates[i]
|
||||
return results
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing LLM response: {e}")
|
||||
print(f"Response: {analysis_text}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Error calling LLM API: {e}")
|
||||
return []
|
||||
|
||||
def filter_similar_tags(self, query_image_url: str, candidate_images: List[str], similarity_threshold: float = 70.0) -> List[Dict]:
|
||||
"""
|
||||
Filter candidates based on LLM similarity analysis
|
||||
Args:
|
||||
query_image_url: URL of the query tag image
|
||||
candidate_images: List of candidate tag image URLs
|
||||
similarity_threshold: Minimum similarity score to include
|
||||
Returns:
|
||||
Filtered list of similar tags with scores
|
||||
"""
|
||||
analysis_results = self.analyze_tag_similarity(query_image_url, candidate_images)
|
||||
filtered_results = [
|
||||
result for result in analysis_results
|
||||
if result.get('similarity_score', 0) >= similarity_threshold
|
||||
]
|
||||
filtered_results.sort(key=lambda x: x.get('similarity_score', 0), reverse=True)
|
||||
return filtered_results
|
||||
@@ -0,0 +1,150 @@
|
||||
import runpod
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
runpod.api_key = os.getenv('RUNPOD_API_KEY')
|
||||
|
||||
class TagIdentification:
|
||||
_instance = None
|
||||
_cache = {}
|
||||
_cache_file = os.path.join(os.path.dirname(__file__), 'tag_identification_cache.json')
|
||||
_is_cache_loaded = False
|
||||
|
||||
def __new__(cls, endpoint_id):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(TagIdentification, cls).__new__(cls)
|
||||
cls._instance.endpoint_id = endpoint_id
|
||||
cls._instance.endpoint = runpod.Endpoint(endpoint_id)
|
||||
cls._instance._load_cache()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, endpoint_id):
|
||||
# Initialization already done in __new__
|
||||
pass
|
||||
|
||||
def _load_cache(self):
|
||||
"""Load the tag identification cache from file."""
|
||||
if self._is_cache_loaded:
|
||||
return
|
||||
|
||||
if os.path.exists(self._cache_file):
|
||||
try:
|
||||
with open(self._cache_file, 'r') as f:
|
||||
self._cache = json.load(f)
|
||||
print(f"Loaded {len(self._cache)} cached tag identifications")
|
||||
except Exception as e:
|
||||
print(f"Error loading cache: {e}")
|
||||
self._cache = {}
|
||||
|
||||
self._is_cache_loaded = True
|
||||
|
||||
def _save_cache(self):
|
||||
"""Save the tag identification cache to file."""
|
||||
try:
|
||||
with open(self._cache_file, 'w') as f:
|
||||
json.dump(self._cache, f)
|
||||
print(f"Saved {len(self._cache)} cached tag identifications")
|
||||
except Exception as e:
|
||||
print(f"Error saving cache: {e}")
|
||||
|
||||
def preload_tags(self, image_urls):
|
||||
"""
|
||||
Preload tag identifications for a list of image URLs.
|
||||
Args:
|
||||
image_urls: List of image URLs to preload
|
||||
"""
|
||||
print(f"Preloading {len(image_urls)} tag identifications...")
|
||||
|
||||
for i, url in enumerate(image_urls):
|
||||
if url in self._cache:
|
||||
continue
|
||||
|
||||
try:
|
||||
result = self._identify_tag_api(url)
|
||||
if result:
|
||||
self._cache[url] = result
|
||||
|
||||
# Save cache periodically
|
||||
if (i + 1) % 10 == 0:
|
||||
self._save_cache()
|
||||
print(f"Preloaded {i+1}/{len(image_urls)} tags")
|
||||
except Exception as e:
|
||||
print(f"Error preloading tag for {url}: {e}")
|
||||
|
||||
# Final save
|
||||
self._save_cache()
|
||||
print(f"Completed preloading {len(image_urls)} tags")
|
||||
|
||||
def identify_tag(self, image_url):
|
||||
"""
|
||||
Identify tag from image URL, using cache if available.
|
||||
Args:
|
||||
image_url: URL of the image to identify
|
||||
Returns:
|
||||
Tag identification result
|
||||
"""
|
||||
# Check cache first
|
||||
if image_url in self._cache:
|
||||
print(f"Cache hit for {image_url}")
|
||||
return self._cache[image_url]
|
||||
|
||||
# Call API if not in cache
|
||||
result = self._identify_tag_api(image_url)
|
||||
|
||||
# Cache the result
|
||||
if result:
|
||||
self._cache[image_url] = result
|
||||
self._save_cache()
|
||||
|
||||
return result
|
||||
|
||||
def _identify_tag_api(self, image_url):
|
||||
"""Make the actual API call to identify the tag."""
|
||||
prompt = """
|
||||
You will tell me which tag it belongs to:
|
||||
1. Alstyle Apparel & Activewear T-Shirt Tags 1995-2006
|
||||
2. Anvil T-Shirt Tags 1989-2007
|
||||
3. Ched and Anvil T-Shirt Tags 1976-1988
|
||||
4. Delta T-Shirt Tags 1988-2014
|
||||
5. Fruit of the Loom 1970-1998
|
||||
6. Giant T-Shirt Tags 1991-1996
|
||||
7. Gildan T-Shirt Tags 1995-2002
|
||||
8. Hanes T-Shirt Tags 1989-1997
|
||||
9. Jerzees T-Shirt Tags 1985-1998
|
||||
10. Oneita T-Shirt Tags 1984-1999
|
||||
11. Screen Stars T-Shirt Tags 1980-1994
|
||||
12. Signal T-Shirt Tags 1977-1994
|
||||
13. Sportswear T-Shirt Tags 1968 – 1990
|
||||
14. Stedman & Hi Cru T-Shirt Tags 1971-1997
|
||||
15. Tennessee River T-Shirt Tags 1984-2010
|
||||
16. Wild Oats T-Shirt Tags 1984-1997
|
||||
17. Winterland T-Shirt Tags 1982-2008
|
||||
18. Others
|
||||
|
||||
Just Give me the Tag only(Don't add anything)
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"input": {
|
||||
"input_image_url": image_url,
|
||||
"vlm_prompt": "What is the tag of the t-shirt? Just give me the name of the tag nothing else",
|
||||
"max_new_tokens": 20
|
||||
}
|
||||
}
|
||||
|
||||
# Send the request to the endpoint and get the response
|
||||
run_request = self.endpoint.run(payload)
|
||||
|
||||
# Check the status of the endpoint run request in a loop until completed or an error occurs
|
||||
while True:
|
||||
status = run_request.status()
|
||||
if status == 'COMPLETED':
|
||||
return run_request.output()
|
||||
elif status == 'FAILED':
|
||||
print("Request failed.")
|
||||
return None
|
||||
time.sleep(1)
|
||||
return None
|
||||
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import os
|
||||
|
||||
def get_best_match(tag_response, tag_guides_path=None, top_n=1):
|
||||
"""
|
||||
Find the best match for a tag_response using cosine similarity on local tag_guides_clean.json.
|
||||
|
||||
Args:
|
||||
tag_response (str): The input tag to be matched.
|
||||
tag_guides_path (str): Path to the local tag_guides_clean.json file.
|
||||
top_n (int): Number of top matches to return (default is 1).
|
||||
|
||||
Returns:
|
||||
list: A list of top_n best matches with similarity scores and matched data.
|
||||
"""
|
||||
if tag_guides_path is None:
|
||||
tag_guides_path = os.path.join(os.path.dirname(__file__), '../data/tag_guides_clean.json')
|
||||
with open(tag_guides_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
names = [item['name'] for item in data['tag_guides']]
|
||||
vectorizer = TfidfVectorizer().fit(names + [tag_response])
|
||||
name_vectors = vectorizer.transform(names)
|
||||
response_vector = vectorizer.transform([tag_response])
|
||||
cosine_similarities = cosine_similarity(response_vector, name_vectors).flatten()
|
||||
top_indices = cosine_similarities.argsort()[-top_n:][::-1]
|
||||
best_matches = []
|
||||
for index in top_indices:
|
||||
best_matches.append({
|
||||
'matched_name': data['tag_guides'][index]['name'],
|
||||
'similarity_score': round(cosine_similarities[index], 4),
|
||||
'matched_data': data['tag_guides'][index]
|
||||
})
|
||||
return best_matches
|
||||
Reference in New Issue
Block a user