From ab418262b185ebb2bc998efcd7bd523de3e22477 Mon Sep 17 00:00:00 2001 From: timothyafolami Date: Wed, 7 Aug 2024 18:27:42 +0100 Subject: [PATCH] added fastapi endpoints --- data_ingestion/data_ingest.py | 3 +- doc-experiment.ipynb | 77 ++++------------------------------- main.py | 56 +++++++++++++++++++++++++ requirements.txt | 3 +- 4 files changed, 67 insertions(+), 72 deletions(-) create mode 100644 main.py diff --git a/data_ingestion/data_ingest.py b/data_ingestion/data_ingest.py index ac48baf7..3515e625 100644 --- a/data_ingestion/data_ingest.py +++ b/data_ingestion/data_ingest.py @@ -26,7 +26,8 @@ def load_data(data_path: str): save_embedded_data(embed_db) logger.info(f"Vector store saved") - return "Vector store created and saved" + print("Vector store created and saved") + return embed_db if __name__ == "__main__": diff --git a/doc-experiment.ipynb b/doc-experiment.ipynb index 165b4234..c82c2670 100644 --- a/doc-experiment.ipynb +++ b/doc-experiment.ipynb @@ -341,6 +341,13 @@ " " ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Search" + ] + }, { "cell_type": "code", "execution_count": 83, @@ -543,76 +550,6 @@ " print(f\"* [SIM={score:3f}] {res.page_content} [{res.metadata}]\")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data Search" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "query = \"Steering assist function/lane centering function\"\n", - "docs = load_db.similarity_search(query)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(docs[0].page_content)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(docs[0].metadata['page'])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def search(db, query, k=4):\n", - " docs = db.similarity_search(query, k)\n", - " all = \"\"\n", - " pages = []\n", - " for doc in docs:\n", - " all += f\"{doc.page_content}\\n\"\n", - " pages.append(doc.metadata['page'])\n", - " return docs[0].page_content, all, pages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "search_result, all, pages = search(db, \"What is LDA\")\n", - "print( search_result )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pages" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/main.py b/main.py new file mode 100644 index 00000000..81b9bf71 --- /dev/null +++ b/main.py @@ -0,0 +1,56 @@ +# main.py +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from data_ingestion.utils import search, load_embedded_data +from data_ingestion.data_ingest import load_data + +app = FastAPI() + +# Initialize global variables for FAISS index and vector store +try: + vector_store = load_embedded_data() +except Exception as e: + vector_store = None + +# Define allowed origins for CORS +origins = [ + "http://localhost", + "http://localhost:8000", + "http://localhost:3000", + # Add other allowed origins here +] + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=origins, # Allows requests from listed origins + allow_credentials=True, + allow_methods=["*"], # Allows all HTTP methods + allow_headers=["*"], # Allows all headers +) + +class SearchRequest(BaseModel): + query: str + +@app.post("/load_documents") +def load_documents(directory: str): + global vector_store + + # Load documents using the utility function + vector_store = load_data(directory) + + return {"status": "Documents loaded successfully"} + +@app.post("/search") +def search(request: SearchRequest): + global vector_store + + # Perform search using the utility function + results = search(vector_store, request.query) + + return {"results": results} + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt index a9f1607d..5f3958d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ langchain-huggingface langchain-text-splitters unstructured[all-docs] docx2txt -docx \ No newline at end of file +docx +"fastapi[standard]" \ No newline at end of file