Source code for poezio.decorators

"""
Module containing various decorators
"""

from __future__ import annotations
from inspect import iscoroutinefunction

from typing import (
    cast,
    Any,
    Awaitable,
    Concatenate,
    Callable,
    Protocol,
    ParamSpec,
    TypeVar,
    TYPE_CHECKING,
    overload
)

from slixmpp import JID
from poezio import common

if TYPE_CHECKING:
    from poezio.core.core import Core


P = ParamSpec('P')
C = TypeVar('C', bound=Callable[..., Any])
T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')


class HasCore(Protocol):
    core: 'Core'


CoreSelf = TypeVar('CoreSelf', bound=HasCore)


def wrap_before_quoted(
        func: Callable[Concatenate[T, list[str], P], U],
        before: Callable[[str], list[str]]
) -> Callable[Concatenate[T, str, P], U]:
    """
    Wrapper for functions requiring "before" handlers.
    """

    if iscoroutinefunction(func):
        def awrap(self: T, args: str, /, *list_args: P.args, **kwargs: P.kwargs) -> U:
            new_args = before(args)
            return cast(U, func(self, new_args, *list_args, **kwargs))
        return awrap
    else:
        def wrap(self: T, args: str, /, *list_args: P.args, **kwargs: P.kwargs) -> U:
            new_args = before(args)
            return func(self, new_args, *list_args, **kwargs)
        return wrap


def wrap_before_ignored(
        func: Callable[Concatenate[T, P], U],
) -> Callable[Concatenate[T, str, P], U]:
    """
    Wrapper for functions requiring "before" handlers.
    """

    if iscoroutinefunction(func):
        def awrap(self: T, args: str, /, *list_args: P.args, **kwargs: P.kwargs) -> U:
            return cast(U, func(self, *list_args, **kwargs))
        return awrap
    else:
        def wrap(self: T, args: str, /, *list_args: P.args, **kwargs: P.kwargs) -> U:
            return func(self, *list_args, **kwargs)
        return wrap


def wrap_before_cancel_on_false(
        func: Callable[P, U],
        before: Callable[[list[Any], dict[str, Any]], bool],
) -> Callable[P, U | bool]:
    """
    Wrapper for functions requiring "before" handlers.
    """

    if iscoroutinefunction(func):
        def awrap(*list_args: P.args, **kwargs: P.kwargs) -> U | bool:
            if before(cast(list[Any], list_args), cast(dict[str, Any], kwargs)):
                return cast(U, func(*list_args, **kwargs))
            return False
        return awrap
    else:
        def wrap(*list_args: P.args, **kwargs: P.kwargs) -> U | bool:
            if before(cast(list[Any], list_args), cast(dict[str, Any], kwargs)):
                return func(*list_args, **kwargs)
            return False
        return wrap


@overload
def wrap_after(
        func: Callable[P, Awaitable[V]],
        after: Callable[[U, list[Any], dict[str, Any]], None],
) -> Callable[P, Awaitable[V]]:
    ...


@overload
def wrap_after(
        func: Callable[P, U],
        after: Callable[[U, list[Any], dict[str, Any]], None],
) -> Callable[P, U]:
    ...


def wrap_after(
        func: Callable[P, Awaitable[V]] | Callable[P, U],
        after: Callable[[U, list[Any], dict[str, Any]], None],
) -> Callable[P, Awaitable[V]] | Callable[P,  U]:
    """
    Wrapper for functions requiring "before" handlers.
    """

    if iscoroutinefunction(func):
        async def awrap(*list_args: P.args, **kwargs: P.kwargs) -> V:
            result = await func(*list_args, **kwargs)
            after(result, cast(list[Any], list_args), cast(dict[str, Any], kwargs))
            return cast(V, result)
        return awrap
    else:
        def wrap(*list_args: P.args, **kwargs: P.kwargs) -> U:
            result = cast(U, func(*list_args, **kwargs))
            after(result, cast(list[Any], list_args), cast(dict[str, Any], kwargs))
            return result
        return wrap


class RefreshWrapper:
    core: 'Core | None'

    def __init__(self) -> None:
        self.core = None

    def conditional(self, func: C) -> C:
        """
        Decorator to refresh the UI if the wrapped function
        returns True
        """
        def after(result: Any, args: list[Any], kwargs: dict[str, Any]) -> Any:
            if self.core is not None and result:
                self.core.refresh_window()  # pylint: disable=no-member
            return result

        wrap = wrap_after(func, after=after)

        return cast(C, wrap)

    def always(self, func: C) -> C:
        """
        Decorator that refreshes the UI no matter what after the function
        """
        def after(result: Any, args: list[Any], kwargs: dict[str, Any]) -> Any:
            if self.core is not None:
                self.core.refresh_window()  # pylint: disable=no-member
            return result

        wrap = wrap_after(func, after=after)
        return cast(C, wrap)

    def update(self, func: C) -> C:
        """
        Decorator that only updates the screen
        """

        def after(result: Any, args: list[Any], kwargs: dict[str, Any]) -> Any:
            if self.core is not None:
                self.core.doupdate()  # pylint: disable=no-member
            return result
        wrap = wrap_after(func, after=after)
        return cast(C, wrap)


