added fastapi endpoints
This commit is contained in:
@@ -26,7 +26,8 @@ def load_data(data_path: str):
|
|||||||
save_embedded_data(embed_db)
|
save_embedded_data(embed_db)
|
||||||
logger.info(f"Vector store saved")
|
logger.info(f"Vector store saved")
|
||||||
|
|
||||||
return "Vector store created and saved"
|
print("Vector store created and saved")
|
||||||
|
return embed_db
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
+7
-70
@@ -341,6 +341,13 @@
|
|||||||
" "
|
" "
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Data Search"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 83,
|
"execution_count": 83,
|
||||||
@@ -543,76 +550,6 @@
|
|||||||
" print(f\"* [SIM={score:3f}] {res.page_content} [{res.metadata}]\")"
|
" print(f\"* [SIM={score:3f}] {res.page_content} [{res.metadata}]\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Data Search"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"query = \"Steering assist function/lane centering function\"\n",
|
|
||||||
"docs = load_db.similarity_search(query)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(docs[0].page_content)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(docs[0].metadata['page'])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def search(db, query, k=4):\n",
|
|
||||||
" docs = db.similarity_search(query, k)\n",
|
|
||||||
" all = \"\"\n",
|
|
||||||
" pages = []\n",
|
|
||||||
" for doc in docs:\n",
|
|
||||||
" all += f\"{doc.page_content}\\n\"\n",
|
|
||||||
" pages.append(doc.metadata['page'])\n",
|
|
||||||
" return docs[0].page_content, all, pages"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"search_result, all, pages = search(db, \"What is LDA\")\n",
|
|
||||||
"print( search_result )"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"pages"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
# main.py
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from data_ingestion.utils import search, load_embedded_data
|
||||||
|
from data_ingestion.data_ingest import load_data
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Initialize global variables for FAISS index and vector store
|
||||||
|
try:
|
||||||
|
vector_store = load_embedded_data()
|
||||||
|
except Exception as e:
|
||||||
|
vector_store = None
|
||||||
|
|
||||||
|
# Define allowed origins for CORS
|
||||||
|
origins = [
|
||||||
|
"http://localhost",
|
||||||
|
"http://localhost:8000",
|
||||||
|
"http://localhost:3000",
|
||||||
|
# Add other allowed origins here
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins, # Allows requests from listed origins
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # Allows all HTTP methods
|
||||||
|
allow_headers=["*"], # Allows all headers
|
||||||
|
)
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
@app.post("/load_documents")
|
||||||
|
def load_documents(directory: str):
|
||||||
|
global vector_store
|
||||||
|
|
||||||
|
# Load documents using the utility function
|
||||||
|
vector_store = load_data(directory)
|
||||||
|
|
||||||
|
return {"status": "Documents loaded successfully"}
|
||||||
|
|
||||||
|
@app.post("/search")
|
||||||
|
def search(request: SearchRequest):
|
||||||
|
global vector_store
|
||||||
|
|
||||||
|
# Perform search using the utility function
|
||||||
|
results = search(vector_store, request.query)
|
||||||
|
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
@@ -11,3 +11,4 @@ langchain-text-splitters
|
|||||||
unstructured[all-docs]
|
unstructured[all-docs]
|
||||||
docx2txt
|
docx2txt
|
||||||
docx
|
docx
|
||||||
|
"fastapi[standard]"
|
||||||
Reference in New Issue
Block a user