diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index 81d7ae9..26e197f 100644 Binary files a/app/__pycache__/main.cpython-312.pyc and b/app/__pycache__/main.cpython-312.pyc differ diff --git a/app/main.py b/app/main.py index fa6a869..4a79176 100644 --- a/app/main.py +++ b/app/main.py @@ -1,4 +1,6 @@ import io +import logging +import os import pandas as pd from db.db import Base, db_dependency, engine @@ -13,7 +15,8 @@ from routers import ( projects, report_route, ) -from schemas.router_schemas import InvestmentResponse, PaginatedResponse +from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse +from services.company_querying import CompanyQueryProcessor from services.llm_parser import InvestorProcessor from services.querying import QueryProcessor @@ -114,6 +117,27 @@ async def query_investors(request: QueryRequest): return results +@app.post( + "/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"] +) +async def query_companies(request: QueryRequest): + """ + Query companies using natural language. + + Returns company matches with their investor relationships, team members, and sectors. + + Supports queries like: + - "Show me fintech companies founded in 2020" + - "Find healthcare companies in San Francisco" + - "Companies in the AI sector" + - "Companies that received funding from Sequoia" + - "European startups founded after 2019" + """ + processor = CompanyQueryProcessor() + results = processor.process_query(request.question) + return results + + app.include_router(investors.router) app.include_router(companies.router) app.include_router(projects.router) diff --git a/app/routers/folk_crm.py b/app/routers/folk_crm.py index ae68ad5..4ed0812 100644 --- a/app/routers/folk_crm.py +++ b/app/routers/folk_crm.py @@ -1,15 +1,21 @@ +import os from typing import List from db.db import get_db from db.models import InvestorTable from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from services.crm import folk +from services.crm import FolkAPI from sqlalchemy.orm import Session, selectinload router = APIRouter(prefix="/folk", tags=["Folk CRM"]) +def get_folk_client(): + """Get Folk API client with loaded environment variables""" + return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", "")) + + class GroupResponse(BaseModel): id: str name: str @@ -44,6 +50,7 @@ def get_folk_groups(): to sync investors to Folk. """ try: + folk = get_folk_client() groups_data = folk.get_groups() items = groups_data.get("data", {}).get("items", []) @@ -71,6 +78,7 @@ def sync_investors_to_folk( Returns: Summary of sync operation including successes and errors """ + folk = get_folk_client() # Fetch investors with their team members investors = ( db.query(InvestorTable) diff --git a/app/schemas/__pycache__/router_schemas.cpython-312.pyc b/app/schemas/__pycache__/router_schemas.cpython-312.pyc index d58670d..dbc61b1 100644 Binary files a/app/schemas/__pycache__/router_schemas.cpython-312.pyc and b/app/schemas/__pycache__/router_schemas.cpython-312.pyc differ diff --git a/app/schemas/router_schemas.py b/app/schemas/router_schemas.py index 10eee8c..6f182b3 100644 --- a/app/schemas/router_schemas.py +++ b/app/schemas/router_schemas.py @@ -168,6 +168,7 @@ class InvestorFundData(BaseModel): class Config: from_attributes = True + class InvestorMinimal(BaseModel): """Minimal investor info with just id and name""" @@ -177,6 +178,7 @@ class InvestorMinimal(BaseModel): class Config: from_attributes = True + class CompanySchemaMinimal(BaseModel): id: int name: str @@ -188,9 +190,12 @@ class CompanySchemaMinimal(BaseModel): class Config: from_attributes = True + class CompanyData(BaseModel): # Renamed from CompaniesData for consistency company: CompanySchemaMinimal investors: List[InvestorMinimal] + # members: List[CompanyMemberSchema] = [] + sectors: List[SectorSchema] = [] class Config: from_attributes = True diff --git a/app/services/company_querying.py b/app/services/company_querying.py new file mode 100644 index 0000000..80a2b71 --- /dev/null +++ b/app/services/company_querying.py @@ -0,0 +1,176 @@ +import logging +import os +from typing import List + +from db.db import DATABASE_URL, get_db +from db.models import CompanyTable +from langchain import hub +from langchain_community.agent_toolkits import SQLDatabaseToolkit +from langchain_community.utilities import SQLDatabase +from langchain_openai import ChatOpenAI +from langgraph.prebuilt import create_react_agent +from schemas.router_schemas import CompanyData, PaginatedResponse +from sqlalchemy.orm import selectinload + +logger = logging.getLogger(__name__) +# Connect to SQLite +prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") +db = SQLDatabase.from_uri(DATABASE_URL) + + +class CompanyQueryProcessor: + def __init__(self): + self.llm = ChatOpenAI( + api_key=os.getenv("OPENROUTER_API_KEY"), + base_url="https://openrouter.ai/api/v1", + model="openai/gpt-4o-mini", + temperature=0, + ) + self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) + # Update system message to specifically request only company IDs + system_message_updated = ( + prompt_template.format(dialect="SQLite", top_k=5) + + "\n\n=== CRITICAL INSTRUCTIONS ===" + + "\n- Your ONLY task is to run SQL queries and extract company IDs" + + "\n- When you get SQL results with company IDs, return them EXACTLY as shown" + + "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs" + + "\n- Do NOT add any explanations, just list the IDs" + + "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'" + + "\n\n=== QUERY GUIDELINES ===" + + "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'" + + "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'" + + "\n3. For location searches: WHERE companies.location LIKE '%location%'" + + "\n4. For founding year searches: WHERE companies.founded_year >= year" + + "\n5. For investor-related: JOIN investor_companies table" + ) + self.agent = create_react_agent( + model=self.llm, + tools=self.toolkit.get_tools(), + prompt=system_message_updated, + ) + + def process_query(self, question: str) -> PaginatedResponse[CompanyData]: + """Process a query using the LLM and return company response data. + + Args: + question: The natural language query to process + """ + # Let the LLM handle all database interactions and filtering to get company IDs + response = self.agent.invoke( + {"messages": [("user", question)]}, + config={"recursion_limit": 50}, + ) + + # Extract the actual message content + logger.info(f"{response}") + + # Look through all messages to find the SQL query results (ToolMessage with actual data) + company_ids = [] + for message in response["messages"]: + if hasattr(message, "content") and message.content: + # Check if this looks like SQL results (contains tuples with numbers) + if "(" in str(message.content) and "," in str(message.content): + company_ids = self._extract_company_ids_from_response( + str(message.content) + ) + if company_ids: + logger.info( + f"Extracted {len(company_ids)} company IDs from results" + ) + break + + # If no IDs found from ToolMessage, check the final AI message + if not company_ids: + final_message_content = response["messages"][-1].content + logger.info(f"AI Response: \n{final_message_content}") + company_ids = self._extract_company_ids_from_response(final_message_content) + + # Fetch full company data with relationships using the IDs + return self._fetch_companies_by_ids(company_ids) + + def _extract_company_ids_from_response(self, ai_response: str) -> List[int]: + """Extract company IDs from AI response.""" + import re + + company_ids = [] + + # Check if response is NO_RESULTS + if "NO_RESULTS" in ai_response.upper(): + return [] + + try: + # The response contains tuples like (1,), (5,), etc. + # Extract numbers between parentheses + pattern = r"\((\d+),?\)" + matches = re.findall(pattern, ai_response) + if matches: + company_ids = [int(match) for match in matches] + else: + # Fallback: extract all numbers + numbers = re.findall(r"\b\d+\b", ai_response) + # Filter out very large numbers that might be tokens or timestamps + company_ids = [int(num) for num in numbers if int(num) < 100000] + + except Exception as e: + logger.error(f"Error extracting IDs from response: {e}") + return [] + + return company_ids + + def _fetch_companies_by_ids( + self, company_ids: List[int] + ) -> PaginatedResponse[CompanyData]: + """Fetch companies with all their relationships from the database using company IDs. + + Args: + company_ids: List of company IDs to fetch + """ + if not company_ids: + return PaginatedResponse( + items=[], + total=0, + page=1, + page_size=len(company_ids) if company_ids else 10, + total_pages=0, + ) + + # Get database session + db_session = next(get_db()) + + try: + # Query companies with all necessary relationships loaded + companies = ( + db_session.query(CompanyTable) + .options( + selectinload(CompanyTable.investors), + selectinload(CompanyTable.members), + selectinload(CompanyTable.sectors), + ) + .filter(CompanyTable.id.in_(company_ids)) + .all() + ) + + # Transform to CompanyData format + company_data_list = [] + for company in companies: + company_data = CompanyData( + company=company, + investors=company.investors, + members=company.members, + sectors=company.sectors, + ) + company_data_list.append(company_data) + + total_count = len(company_data_list) + total_pages = 1 if total_count > 0 else 0 + + return PaginatedResponse( + items=company_data_list, + total=total_count, + page=1, + page_size=total_count, + total_pages=total_pages, + ) + + finally: + db_session.close() diff --git a/app/services/crm.py b/app/services/crm.py index 4801c5a..d23f33a 100644 --- a/app/services/crm.py +++ b/app/services/crm.py @@ -1,14 +1,24 @@ +import logging import os import sys import requests +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + class FolkAPI: BASE_URL = "https://api.folk.app/v1" def __init__(self, api_key: str): + api_key = os.environ.get("FOLK_API_KEY", api_key) self.headers = {"Authorization": f"Bearer {api_key}"} + logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***") def get_groups(self): """Fetch all groups from Folk.""" @@ -190,71 +200,3 @@ class FolkAPI: response.raise_for_status() return response.json() - -# Prefer getting the API key from the environment. If not set, fall back to the -# existing (hard-coded) key so behavior is unchanged for now. -DEFAULT_API_KEY = "FOLKfIGXuv74ML9EAajxyiUR39ePaNrZ" -api_key = os.environ.get("FOLK_API_KEY", DEFAULT_API_KEY) - -folk = FolkAPI(api_key=api_key) - - -def example_flow(): - # Step 1: Get groups - groups = folk.get_groups() - print(groups) - - # Safely dig into the returned structure. The API returns groups under - # groups['data']['items'] (not groups['data'][0]). Handle missing/empty. - items = groups.get("data", {}).get("items", []) - if not items: - print("No groups returned by Folk API.") - sys.exit(1) - - # Choose the first group as an example - group_id = items[0].get("id") - if not group_id: - print("No id found for the first group item.") - sys.exit(1) - - # Step 2: Choose a group_id and create a company - company = folk.create_company( - name="2050 Investment Partners", - group_id=group_id, - website="https://2050.com", - linkedin_url="https://linkedin.com/company/2050-investments", - ) - - # Step 3: Add a person to the same group or company - person = folk.create_person( - first_name="John", - last_name="Doe", - email="john@2050.com", - company_id=company.get("data", {}).get("id"), - group_id=group_id, - ) - - print("Created company:", company) - print("Created person:", person) - - -if __name__ == "__main__": - try: - example_flow() - except requests.HTTPError as e: - # Try to include response body for easier debugging if available - resp = getattr(e, "response", None) - if resp is not None: - try: - body = resp.text - except Exception: - body = "" - print("HTTP error while talking to Folk API:", e) - print("Response status:", resp.status_code) - print("Response body:", body) - else: - print("HTTP error while talking to Folk API:", e) - sys.exit(1) - except Exception as e: # pragma: no cover - top-level safety - print("Unexpected error:", e) - sys.exit(1)