ba0ed169ce
- Added InvestorProcessor class for processing CSV data in batches and saving to SQL and vector databases. - Introduced QueryProcessor class for querying investor information from SQL and vector databases. - Integrated OpenAI's ChatGPT for structured output generation. - Implemented data cleaning and control character removal in CSV processing. - Added asynchronous processing capabilities for batch handling. - Established connection to ChromaDB for vector storage of investor descriptions. - Defined structured output schemas using Pydantic for investor data validation. - Enhanced settings management for API key and database configurations.
179 lines
6.4 KiB
Python
179 lines
6.4 KiB
Python
import asyncio
|
|
from typing import List, Optional
|
|
|
|
import chromadb
|
|
import pandas as pd
|
|
from db.tables import InvestorTable
|
|
from langchain_core.prompts import PromptTemplate
|
|
from langchain_openai import ChatOpenAI
|
|
from pydantic_schemas import Investor, InvestorList
|
|
from settings import settings
|
|
|
|
# Add these imports for your databases
|
|
# from sqlalchemy.ext.asyncio import AsyncSession
|
|
# from your_vector_db import VectorDBClient
|
|
|
|
|
|
class InvestorProcessor:
|
|
def __init__(
|
|
self,
|
|
sql_session: Optional[object] = None,
|
|
vector_db_client: Optional[object] = None,
|
|
):
|
|
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a list of structured records.
|
|
|
|
Given the following CSV data rows:
|
|
{question}
|
|
|
|
For each row, extract and structure the following fields:
|
|
- name: The investor's full name
|
|
- aum: Assets under management (as integer, use 0 if not available)
|
|
- check_size: Investment check size (as string)
|
|
- sector_focus: Sector focus (as string)
|
|
- stage_focus: Investment stage focus (as string)
|
|
- region: Geographic region (as string)
|
|
- investment_thesis: Investment thesis (as string)
|
|
- investor_description: Description of the investor (as string)
|
|
|
|
Important:
|
|
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers)
|
|
- Ensure all text fields are properly escaped and contain no control characters
|
|
- Return clean, valid JSON only
|
|
|
|
Return the data as a structured list of investors."""
|
|
|
|
self.prompt = PromptTemplate(
|
|
template=self.template, input_variables=["question"]
|
|
)
|
|
|
|
self.llm = ChatOpenAI(
|
|
api_key=settings.OPENROUTER_API_KEY,
|
|
base_url="https://openrouter.ai/api/v1",
|
|
model="openai/gpt-oss-120b:fre",
|
|
temperature=0,
|
|
)
|
|
|
|
self.structured_llm = self.llm.with_structured_output(InvestorList)
|
|
self.sql_session = sql_session
|
|
self.vector_db_client = vector_db_client
|
|
|
|
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
|
self.collection = self.vector_db_client.get_or_create_collection(
|
|
name="investor_descriptions",
|
|
metadata={
|
|
"description": "Investor descriptions and investment thesis focus"
|
|
},
|
|
)
|
|
|
|
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List:
|
|
"""Process a single batch of data"""
|
|
# Convert batch to string representation - clean the data
|
|
batch_str = ""
|
|
for idx, row in batch.iterrows():
|
|
# Clean values to remove control characters
|
|
cleaned_row = {}
|
|
for key, value in row.items():
|
|
if pd.notna(value):
|
|
# Convert to string and clean control characters
|
|
clean_value = (
|
|
str(value)
|
|
.replace("\n", " ")
|
|
.replace("\r", " ")
|
|
.replace("\t", " ")
|
|
)
|
|
# Remove other control characters
|
|
clean_value = "".join(
|
|
char
|
|
for char in clean_value
|
|
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
|
)
|
|
cleaned_row[key] = clean_value
|
|
|
|
row_str = ", ".join(
|
|
[f"{key}: {value}" for key, value in cleaned_row.items()]
|
|
)
|
|
batch_str += f"Row {idx + 1}: {row_str}\n"
|
|
|
|
try:
|
|
print(f"Processing batch {batch_idx + 1}...")
|
|
batch_results = await self.structured_llm.ainvoke(batch_str)
|
|
return batch_results.investor_list
|
|
except Exception as e:
|
|
print(f"Error processing batch {batch_idx + 1}: {e}")
|
|
return []
|
|
|
|
async def _save_to_sql(self, investors: List[Investor]) -> None:
|
|
"""Save investors to SQL database"""
|
|
if not self.sql_session:
|
|
return
|
|
|
|
# Implement SQL saving logic here
|
|
for investor in investors:
|
|
db_investor = InvestorTable(
|
|
name=investor.name,
|
|
aum=investor.aum,
|
|
check_size=investor.check_size,
|
|
sector_focus=investor.sector_focus,
|
|
stage_focus=investor.stage_focus,
|
|
region=investor.region,
|
|
)
|
|
self.sql_session.add(db_investor)
|
|
self.sql_session.commit()
|
|
|
|
async def _save_to_vector_db(self, investors: List[Investor]) -> None:
|
|
"""Save investors to vector database"""
|
|
if not self.vector_db_client:
|
|
return
|
|
|
|
documents = []
|
|
metadatas = []
|
|
ids = []
|
|
|
|
for i, investor in enumerate(investors):
|
|
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}"
|
|
documents.append(doc_text)
|
|
metadatas.append({"name": investor.name})
|
|
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}")
|
|
|
|
if documents:
|
|
# Use add method with proper parameters
|
|
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
|
|
|
async def process_csv(
|
|
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
|
) -> List:
|
|
"""Process CSV data in parallel batches and save to databases"""
|
|
results = []
|
|
|
|
# Create batches
|
|
batches = []
|
|
for i in range(0, len(df), batch_size):
|
|
batch = df.iloc[i : i + batch_size]
|
|
batches.append((batch, i // batch_size))
|
|
|
|
# Process batches with concurrency control
|
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
|
async def process_with_semaphore(batch_data):
|
|
batch, batch_idx = batch_data
|
|
async with semaphore:
|
|
return await self._process_batch(batch, batch_idx)
|
|
|
|
# Execute all batches concurrently
|
|
batch_results = await asyncio.gather(
|
|
*[process_with_semaphore(batch_data) for batch_data in batches],
|
|
return_exceptions=True,
|
|
)
|
|
|
|
# Collect results, filtering out exceptions
|
|
for batch_result in batch_results:
|
|
if not isinstance(batch_result, Exception):
|
|
results.extend(batch_result)
|
|
|
|
# Save to databases
|
|
if results:
|
|
await self._save_to_sql(results)
|
|
await self._save_to_vector_db(results)
|
|
|
|
return results
|