2024-08-07 18:55:56 +01:00
|
|
|
import sys, os
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
|
|
2024-08-07 18:27:42 +01:00
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from pydantic import BaseModel
|
2024-08-14 23:09:10 +01:00
|
|
|
from utils import search, load_embedded_data
|
|
|
|
|
from data_ingest import load_data
|
2024-08-07 18:27:42 +01:00
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
# Initialize global variables for FAISS index and vector store
|
|
|
|
|
try:
|
|
|
|
|
vector_store = load_embedded_data()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
vector_store = None
|
|
|
|
|
|
|
|
|
|
# Define allowed origins for CORS
|
|
|
|
|
origins = [
|
|
|
|
|
"http://localhost",
|
|
|
|
|
"http://localhost:8000",
|
|
|
|
|
"http://localhost:3000",
|
|
|
|
|
# Add other allowed origins here
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Add CORS middleware
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
allow_origins=origins, # Allows requests from listed origins
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
allow_methods=["*"], # Allows all HTTP methods
|
|
|
|
|
allow_headers=["*"], # Allows all headers
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
|
|
|
query: str
|
|
|
|
|
|
2024-08-07 18:55:56 +01:00
|
|
|
@app.get("/load_documents")
|
2024-08-07 18:27:42 +01:00
|
|
|
def load_documents(directory: str):
|
|
|
|
|
global vector_store
|
|
|
|
|
|
|
|
|
|
# Load documents using the utility function
|
|
|
|
|
vector_store = load_data(directory)
|
|
|
|
|
|
|
|
|
|
return {"status": "Documents loaded successfully"}
|
|
|
|
|
|
2024-08-07 18:55:56 +01:00
|
|
|
@app.get("/search")
|
2024-08-07 18:27:42 +01:00
|
|
|
def search(request: SearchRequest):
|
|
|
|
|
global vector_store
|
|
|
|
|
|
|
|
|
|
# Perform search using the utility function
|
|
|
|
|
results = search(vector_store, request.query)
|
|
|
|
|
|
|
|
|
|
return {"results": results}
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|