323 lines
12 KiB
Python
323 lines
12 KiB
Python
# Custom AI bot with improved error handling and connection methods
|
|
|
|
import asyncio
|
|
import socketio
|
|
import os
|
|
import traceback
|
|
import logging
|
|
import aiohttp
|
|
import sys
|
|
from env import WEBUI_URL, TOKEN
|
|
from utils import send_message, send_typing
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler('bot_debug.log')
|
|
]
|
|
)
|
|
logger = logging.getLogger('openwebui_bot')
|
|
|
|
# Get model configuration from environment variables
|
|
MODEL_ID = os.getenv("MODEL_ID", "llama3.1")
|
|
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful AI assistant.")
|
|
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
|
|
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "2048"))
|
|
TOP_P = float(os.getenv("TOP_P", "0.9"))
|
|
TRIGGERS = os.getenv("TRIGGERS", "@ai,@bot,@assistant,@chatbot").split(",")
|
|
RESPOND_TO_ALL = os.getenv("RESPOND_TO_ALL", "false").lower() == "true"
|
|
|
|
# Create an asynchronous Socket.IO client instance
|
|
sio = socketio.AsyncClient(logger=False, engineio_logger=False)
|
|
|
|
# Event handlers
|
|
@sio.event
|
|
async def connect():
|
|
logger.info("Connected to OpenWebUI!")
|
|
|
|
@sio.event
|
|
async def disconnect():
|
|
logger.info("Disconnected from OpenWebUI!")
|
|
|
|
# Function to call the OpenAI-compatible API
|
|
async def openai_chat_completion(messages):
|
|
payload = {
|
|
"model": MODEL_ID,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"temperature": TEMPERATURE,
|
|
"max_tokens": MAX_TOKENS,
|
|
"top_p": TOP_P
|
|
}
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
logger.info(f"Sending request to {WEBUI_URL}/api/chat/completions")
|
|
logger.debug(f"Payload: {payload}")
|
|
|
|
async with session.post(
|
|
f"{WEBUI_URL}/api/chat/completions",
|
|
headers={"Authorization": f"Bearer {TOKEN}"},
|
|
json=payload,
|
|
timeout=300 # 5-minute timeout
|
|
) as response:
|
|
if response.status == 200:
|
|
result = await response.json()
|
|
logger.info("API request successful")
|
|
return result
|
|
else:
|
|
error_text = await response.text()
|
|
logger.error(f"API error: {response.status} - {error_text}")
|
|
return {"error": error_text, "status": response.status}
|
|
except Exception as e:
|
|
logger.error(f"Error in openai_chat_completion: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
return {"error": f"Error: {str(e)}", "status": 500}
|
|
|
|
# Helper function to send typing indicators while waiting for a response
|
|
async def send_typing_until_complete(channel_id, coro):
|
|
"""
|
|
Sends typing indicators every second until the provided coroutine completes.
|
|
"""
|
|
task = asyncio.create_task(coro)
|
|
try:
|
|
while not task.done():
|
|
await send_typing(sio, channel_id)
|
|
await asyncio.sleep(1)
|
|
return await task
|
|
except Exception as e:
|
|
task.cancel()
|
|
raise e
|
|
|
|
# Define a function to handle channel events
|
|
def events(user_id):
|
|
@sio.on("channel-events")
|
|
async def channel_events(data):
|
|
try:
|
|
logger.debug(f"Received channel event: {data}")
|
|
|
|
# Ignore events from the bot itself
|
|
if data["user"]["id"] == user_id:
|
|
logger.debug(f"Ignoring message from self (bot ID: {user_id})")
|
|
return
|
|
|
|
# Only process message events
|
|
if data["data"]["type"] == "message":
|
|
message_content = data["data"]["data"]["content"]
|
|
channel_id = data["channel_id"]
|
|
sender_name = data["user"]["name"]
|
|
|
|
logger.info(f"Message in channel: {sender_name}: {message_content}")
|
|
|
|
# Check if we should respond
|
|
should_respond = RESPOND_TO_ALL
|
|
message_lower = message_content.lower()
|
|
|
|
if not should_respond:
|
|
# Check if the message mentions the bot
|
|
for trigger in TRIGGERS:
|
|
trigger_lower = trigger.lower()
|
|
if trigger_lower in message_lower:
|
|
logger.info(f"Trigger detected: {trigger}")
|
|
should_respond = True
|
|
break
|
|
|
|
if not should_respond:
|
|
logger.debug("No trigger detected, skipping message")
|
|
return
|
|
|
|
# Remove the trigger from the message
|
|
processed_message = message_content
|
|
|
|
# Only try to remove triggers if we're not responding to all messages
|
|
if not RESPOND_TO_ALL:
|
|
for trigger in TRIGGERS:
|
|
trigger_lower = trigger.lower()
|
|
if trigger_lower in message_lower:
|
|
# Find the trigger in the message
|
|
index = message_lower.find(trigger_lower)
|
|
if index != -1:
|
|
# Remove the trigger
|
|
processed_message = processed_message[:index] + processed_message[index + len(trigger):].strip()
|
|
|
|
# If the message is empty after removing the trigger, use a default prompt
|
|
if not processed_message.strip():
|
|
processed_message = "Hello, how can I help you?"
|
|
break
|
|
|
|
# Show typing indicator
|
|
await send_typing(sio, channel_id)
|
|
|
|
try:
|
|
# Prepare the messages for the API
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": processed_message}
|
|
]
|
|
|
|
# Call the API while showing typing indicators
|
|
response = await send_typing_until_complete(
|
|
channel_id, openai_chat_completion(messages)
|
|
)
|
|
|
|
# Process the response
|
|
if response.get("choices"):
|
|
completion = response["choices"][0]["message"]["content"]
|
|
# Add a robot emoji to the response
|
|
formatted_response = f"🤖 {completion}"
|
|
await send_message(channel_id, formatted_response)
|
|
else:
|
|
error_message = response.get("error", "I'm sorry, I couldn't generate a response.")
|
|
await send_message(channel_id, f"🤖 Error: {error_message}")
|
|
except Exception as e:
|
|
logger.error(f"Error generating response: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
await send_message(
|
|
channel_id,
|
|
"🤖 Something went wrong while processing your request."
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing channel event: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Define an async function for the main workflow
|
|
async def main():
|
|
max_retries = 5
|
|
retry_delay = 5 # seconds
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
# Ensure the URL is properly formatted
|
|
base_url = WEBUI_URL.rstrip('/')
|
|
logger.info(f"Connecting to {base_url}... (Attempt {attempt}/{max_retries})")
|
|
|
|
# Try different connection methods
|
|
connection_methods = [
|
|
# Method 1: Standard connection
|
|
{
|
|
"url": base_url,
|
|
"socketio_path": "/ws/socket.io",
|
|
"transports": ["websocket"],
|
|
"description": "Standard WebSocket connection"
|
|
},
|
|
# Method 2: Alternative socket.io path
|
|
{
|
|
"url": base_url,
|
|
"socketio_path": "/socket.io",
|
|
"transports": ["websocket"],
|
|
"description": "Alternative socket.io path"
|
|
},
|
|
# Method 3: Try with polling transport
|
|
{
|
|
"url": base_url,
|
|
"socketio_path": "/ws/socket.io",
|
|
"transports": ["polling", "websocket"],
|
|
"description": "Polling transport"
|
|
},
|
|
# Method 4: Alternative path with polling
|
|
{
|
|
"url": base_url,
|
|
"socketio_path": "/socket.io",
|
|
"transports": ["polling", "websocket"],
|
|
"description": "Alternative path with polling"
|
|
}
|
|
]
|
|
|
|
# Try each connection method
|
|
connected = False
|
|
for method in connection_methods:
|
|
if connected:
|
|
break
|
|
|
|
try:
|
|
logger.info(f"Trying {method['description']}...")
|
|
await sio.connect(
|
|
method["url"],
|
|
socketio_path=method["socketio_path"],
|
|
transports=method["transports"]
|
|
)
|
|
logger.info(f"Connection successful using {method['description']}!")
|
|
connected = True
|
|
except Exception as conn_error:
|
|
logger.error(f"{method['description']} failed: {str(conn_error)}")
|
|
|
|
if not connected:
|
|
raise Exception("All connection methods failed")
|
|
|
|
logger.info("Connection established!")
|
|
break # Connection successful, exit the retry loop
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect: {e}")
|
|
logger.error(traceback.format_exc())
|
|
if attempt < max_retries:
|
|
logger.info(f"Retrying in {retry_delay} seconds...")
|
|
await asyncio.sleep(retry_delay)
|
|
else:
|
|
logger.error("Maximum connection attempts reached. Exiting.")
|
|
return
|
|
|
|
try:
|
|
# Callback function for user-join
|
|
async def join_callback(*args):
|
|
try:
|
|
logger.info(f"Join callback received: {args}")
|
|
if args and len(args) > 0:
|
|
data = args[0]
|
|
if isinstance(data, dict) and "id" in data:
|
|
bot_id = data["id"]
|
|
logger.info(f"Bot connected with ID: {bot_id}")
|
|
events(bot_id) # Attach the event handlers
|
|
else:
|
|
logger.warning(f"Invalid callback data: {data}")
|
|
events("bot-default-id") # Use a default ID
|
|
else:
|
|
logger.warning("No callback data received")
|
|
events("bot-default-id") # Use a default ID
|
|
except Exception as e:
|
|
logger.error(f"Error in join callback: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Authenticate with the server
|
|
logger.info("Authenticating with the server...")
|
|
await sio.emit("user-join", {"auth": {"token": TOKEN}}, callback=join_callback)
|
|
|
|
# Register for channel events directly since the callback might not provide the bot ID
|
|
logger.info("Registering for channel events directly...")
|
|
events("bot-user") # Use a default bot ID
|
|
|
|
# Wait indefinitely to keep the connection open
|
|
logger.info("Waiting for events...")
|
|
await sio.wait()
|
|
except Exception as e:
|
|
logger.error(f"Error in main loop: {str(e)}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
# Graceful shutdown
|
|
async def shutdown():
|
|
logger.info("Shutting down bot...")
|
|
if sio.connected:
|
|
await sio.disconnect()
|
|
logger.info("Bot shutdown complete.")
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("Starting custom AI bot...")
|
|
logger.info(f"OpenWebUI URL: {WEBUI_URL}")
|
|
logger.info(f"Model: {MODEL_ID}")
|
|
logger.info(f"Triggers: {TRIGGERS}")
|
|
logger.info(f"Respond to all: {RESPOND_TO_ALL}")
|
|
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Bot stopped by user")
|
|
try:
|
|
asyncio.run(shutdown())
|
|
except Exception as e:
|
|
logger.error(f"Error during shutdown: {str(e)}")
|
|
except Exception as e:
|
|
logger.error(f"Error running bot: {str(e)}")
|
|
logger.error(traceback.format_exc())
|