Add default service selection
This commit is contained in:
parent
b840513545
commit
2e7971863c
5 changed files with 80 additions and 10 deletions
34
alembic/versions/9f96b664be50_add_default_field.py
Normal file
34
alembic/versions/9f96b664be50_add_default_field.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
"""add default field
|
||||
|
||||
Revision ID: 9f96b664be50
|
||||
Revises: 59bdc35f510c
|
||||
Create Date: 2025-04-10 21:33:30.353105
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9f96b664be50'
|
||||
down_revision: Union[str, None] = '59bdc35f510c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('tracks', 'telegram_reference')
|
||||
op.add_column('users', sa.Column('default', sa.String(), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'default')
|
||||
op.add_column('tracks', sa.Column('telegram_reference', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True))
|
||||
# ### end Alembic commands ###
|
|
@ -18,14 +18,14 @@ import urllib.parse
|
|||
from mutagen.id3 import ID3, APIC
|
||||
import logging
|
||||
from cachetools import LRUCache
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
import jwt
|
||||
|
||||
|
||||
from app.MusicProvider import MusicProviderContext, SpotifyStrategy, YandexMusicStrategy
|
||||
from app.config import config
|
||||
from app.dependencies import get_session, get_session_context
|
||||
from app.models import Track
|
||||
from app.models import Track, User
|
||||
from app.youtube_api import name_to_youtube, download_youtube
|
||||
|
||||
logging.basicConfig(
|
||||
|
@ -40,6 +40,11 @@ client.parse_mode = 'html'
|
|||
cache = LRUCache(maxsize=100)
|
||||
|
||||
|
||||
async def get_user(user_id):
|
||||
async with get_session_context() as session:
|
||||
return await session.scalar(select(User).where(User.id == user_id))
|
||||
|
||||
|
||||
def get_spotify_link(user_id) -> str:
|
||||
params = {
|
||||
'client_id': config.spotify.client_id,
|
||||
|
@ -75,6 +80,30 @@ async def start(e: events.NewMessage.Event):
|
|||
buttons=buttons)
|
||||
|
||||
|
||||
@client.on(events.NewMessage(pattern='/default'))
|
||||
async def change_default(e: events.NewMessage.Event):
|
||||
user = await get_user(e.chat_id)
|
||||
if not user:
|
||||
return await e.respond('Please link your account first')
|
||||
buttons = []
|
||||
if user.spotify_auth:
|
||||
buttons.append(Button.inline('Spotify', 'default_spotify'))
|
||||
if user.ymusic_auth:
|
||||
buttons.append(Button.inline('Yandex music', 'default_ymusic'))
|
||||
|
||||
await e.respond('Select service you want to use as default', buttons=buttons)
|
||||
|
||||
|
||||
@client.on(events.CallbackQuery(pattern='default_*'))
|
||||
async def set_default(e: events.CallbackQuery.Event):
|
||||
async with get_session_context() as session:
|
||||
await session.execute(
|
||||
update(User).where(User.id == e.sender_id).values(default=str(e.data).split('_')[1])
|
||||
)
|
||||
await session.commit()
|
||||
await e.respond('Default service updated')
|
||||
|
||||
|
||||
async def fetch_file(url) -> bytes:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
|
@ -82,10 +111,6 @@ async def fetch_file(url) -> bytes:
|
|||
return await response.read()
|
||||
|
||||
|
||||
def get_songlink(track_id) -> str:
|
||||
return f'<a href="https://song.link/s/{track_id}">Other</a>'
|
||||
|
||||
|
||||
# TODO: make faster and somehow fix cover not displaying in response
|
||||
async def update_dummy_file_cover(cover_url: str):
|
||||
cover = await fetch_file(cover_url)
|
||||
|
@ -137,7 +162,13 @@ async def build_response(track: Track, track_id: str, links: str):
|
|||
|
||||
@client.on(events.InlineQuery())
|
||||
async def query_list(e: events.InlineQuery.Event):
|
||||
ctx = MusicProviderContext(YandexMusicStrategy(e.sender_id))
|
||||
user = await get_user(e.sender_id)
|
||||
if not user:
|
||||
return await e.answer(switch_pm='Link account first', switch_pm_param='link')
|
||||
if user.default == 'spotify' and user.spotify_auth:
|
||||
ctx = MusicProviderContext(SpotifyStrategy(e.sender_id))
|
||||
else:
|
||||
ctx = MusicProviderContext(YandexMusicStrategy(e.sender_id))
|
||||
tracks = (await ctx.get_tracks())[:5]
|
||||
result = []
|
||||
|
||||
|
|
|
@ -70,9 +70,11 @@ async def spotify_callback(code: str, state: str, session: AsyncSession = Depend
|
|||
user = await session.get(User, user_id)
|
||||
if user:
|
||||
user.spotify_auth = creds
|
||||
user.default = 'spotify'
|
||||
else:
|
||||
user = User(id=user_id,
|
||||
spotify_auth=creds
|
||||
spotify_auth=creds,
|
||||
default='spotify'
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
@ -88,9 +90,11 @@ async def ym_callback(state: str, code: str, cid: str, session: AsyncSession = D
|
|||
user = await session.get(User, user_id)
|
||||
if user:
|
||||
user.ymusic_auth = creds
|
||||
user.default = 'ymusic'
|
||||
else:
|
||||
user = User(id=user_id,
|
||||
ymusic_auth=creds
|
||||
ymusic_auth=creds,
|
||||
default='ymusic'
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
|
|
@ -12,5 +12,7 @@ class User(Base):
|
|||
spotify_auth: Mapped[dict] = mapped_column(JSON, default={})
|
||||
ymusic_auth: Mapped[dict] = mapped_column(JSON, default={})
|
||||
|
||||
default: Mapped[str]
|
||||
|
||||
|
||||
__all__ = ['User']
|
||||
|
|
|
@ -13,7 +13,6 @@ ytmusic = YTMusic('oauth.json', oauth_credentials=OAuthCredentials(client_id=con
|
|||
|
||||
def name_to_youtube(name: str):
|
||||
results = ytmusic.search(name, 'songs', limit=5)
|
||||
print(results)
|
||||
return results[0]['videoId']
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue