Skip to content

[REFACTOR] Enhancements and Code Cleanup #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: python
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion WebStreamer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


import time
from .vars import Var
from WebStreamer.bot.clients import StreamBot
from .vars import Var

__version__ = "2.2.4"
StartTime = time.time()
10 changes: 5 additions & 5 deletions WebStreamer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import sys
import asyncio
import logging
from .vars import Var
from aiohttp import web
from pyrogram import idle
from WebStreamer import utils
from WebStreamer import StreamBot
from WebStreamer.server import web_server
from WebStreamer.bot.clients import initialize_clients
from .vars import Var


logging.basicConfig(
Expand Down Expand Up @@ -43,10 +43,10 @@ async def start_services():
await server.setup()
await web.TCPSite(server, Var.BIND_ADDRESS, Var.PORT).start()
logging.info("Service Started")
logging.info("bot =>> {}".format(bot_info.first_name))
logging.info("bot =>> %s", bot_info.first_name)
if bot_info.dc_id:
logging.info("DC ID =>> {}".format(str(bot_info.dc_id)))
logging.info("URL =>> {}".format(Var.URL))
logging.info("DC ID =>> %d", bot_info.dc_id)
logging.info("URL =>> %s", Var.URL)
await idle()

async def cleanup():
Expand All @@ -63,4 +63,4 @@ async def cleanup():
finally:
loop.run_until_complete(cleanup())
loop.stop()
logging.info("Stopped Services")
logging.info("Stopped Services")
9 changes: 5 additions & 4 deletions WebStreamer/bot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@

import os
import os.path
from ..vars import Var
import logging
from typing import Dict
from pyrogram import Client
from ..vars import Var

logger = logging.getLogger("bot")

sessions_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sessions")
if Var.USE_SESSION_FILE:
logger.info("Using session files")
logger.info("Session folder path: {}".format(sessions_dir))
logger.info("Session folder path: %s", sessions_dir)
if not os.path.isdir(sessions_dir):
os.makedirs(sessions_dir)

Expand All @@ -29,5 +30,5 @@
in_memory=not Var.USE_SESSION_FILE,
)

multi_clients = {}
work_loads = {}
multi_clients: Dict[int, Client] = {}
work_loads: Dict[int, int] = {}
12 changes: 6 additions & 6 deletions WebStreamer/bot/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import asyncio
import logging
from os import environ
from ..vars import Var
from pyrogram import Client
from ..vars import Var
from . import multi_clients, work_loads, sessions_dir, StreamBot

logger = logging.getLogger("multi_client")
Expand All @@ -24,10 +24,10 @@ async def initialize_clients():
if not all_tokens:
logger.info("No additional clients found, using default client")
return

async def start_client(client_id, token):
try:
logger.info(f"Starting - Client {client_id}")
logger.info("Starting - Client %s", client_id)
if client_id == len(all_tokens):
await asyncio.sleep(2)
print("This will take some time, please wait...")
Expand All @@ -44,12 +44,12 @@ async def start_client(client_id, token):
work_loads[client_id] = 0
return client_id, client
except Exception:
logger.error(f"Failed starting Client - {client_id} Error:", exc_info=True)
logger.error("Failed starting Client - %s Error:", client_id, exc_info=True)

clients = await asyncio.gather(*[start_client(i, token) for i, token in all_tokens.items()])
multi_clients.update(dict(clients))
if len(multi_clients) != 1:
Var.MULTI_CLIENT = True
logger.info("Multi-client mode enabled")
else:
logger.info("No additional clients were initialized, using default client")
logger.info("No additional clients were initialized, using default client")
2 changes: 1 addition & 1 deletion WebStreamer/bot/plugins/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pyrogram import filters
from pyrogram.types import Message

from WebStreamer.vars import Var
from WebStreamer.vars import Var
from WebStreamer.bot import StreamBot

@StreamBot.on_message(filters.command(["start", "help"]) & filters.private)
Expand Down
29 changes: 12 additions & 17 deletions WebStreamer/bot/plugins/stream.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# This file is a part of TG-FileStreamBot
# Coding : Jyothis Jayanth [@EverythingSuckz]

import logging
from pyrogram import filters, errors
from WebStreamer.vars import Var
from urllib.parse import quote_plus
from WebStreamer.bot import StreamBot, logger
from WebStreamer.utils import get_hash, get_name
from pyrogram.enums.parse_mode import ParseMode
from pyrogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
from WebStreamer.vars import Var
from WebStreamer.bot import StreamBot, logger
from WebStreamer.utils import get_hash, get_mimetype


