Made improvements to parsing
This commit is contained in:
@@ -338,7 +338,7 @@ When `--use-llm` is enabled:
|
|||||||
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
OPENROUTER_API_KEY=your_openrouter_api_key_here
|
||||||
|
|
||||||
# Database Configuration (optional, defaults to SQLite)
|
# Database Configuration (optional, defaults to SQLite)
|
||||||
DATABASE_URL=sqlite:///investors_2.db
|
DATABASE_URL=sqlite:///investors.db
|
||||||
|
|
||||||
# FastAPI Configuration
|
# FastAPI Configuration
|
||||||
API_HOST=localhost
|
API_HOST=localhost
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
+1
-1
@@ -6,7 +6,7 @@ from db.db import db_dependency, init_database
|
|||||||
from fastapi import FastAPI, File, UploadFile
|
from fastapi import FastAPI, File, UploadFile
|
||||||
from py_schemas import InvestorList
|
from py_schemas import InvestorList
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from services.openrouter import InvestorProcessor
|
from services.openrouter_v2 import InvestorProcessor
|
||||||
from services.querying import QueryProcessor
|
from services.querying import QueryProcessor
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,290 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import pandas as pd
|
||||||
|
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from py_schemas import InvestorData
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorOutput(BaseModel):
|
||||||
|
"""Schema for LLM structured output"""
|
||||||
|
|
||||||
|
investor_data: InvestorData
|
||||||
|
|
||||||
|
|
||||||
|
class InvestorProcessor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sql_session: Optional[object] = None,
|
||||||
|
vector_db_client: Optional[object] = None,
|
||||||
|
):
|
||||||
|
self.template = """You are an expert data extraction assistant. Extract investor information from the provided CSV data and return it as a structured record.
|
||||||
|
|
||||||
|
Given the following CSV data row:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Extract and structure the following fields for the investor:
|
||||||
|
- name: The investor's full name
|
||||||
|
- description: Description of the investor
|
||||||
|
- aum: Assets under management (as integer, use 0 if not available)
|
||||||
|
- check_size_lower: Lower bound of investment check size (as integer)
|
||||||
|
- check_size_upper: Upper bound of investment check size (as integer)
|
||||||
|
- geographic_focus: Geographic region focus
|
||||||
|
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||||
|
- number_of_investments: Number of investments made (default 0)
|
||||||
|
|
||||||
|
Also extract related data:
|
||||||
|
- portfolio_companies: List of companies they've invested in
|
||||||
|
- team_members: List of team members with name, role, email
|
||||||
|
- sectors: List of sectors they focus on
|
||||||
|
|
||||||
|
Important:
|
||||||
|
- If a field is not available, use appropriate defaults
|
||||||
|
- stage_focus must be one of the valid enum values
|
||||||
|
- Return clean, valid JSON only
|
||||||
|
|
||||||
|
Return the data as a single comprehensive investor data record."""
|
||||||
|
|
||||||
|
self.prompt = PromptTemplate(
|
||||||
|
template=self.template, input_variables=["question"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.llm = ChatOpenAI(
|
||||||
|
api_key=settings.OPENROUTER_API_KEY,
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
model="google/gemini-2.5-flash-lite",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.structured_llm = self.llm.with_structured_output(InvestorOutput)
|
||||||
|
self.sql_session = sql_session
|
||||||
|
self.vector_db_client = vector_db_client
|
||||||
|
|
||||||
|
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
|
||||||
|
self.collection = self.vector_db_client.get_or_create_collection(
|
||||||
|
name="investor_descriptions",
|
||||||
|
metadata={
|
||||||
|
"description": "Investor descriptions and investment thesis focus"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_row(
|
||||||
|
self, row: pd.Series, row_idx: int
|
||||||
|
) -> Optional[InvestorData]:
|
||||||
|
"""Process a single row of data"""
|
||||||
|
# Clean values to remove control characters
|
||||||
|
cleaned_row = {}
|
||||||
|
for key, value in row.items():
|
||||||
|
if pd.notna(value):
|
||||||
|
# Convert to string and clean control characters
|
||||||
|
clean_value = (
|
||||||
|
str(value)
|
||||||
|
.replace("\n", " ")
|
||||||
|
.replace("\r", " ")
|
||||||
|
.replace("\t", " ")
|
||||||
|
)
|
||||||
|
# Remove other control characters
|
||||||
|
clean_value = "".join(
|
||||||
|
char
|
||||||
|
for char in clean_value
|
||||||
|
if ord(char) >= 32 or char in ["\n", "\r", "\t"]
|
||||||
|
)
|
||||||
|
cleaned_row[key] = clean_value
|
||||||
|
|
||||||
|
row_str = ", ".join(
|
||||||
|
[f"{key}: {value}" for key, value in cleaned_row.items()]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Processing row {row_idx + 1}...")
|
||||||
|
result = await self.structured_llm.ainvoke(row_str)
|
||||||
|
if result.investor_data:
|
||||||
|
return result.investor_data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing row {row_idx + 1}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||||
|
"""Save investors and related data to SQL database"""
|
||||||
|
if not self.sql_session:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
for investor_data in investor_data_list:
|
||||||
|
# Save investor
|
||||||
|
db_investor = InvestorTable(
|
||||||
|
name=investor_data.investor.name,
|
||||||
|
description=investor_data.investor.description,
|
||||||
|
aum=investor_data.investor.aum,
|
||||||
|
check_size_lower=investor_data.investor.check_size_lower,
|
||||||
|
check_size_upper=investor_data.investor.check_size_upper,
|
||||||
|
geographic_focus=investor_data.investor.geographic_focus,
|
||||||
|
stage_focus=investor_data.investor.stage_focus,
|
||||||
|
number_of_investments=investor_data.investor.number_of_investments,
|
||||||
|
)
|
||||||
|
self.sql_session.add(db_investor)
|
||||||
|
self.sql_session.flush() # Get the ID
|
||||||
|
|
||||||
|
# Save sectors and create associations
|
||||||
|
for sector_data in investor_data.sectors:
|
||||||
|
# Check if sector exists, create if not
|
||||||
|
existing_sector = (
|
||||||
|
self.sql_session.query(SectorTable)
|
||||||
|
.filter(SectorTable.name == sector_data.name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_sector:
|
||||||
|
db_sector = SectorTable(name=sector_data.name)
|
||||||
|
self.sql_session.add(db_sector)
|
||||||
|
self.sql_session.flush()
|
||||||
|
# Add sector to investor's sectors
|
||||||
|
db_investor.sectors.append(db_sector)
|
||||||
|
else:
|
||||||
|
# Add existing sector to investor if not already there
|
||||||
|
if existing_sector not in db_investor.sectors:
|
||||||
|
db_investor.sectors.append(existing_sector)
|
||||||
|
|
||||||
|
# Save companies and create portfolio associations
|
||||||
|
for company_data in investor_data.portfolio_companies:
|
||||||
|
# Check if company exists, create if not
|
||||||
|
existing_company = (
|
||||||
|
self.sql_session.query(CompanyTable)
|
||||||
|
.filter(CompanyTable.name == company_data.name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_company:
|
||||||
|
db_company = CompanyTable(
|
||||||
|
name=company_data.name,
|
||||||
|
industry=company_data.industry,
|
||||||
|
location=company_data.location,
|
||||||
|
founded_year=company_data.founded_year,
|
||||||
|
website=company_data.website,
|
||||||
|
)
|
||||||
|
self.sql_session.add(db_company)
|
||||||
|
self.sql_session.flush()
|
||||||
|
|
||||||
|
# Add to investor's portfolio
|
||||||
|
db_investor.portfolio_companies.append(db_company)
|
||||||
|
else:
|
||||||
|
# Add existing company to portfolio if not already there
|
||||||
|
if existing_company not in db_investor.portfolio_companies:
|
||||||
|
db_investor.portfolio_companies.append(existing_company)
|
||||||
|
|
||||||
|
# Save team members
|
||||||
|
for team_member_data in investor_data.team_members:
|
||||||
|
# Check if team member exists
|
||||||
|
existing_member = (
|
||||||
|
self.sql_session.query(InvestorTeamMember)
|
||||||
|
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_member:
|
||||||
|
db_team_member = InvestorTeamMember(
|
||||||
|
name=team_member_data.name,
|
||||||
|
role=team_member_data.role,
|
||||||
|
email=team_member_data.email,
|
||||||
|
investor_id=db_investor.id,
|
||||||
|
)
|
||||||
|
self.sql_session.add(db_team_member)
|
||||||
|
|
||||||
|
self.sql_session.commit()
|
||||||
|
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.sql_session.rollback()
|
||||||
|
print(f"Error saving to SQL database: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||||
|
"""Save investors to vector database"""
|
||||||
|
if not self.vector_db_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
metadatas = []
|
||||||
|
ids = []
|
||||||
|
|
||||||
|
for i, investor_data in enumerate(investor_data_list):
|
||||||
|
investor = investor_data.investor
|
||||||
|
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||||
|
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||||
|
|
||||||
|
doc_text = f"""
|
||||||
|
Investor: {investor.name}
|
||||||
|
Description: {investor.description or "N/A"}
|
||||||
|
AUM: ${investor.aum:,}
|
||||||
|
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||||
|
Geographic Focus: {investor.geographic_focus}
|
||||||
|
Stage Focus: {investor.stage_focus.value}
|
||||||
|
Sectors: {sectors}
|
||||||
|
Portfolio Companies: {companies}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
documents.append(doc_text)
|
||||||
|
metadatas.append(
|
||||||
|
{
|
||||||
|
"name": investor.name,
|
||||||
|
"stage_focus": investor.stage_focus.value,
|
||||||
|
"geographic_focus": investor.geographic_focus,
|
||||||
|
"aum": investor.aum,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ids.append(
|
||||||
|
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
try:
|
||||||
|
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||||
|
print(
|
||||||
|
f"Successfully saved {len(documents)} investors to vector database"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error saving to vector database: {e}")
|
||||||
|
|
||||||
|
async def process_csv(
|
||||||
|
self, df: pd.DataFrame, max_concurrent: int = 10
|
||||||
|
) -> List[InvestorData]:
|
||||||
|
"""Process CSV data one row at a time and save to databases"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Create semaphore for concurrency control
|
||||||
|
semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
|
async def process_row_with_semaphore(row_data):
|
||||||
|
row, row_idx = row_data
|
||||||
|
async with semaphore:
|
||||||
|
return await self._process_row(row, row_idx)
|
||||||
|
|
||||||
|
# Create row tasks
|
||||||
|
row_tasks = []
|
||||||
|
for idx, row in df.iterrows():
|
||||||
|
row_tasks.append((row, idx))
|
||||||
|
|
||||||
|
# Execute all rows concurrently
|
||||||
|
row_results = await asyncio.gather(
|
||||||
|
*[process_row_with_semaphore(row_data) for row_data in row_tasks],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect results, filtering out exceptions and None values
|
||||||
|
for row_result in row_results:
|
||||||
|
if not isinstance(row_result, Exception) and row_result is not None:
|
||||||
|
results.append(row_result)
|
||||||
|
|
||||||
|
# Save to databases
|
||||||
|
if results:
|
||||||
|
print(f"Successfully processed {len(results)} investors")
|
||||||
|
await self._save_to_sql(results)
|
||||||
|
await self._save_to_vector_db(results)
|
||||||
|
|
||||||
|
return results
|
||||||
@@ -14,7 +14,7 @@ from sqlalchemy.orm import selectinload
|
|||||||
# Connect to SQLite
|
# Connect to SQLite
|
||||||
|
|
||||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||||
db = SQLDatabase.from_uri("sqlite:///investors_2.db")
|
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||||
system_message = (
|
system_message = (
|
||||||
prompt_template.format(dialect="SQLite", top_k=5)
|
prompt_template.format(dialect="SQLite", top_k=5)
|
||||||
+ "\n Get answers from the Sql database and the vector database"
|
+ "\n Get answers from the Sql database and the vector database"
|
||||||
|
|||||||
Reference in New Issue
Block a user