Files
Anton_wireframe/app/services/querying.py
T

281 lines
9.9 KiB
Python
Raw Normal View History

from typing import List, Optional
import chromadb
from db.models import InvestorTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from py_schemas import InvestorData, InvestorList
from settings import settings
from sqlalchemy.orm import selectinload
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
2025-09-11 16:23:22 +01:00
db = SQLDatabase.from_uri("sqlite:///investors.db")
system_message = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n Get answers from the Sql database and the vector database"
)
class QueryProcessor:
def __init__(
self,
sql_session: Optional[object] = None,
vector_db_client: Optional[object] = None,
):
self.sql_session = sql_session
self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="google/gemini-2.5-flash-lite",
temperature=0.3,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools() + [self.query_vector_database],
prompt=system_message,
)
self.vector_db_client = vector_db_client
self.vector_db_client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.vector_db_client.get_or_create_collection(
name="investor_descriptions",
metadata={
"description": "Investor descriptions and investment thesis focus"
},
)
def query_sql_database(self, query: str) -> Optional[InvestorList]:
"""Query the SQL database for investor information."""
if not self.sql_session:
return None
# Implement SQL querying logic here
result = self.sql_session.execute(query)
investors = result.scalars().all()
return InvestorList(investors=investors)
def query_vector_database(self, query: str) -> Optional[InvestorList]:
"""Query the vector database for investor information."""
if not self.vector_db_client:
return None
print("VECTOR STORE WAS CALLED")
# Query the collection directly, not passing collection as parameter
results = self.collection.query(
query_texts=[query], # ChromaDB expects a list of query texts
n_results=3, # Specify how many results you want
)
print(results)
# ChromaDB returns results in a different structure
# results will have 'documents', 'metadatas', 'ids', 'distances'
return results
def process_query(self, question: str) -> InvestorList:
"""Process a query using the LLM and return structured investor data."""
# Extract filters from the query first
filters = self._extract_filters_from_query(question)
# Get AI response for additional context
response = self.agent.invoke(
{"messages": [("user", question)]},
)
# Extract the actual message content
ai_response = (
response["messages"][-1].content if response.get("messages") else ""
)
# Try to extract investor IDs or names from the AI response
investor_ids = self._extract_investor_info_from_response(ai_response)
# Fetch filtered investor data with relationships from database
return self._fetch_investors_with_relationships(investor_ids, filters)
def _extract_investor_info_from_response(self, ai_response: str) -> List[int]:
"""Extract investor IDs from AI response. This is a simple implementation."""
# This is a basic implementation - you might want to make it more sophisticated
# based on how your AI formats responses
investor_ids = []
# If the AI can't provide structured data, fall back to getting all investors
# that match basic criteria
try:
# Try to extract numbers that might be IDs
import re
ids = re.findall(r"\bid:\s*(\d+)", ai_response.lower())
investor_ids = [int(id_str) for id_str in ids]
except Exception:
pass
return investor_ids if investor_ids else []
def _extract_filters_from_query(self, question: str) -> dict:
"""Extract filter criteria from natural language query."""
question_lower = question.lower()
filters = {}
# Extract stage filters
if any(
stage in question_lower
for stage in [
"seed",
"series a",
"series b",
"series c",
"growth",
"late stage",
]
):
if "seed" in question_lower:
filters["stage"] = "SEED"
elif "series a" in question_lower:
filters["stage"] = "SERIES_A"
elif "series b" in question_lower:
filters["stage"] = "SERIES_B"
elif "series c" in question_lower:
filters["stage"] = "SERIES_C"
elif "growth" in question_lower:
filters["stage"] = "GROWTH"
elif "late stage" in question_lower:
filters["stage"] = "LATE_STAGE"
# Extract geographic filters
if any(
geo in question_lower
for geo in [
"us",
"usa",
"united states",
"europe",
"asia",
"silicon valley",
"bay area",
]
):
if (
"us" in question_lower
or "usa" in question_lower
or "united states" in question_lower
):
filters["geography"] = "US"
elif "europe" in question_lower:
filters["geography"] = "Europe"
elif "asia" in question_lower:
filters["geography"] = "Asia"
elif "silicon valley" in question_lower or "bay area" in question_lower:
filters["geography"] = "Silicon Valley"
# Extract sector filters
sectors = [
"fintech",
"healthcare",
"saas",
"ai",
"biotech",
"consumer",
"enterprise",
"crypto",
"blockchain",
]
for sector in sectors:
if sector in question_lower:
filters["sector"] = sector
break
# Extract check size filters (simple patterns)
import re
amounts = re.findall(
r"\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*(?:million|m|k|thousand)", question_lower
)
if amounts:
amount = amounts[0].replace(",", "")
if "million" in question_lower or "m" in question_lower:
filters["min_check_size"] = int(float(amount) * 1000000)
elif "thousand" in question_lower or "k" in question_lower:
filters["min_check_size"] = int(float(amount) * 1000)
return filters
def _fetch_investors_with_relationships(
self, investor_ids: List[int] = None, filters: dict = None
) -> InvestorList:
"""Fetch investors with all their relationships from the database."""
if not self.sql_session:
return InvestorList(investors=[])
# Import here to avoid circular imports
from db.models import SectorTable
# Build query with all relationships loaded
query = self.sql_session.query(InvestorTable).options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
# Apply filters if provided
if filters:
if "stage" in filters:
from db.models import InvestmentStage
stage_enum = getattr(InvestmentStage, filters["stage"])
query = query.filter(InvestorTable.stage_focus == stage_enum)
if "geography" in filters:
query = query.filter(
InvestorTable.geographic_focus.ilike(f"%{filters['geography']}%")
)
if "min_check_size" in filters:
query = query.filter(
InvestorTable.check_size_lower >= filters["min_check_size"]
)
if "max_check_size" in filters:
query = query.filter(
InvestorTable.check_size_upper <= filters["max_check_size"]
)
if "min_aum" in filters:
query = query.filter(InvestorTable.aum >= filters["min_aum"])
if "max_aum" in filters:
query = query.filter(InvestorTable.aum <= filters["max_aum"])
if "sector" in filters:
query = query.join(InvestorTable.sectors).filter(
SectorTable.name.ilike(f"%{filters['sector']}%")
)
# Filter by IDs if provided
if investor_ids:
query = query.filter(InvestorTable.id.in_(investor_ids))
else:
# If no specific IDs and no filters, limit to prevent overwhelming response
if not filters:
query = query.limit(10)
investors = query.all()
# Transform to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
return InvestorList(investors=investor_data_list)