import faulthandler
import logging
import multiprocessing
import os
import sys
import tempfile
import threading
import time
import traceback
import types
import unittest
from enum import Enum
from functools import wraps
from typing import Any, Dict, List, Union
import torch
logger = logging.getLogger(__name__)
[docs]
def skip_if_no_gpu(func):
"""Skips if the world size exceeds the number of GPUs."""
@wraps(func)
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
sys.exit()
world_size = int(os.getenv("WORLD_SIZE"))
if torch.cuda.device_count() < world_size:
sys.exit()
return func(*args, **kwargs)
return wrapper
[docs]
def skip_if_lt_x_gpu(n: int):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if torch.cuda.is_available() and torch.cuda.device_count() >= n:
return func(*args, **kwargs)
sys.exit()
return wrapper
return decorator
[docs]
def with_nccl_blocking_wait(func):
"""
Convenience decorator to set/unset TORCH_NCCL_BLOCKING_WAIT flag. Note that use of
this decorator will override the setting of TORCH_NCCL_ASYNC_ERROR_HANDLING for
the particular test. After the test, both TORCH_NCCL_BLOCKING_WAIT and
TORCH_NCCL_ASYNC_ERROR_HANDLING will be restored to their original values.
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Save and unset TORCH_NCCL_ASYNC_ERROR_HANDLING
try:
cached_nccl_async_error_handling: Union[str, None] = os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]
del os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"]
except KeyError:
# TORCH_NCCL_ASYNC_ERROR_HANDLING was unset
cached_nccl_async_error_handling = None
# Save val of TORCH_NCCL_BLOCKING_WAIT and set it
try:
cached_nccl_blocking_wait: Union[str, None] = os.environ["TORCH_NCCL_BLOCKING_WAIT"]
except KeyError:
cached_nccl_blocking_wait = None
finally:
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
try:
ret = func(*args, **kwargs)
return ret
finally:
# restore old values
if cached_nccl_async_error_handling is not None:
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = cached_nccl_async_error_handling
if cached_nccl_blocking_wait is not None:
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = cached_nccl_blocking_wait
return wrapper
DEFAULT_WORLD_SIZE: int = 4
DEFAULT_TIME_OUT: int = 300
[docs]
class MultiProcessTestCase(unittest.TestCase):
MAIN_PROCESS_RANK: int = -1
TEST_ERROR_EXIT_CODE: int = 10
@property
def world_size(self) -> int:
return DEFAULT_WORLD_SIZE
[docs]
def join_or_run(self, fn):
@wraps(fn)
def wrapper(self):
if self.rank == self.MAIN_PROCESS_RANK:
self._join_processes(fn)
else:
fn()
return types.MethodType(wrapper, self)
# The main process spawns N subprocesses that run the test
def __init__(self, meethod_name: str = "runTest") -> None:
super().__init__(meethod_name)
fn = getattr(self, meethod_name)
setattr(self, meethod_name, self.join_or_run(fn))
[docs]
def setUp(self) -> None:
super().setUp()
self.skip_return_code_checks = []
self.processes: List[multiprocessing.Process] = []
self.rank: int = self.MAIN_PROCESS_RANK
self.file_name: str = tempfile.NamedTemporaryFile(delete=False).name
self.pid_to_pipe: Dict[int, multiprocessing.connection.Connection] = {}
[docs]
def tearDown(self) -> None:
super().tearDown()
for p in self.processes:
p.terminate()
self.processes = []
def _current_test_name(self) -> str:
return self.id().split(".")[-1]
def _start_processes(self, proc: Any) -> None:
self.processes = []
for rank in range(self.world_size):
parent_conn, child_conn = torch.multiprocessing.Pipe()
process = proc(
target=self.__class__._run,
name=f"process {rank}",
args=(rank, self._current_test_name(), self.file_name, child_conn),
)
process.start()
logger.info(f"Started process {rank} with pid {process.pid}")
self.pid_to_pipe[process.pid] = parent_conn
self.processes.append(process)
def _spawn_processes(self) -> None:
proc = torch.multiprocessing.get_context("spawn").Process
self._start_processes(proc)
[docs]
class Event(Enum):
GET_TRACEBACK = 1
@staticmethod
def _event_listener(parent_pipe, signal_pipe, rank: int):
logger.info(f"Starting event listener thread for rank {rank}")
while True:
ready_pipes = multiprocessing.connection.wait([parent_pipe, signal_pipe])
if parent_pipe in ready_pipes:
if parent_pipe.closed:
logger.info(f"Pipe closed for process {rank}, stopping event listener thread")
return
event = parent_pipe.recv()
logger.info(f"Received event {event} on process {rank}")
if event == MultiProcessTestCase.Event.GET_TRACEBACK:
with tempfile.NamedTemporaryFile(mode="r+") as tmp_file:
faulthandler.dump_traceback(tmp_file)
tmp_file.flush()
tmp_file.seek(0)
parent_pipe.send(tmp_file.read())
logger.info(f"Process {rank} sent traceback")
if signal_pipe in ready_pipes:
return
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.run_test(test_name, parent_pipe)
[docs]
def run_test(self, test_name: str, parent_pipe) -> None:
signal_recv_pipe, singal_send_pipe = torch.multiprocessing.Pipe(duplex=False)
event_listener_thread = threading.Thread(
target=MultiProcessTestCase._event_listener,
args=(parent_pipe, signal_recv_pipe, self.rank),
daemon=True,
)
event_listener_thread.start()
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
try:
getattr(self, test_name)()
except unittest.SkipTest as se:
logger.info(f"Process {self.rank} skipping test {test_name} for following reason: {se}")
sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
except Exception as e:
logger.error(f"Caught exception: \n{traceback.format_exc()} exiting process {self.rank}")
# Send error to parent process
parent_pipe.send(traceback.format_exc())
sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
finally:
if singal_send_pipe is not None:
singal_send_pipe.send(None)
assert event_listener_thread is not None
event_listener_thread.join()
# Close pipe after done
parent_pipe.close()
def _get_timeout_process_traceback(self) -> None:
pipes = []
for i, process in enumerate(self.processes):
if process.exitcode is None:
pipe = self.pid_to_pipe[process.pid]
try:
pipe.send(MultiProcessTestCase.Event.GET_TRACEBACK)
pipes.append((i, pipe))
except ConnectionError as e:
logger.error(f"Encountered error while trying to get traceback for process {i}: {e}")
# Wait for results
for rank, pipe in pipes:
try:
# Wait
if pipe.poll(5):
if pipe.closed:
logger.info(f"Pipe closed for process {rank}, cannot retrieve traceback")
continue
traceback = pipe.recv()
logger.error(f"Process {rank} timed out with traceback: \n\n{traceback}")
else:
logger.error("Could not retrieve traceback for timed out process: {rank}")
except ConnectionError as e:
logger.error(f"Encountered error while trying to get traceback for process {rank}: {e}")
def _join_processes(self, fn) -> None:
timeout = DEFAULT_TIME_OUT
start_time = time.time()
subprocess_error = False
try:
while True:
# Check if any subprocess exited with an error early
for i, p in enumerate(self.processes):
if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE:
logger.error(f"Process {i} terminated with exit code {p.exitcode}, terminating remaining processes")
active_children = torch.multiprocessing.active_children()
for ac in active_children:
ac.terminate()
subprocess_error = True
break
if subprocess_error:
break
# All processes have joined cleanly if they all a valid exitcode
if all(p.exitcode is not None for p in self.processes):
break
# Check if we should time out the test
elapsed = time.time() - start_time
if elapsed > timeout:
self._get_timeout_process_traceback()
logger.error(f"Time out after {timeout} seconds and killing subprocesses")
for p in self.processes:
p.terminate()
break
# Sleep to avoid busy polling
time.sleep(0.1)
elapsed_time = time.time() - start_time
if fn in self.skip_return_code_checks:
self._check_no_test_errors(elapsed_time)
else:
self._check_return_codes(elapsed_time)
finally:
# Close all pipes
for pipe in self.pid_to_pipe.values():
pipe.close()
def _check_no_test_errors(self, elapsed_time) -> None:
"""Checks that we didn't have any errors thrown in the child processes."""
for i, p in enumerate(self.processes):
if p.exitcode is None:
raise RuntimeError(f"Process {i} timed out after {elapsed_time} seconds")
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
def _check_return_codes(self, elapsed_time) -> None:
"""
Checks that the return codes of all spawned processes match, and skips
tests if they returned a return code indicating a skipping condition.
"""
# If no processes are spawned, there is nothing to check
if not self.processes:
logger.warning("Note: no subprocesses were spawned, test was likely skipped.")
return
first_process = self.processes[0]
# TODO: Enhance
errored_processes = [(i, p) for i, p in enumerate(self.processes) if p.exitcode == MultiProcessTestCase.TEST_ERROR_EXIT_CODE]
if errored_processes:
error = ""
for i, proc in errored_processes:
error_msg = self.pid_to_pipe[proc.pid].recv()
error += f"Process {i} exited with error code {MultiProcessTestCase.TEST_ERROR_EXIT_CODE} and exception:\n{error_msg}\n"
raise RuntimeError(error)
# If no process exited uncleanly, check timeouts and then exit each process cleanly
for i, p in enumerate(self.processes):
if p.exitcode is None:
raise RuntimeError(f"Process {i} terminated or timed out after {elapsed_time} seconds")
self.assertEqual(
p.exitcode,
first_process.exitcode,
msg=f"Expect process {i} exit code to match process 0 exit code of {first_process.exitcode} but got {p.exitcode}",
)
self.assertEqual(
first_process.exitcode,
0,
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
)
@property
def is_master(self) -> bool:
return self.rank == 0