aboutsummaryrefslogtreecommitdiff
blob: a8a3ad4fbfa4ab58c2f365e2f6913441433700b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copyright 2018-2021 Gentoo Authors
# Distributed under the terms of the GNU General Public License v2

__all__ = (
    "RetryError",
    "retry",
)

import functools

from portage.exception import PortageException
from portage.util.futures import asyncio


class RetryError(PortageException):
    """Raised when retry fails."""

    def __init__(self):
        PortageException.__init__(self, "retry error")


def retry(
    try_max=None,
    try_timeout=None,
    overall_timeout=None,
    delay_func=None,
    reraise=False,
    loop=None,
):
    """
    Create and return a retry decorator. The decorator is intended to
    operate only on a coroutine function.

    @param try_max: maximum number of tries
    @type try_max: int or None
    @param try_timeout: number of seconds to wait for a try to succeed
            before cancelling it, which is only effective if func returns
            tasks that support cancellation
    @type try_timeout: float or None
    @param overall_timeout: number of seconds to wait for retires to
            succeed before aborting, which is only effective if func returns
            tasks that support cancellation
    @type overall_timeout: float or None
    @param delay_func: function that takes an int argument corresponding
            to the number of previous tries and returns a number of seconds
            to wait before the next try
    @type delay_func: callable
    @param reraise: Reraise the last exception, instead of RetryError
    @type reraise: bool
    @param loop: event loop
    @type loop: EventLoop
    @return: func decorated with retry support
    @rtype: callable
    """
    return functools.partial(
        _retry_wrapper, loop, try_max, try_timeout, overall_timeout, delay_func, reraise
    )


def _retry_wrapper(
    _loop, try_max, try_timeout, overall_timeout, delay_func, reraise, func, loop=None
):
    """
    Create and return a decorated function.
    """
    return functools.partial(
        _retry,
        loop or _loop,
        try_max,
        try_timeout,
        overall_timeout,
        delay_func,
        reraise,
        func,
    )


def _retry(
    loop,
    try_max,
    try_timeout,
    overall_timeout,
    delay_func,
    reraise,
    func,
    *args,
    **kwargs,
):
    """
    Retry coroutine, used to implement retry decorator.

    @return: func return value
    @rtype: asyncio.Future (or compatible)
    """
    loop = asyncio._wrap_loop(loop)
    future = loop.create_future()
    _Retry(
        future,
        loop,
        try_max,
        try_timeout,
        overall_timeout,
        delay_func,
        reraise,
        functools.partial(func, *args, **kwargs),
    )
    return future


class _Retry:
    def __init__(
        self,
        future,
        loop,
        try_max,
        try_timeout,
        overall_timeout,
        delay_func,
        reraise,
        func,
    ):
        self._future = future
        self._loop = loop
        self._try_max = try_max
        self._try_timeout = try_timeout
        self._delay_func = delay_func
        self._reraise = reraise
        self._func = func

        self._try_timeout_handle = None
        self._overall_timeout_handle = None
        self._overall_timeout_expired = None
        self._tries = 0
        self._current_task = None
        self._previous_result = None

        future.add_done_callback(self._cancel_callback)
        if overall_timeout is not None:
            self._overall_timeout_handle = loop.call_later(
                overall_timeout, self._overall_timeout_callback
            )
        self._begin_try()

    def _cancel_callback(self, future):
        if future.cancelled() and self._current_task is not None:
            self._current_task.cancel()

    def _try_timeout_callback(self):
        self._try_timeout_handle = None
        self._current_task.cancel()

    def _overall_timeout_callback(self):
        self._overall_timeout_handle = None
        self._overall_timeout_expired = True
        self._current_task.cancel()
        self._retry_error()

    def _begin_try(self):
        self._tries += 1
        self._current_task = asyncio.ensure_future(self._func(), loop=self._loop)
        self._current_task.add_done_callback(self._try_done)
        if self._try_timeout is not None:
            self._try_timeout_handle = self._loop.call_later(
                self._try_timeout, self._try_timeout_callback
            )

    def _try_done(self, future):
        self._current_task = None

        if self._try_timeout_handle is not None:
            self._try_timeout_handle.cancel()
            self._try_timeout_handle = None

        if not future.cancelled():
            # consume exception, so that the event loop
            # exception handler does not report it
            future.exception()

        if self._overall_timeout_expired:
            return

        try:
            if self._future.cancelled():
                return

            self._previous_result = future
            if not (future.cancelled() or future.exception() is not None):
                # success
                self._future.set_result(future.result())
                return
        finally:
            if self._future.done() and self._overall_timeout_handle is not None:
                self._overall_timeout_handle.cancel()
                self._overall_timeout_handle = None

        if self._try_max is not None and self._tries >= self._try_max:
            self._retry_error()
            return

        if self._delay_func is not None:
            delay = self._delay_func(self._tries)
            self._current_task = self._loop.call_later(delay, self._delay_done)
            return

        self._begin_try()

    def _delay_done(self):
        self._current_task = None

        if self._future.cancelled() or self._overall_timeout_expired:
            return

        self._begin_try()

    def _retry_error(self):
        if self._previous_result is None or self._previous_result.cancelled():
            cause = asyncio.TimeoutError()
        else:
            cause = self._previous_result.exception()

        if self._reraise:
            e = cause
        else:
            e = RetryError()
            e.__cause__ = cause

        self._future.set_exception(e)