@StreamBot.on_message(
Expand All @@ -30,25 +28,22 @@ async def media_receive_handler(_, m: Message):
return await m.reply("You are not <b>allowed to use</b> this <a href='https://github.com/EverythingSuckz/TG-FileStreamBot'>bot</a>.", quote=True)
log_msg = await m.forward(chat_id=Var.BIN_CHANNEL)
file_hash = get_hash(log_msg, Var.HASH_LENGTH)
stream_link = f"{Var.URL}{log_msg.id}/{quote_plus(get_name(m))}?hash={file_hash}"
short_link = f"{Var.URL}{file_hash}{log_msg.id}"
logger.info(f"Generated link: {stream_link} for {m.from_user.first_name}")
mimetype = get_mimetype(log_msg)
stream_link = f"{Var.URL}{log_msg.id}?hash={file_hash}"
logger.info("Generated link: %s for %s", stream_link, m.from_user.first_name)
markup = [InlineKeyboardButton("Download", url=stream_link+"&d=true")]
if set(mimetype.split("/")) & {"video","audio","pdf"}:
markup.append(InlineKeyboardButton("Stream", url=stream_link))
try:
await m.reply_text(
text="<code>{}</code>\n(<a href='{}'>shortened</a>)".format(
stream_link, short_link
),
text=f"<code>{stream_link}</code>",
quote=True,
parse_mode=ParseMode.HTML,
reply_markup=InlineKeyboardMarkup(
[[InlineKeyboardButton("Open", url=stream_link)]]
),
reply_markup=InlineKeyboardMarkup([markup]),
)
except errors.ButtonUrlInvalid:
await m.reply_text(
text="<code>{}</code>\n\nshortened: {})".format(
stream_link, short_link
),
text=f"<code>{stream_link}</code>)",
quote=True,
parse_mode=ParseMode.HTML,
)
2 changes: 1 addition & 1 deletion WebStreamer/server/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ class InvalidHash(Exception):
message = "Invalid hash"

class FIleNotFound(Exception):
message = "File not found"
message = "File not found"
38 changes: 15 additions & 23 deletions WebStreamer/server/stream_routes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Taken from megadlbot_oss <https://github.com/eyaadh/megadlbot_oss/blob/master/mega/webserver/routes.py>
# Thanks to Eyaadh <https://github.com/eyaadh>

import re
import time
import math
import logging
import secrets
import mimetypes
from aiohttp import web
from aiohttp.http_exceptions import BadStatusLine
Expand Down Expand Up @@ -37,17 +35,11 @@ async def root_route_handler(_):
)


@routes.get(r"/{path:\S+}", allow_head=True)
@routes.get(r"/{message_id:\d+}", allow_head=True)
async def stream_handler(request: web.Request):
try:
path = request.match_info["path"]
match = re.search(r"^([0-9a-f]{%s})(\d+)$" % (Var.HASH_LENGTH), path)
if match:
secure_hash = match.group(1)
message_id = int(match.group(2))
else:
message_id = int(re.search(r"(\d+)(?:\/\S+)?", path).group(1))
secure_hash = request.rel_url.query.get("hash")
message_id = int(request.match_info["message_id"])
secure_hash = request.rel_url.query.get("hash")
return await media_streamer(request, message_id, secure_hash)
except InvalidHash as e:
raise web.HTTPForbidden(text=e.message)
Expand All @@ -63,29 +55,29 @@ async def stream_handler(request: web.Request):

async def media_streamer(request: web.Request, message_id: int, secure_hash: str):
range_header = request.headers.get("Range", 0)

index = min(work_loads, key=work_loads.get)
faster_client = multi_clients[index]

if Var.MULTI_CLIENT:
logger.info(f"Client {index} is now serving {request.remote}")
logger.info("Client %d is now serving %s", index, request.remote)

if faster_client in class_cache:
tg_connect = class_cache[faster_client]
logger.debug(f"Using cached ByteStreamer object for client {index}")
logger.debug("Using cached ByteStreamer object for client %d", index)
else:
logger.debug(f"Creating new ByteStreamer object for client {index}")
logger.debug("Creating new ByteStreamer object for client %d", index)
tg_connect = utils.ByteStreamer(faster_client)
class_cache[faster_client] = tg_connect
logger.debug("before calling get_file_properties")
file_id = await tg_connect.get_file_properties(message_id)
logger.debug("after calling get_file_properties")


if utils.get_hash(file_id.unique_id, Var.HASH_LENGTH) != secure_hash:
logger.debug(f"Invalid hash for message with ID {message_id}")
logger.debug("Invalid hash for message with ID %d", message_id)
raise InvalidHash

file_size = file_id.file_size

if range_header:
Expand Down Expand Up @@ -117,13 +109,13 @@ async def media_streamer(request: web.Request, message_id: int, secure_hash: str
)
mime_type = file_id.mime_type
file_name = utils.get_name(file_id)
disposition = "attachment"
disposition = "inline"

if not mime_type:
mime_type = mimetypes.guess_type(file_name)[0] or "application/octet-stream"

if "video/" in mime_type or "audio/" in mime_type or "/html" in mime_type:
disposition = "inline"
if request.rel_url.query.get("d") == "true":
disposition = "attachment"

