feat: Implement company querying functionality with natural language processing and logging

This commit is contained in:
bolade
2025-10-27 20:12:30 +01:00
parent 1ac755b2d7
commit ff0010019e
7 changed files with 225 additions and 70 deletions
Binary file not shown.
+25 -1
View File
@@ -1,4 +1,6 @@
import io
import logging
import os
import pandas as pd
from db.db import Base, db_dependency, engine
@@ -13,7 +15,8 @@ from routers import (
projects,
report_route,
)
from schemas.router_schemas import InvestmentResponse, PaginatedResponse
from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse
from services.company_querying import CompanyQueryProcessor
from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor
@@ -114,6 +117,27 @@ async def query_investors(request: QueryRequest):
return results
@app.post(
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
)
async def query_companies(request: QueryRequest):
"""
Query companies using natural language.
Returns company matches with their investor relationships, team members, and sectors.
Supports queries like:
- "Show me fintech companies founded in 2020"
- "Find healthcare companies in San Francisco"
- "Companies in the AI sector"
- "Companies that received funding from Sequoia"
- "European startups founded after 2019"
"""
processor = CompanyQueryProcessor()
results = processor.process_query(request.question)
return results
app.include_router(investors.router)
app.include_router(companies.router)
app.include_router(projects.router)
+9 -1
View File
@@ -1,15 +1,21 @@
import os
from typing import List
from db.db import get_db
from db.models import InvestorTable
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from services.crm import folk
from services.crm import FolkAPI
from sqlalchemy.orm import Session, selectinload
router = APIRouter(prefix="/folk", tags=["Folk CRM"])
def get_folk_client():
"""Get Folk API client with loaded environment variables"""
return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", ""))
class GroupResponse(BaseModel):
id: str
name: str
@@ -44,6 +50,7 @@ def get_folk_groups():
to sync investors to Folk.
"""
try:
folk = get_folk_client()
groups_data = folk.get_groups()
items = groups_data.get("data", {}).get("items", [])
@@ -71,6 +78,7 @@ def sync_investors_to_folk(
Returns:
Summary of sync operation including successes and errors
"""
folk = get_folk_client()
# Fetch investors with their team members
investors = (
db.query(InvestorTable)
+5
View File
@@ -168,6 +168,7 @@ class InvestorFundData(BaseModel):
class Config:
from_attributes = True
class InvestorMinimal(BaseModel):
"""Minimal investor info with just id and name"""
@@ -177,6 +178,7 @@ class InvestorMinimal(BaseModel):
class Config:
from_attributes = True
class CompanySchemaMinimal(BaseModel):
id: int
name: str
@@ -188,9 +190,12 @@ class CompanySchemaMinimal(BaseModel):
class Config:
from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchemaMinimal
investors: List[InvestorMinimal]
# members: List[CompanyMemberSchema] = []
sectors: List[SectorSchema] = []
class Config:
from_attributes = True
+176
View File
@@ -0,0 +1,176 @@
import logging
import os
from typing import List
from db.db import DATABASE_URL, get_db
from db.models import CompanyTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from schemas.router_schemas import CompanyData, PaginatedResponse
from sqlalchemy.orm import selectinload
logger = logging.getLogger(__name__)
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri(DATABASE_URL)
class CompanyQueryProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini",
temperature=0,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only company IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\n=== CRITICAL INSTRUCTIONS ==="
+ "\n- Your ONLY task is to run SQL queries and extract company IDs"
+ "\n- When you get SQL results with company IDs, return them EXACTLY as shown"
+ "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs"
+ "\n- Do NOT add any explanations, just list the IDs"
+ "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'"
+ "\n\n=== QUERY GUIDELINES ==="
+ "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'"
+ "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'"
+ "\n3. For location searches: WHERE companies.location LIKE '%location%'"
+ "\n4. For founding year searches: WHERE companies.founded_year >= year"
+ "\n5. For investor-related: JOIN investor_companies table"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
)
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
"""Process a query using the LLM and return company response data.
Args:
question: The natural language query to process
"""
# Let the LLM handle all database interactions and filtering to get company IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
config={"recursion_limit": 50},
)
# Extract the actual message content
logger.info(f"{response}")
# Look through all messages to find the SQL query results (ToolMessage with actual data)
company_ids = []
for message in response["messages"]:
if hasattr(message, "content") and message.content:
# Check if this looks like SQL results (contains tuples with numbers)
if "(" in str(message.content) and "," in str(message.content):
company_ids = self._extract_company_ids_from_response(
str(message.content)
)
if company_ids:
logger.info(
f"Extracted {len(company_ids)} company IDs from results"
)
break
# If no IDs found from ToolMessage, check the final AI message
if not company_ids:
final_message_content = response["messages"][-1].content
logger.info(f"AI Response: \n{final_message_content}")
company_ids = self._extract_company_ids_from_response(final_message_content)
# Fetch full company data with relationships using the IDs
return self._fetch_companies_by_ids(company_ids)
def _extract_company_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract company IDs from AI response."""
import re
company_ids = []
# Check if response is NO_RESULTS
if "NO_RESULTS" in ai_response.upper():
return []
try:
# The response contains tuples like (1,), (5,), etc.
# Extract numbers between parentheses
pattern = r"\((\d+),?\)"
matches = re.findall(pattern, ai_response)
if matches:
company_ids = [int(match) for match in matches]
else:
# Fallback: extract all numbers
numbers = re.findall(r"\b\d+\b", ai_response)
# Filter out very large numbers that might be tokens or timestamps
company_ids = [int(num) for num in numbers if int(num) < 100000]
except Exception as e:
logger.error(f"Error extracting IDs from response: {e}")
return []
return company_ids
def _fetch_companies_by_ids(
self, company_ids: List[int]
) -> PaginatedResponse[CompanyData]:
"""Fetch companies with all their relationships from the database using company IDs.
Args:
company_ids: List of company IDs to fetch
"""
if not company_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=len(company_ids) if company_ids else 10,
total_pages=0,
)
# Get database session
db_session = next(get_db())
try:
# Query companies with all necessary relationships loaded
companies = (
db_session.query(CompanyTable)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id.in_(company_ids))
.all()
)
# Transform to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
total_count = len(company_data_list)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally:
db_session.close()
+10 -68
View File
@@ -1,14 +1,24 @@
import logging
import os
import sys
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
class FolkAPI:
BASE_URL = "https://api.folk.app/v1"
def __init__(self, api_key: str):
api_key = os.environ.get("FOLK_API_KEY", api_key)
self.headers = {"Authorization": f"Bearer {api_key}"}
logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***")
def get_groups(self):
"""Fetch all groups from Folk."""
@@ -190,71 +200,3 @@ class FolkAPI:
response.raise_for_status()
return response.json()
# Prefer getting the API key from the environment. If not set, fall back to the
# existing (hard-coded) key so behavior is unchanged for now.
DEFAULT_API_KEY = "FOLKfIGXuv74ML9EAajxyiUR39ePaNrZ"
api_key = os.environ.get("FOLK_API_KEY", DEFAULT_API_KEY)
folk = FolkAPI(api_key=api_key)
def example_flow():
# Step 1: Get groups
groups = folk.get_groups()
print(groups)
# Safely dig into the returned structure. The API returns groups under
# groups['data']['items'] (not groups['data'][0]). Handle missing/empty.
items = groups.get("data", {}).get("items", [])
if not items:
print("No groups returned by Folk API.")
sys.exit(1)
# Choose the first group as an example
group_id = items[0].get("id")
if not group_id:
print("No id found for the first group item.")
sys.exit(1)
# Step 2: Choose a group_id and create a company
company = folk.create_company(
name="2050 Investment Partners",
group_id=group_id,
website="https://2050.com",
linkedin_url="https://linkedin.com/company/2050-investments",
)
# Step 3: Add a person to the same group or company
person = folk.create_person(
first_name="John",
last_name="Doe",
email="john@2050.com",
company_id=company.get("data", {}).get("id"),
group_id=group_id,
)
print("Created company:", company)
print("Created person:", person)
if __name__ == "__main__":
try:
example_flow()
except requests.HTTPError as e:
# Try to include response body for easier debugging if available
resp = getattr(e, "response", None)
if resp is not None:
try:
body = resp.text
except Exception:
body = "<unreadable response body>"
print("HTTP error while talking to Folk API:", e)
print("Response status:", resp.status_code)
print("Response body:", body)
else:
print("HTTP error while talking to Folk API:", e)
sys.exit(1)
except Exception as e: # pragma: no cover - top-level safety
print("Unexpected error:", e)
sys.exit(1)