Files
Anton_wireframe/app/services/openrouter.py
T

179 lines
6.4 KiB
Python
Raw Normal View History

import asyncio
from typing import List, Optional
import chromadb
import pandas as pd
from db.tables import InvestorTable
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from pydantic_schemas import Investor, InvestorList
from settings import settings
# Add these imports for your databases
# from sqlalchemy.ext.asyncio import AsyncSession
# from your_vector_db import VectorDBClient
class InvestorProcessor:
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
):
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records.
Given the following CSV data rows:
{question}
For each row, extract and structure the following fields:
- name: The investor's full name
- aum: Assets under management (as integer, use 0 if not available)
- check_size: Investment check size (as string)
- sector_focus: Sector focus (as string)
- stage_focus: Investment stage focus (as string)
- region: Geographic region (as string)
- investment_thesis: Investment thesis (as string)
- investor_description: Description of the investor (as string)
Important:
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers)
- Ensure all text fields are properly escaped and contain no control characters
- Return clean, valid JSON only
Return the data as a structured list of investors."""
self.prompt = PromptTemplate(
template=self.template, input_variables=["question"]
)
self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-oss-120b:fre",
temperature=0,
)
self.structured_llm = self.llm.with_structured_output(InvestorList)
self.sql_session = sql_session
self.vector_db_client = vector_db_client
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus"
},
)
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List:
"""Process a single batch of data"""
# Convert batch to string representation - clean the data
batch_str = ""
for idx, row in batch.iterrows():
# Clean values to remove control characters
cleaned_row = {}
for key, value in row.items():
if pd.notna(value):
# Convert to string and clean control characters
clean_value = (
str(value)
.replace("\n", " ")
.replace("\r", " ")
.replace("\t", " ")
)
# Remove other control characters
clean_value = "".join(
char
for char in clean_value
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
)
cleaned_row[key] = clean_value
row_str = ", ".join(
[f"{key}: {value}" for key, value in cleaned_row.items()]
)
batch_str += f"Row {idx + 1}: {row_str}\n"
try:
print(f"Processing batch {batch_idx + 1}...")
batch_results = await self.structured_llm.ainvoke(batch_str)
return batch_results.investor_list
except Exception as e:
print(f"Error processing batch {batch_idx + 1}: {e}")
return []
async def _save_to_sql(self, investors: List[Investor]) -> None:
"""Save investors to SQL database"""
if not self.sql_session:
return
# Implement SQL saving logic here
for investor in investors:
db_investor = InvestorTable(
name=investor.name,
aum=investor.aum,
check_size=investor.check_size,
sector_focus=investor.sector_focus,
stage_focus=investor.stage_focus,
region=investor.region,
)
self.sql_session.add(db_investor)
self.sql_session.commit()
async def _save_to_vector_db(self, investors: List[Investor]) -> None:
"""Save investors to vector database"""
if not self.vector_db_client:
return
documents = []
metadatas = []
ids = []
for i, investor in enumerate(investors):
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}"
documents.append(doc_text)
metadatas.append({"name": investor.name})
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}")
if documents:
# Use add method with proper parameters
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
async def process_csv(
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
) -> List:
"""Process CSV data in parallel batches and save to databases"""
results = []
# Create batches
batches = []
for i in range(0, len(df), batch_size):
batch = df.iloc[i : i + batch_size]
batches.append((batch, i // batch_size))
# Process batches with concurrency control
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(batch_data):
batch, batch_idx = batch_data
async with semaphore:
return await self._process_batch(batch, batch_idx)
# Execute all batches concurrently
batch_results = await asyncio.gather(
*[process_with_semaphore(batch_data) for batch_data in batches],
return_exceptions=True,
)
# Collect results, filtering out exceptions
for batch_result in batch_results:
if not isinstance(batch_result, Exception):
results.extend(batch_result)
# Save to databases
if results:
await self._save_to_sql(results)
await self._save_to_vector_db(results)
return results