Files
2025-03-06 17:32:42 +00:00

94 lines
3.0 KiB
Python

import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import openai
from openai import OpenAI
from langchain_community.document_loaders.csv_loader import CSVLoader
from pathlib import Path
from langchain_openai import ChatOpenAI,OpenAIEmbeddings
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from scripts.rag import get_calculations, response_agent,setup_rag_system
import os
from typing import Optional
from fastapi import FastAPI, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
load_dotenv()
# Load environment variables from a .env file
# Initialize OpenAI LLM and Embeddings
llm = ChatOpenAI(model="gpt-4o-mini") # Use the appropriate model
embedding_function = OpenAIEmbeddings()
# Set up the RAG system
file_path = r'docs/UpdateAI_training-company-data.md'
rag_chain = setup_rag_system(file_path, llm, embedding_function)
# Initialize FastAPI app
app = FastAPI()
load_dotenv()
API_KEY = os.getenv("API_KEY")
# Initialize FastAPI app
app = FastAPI(
title="Update stack AI API",
description="API For fire Update stack",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Setup API key authentication
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def get_api_key(api_key_header: str = Security(api_key_header)) -> str:
"""Validate API key from header"""
if not api_key_header or not api_key_header.startswith('Bearer '):
raise HTTPException(
status_code=401,
detail={"error": "Unauthorized", "message": "API key is missing or invalid."}
)
token = api_key_header.split(' ')[1]
if token != API_KEY:
raise HTTPException(
status_code=401,
detail={"error": "Unauthorized", "message": "API key does not match."}
)
return token
# Define a request model
class QuestionRequest(BaseModel):
question: str
@app.post("/update-stack/v1/ask")
def ask_question(request: QuestionRequest,
api_key: str = Depends(get_api_key)):
try:
# Use the RAG system to get the answer
calculation_data = get_calculations(rag_chain, request.question)
final_response = response_agent(request.question, calculation_data)
return {"answer": final_response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5079)