From 2e7971863c359532c7e7f4b6cd9e565ee8ce10e2 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 10 Apr 2025 21:52:03 +0300 Subject: [PATCH] Add default service selection --- .../9f96b664be50_add_default_field.py | 34 ++++++++++++++ app/__init__.py | 45 ++++++++++++++++--- app/callback_listener.py | 8 +++- app/models/user.py | 2 + app/youtube_api.py | 1 - 5 files changed, 80 insertions(+), 10 deletions(-) create mode 100644 alembic/versions/9f96b664be50_add_default_field.py diff --git a/alembic/versions/9f96b664be50_add_default_field.py b/alembic/versions/9f96b664be50_add_default_field.py new file mode 100644 index 0000000..1262fd8 --- /dev/null +++ b/alembic/versions/9f96b664be50_add_default_field.py @@ -0,0 +1,34 @@ +"""add default field + +Revision ID: 9f96b664be50 +Revises: 59bdc35f510c +Create Date: 2025-04-10 21:33:30.353105 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '9f96b664be50' +down_revision: Union[str, None] = '59bdc35f510c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tracks', 'telegram_reference') + op.add_column('users', sa.Column('default', sa.String(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'default') + op.add_column('tracks', sa.Column('telegram_reference', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True)) + # ### end Alembic commands ### diff --git a/app/__init__.py b/app/__init__.py index eff8e74..31223fc 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -18,14 +18,14 @@ import urllib.parse from mutagen.id3 import ID3, APIC import logging from cachetools import LRUCache -from sqlalchemy import select +from sqlalchemy import select, update import jwt from app.MusicProvider import MusicProviderContext, SpotifyStrategy, YandexMusicStrategy from app.config import config from app.dependencies import get_session, get_session_context -from app.models import Track +from app.models import Track, User from app.youtube_api import name_to_youtube, download_youtube logging.basicConfig( @@ -40,6 +40,11 @@ client.parse_mode = 'html' cache = LRUCache(maxsize=100) +async def get_user(user_id): + async with get_session_context() as session: + return await session.scalar(select(User).where(User.id == user_id)) + + def get_spotify_link(user_id) -> str: params = { 'client_id': config.spotify.client_id, @@ -75,6 +80,30 @@ async def start(e: events.NewMessage.Event): buttons=buttons) +@client.on(events.NewMessage(pattern='/default')) +async def change_default(e: events.NewMessage.Event): + user = await get_user(e.chat_id) + if not user: + return await e.respond('Please link your account first') + buttons = [] + if user.spotify_auth: + buttons.append(Button.inline('Spotify', 'default_spotify')) + if user.ymusic_auth: + buttons.append(Button.inline('Yandex music', 'default_ymusic')) + + await e.respond('Select service you want to use as default', buttons=buttons) + + +@client.on(events.CallbackQuery(pattern='default_*')) +async def set_default(e: events.CallbackQuery.Event): + async with get_session_context() as session: + await session.execute( + update(User).where(User.id == e.sender_id).values(default=str(e.data).split('_')[1]) + ) + await session.commit() + await e.respond('Default service updated') + + async def fetch_file(url) -> bytes: async with aiohttp.ClientSession() as session: async with session.get(url) as response: @@ -82,10 +111,6 @@ async def fetch_file(url) -> bytes: return await response.read() -def get_songlink(track_id) -> str: - return f'Other' - - # TODO: make faster and somehow fix cover not displaying in response async def update_dummy_file_cover(cover_url: str): cover = await fetch_file(cover_url) @@ -137,7 +162,13 @@ async def build_response(track: Track, track_id: str, links: str): @client.on(events.InlineQuery()) async def query_list(e: events.InlineQuery.Event): - ctx = MusicProviderContext(YandexMusicStrategy(e.sender_id)) + user = await get_user(e.sender_id) + if not user: + return await e.answer(switch_pm='Link account first', switch_pm_param='link') + if user.default == 'spotify' and user.spotify_auth: + ctx = MusicProviderContext(SpotifyStrategy(e.sender_id)) + else: + ctx = MusicProviderContext(YandexMusicStrategy(e.sender_id)) tracks = (await ctx.get_tracks())[:5] result = [] diff --git a/app/callback_listener.py b/app/callback_listener.py index 5d7fa9a..ac77daf 100644 --- a/app/callback_listener.py +++ b/app/callback_listener.py @@ -70,9 +70,11 @@ async def spotify_callback(code: str, state: str, session: AsyncSession = Depend user = await session.get(User, user_id) if user: user.spotify_auth = creds + user.default = 'spotify' else: user = User(id=user_id, - spotify_auth=creds + spotify_auth=creds, + default='spotify' ) session.add(user) await session.commit() @@ -88,9 +90,11 @@ async def ym_callback(state: str, code: str, cid: str, session: AsyncSession = D user = await session.get(User, user_id) if user: user.ymusic_auth = creds + user.default = 'ymusic' else: user = User(id=user_id, - ymusic_auth=creds + ymusic_auth=creds, + default='ymusic' ) session.add(user) await session.commit() diff --git a/app/models/user.py b/app/models/user.py index f38f146..f5d5b21 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -12,5 +12,7 @@ class User(Base): spotify_auth: Mapped[dict] = mapped_column(JSON, default={}) ymusic_auth: Mapped[dict] = mapped_column(JSON, default={}) + default: Mapped[str] + __all__ = ['User'] diff --git a/app/youtube_api.py b/app/youtube_api.py index 67296e8..41cdbf6 100644 --- a/app/youtube_api.py +++ b/app/youtube_api.py @@ -13,7 +13,6 @@ ytmusic = YTMusic('oauth.json', oauth_credentials=OAuthCredentials(client_id=con def name_to_youtube(name: str): results = ytmusic.search(name, 'songs', limit=5) - print(results) return results[0]['videoId']