code reviewed

This commit is contained in:
Aherobo Ovie Victor
2025-07-22 09:46:32 +01:00
parent a9351f2c86
commit 5e07248594
11 changed files with 24 additions and 7 deletions
+124
View File
@@ -0,0 +1,124 @@
import openai
import json
import base64
from PIL import Image
import io
import requests
from typing import List, Dict
import os
class LLMTagSimilarity:
def __init__(self, api_key: str = None, model: str = "gpt-4o"):
"""
Initialize LLM-based tag similarity analyzer
Args:
api_key: OpenAI API key (if None, will use OPENAI_API_KEY env var)
model: Model to use for analysis
"""
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")
self.client = openai.OpenAI(api_key=api_key)
self.model = model
def encode_image_to_base64(self, image_url: str) -> str:
"""Convert image URL to base64 for API"""
try:
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content))
buffer = io.BytesIO()
image.save(buffer, format='JPEG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return img_str
except Exception as e:
print(f"Error encoding image {image_url}: {e}")
return None
def analyze_tag_similarity(self, query_image_url: str, candidate_images: List[str], max_candidates: int = 30) -> List[Dict]:
"""
Use LLM to analyze similarity between query tag and candidate tags
Args:
query_image_url: URL of the query tag image
candidate_images: List of candidate tag image URLs
max_candidates: Maximum number of candidates to analyze
Returns:
List of candidates with similarity scores and explanations
"""
candidates = candidate_images[:max_candidates]
query_base64 = self.encode_image_to_base64(query_image_url)
if not query_base64:
return []
candidate_base64s = []
for img_url in candidates:
base64_img = self.encode_image_to_base64(img_url)
if base64_img:
candidate_base64s.append(base64_img)
prompt = """
You are an expert in t-shirt tag authentication. Your job is to strictly assess visual authenticity between a query tag and several candidate tags.
For each candidate, return:
1. similarity_score (0-100): Based on **visual design fidelity** (not just text). How visually similar the tag looks. not just a color combination but color is very important. for example, a tag could have a whilte background a blue text and a red stripe is not similar at all to the same kind of tag with same white background and same text style and everything but has a red text with blue stripe. Also text colors are very important, two tags can look alike but text color are different... they are not similar. color is simple... for example, sky blue, navy blue is all blue... red, wine, is all the same. then focus on the overall style of the tag.. could be the same brand or name but diffrent style isnt similar. background color and text color is very important. so for example, if a tag background is black, similar tags are tags that are visiually similar and have black background as well. if a tag background is white, similar tags are tags that are visiually similar and have white background and so on.
Output a JSON array of objects like:
{
"candidate_index": 0,
"similarity_score": 42,
}
"""
content = [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{query_base64}",
"detail": "high"
}
}
]
for i, base64_img in enumerate(candidate_base64s):
content.append({"type": "text", "text": f"Candidate {i+1}:"})
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_img}",
"detail": "high"
}
})
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": content}],
max_tokens=2000,
temperature=0.1
)
analysis_text = response.choices[0].message.content
try:
start_idx = analysis_text.find('[')
end_idx = analysis_text.rfind(']') + 1
json_str = analysis_text[start_idx:end_idx]
results = json.loads(json_str)
for i, result in enumerate(results):
result['original_url'] = candidates[i]
return results
except json.JSONDecodeError as e:
print(f"Error parsing LLM response: {e}")
print(f"Response: {analysis_text}")
return []
except Exception as e:
print(f"Error calling LLM API: {e}")
return []
def filter_similar_tags(self, query_image_url: str, candidate_images: List[str], similarity_threshold: float = 70.0) -> List[Dict]:
"""
Filter candidates based on LLM similarity analysis
Args:
query_image_url: URL of the query tag image
candidate_images: List of candidate tag image URLs
similarity_threshold: Minimum similarity score to include
Returns:
Filtered list of similar tags with scores
"""
analysis_results = self.analyze_tag_similarity(query_image_url, candidate_images)
filtered_results = [
result for result in analysis_results
if result.get('similarity_score', 0) >= similarity_threshold
]
filtered_results.sort(key=lambda x: x.get('similarity_score', 0), reverse=True)
return filtered_results
+150
View File
@@ -0,0 +1,150 @@
import runpod
import time
import os
import json
from dotenv import load_dotenv
load_dotenv()
runpod.api_key = os.getenv('RUNPOD_API_KEY')
class TagIdentification:
_instance = None
_cache = {}
_cache_file = os.path.join(os.path.dirname(__file__), 'tag_identification_cache.json')
_is_cache_loaded = False
def __new__(cls, endpoint_id):
if cls._instance is None:
cls._instance = super(TagIdentification, cls).__new__(cls)
cls._instance.endpoint_id = endpoint_id
cls._instance.endpoint = runpod.Endpoint(endpoint_id)
cls._instance._load_cache()
return cls._instance
def __init__(self, endpoint_id):
# Initialization already done in __new__
pass
def _load_cache(self):
"""Load the tag identification cache from file."""
if self._is_cache_loaded:
return
if os.path.exists(self._cache_file):
try:
with open(self._cache_file, 'r') as f:
self._cache = json.load(f)
print(f"Loaded {len(self._cache)} cached tag identifications")
except Exception as e:
print(f"Error loading cache: {e}")
self._cache = {}
self._is_cache_loaded = True
def _save_cache(self):
"""Save the tag identification cache to file."""
try:
with open(self._cache_file, 'w') as f:
json.dump(self._cache, f)
print(f"Saved {len(self._cache)} cached tag identifications")
except Exception as e:
print(f"Error saving cache: {e}")
def preload_tags(self, image_urls):
"""
Preload tag identifications for a list of image URLs.
Args:
image_urls: List of image URLs to preload
"""
print(f"Preloading {len(image_urls)} tag identifications...")
for i, url in enumerate(image_urls):
if url in self._cache:
continue
try:
result = self._identify_tag_api(url)
if result:
self._cache[url] = result
# Save cache periodically
if (i + 1) % 10 == 0:
self._save_cache()
print(f"Preloaded {i+1}/{len(image_urls)} tags")
except Exception as e:
print(f"Error preloading tag for {url}: {e}")
# Final save
self._save_cache()
print(f"Completed preloading {len(image_urls)} tags")
def identify_tag(self, image_url):
"""
Identify tag from image URL, using cache if available.
Args:
image_url: URL of the image to identify
Returns:
Tag identification result
"""
# Check cache first
if image_url in self._cache:
print(f"Cache hit for {image_url}")
return self._cache[image_url]
# Call API if not in cache
result = self._identify_tag_api(image_url)
# Cache the result
if result:
self._cache[image_url] = result
self._save_cache()
return result
def _identify_tag_api(self, image_url):
"""Make the actual API call to identify the tag."""
prompt = """
You will tell me which tag it belongs to:
1. Alstyle Apparel & Activewear T-Shirt Tags 1995-2006
2. Anvil T-Shirt Tags 1989-2007
3. Ched and Anvil T-Shirt Tags 1976-1988
4. Delta T-Shirt Tags 1988-2014
5. Fruit of the Loom 1970-1998
6. Giant T-Shirt Tags 1991-1996
7. Gildan T-Shirt Tags 1995-2002
8. Hanes T-Shirt Tags 1989-1997
9. Jerzees T-Shirt Tags 1985-1998
10. Oneita T-Shirt Tags 1984-1999
11. Screen Stars T-Shirt Tags 1980-1994
12. Signal T-Shirt Tags 1977-1994
13. Sportswear T-Shirt Tags 1968 1990
14. Stedman & Hi Cru T-Shirt Tags 1971-1997
15. Tennessee River T-Shirt Tags 1984-2010
16. Wild Oats T-Shirt Tags 1984-1997
17. Winterland T-Shirt Tags 1982-2008
18. Others
Just Give me the Tag only(Don't add anything)
"""
payload = {
"input": {
"input_image_url": image_url,
"vlm_prompt": "What is the tag of the t-shirt? Just give me the name of the tag nothing else",
"max_new_tokens": 20
}
}
# Send the request to the endpoint and get the response
run_request = self.endpoint.run(payload)
# Check the status of the endpoint run request in a loop until completed or an error occurs
while True:
status = run_request.status()
if status == 'COMPLETED':
return run_request.output()
elif status == 'FAILED':
print("Request failed.")
return None
time.sleep(1)
return None
+35
View File
@@ -0,0 +1,35 @@
import json
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import os
def get_best_match(tag_response, tag_guides_path=None, top_n=1):
"""
Find the best match for a tag_response using cosine similarity on local tag_guides_clean.json.
Args:
tag_response (str): The input tag to be matched.
tag_guides_path (str): Path to the local tag_guides_clean.json file.
top_n (int): Number of top matches to return (default is 1).
Returns:
list: A list of top_n best matches with similarity scores and matched data.
"""
if tag_guides_path is None:
tag_guides_path = os.path.join(os.path.dirname(__file__), '../data/tag_guides_clean.json')
with open(tag_guides_path, 'r') as f:
data = json.load(f)
names = [item['name'] for item in data['tag_guides']]
vectorizer = TfidfVectorizer().fit(names + [tag_response])
name_vectors = vectorizer.transform(names)
response_vector = vectorizer.transform([tag_response])
cosine_similarities = cosine_similarity(response_vector, name_vectors).flatten()
top_indices = cosine_similarities.argsort()[-top_n:][::-1]
best_matches = []
for index in top_indices:
best_matches.append({
'matched_name': data['tag_guides'][index]['name'],
'similarity_score': round(cosine_similarities[index], 4),
'matched_data': data['tag_guides'][index]
})
return best_matches