228 lines
6.4 KiB
Python
228 lines
6.4 KiB
Python
|
|
"""
|
||
|
|
Service for chat functionality.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import List, Dict, Any, Optional
|
||
|
|
from app.database.db import db
|
||
|
|
from app.models.chat import Chat, Message, TeamChatMember
|
||
|
|
from app.models.user import User
|
||
|
|
|
||
|
|
class ChatService:
|
||
|
|
"""Service for chat functionality."""
|
||
|
|
|
||
|
|
def create_chat(self, user_id: int, title: Optional[str] = None,
|
||
|
|
is_team_chat: bool = False, model_name: Optional[str] = None) -> Chat:
|
||
|
|
"""
|
||
|
|
Create a new chat.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_id: ID of the user creating the chat.
|
||
|
|
title: Optional title for the chat.
|
||
|
|
is_team_chat: Whether this is a team chat.
|
||
|
|
model_name: Name of the model to use for this chat.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Created chat.
|
||
|
|
"""
|
||
|
|
from app.config.config import Config
|
||
|
|
|
||
|
|
chat = Chat(
|
||
|
|
user_id=user_id,
|
||
|
|
title=title,
|
||
|
|
is_team_chat=is_team_chat,
|
||
|
|
model_name=model_name or Config().DEFAULT_MODEL
|
||
|
|
)
|
||
|
|
|
||
|
|
db.session.add(chat)
|
||
|
|
db.session.commit()
|
||
|
|
|
||
|
|
# If it's a team chat, add the creator as a member
|
||
|
|
if is_team_chat:
|
||
|
|
self.add_team_member(chat.id, user_id)
|
||
|
|
|
||
|
|
return chat
|
||
|
|
|
||
|
|
def get_chat(self, chat_id: int) -> Optional[Chat]:
|
||
|
|
"""
|
||
|
|
Get a chat by ID.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the chat.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Chat if found, None otherwise.
|
||
|
|
"""
|
||
|
|
return Chat.query.get(chat_id)
|
||
|
|
|
||
|
|
def get_user_chats(self, user_id: int) -> List[Chat]:
|
||
|
|
"""
|
||
|
|
Get all chats for a user.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
user_id: ID of the user.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of chats.
|
||
|
|
"""
|
||
|
|
# Get private chats
|
||
|
|
private_chats = Chat.query.filter_by(
|
||
|
|
user_id=user_id,
|
||
|
|
is_team_chat=False
|
||
|
|
).order_by(Chat.updated_at.desc()).all()
|
||
|
|
|
||
|
|
# Get team chats where user is a member
|
||
|
|
team_chat_ids = db.session.query(TeamChatMember.chat_id).filter_by(user_id=user_id).all()
|
||
|
|
team_chat_ids = [chat_id for (chat_id,) in team_chat_ids]
|
||
|
|
|
||
|
|
team_chats = Chat.query.filter(
|
||
|
|
Chat.id.in_(team_chat_ids)
|
||
|
|
).order_by(Chat.updated_at.desc()).all()
|
||
|
|
|
||
|
|
# Combine and sort by updated_at
|
||
|
|
all_chats = private_chats + team_chats
|
||
|
|
all_chats.sort(key=lambda x: x.updated_at, reverse=True)
|
||
|
|
|
||
|
|
return all_chats
|
||
|
|
|
||
|
|
def add_message(self, chat_id: int, content: str,
|
||
|
|
is_user_message: bool = True, user_id: Optional[int] = None) -> Message:
|
||
|
|
"""
|
||
|
|
Add a message to a chat.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the chat.
|
||
|
|
content: Message content.
|
||
|
|
is_user_message: Whether this is a user message (vs. bot message).
|
||
|
|
user_id: ID of the user sending the message (required for user messages).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Created message.
|
||
|
|
"""
|
||
|
|
message = Message(
|
||
|
|
chat_id=chat_id,
|
||
|
|
content=content,
|
||
|
|
is_user_message=is_user_message,
|
||
|
|
user_id=user_id if is_user_message else None
|
||
|
|
)
|
||
|
|
|
||
|
|
db.session.add(message)
|
||
|
|
|
||
|
|
# Update chat's updated_at timestamp
|
||
|
|
chat = Chat.query.get(chat_id)
|
||
|
|
if chat:
|
||
|
|
chat.updated_at = message.created_at
|
||
|
|
|
||
|
|
db.session.commit()
|
||
|
|
|
||
|
|
return message
|
||
|
|
|
||
|
|
def get_chat_messages(self, chat_id: int) -> List[Message]:
|
||
|
|
"""
|
||
|
|
Get all messages for a chat.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the chat.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of messages.
|
||
|
|
"""
|
||
|
|
return Message.query.filter_by(chat_id=chat_id).order_by(Message.created_at).all()
|
||
|
|
|
||
|
|
def add_team_member(self, chat_id: int, user_id: int) -> Optional[TeamChatMember]:
|
||
|
|
"""
|
||
|
|
Add a user to a team chat.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the team chat.
|
||
|
|
user_id: ID of the user to add.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Created team chat member if successful, None otherwise.
|
||
|
|
"""
|
||
|
|
chat = Chat.query.get(chat_id)
|
||
|
|
if not chat or not chat.is_team_chat:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Check if user is already a member
|
||
|
|
existing_member = TeamChatMember.query.filter_by(
|
||
|
|
chat_id=chat_id,
|
||
|
|
user_id=user_id
|
||
|
|
).first()
|
||
|
|
|
||
|
|
if existing_member:
|
||
|
|
return existing_member
|
||
|
|
|
||
|
|
member = TeamChatMember(
|
||
|
|
chat_id=chat_id,
|
||
|
|
user_id=user_id
|
||
|
|
)
|
||
|
|
|
||
|
|
db.session.add(member)
|
||
|
|
db.session.commit()
|
||
|
|
|
||
|
|
return member
|
||
|
|
|
||
|
|
def get_team_members(self, chat_id: int) -> List[User]:
|
||
|
|
"""
|
||
|
|
Get all members of a team chat.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the team chat.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of users.
|
||
|
|
"""
|
||
|
|
member_ids = db.session.query(TeamChatMember.user_id).filter_by(chat_id=chat_id).all()
|
||
|
|
member_ids = [user_id for (user_id,) in member_ids]
|
||
|
|
|
||
|
|
return User.query.filter(User.id.in_(member_ids)).all()
|
||
|
|
|
||
|
|
def remove_team_member(self, chat_id: int, user_id: int) -> bool:
|
||
|
|
"""
|
||
|
|
Remove a user from a team chat.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the team chat.
|
||
|
|
user_id: ID of the user to remove.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if removal was successful, False otherwise.
|
||
|
|
"""
|
||
|
|
member = TeamChatMember.query.filter_by(
|
||
|
|
chat_id=chat_id,
|
||
|
|
user_id=user_id
|
||
|
|
).first()
|
||
|
|
|
||
|
|
if not member:
|
||
|
|
return False
|
||
|
|
|
||
|
|
db.session.delete(member)
|
||
|
|
db.session.commit()
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def delete_chat(self, chat_id: int) -> bool:
|
||
|
|
"""
|
||
|
|
Delete a chat and all its messages.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
chat_id: ID of the chat to delete.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if deletion was successful, False otherwise.
|
||
|
|
"""
|
||
|
|
chat = Chat.query.get(chat_id)
|
||
|
|
if not chat:
|
||
|
|
return False
|
||
|
|
|
||
|
|
try:
|
||
|
|
db.session.delete(chat)
|
||
|
|
db.session.commit()
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
# Log the error
|
||
|
|
print(f"Error deleting chat {chat_id}: {str(e)}")
|
||
|
|
db.session.rollback()
|
||
|
|
return False
|