added fastapi endpoints
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
# main.py
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from data_ingestion.utils import search, load_embedded_data
|
||||
from data_ingestion.data_ingest import load_data
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize global variables for FAISS index and vector store
|
||||
try:
|
||||
vector_store = load_embedded_data()
|
||||
except Exception as e:
|
||||
vector_store = None
|
||||
|
||||
# Define allowed origins for CORS
|
||||
origins = [
|
||||
"http://localhost",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:3000",
|
||||
# Add other allowed origins here
|
||||
]
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins, # Allows requests from listed origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all HTTP methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
@app.post("/load_documents")
|
||||
def load_documents(directory: str):
|
||||
global vector_store
|
||||
|
||||
# Load documents using the utility function
|
||||
vector_store = load_data(directory)
|
||||
|
||||
return {"status": "Documents loaded successfully"}
|
||||
|
||||
@app.post("/search")
|
||||
def search(request: SearchRequest):
|
||||
global vector_store
|
||||
|
||||
# Perform search using the utility function
|
||||
results = search(vector_store, request.query)
|
||||
|
||||
return {"results": results}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user