return web.Response(
status=206 if range_header else 200,
Expand Down
4 changes: 2 additions & 2 deletions WebStreamer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

from .keepalive import ping_server
from .time_format import get_readable_time
from .file_properties import get_hash, get_name
from .custom_dl import ByteStreamer
from .file_properties import get_hash, get_name, get_mimetype
from .custom_dl import ByteStreamer
42 changes: 19 additions & 23 deletions WebStreamer/utils/custom_dl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import math
import asyncio
import logging
from WebStreamer import Var
from typing import Dict, Union
from WebStreamer.bot import work_loads
from typing import AsyncGenerator, Dict, Union
from pyrogram import Client, utils, raw
from .file_properties import get_file_ids
from pyrogram.session import Session, Auth
from pyrogram.errors import AuthBytesInvalid
from WebStreamer.server.exceptions import FIleNotFound
from pyrogram.file_id import FileId, FileType, ThumbnailSource
from WebStreamer.server.exceptions import FIleNotFound
from WebStreamer import Var
from WebStreamer.bot import work_loads
from .file_properties import get_file_ids

logger = logging.getLogger("streamer")

Expand Down Expand Up @@ -42,21 +41,21 @@ async def get_file_properties(self, message_id: int) -> FileId:
"""
if message_id not in self.cached_file_ids:
await self.generate_file_properties(message_id)
logger.debug(f"Cached file properties for message with ID {message_id}")
logger.debug("Cached file properties for message with ID %d", message_id)
return self.cached_file_ids[message_id]

async def generate_file_properties(self, message_id: int) -> FileId:
"""
Generates the properties of a media file on a specific message.
returns ths properties in a FIleId class.
"""
file_id = await get_file_ids(self.client, Var.BIN_CHANNEL, message_id)
logger.debug(f"Generated file ID and Unique ID for message with ID {message_id}")
logger.debug("Generated file ID and Unique ID for message with ID %d", message_id)
if not file_id:
logger.debug(f"Message with ID {message_id} not found")
logger.debug("Message with ID %d not found", message_id)
raise FIleNotFound
self.cached_file_ids[message_id] = file_id
logger.debug(f"Cached media message with ID {message_id}")
logger.debug("Cached media message with ID %d", message_id)
return self.cached_file_ids[message_id]

async def generate_media_session(self, client: Client, file_id: FileId) -> Session:
Expand Down Expand Up @@ -93,9 +92,7 @@ async def generate_media_session(self, client: Client, file_id: FileId) -> Sessi
)
break
except AuthBytesInvalid:
logger.debug(
f"Invalid authorization bytes for DC {file_id.dc_id}"
)
logger.debug("Invalid authorization bytes for DC %d", file_id.dc_id)
continue
else:
await media_session.stop()
Expand All @@ -109,10 +106,10 @@ async def generate_media_session(self, client: Client, file_id: FileId) -> Sessi
is_media=True,
)
await media_session.start()
logger.debug(f"Created media session for DC {file_id.dc_id}")
logger.debug("Created media session for DC %d", file_id.dc_id)
client.media_sessions[file_id.dc_id] = media_session
else:
logger.debug(f"Using cached media session for DC {file_id.dc_id}")
logger.debug("Using cached media session for DC %d", file_id.dc_id)
return media_session


Expand Down Expand Up @@ -141,9 +138,8 @@ async def get_location(file_id: FileId) -> Union[raw.types.InputPhotoFileLocatio

location = raw.types.InputPeerPhotoFileLocation(
peer=peer,
volume_id=file_id.volume_id,
local_id=file_id.local_id,
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG,
photo_id=file_id.media_id,
big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG
)
elif file_type == FileType.PHOTO:
location = raw.types.InputPhotoFileLocation(
Expand All @@ -170,15 +166,15 @@ async def yield_file(
last_part_cut: int,
part_count: int,
chunk_size: int,
) -> Union[str, None]:
) -> AsyncGenerator[bytes, None]:
"""
Custom generator that yields the bytes of the media file.
Modded from <https://github.com/eyaadh/megadlbot_oss/blob/master/mega/telegram/utils/custom_download.py#L20>
Thanks to Eyaadh <https://github.com/eyaadh>
"""
client = self.client
work_loads[index] += 1
logger.debug(f"Starting to yielding file with client {index}.")
logger.debug("Starting to yielding file with client %d.", index)
media_session = await self.generate_media_session(client, file_id)

current_part = 1
Expand Down Expand Up @@ -218,10 +214,10 @@ async def yield_file(
except (TimeoutError, AttributeError):
pass
finally:
logger.debug(f"Finished yielding file with {current_part} parts.")
logger.debug("Finished yielding file with %d parts.", current_part)
work_loads[index] -= 1


async def clean_cache(self) -> None:
"""
function to clean the cache to reduce memory usage
Expand Down
Loading