added fastapi endpoints

This commit is contained in:
timothyafolami
2024-08-07 18:27:42 +01:00
parent 228fffefd8
commit ab418262b1
4 changed files with 67 additions and 72 deletions
+2 -1
View File
@@ -26,7 +26,8 @@ def load_data(data_path: str):
save_embedded_data(embed_db) save_embedded_data(embed_db)
logger.info(f"Vector store saved") logger.info(f"Vector store saved")
return "Vector store created and saved" print("Vector store created and saved")
return embed_db
if __name__ == "__main__": if __name__ == "__main__":
+7 -70
View File
@@ -341,6 +341,13 @@
" " " "
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Search"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 83, "execution_count": 83,
@@ -543,76 +550,6 @@
" print(f\"* [SIM={score:3f}] {res.page_content} [{res.metadata}]\")" " print(f\"* [SIM={score:3f}] {res.page_content} [{res.metadata}]\")"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"Steering assist function/lane centering function\"\n",
"docs = load_db.similarity_search(query)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(docs[0].page_content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(docs[0].metadata['page'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def search(db, query, k=4):\n",
" docs = db.similarity_search(query, k)\n",
" all = \"\"\n",
" pages = []\n",
" for doc in docs:\n",
" all += f\"{doc.page_content}\\n\"\n",
" pages.append(doc.metadata['page'])\n",
" return docs[0].page_content, all, pages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"search_result, all, pages = search(db, \"What is LDA\")\n",
"print( search_result )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pages"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
+56
View File
@@ -0,0 +1,56 @@
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from data_ingestion.utils import search, load_embedded_data
from data_ingestion.data_ingest import load_data
app = FastAPI()
# Initialize global variables for FAISS index and vector store
try:
vector_store = load_embedded_data()
except Exception as e:
vector_store = None
# Define allowed origins for CORS
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:3000",
# Add other allowed origins here
]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Allows requests from listed origins
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods
allow_headers=["*"], # Allows all headers
)
class SearchRequest(BaseModel):
query: str
@app.post("/load_documents")
def load_documents(directory: str):
global vector_store
# Load documents using the utility function
vector_store = load_data(directory)
return {"status": "Documents loaded successfully"}
@app.post("/search")
def search(request: SearchRequest):
global vector_store
# Perform search using the utility function
results = search(vector_store, request.query)
return {"results": results}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
+1
View File
@@ -11,3 +11,4 @@ langchain-text-splitters
unstructured[all-docs] unstructured[all-docs]
docx2txt docx2txt
docx docx
"fastapi[standard]"