aboutsummaryrefslogtreecommitdiff
blob: eb426a5c025b166d20c7de64bc04a24597f45356 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright 2020-2024 Gentoo Authors
# Distributed under the terms of the GNU General Public License v2

import functools
import multiprocessing
import sys

import portage
from portage import os
from portage.tests import TestCase
from portage.util._async.AsyncFunction import AsyncFunction
from portage.util.futures import asyncio
from portage.util.futures._asyncio.streams import _writer
from portage.util.futures.unix_events import _set_nonblocking


class AsyncFunctionTestCase(TestCase):
    @staticmethod
    def _read_from_stdin(pw):
        if pw is not None:
            os.close(pw)
        return "".join(sys.stdin)

    async def _testAsyncFunctionStdin(self, loop):
        test_string = "1\n2\n3\n"
        pr, pw = multiprocessing.Pipe(duplex=False)
        stdin_backup = os.dup(portage._get_stdin().fileno())
        os.dup2(pr.fileno(), portage._get_stdin().fileno())
        pr.close()
        try:
            reader = AsyncFunction(
                # Should automatically inherit stdin as fd_pipes[0]
                # when background is False, for things like
                # emerge --sync --ask (bug 916116).
                background=False,
                scheduler=loop,
                target=self._read_from_stdin,
                args=(
                    (
                        pw.fileno()
                        if multiprocessing.get_start_method() == "fork"
                        else None
                    ),
                ),
            )
            reader.start()
        finally:
            os.dup2(stdin_backup, portage._get_stdin().fileno())
            os.close(stdin_backup)

        _set_nonblocking(pw.fileno())
        with open(pw.fileno(), mode="wb", buffering=0, closefd=False) as pipe_write:
            await _writer(pipe_write, test_string.encode("utf_8"))
        pw.close()
        self.assertEqual((await reader.async_wait()), os.EX_OK)
        self.assertEqual(reader.result, test_string)

    def testAsyncFunctionStdin(self):
        loop = asyncio._wrap_loop()
        loop.run_until_complete(self._testAsyncFunctionStdin(loop=loop))

    def testAsyncFunctionStdinSpawn(self):
        orig_start_method = multiprocessing.get_start_method()
        if orig_start_method == "spawn":
            self.skipTest("multiprocessing start method is already spawn")
        # NOTE: An attempt was made to use multiprocessing.get_context("spawn")
        # here, but it caused the python process to terminate unexpectedly
        # during a send_handle call.
        multiprocessing.set_start_method("spawn", force=True)
        try:
            self.testAsyncFunctionStdin()
        finally:
            multiprocessing.set_start_method(orig_start_method, force=True)

    @staticmethod
    def _test_getpid_fork(preexec_fn=None):
        """
        Verify that portage.getpid() cache is updated in a forked child process.
        """
        if preexec_fn is not None:
            preexec_fn()
        loop = asyncio._wrap_loop()
        proc = AsyncFunction(scheduler=loop, target=portage.getpid)
        proc.start()
        proc.wait()
        return proc.pid == proc.result

    def test_getpid_fork(self):
        self.assertTrue(self._test_getpid_fork())

    def test_spawn_getpid(self):
        """
        Test portage.getpid() with multiprocessing spawn start method.
        """
        loop = asyncio._wrap_loop()
        proc = AsyncFunction(
            scheduler=loop,
            target=self._test_getpid_fork,
            kwargs=dict(
                preexec_fn=functools.partial(
                    multiprocessing.set_start_method, "spawn", force=True
                )
            ),
        )
        proc.start()
        self.assertEqual(proc.wait(), 0)
        self.assertTrue(proc.result)

    def test_getpid_double_fork(self):
        """
        Verify that portage.getpid() cache is updated correctly after
        two forks.
        """
        loop = asyncio._wrap_loop()
        proc = AsyncFunction(scheduler=loop, target=self._test_getpid_fork)
        proc.start()
        self.assertEqual(proc.wait(), 0)
        self.assertTrue(proc.result)