Files
ds_zagres_ai/app/services/chat_service.py
T
2025-05-09 15:41:16 +01:00

228 lines
6.4 KiB
Python

"""
Service for chat functionality.
"""
from typing import List, Dict, Any, Optional
from app.database.db import db
from app.models.chat import Chat, Message, TeamChatMember
from app.models.user import User
class ChatService:
"""Service for chat functionality."""
def create_chat(self, user_id: int, title: Optional[str] = None,
is_team_chat: bool = False, model_name: Optional[str] = None) -> Chat:
"""
Create a new chat.
Args:
user_id: ID of the user creating the chat.
title: Optional title for the chat.
is_team_chat: Whether this is a team chat.
model_name: Name of the model to use for this chat.
Returns:
Created chat.
"""
from app.config.config import Config
chat = Chat(
user_id=user_id,
title=title,
is_team_chat=is_team_chat,
model_name=model_name or Config().DEFAULT_MODEL
)
db.session.add(chat)
db.session.commit()
# If it's a team chat, add the creator as a member
if is_team_chat:
self.add_team_member(chat.id, user_id)
return chat
def get_chat(self, chat_id: int) -> Optional[Chat]:
"""
Get a chat by ID.
Args:
chat_id: ID of the chat.
Returns:
Chat if found, None otherwise.
"""
return Chat.query.get(chat_id)
def get_user_chats(self, user_id: int) -> List[Chat]:
"""
Get all chats for a user.
Args:
user_id: ID of the user.
Returns:
List of chats.
"""
# Get private chats
private_chats = Chat.query.filter_by(
user_id=user_id,
is_team_chat=False
).order_by(Chat.updated_at.desc()).all()
# Get team chats where user is a member
team_chat_ids = db.session.query(TeamChatMember.chat_id).filter_by(user_id=user_id).all()
team_chat_ids = [chat_id for (chat_id,) in team_chat_ids]
team_chats = Chat.query.filter(
Chat.id.in_(team_chat_ids)
).order_by(Chat.updated_at.desc()).all()
# Combine and sort by updated_at
all_chats = private_chats + team_chats
all_chats.sort(key=lambda x: x.updated_at, reverse=True)
return all_chats
def add_message(self, chat_id: int, content: str,
is_user_message: bool = True, user_id: Optional[int] = None) -> Message:
"""
Add a message to a chat.
Args:
chat_id: ID of the chat.
content: Message content.
is_user_message: Whether this is a user message (vs. bot message).
user_id: ID of the user sending the message (required for user messages).
Returns:
Created message.
"""
message = Message(
chat_id=chat_id,
content=content,
is_user_message=is_user_message,
user_id=user_id if is_user_message else None
)
db.session.add(message)
# Update chat's updated_at timestamp
chat = Chat.query.get(chat_id)
if chat:
chat.updated_at = message.created_at
db.session.commit()
return message
def get_chat_messages(self, chat_id: int) -> List[Message]:
"""
Get all messages for a chat.
Args:
chat_id: ID of the chat.
Returns:
List of messages.
"""
return Message.query.filter_by(chat_id=chat_id).order_by(Message.created_at).all()
def add_team_member(self, chat_id: int, user_id: int) -> Optional[TeamChatMember]:
"""
Add a user to a team chat.
Args:
chat_id: ID of the team chat.
user_id: ID of the user to add.
Returns:
Created team chat member if successful, None otherwise.
"""
chat = Chat.query.get(chat_id)
if not chat or not chat.is_team_chat:
return None
# Check if user is already a member
existing_member = TeamChatMember.query.filter_by(
chat_id=chat_id,
user_id=user_id
).first()
if existing_member:
return existing_member
member = TeamChatMember(
chat_id=chat_id,
user_id=user_id
)
db.session.add(member)
db.session.commit()
return member
def get_team_members(self, chat_id: int) -> List[User]:
"""
Get all members of a team chat.
Args:
chat_id: ID of the team chat.
Returns:
List of users.
"""
member_ids = db.session.query(TeamChatMember.user_id).filter_by(chat_id=chat_id).all()
member_ids = [user_id for (user_id,) in member_ids]
return User.query.filter(User.id.in_(member_ids)).all()
def remove_team_member(self, chat_id: int, user_id: int) -> bool:
"""
Remove a user from a team chat.
Args:
chat_id: ID of the team chat.
user_id: ID of the user to remove.
Returns:
True if removal was successful, False otherwise.
"""
member = TeamChatMember.query.filter_by(
chat_id=chat_id,
user_id=user_id
).first()
if not member:
return False
db.session.delete(member)
db.session.commit()
return True
def delete_chat(self, chat_id: int) -> bool:
"""
Delete a chat and all its messages.
Args:
chat_id: ID of the chat to delete.
Returns:
True if deletion was successful, False otherwise.
"""
chat = Chat.query.get(chat_id)
if not chat:
return False
try:
db.session.delete(chat)
db.session.commit()
return True
except Exception as e:
# Log the error
print(f"Error deleting chat {chat_id}: {str(e)}")
db.session.rollback()
return False