feat: Implement company querying functionality with natural language processing and logging

This commit is contained in:
bolade
2025-10-27 20:12:30 +01:00
parent 1ac755b2d7
commit ff0010019e
7 changed files with 225 additions and 70 deletions
Binary file not shown.
+25 -1
View File
@@ -1,4 +1,6 @@
import io import io
import logging
import os
import pandas as pd import pandas as pd
from db.db import Base, db_dependency, engine from db.db import Base, db_dependency, engine
@@ -13,7 +15,8 @@ from routers import (
projects, projects,
report_route, report_route,
) )
from schemas.router_schemas import InvestmentResponse, PaginatedResponse from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse
from services.company_querying import CompanyQueryProcessor
from services.llm_parser import InvestorProcessor from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor from services.querying import QueryProcessor
@@ -114,6 +117,27 @@ async def query_investors(request: QueryRequest):
return results return results
@app.post(
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
)
async def query_companies(request: QueryRequest):
"""
Query companies using natural language.
Returns company matches with their investor relationships, team members, and sectors.
Supports queries like:
- "Show me fintech companies founded in 2020"
- "Find healthcare companies in San Francisco"
- "Companies in the AI sector"
- "Companies that received funding from Sequoia"
- "European startups founded after 2019"
"""
processor = CompanyQueryProcessor()
results = processor.process_query(request.question)
return results
app.include_router(investors.router) app.include_router(investors.router)
app.include_router(companies.router) app.include_router(companies.router)
app.include_router(projects.router) app.include_router(projects.router)
+9 -1
View File
@@ -1,15 +1,21 @@
import os
from typing import List from typing import List
from db.db import get_db from db.db import get_db
from db.models import InvestorTable from db.models import InvestorTable
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from services.crm import folk from services.crm import FolkAPI
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
router = APIRouter(prefix="/folk", tags=["Folk CRM"]) router = APIRouter(prefix="/folk", tags=["Folk CRM"])
def get_folk_client():
"""Get Folk API client with loaded environment variables"""
return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", ""))
class GroupResponse(BaseModel): class GroupResponse(BaseModel):
id: str id: str
name: str name: str
@@ -44,6 +50,7 @@ def get_folk_groups():
to sync investors to Folk. to sync investors to Folk.
""" """
try: try:
folk = get_folk_client()
groups_data = folk.get_groups() groups_data = folk.get_groups()
items = groups_data.get("data", {}).get("items", []) items = groups_data.get("data", {}).get("items", [])
@@ -71,6 +78,7 @@ def sync_investors_to_folk(
Returns: Returns:
Summary of sync operation including successes and errors Summary of sync operation including successes and errors
""" """
folk = get_folk_client()
# Fetch investors with their team members # Fetch investors with their team members
investors = ( investors = (
db.query(InvestorTable) db.query(InvestorTable)
+5
View File
@@ -168,6 +168,7 @@ class InvestorFundData(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorMinimal(BaseModel): class InvestorMinimal(BaseModel):
"""Minimal investor info with just id and name""" """Minimal investor info with just id and name"""
@@ -177,6 +178,7 @@ class InvestorMinimal(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class CompanySchemaMinimal(BaseModel): class CompanySchemaMinimal(BaseModel):
id: int id: int
name: str name: str
@@ -188,9 +190,12 @@ class CompanySchemaMinimal(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchemaMinimal company: CompanySchemaMinimal
investors: List[InvestorMinimal] investors: List[InvestorMinimal]
# members: List[CompanyMemberSchema] = []
sectors: List[SectorSchema] = []
class Config: class Config:
from_attributes = True from_attributes = True
+176
View File
@@ -0,0 +1,176 @@
import logging
import os
from typing import List
from db.db import DATABASE_URL, get_db
from db.models import CompanyTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from schemas.router_schemas import CompanyData, PaginatedResponse
from sqlalchemy.orm import selectinload
logger = logging.getLogger(__name__)
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri(DATABASE_URL)
class CompanyQueryProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini",
temperature=0,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only company IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\n=== CRITICAL INSTRUCTIONS ==="
+ "\n- Your ONLY task is to run SQL queries and extract company IDs"
+ "\n- When you get SQL results with company IDs, return them EXACTLY as shown"
+ "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs"
+ "\n- Do NOT add any explanations, just list the IDs"
+ "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'"
+ "\n\n=== QUERY GUIDELINES ==="
+ "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'"
+ "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'"
+ "\n3. For location searches: WHERE companies.location LIKE '%location%'"
+ "\n4. For founding year searches: WHERE companies.founded_year >= year"
+ "\n5. For investor-related: JOIN investor_companies table"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
)
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
"""Process a query using the LLM and return company response data.
Args:
question: The natural language query to process
"""
# Let the LLM handle all database interactions and filtering to get company IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
config={"recursion_limit": 50},
)
# Extract the actual message content
logger.info(f"{response}")
# Look through all messages to find the SQL query results (ToolMessage with actual data)
company_ids = []
for message in response["messages"]:
if hasattr(message, "content") and message.content:
# Check if this looks like SQL results (contains tuples with numbers)
if "(" in str(message.content) and "," in str(message.content):
company_ids = self._extract_company_ids_from_response(
str(message.content)
)
if company_ids:
logger.info(
f"Extracted {len(company_ids)} company IDs from results"
)
break
# If no IDs found from ToolMessage, check the final AI message
if not company_ids:
final_message_content = response["messages"][-1].content
logger.info(f"AI Response: \n{final_message_content}")
company_ids = self._extract_company_ids_from_response(final_message_content)
# Fetch full company data with relationships using the IDs
return self._fetch_companies_by_ids(company_ids)
def _extract_company_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract company IDs from AI response."""
import re
company_ids = []
# Check if response is NO_RESULTS
if "NO_RESULTS" in ai_response.upper():
return []
try:
# The response contains tuples like (1,), (5,), etc.
# Extract numbers between parentheses
pattern = r"\((\d+),?\)"
matches = re.findall(pattern, ai_response)
if matches:
company_ids = [int(match) for match in matches]
else:
# Fallback: extract all numbers
numbers = re.findall(r"\b\d+\b", ai_response)
# Filter out very large numbers that might be tokens or timestamps
company_ids = [int(num) for num in numbers if int(num) < 100000]
except Exception as e:
logger.error(f"Error extracting IDs from response: {e}")
return []
return company_ids
def _fetch_companies_by_ids(
self, company_ids: List[int]
) -> PaginatedResponse[CompanyData]:
"""Fetch companies with all their relationships from the database using company IDs.
Args:
company_ids: List of company IDs to fetch
"""
if not company_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=len(company_ids) if company_ids else 10,
total_pages=0,
)
# Get database session
db_session = next(get_db())
try:
# Query companies with all necessary relationships loaded
companies = (
db_session.query(CompanyTable)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id.in_(company_ids))
.all()
)
# Transform to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
total_count = len(company_data_list)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally:
db_session.close()
+10 -68
View File
@@ -1,14 +1,24 @@
import logging
import os import os
import sys import sys
import requests import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
class FolkAPI: class FolkAPI:
BASE_URL = "https://api.folk.app/v1" BASE_URL = "https://api.folk.app/v1"
def __init__(self, api_key: str): def __init__(self, api_key: str):
api_key = os.environ.get("FOLK_API_KEY", api_key)
self.headers = {"Authorization": f"Bearer {api_key}"} self.headers = {"Authorization": f"Bearer {api_key}"}
logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***")
def get_groups(self): def get_groups(self):
"""Fetch all groups from Folk.""" """Fetch all groups from Folk."""
@@ -190,71 +200,3 @@ class FolkAPI:
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
# Prefer getting the API key from the environment. If not set, fall back to the
# existing (hard-coded) key so behavior is unchanged for now.
DEFAULT_API_KEY = "FOLKfIGXuv74ML9EAajxyiUR39ePaNrZ"
api_key = os.environ.get("FOLK_API_KEY", DEFAULT_API_KEY)
folk = FolkAPI(api_key=api_key)
def example_flow():
# Step 1: Get groups
groups = folk.get_groups()
print(groups)
# Safely dig into the returned structure. The API returns groups under
# groups['data']['items'] (not groups['data'][0]). Handle missing/empty.
items = groups.get("data", {}).get("items", [])
if not items:
print("No groups returned by Folk API.")
sys.exit(1)
# Choose the first group as an example
group_id = items[0].get("id")
if not group_id:
print("No id found for the first group item.")
sys.exit(1)
# Step 2: Choose a group_id and create a company
company = folk.create_company(
name="2050 Investment Partners",
group_id=group_id,
website="https://2050.com",
linkedin_url="https://linkedin.com/company/2050-investments",
)
# Step 3: Add a person to the same group or company
person = folk.create_person(
first_name="John",
last_name="Doe",
email="john@2050.com",
company_id=company.get("data", {}).get("id"),
group_id=group_id,
)
print("Created company:", company)
print("Created person:", person)
if __name__ == "__main__":
try:
example_flow()
except requests.HTTPError as e:
# Try to include response body for easier debugging if available
resp = getattr(e, "response", None)
if resp is not None:
try:
body = resp.text
except Exception:
body = "<unreadable response body>"
print("HTTP error while talking to Folk API:", e)
print("Response status:", resp.status_code)
print("Response body:", body)
else:
print("HTTP error while talking to Folk API:", e)
sys.exit(1)
except Exception as e: # pragma: no cover - top-level safety
print("Unexpected error:", e)
sys.exit(1)