diff --git a/.gitignore b/.gitignore
index 9ef8c9b..a919d16 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,4 +14,5 @@ wheels/
.env
*.session
*.session-journal
-oauth.json
\ No newline at end of file
+oauth.json
+resp.json
\ No newline at end of file
diff --git a/app/MusicProvider/SpotifyStrategy.py b/app/MusicProvider/SpotifyStrategy.py
index ee9a453..1583119 100644
--- a/app/MusicProvider/SpotifyStrategy.py
+++ b/app/MusicProvider/SpotifyStrategy.py
@@ -1,42 +1,15 @@
import asyncio
-import base64
import time
import aiohttp
from app.config import config
from app.MusicProvider.Strategy import MusicProviderStrategy
-from app.dependencies import get_session, get_session_context
+from app.dependencies import get_session_context
from sqlalchemy import select, update
from app.models import User, Track
-
-
-def convert_track(track: dict):
- if track['type'] != 'track':
- return None
-
- return Track(
- name=track['name'],
- artist=', '.join(x['name'] for x in track['artists']),
- cover_url=track['album']['images'][0]['url'],
- spotify_id=track['id']
- )
-
-
-async def refresh_token(refresh_token):
- token_headers = {
- "Content-Type": "application/x-www-form-urlencoded",
- 'Authorization': 'Basic ' + config.spotify.encoded
- }
- token_data = {
- "grant_type": "refresh_token",
- "refresh_token": refresh_token
- }
- async with aiohttp.ClientSession() as session:
- resp = await session.post("https://accounts.spotify.com/api/token", data=token_data, headers=token_headers)
- resp = await resp.json()
- return resp['access_token'], resp['expires_in']
+from app.MusicProvider.auth import refresh_token, get_oauth_creds
class SpotifyStrategy(MusicProviderStrategy):
@@ -54,11 +27,15 @@ class SpotifyStrategy(MusicProviderStrategy):
if int(time.time()) < user.spotify_auth['refresh_at']:
return user.spotify_auth['access_token']
- token, expires_in = await refresh_token(user.spotify_auth['refresh_token'])
+ token, expires_in = await refresh_token('https://accounts.spotify.com/api/token',
+ user.spotify_auth['refresh_token'],
+ config.spotify.encoded
+ )
async with get_session_context() as session:
await session.execute(
- update(User).where(User.id == self.user_id).values(spotify_access_token=token,
- spotify_refresh_at=int(time.time()) + int(expires_in))
+ update(User).where(User.id == self.user_id).values(spotify_auth=get_oauth_creds(token,
+ user.spotify_auth['refresh_token'],
+ expires_in))
)
await session.commit()
return token
@@ -74,6 +51,18 @@ class SpotifyStrategy(MusicProviderStrategy):
return None
return await resp.json()
+ @staticmethod
+ def convert_track(track: dict):
+ if track['type'] != 'track':
+ return None
+
+ return Track(
+ name=track['name'],
+ artist=', '.join(x['name'] for x in track['artists']),
+ cover_url=track['album']['images'][0]['url'],
+ spotify_id=track['id']
+ )
+
async def get_tracks(self, token) -> list[Track]:
current, recent = await asyncio.gather(
self.request('/me/player/currently-playing', token),
@@ -81,9 +70,9 @@ class SpotifyStrategy(MusicProviderStrategy):
)
tracks = []
if current:
- tracks.append(convert_track(current['item']))
+ tracks.append(self.convert_track(current['item']))
for item in recent['items']:
- tracks.append(convert_track(item['track']))
+ tracks.append(self.convert_track(item['track']))
tracks = [x for x in tracks if x]
tracks = list(dict.fromkeys(tracks))
@@ -97,5 +86,11 @@ class SpotifyStrategy(MusicProviderStrategy):
)
return resp.scalars().first()
+ def song_link(self, track: Track):
+ return f'Spotify | Other'
+
+ def track_id(self, track: Track):
+ return track.spotify_id
+
__all__ = ['SpotifyStrategy']
diff --git a/app/MusicProvider/Strategy.py b/app/MusicProvider/Strategy.py
index 1dbee3a..15a18ba 100644
--- a/app/MusicProvider/Strategy.py
+++ b/app/MusicProvider/Strategy.py
@@ -17,4 +17,12 @@ class MusicProviderStrategy(ABC):
@abstractmethod
async def fetch_track(self, track: Track) -> Track:
+ pass
+
+ @abstractmethod
+ def song_link(self, track: Track):
+ pass
+
+ @abstractmethod
+ def track_id(self, track: Track):
pass
\ No newline at end of file
diff --git a/app/MusicProvider/YMusicStrategy.py b/app/MusicProvider/YMusicStrategy.py
index 105dae2..eebc84b 100644
--- a/app/MusicProvider/YMusicStrategy.py
+++ b/app/MusicProvider/YMusicStrategy.py
@@ -1,7 +1,11 @@
+import time
+
+from app import config
+from app.MusicProvider.auth import refresh_token, get_oauth_creds
from app.dependencies import get_session_context
from app.models import Track, User
from app.MusicProvider.Strategy import MusicProviderStrategy
-from sqlalchemy import select
+from sqlalchemy import select, update
from yandex_music import ClientAsync, TracksList
@@ -13,18 +17,34 @@ class YandexMusicStrategy(MusicProviderStrategy):
)
return resp.scalars().first()
- async def handle_token(self) -> str:
+ async def handle_token(self) -> str | None:
async with get_session_context() as session:
res = await session.execute(select(User).where(User.id == self.user_id))
user: User = res.scalars().first()
if not user:
return None
- return user.ymusic_token
+
+ if int(time.time()) < user.ymusic_auth['refresh_at']:
+ return user.ymusic_auth['access_token']
+
+ token, expires_in = await refresh_token('https://oauth.yandex.com/token',
+ user.ymusic_auth['refresh_token'],
+ config.ymusic.encoded
+ )
+ async with get_session_context() as session:
+ await session.execute(
+ update(User).where(User.id == self.user_id).values(spotify_auth=get_oauth_creds(token,
+ user.ymusic_auth['refresh_token'],
+ expires_in))
+ )
+ await session.commit()
+ return token
async def get_tracks(self, token) -> list[Track]:
client = await ClientAsync(token).init()
liked: TracksList = await client.users_likes_tracks()
tracks = await client.tracks([x.id for x in liked.tracks[:5]])
+ print(tracks[0])
return [
Track(
name=x.title,
@@ -34,3 +54,9 @@ class YandexMusicStrategy(MusicProviderStrategy):
)
for x in tracks
]
+
+ def song_link(self, track: Track):
+ return f'Yandex music | Other'
+
+ def track_id(self, track: Track):
+ return track.ymusic_id
diff --git a/app/MusicProvider/__init__.py b/app/MusicProvider/__init__.py
index 995dafa..64964f7 100644
--- a/app/MusicProvider/__init__.py
+++ b/app/MusicProvider/__init__.py
@@ -2,3 +2,4 @@ from app.MusicProvider.Context import MusicProviderContext
from app.MusicProvider.SpotifyStrategy import SpotifyStrategy
from app.MusicProvider.YMusicStrategy import YandexMusicStrategy
from app.MusicProvider.Strategy import MusicProviderStrategy
+import app.MusicProvider.auth
diff --git a/app/MusicProvider/auth.py b/app/MusicProvider/auth.py
new file mode 100644
index 0000000..30574fd
--- /dev/null
+++ b/app/MusicProvider/auth.py
@@ -0,0 +1,26 @@
+import time
+
+import aiohttp
+
+
+def get_oauth_creds(token, refresh_token, expires_in):
+ return {
+ 'access_token': token,
+ 'refresh_token': refresh_token,
+ 'refresh_at': int(time.time()) + expires_in
+ }
+
+
+async def refresh_token(endpoint, refresh_token, creds):
+ token_headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ 'Authorization': 'Basic ' + creds
+ }
+ token_data = {
+ "grant_type": "refresh_token",
+ "refresh_token": refresh_token
+ }
+ async with aiohttp.ClientSession() as session:
+ resp = await session.post(endpoint, data=token_data, headers=token_headers)
+ resp = await resp.json()
+ return resp['access_token'], resp['expires_in']
diff --git a/app/__init__.py b/app/__init__.py
index 1b385ed..eff8e74 100644
--- a/app/__init__.py
+++ b/app/__init__.py
@@ -11,6 +11,7 @@ from telethon.tl.types import (
DocumentAttributeAudio,
InputPeerSelf, InputDocument
)
+from telethon.tl.custom import InlineBuilder
from telethon import functions
from telethon.utils import get_input_document
import urllib.parse
@@ -81,8 +82,8 @@ async def fetch_file(url) -> bytes:
return await response.read()
-def get_track_links(track_id) -> str:
- return f'Spotify | Other'
+def get_songlink(track_id) -> str:
+ return f'Other'
# TODO: make faster and somehow fix cover not displaying in response
@@ -103,7 +104,7 @@ async def update_dummy_file_cover(cover_url: str):
dummy_file = await client.upload_file(res.getvalue(), file_name='empty.mp3')
-async def build_response(e: events.InlineQuery.Event, track: Track):
+async def build_response(track: Track, track_id: str, links: str):
if not track.telegram_id:
dummy_file = await client.upload_file('empty.mp3')
buttons = [Button.inline('Loading', 'loading')]
@@ -114,11 +115,11 @@ async def build_response(e: events.InlineQuery.Event, track: Track):
file_reference=track.telegram_file_reference
)
buttons = None
- return e.builder.document(
+ return await InlineBuilder(client).document(
file=dummy_file,
title=track.name,
description=track.artist,
- id=track.spotify_id,
+ id=track_id,
mime_type='audio/mpeg',
attributes=[
DocumentAttributeAudio(
@@ -129,21 +130,22 @@ async def build_response(e: events.InlineQuery.Event, track: Track):
waveform=None,
)
],
- text=get_track_links(track.spotify_id),
+ text=links,
buttons=buttons
)
@client.on(events.InlineQuery())
async def query_list(e: events.InlineQuery.Event):
- context = MusicProviderContext(SpotifyStrategy(e.sender_id))
- tracks = (await context.get_tracks())[:5]
+ ctx = MusicProviderContext(YandexMusicStrategy(e.sender_id))
+ tracks = (await ctx.get_tracks())[:5]
result = []
for track in tracks:
- track = await context.get_cached_track(track)
- cache[track.spotify_id] = track
- result.append(await build_response(e, track))
+ track = await ctx.get_cached_track(track)
+ music_id = ctx.strategy.track_id(track)
+ cache[music_id] = track
+ result.append(await build_response(track, music_id, ctx.strategy.song_link(track)))
await e.answer(result)
@@ -217,7 +219,7 @@ async def send_track(e: UpdateBotInlineSend):
return
file = await download_track(track)
- await client.edit_message(e.msg_id, file=file, text=get_track_links(e.id))
+ await client.edit_message(e.msg_id, file=file)
async def main():
diff --git a/app/__main__.py b/app/__main__.py
index 1eadcc5..9d782d5 100644
--- a/app/__main__.py
+++ b/app/__main__.py
@@ -1,7 +1,7 @@
import asyncio
-
from app import main
+
async def run():
await main()
diff --git a/app/callback_listener.py b/app/callback_listener.py
index fce901f..5d7fa9a 100644
--- a/app/callback_listener.py
+++ b/app/callback_listener.py
@@ -12,6 +12,7 @@ import jwt
from app.dependencies import get_session
from app.models.user import User
from config import config, OauthCreds
+from app.MusicProvider.auth import get_oauth_creds
client = TelegramClient('nowplaying_callback', config.api_id, config.api_hash)
@@ -65,12 +66,7 @@ def get_decoded_id(string: str):
async def spotify_callback(code: str, state: str, session: AsyncSession = Depends(get_session)):
user_id = get_decoded_id(state)
token, refresh_token, expires_in = await code_to_token(code, 'https://accounts.spotify.com/api/token', config.spotify)
- creds = {
- 'access_token': token,
- 'refresh_token': refresh_token,
- 'refresh_at': int(time.time()) + expires_in
- }
-
+ creds = get_oauth_creds(token, refresh_token, expires_in)
user = await session.get(User, user_id)
if user:
user.spotify_auth = creds
@@ -88,11 +84,7 @@ async def spotify_callback(code: str, state: str, session: AsyncSession = Depend
async def ym_callback(state: str, code: str, cid: str, session: AsyncSession = Depends(get_session)):
user_id = get_decoded_id(state)
token, refresh_token, expires_in = await code_to_token(code, 'https://oauth.yandex.com/token', config.ymusic)
- creds = {
- 'access_token': token,
- 'refresh_token': refresh_token,
- 'refresh_at': int(time.time()) + expires_in
- }
+ creds = get_oauth_creds(token, refresh_token, expires_in)
user = await session.get(User, user_id)
if user:
user.ymusic_auth = creds
diff --git a/app/models/track.py b/app/models/track.py
index 9dd1809..7e07d9c 100644
--- a/app/models/track.py
+++ b/app/models/track.py
@@ -9,7 +9,6 @@ class Track(Base):
__tablename__ = 'tracks'
id: Mapped[int] = mapped_column(primary_key=True)
- telegram_reference: Mapped[Optional[dict]] = mapped_column(JSON)
telegram_id: Mapped[Optional[int]] = mapped_column(BigInteger)
telegram_access_hash: Mapped[Optional[int]] = mapped_column(BigInteger)
telegram_file_reference: Mapped[Optional[bytes]] = mapped_column(LargeBinary)
@@ -24,13 +23,12 @@ class Track(Base):
used_times: Mapped[int] = mapped_column(Integer, default=1)
def __hash__(self):
- return hash(self.spotify_id)
+ return hash(self.spotify_id or self.ymusic_id)
def __eq__(self, other):
if not isinstance(other, Track):
return NotImplemented
- return self.spotify_id == other.spotify_id
+ return (self.spotify_id or self.ymusic_id) == (other.spotify_id or other.spotify_id)
-
-__all__ = ['Track']
\ No newline at end of file
+__all__ = ['Track']