refresh_wrapper = RefreshWrapper()


[docs] class CommandArgParser: """Modify the string argument of the function into a list of strings containing the right number of extracted arguments, or None if we don’t have enough. """ @staticmethod def raw(func: T) -> T: """Just call the function with a single string, which is the original string untouched """ return func @staticmethod def ignored(func: Callable[Concatenate[T, P], U]) -> Callable[Concatenate[T, str, P], U]: """ Call the function without textual arguments """ return wrap_before_ignored(func) @staticmethod def quoted(mandatory: int, optional: int = 0, defaults: list[Any] | None = None, ignore_trailing_arguments: bool = False) -> Callable[[Callable[Concatenate[T, list[str], P], U]], Callable[Concatenate[T, str, P], U]]: """The function receives a list with a number of arguments that is between the numbers `mandatory` and `optional`. If the string doesn’t contain at least `mandatory` arguments, we return None because the given arguments are invalid. If there are any remaining arguments after `mandatory` and `optional` arguments have been found (and “ignore_trailing_arguments" is not True), we append them to the last argument of the list. An argument is a string (with or without whitespaces) between two quotes ("), or a whitespace separated word (if not inside quotes). The argument `defaults` is a list of strings that are used when an optional argument is missing. For example if we accept one optional argument and none is provided, but we have one value in the `defaults` list, we use that string inplace. The `defaults` list can only replace missing optional arguments, not mandatory ones. And it should not contain more than `mandatory` values. Also you cannot Example: This method needs at least one argument, and accepts up to 3 arguments >> @command_args_parser.quoted(1, 2, ['default for first arg'], False) >> def f(args): >> print(args) >> f('coucou les amis') # We have one mandatory and two optional ['coucou', 'les', 'amis'] >> f('"coucou les amis" "PROUT PROUT"') # One mandator and only one optional, # no default for the second ['coucou les amis', 'PROUT PROUT'] >> f('') # Not enough args for mandatory number None >> f('"coucou les potes"') # One mandatory, and use the default value # for the first optional ['coucou les potes, 'default for first arg'] >> f('"un et demi" deux trois quatre cinq six') # We have three trailing arguments ['un et demi', 'deux', 'trois quatre cinq six'] """ default_args_outer = defaults or [] def first(func: Callable[Concatenate[T, list[str], P], U]) -> Callable[Concatenate[T, str, P], U]: def before(params: str) -> list[str]: default_args = default_args_outer cmdargs = params if cmdargs and cmdargs.strip(): split_args = common.shell_split(cmdargs) else: split_args = [] if len(split_args) < mandatory: return [] res, split_args = split_args[:mandatory], split_args[ mandatory:] if optional == -1: opt_args = split_args[:] else: opt_args = split_args[:optional] if opt_args: res += opt_args split_args = split_args[len(opt_args):] default_args = default_args[len(opt_args):] res += default_args if split_args and res and not ignore_trailing_arguments: res[-1] += " " + " ".join(split_args) return res wrap = wrap_before_quoted(func, before=before) return wrap return first
command_args_parser = CommandArgParser() def deny_anonymous(func: C) -> C: """Decorator to disable commands when using an anonymous account.""" def before(args: list[Any], kwargs: dict[str, Any]) -> bool: core = args[0].core if core.xmpp.anon: core.information( 'This command is not available for anonymous accounts.', 'Info' ) return False return True return cast(C, wrap_before_cancel_on_false(func, before=before)) def command_jid_from_context(func: Callable[[CoreSelf, JID | None], U]) -> Callable[[CoreSelf, list[str]], U | None]: """Decorator to get the ConversationsTab's JID, unless specified.""" from slixmpp import JID, InvalidJID from poezio import tabs from poezio.contact import Contact, Resource def before(self: CoreSelf, args: list[str]) -> JID | None: if not hasattr(self, 'core'): return None jid = None if args and args[0]: try: jid = JID(args[0][0]) except InvalidJID: self.core.information(f'Invalid JID {args[0][0]}', 'Error') return None current_tab = self.core.tabs.current_tab if jid is None: if isinstance(current_tab, tabs.RosterInfoTab): roster_tab = self.core.tabs.by_name_and_class( 'Roster', tabs.RosterInfoTab, ) if roster_tab is not None: item = roster_tab.selected_row if isinstance(item, Contact): jid = item.bare_jid elif isinstance(item, Resource): jid = JID(item.jid) chattabs = ( tabs.ConversationTab, tabs.StaticConversationTab, tabs.DynamicConversationTab, ) if isinstance(current_tab, chattabs): jid = JID(current_tab.jid.bare) return jid def wrap(self: CoreSelf, args: list[str]) -> U | None: jid = before(self, args) if jid is None: self.core.information(f'No specified JID to {func.__name__}', 'Error') return None else: return func(self, jid) def awrap(self: CoreSelf, args: list[str]) -> U | None: jid = before(self, args) if jid is None: self.core.information(f'No specified JID to {func.__name__}', 'Error') return None else: return func(self, jid) if iscoroutinefunction(func): return cast(Callable[[CoreSelf, list[str]], U], awrap) else: return cast(Callable[[CoreSelf, list[str]], U], wrap)