Changeset View
Changeset View
Standalone View
Standalone View
python_modules/dagster/dagster/grpc/server.py
import math | import math | ||||
import os | import os | ||||
import queue | import queue | ||||
import sys | import sys | ||||
import tempfile | |||||
import threading | import threading | ||||
import time | import time | ||||
import uuid | import uuid | ||||
from collections import namedtuple | from collections import namedtuple | ||||
from concurrent.futures import ThreadPoolExecutor | from concurrent.futures import ThreadPoolExecutor | ||||
from threading import Event as ThreadingEventType | from threading import Event as ThreadingEventType | ||||
from time import sleep | |||||
import grpc | import grpc | ||||
from dagster import check, seven | from dagster import check, seven | ||||
from dagster.core.code_pointer import CodePointer | from dagster.core.code_pointer import CodePointer | ||||
from dagster.core.definitions.reconstructable import ( | from dagster.core.definitions.reconstructable import ( | ||||
ReconstructableRepository, | ReconstructableRepository, | ||||
repository_def_from_target_def, | repository_def_from_target_def, | ||||
) | ) | ||||
from dagster.core.errors import DagsterUserCodeProcessError | |||||
from dagster.core.host_representation.external_data import external_repository_data_from_def | from dagster.core.host_representation.external_data import external_repository_data_from_def | ||||
from dagster.core.host_representation.origin import ExternalPipelineOrigin, ExternalRepositoryOrigin | from dagster.core.host_representation.origin import ExternalPipelineOrigin, ExternalRepositoryOrigin | ||||
from dagster.core.instance import DagsterInstance | from dagster.core.instance import DagsterInstance | ||||
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin | from dagster.core.types.loadable_target_origin import LoadableTargetOrigin | ||||
from dagster.serdes import ( | from dagster.serdes import ( | ||||
deserialize_json_to_dagster_namedtuple, | deserialize_json_to_dagster_namedtuple, | ||||
serialize_dagster_namedtuple, | serialize_dagster_namedtuple, | ||||
whitelist_for_serdes, | whitelist_for_serdes, | ||||
) | ) | ||||
from dagster.serdes.ipc import ( | from dagster.serdes.ipc import IPCErrorMessage, ipc_write_stream, open_ipc_subprocess | ||||
IPCErrorMessage, | |||||
ipc_write_stream, | |||||
open_ipc_subprocess, | |||||
read_unary_response, | |||||
) | |||||
from dagster.seven import multiprocessing | from dagster.seven import multiprocessing | ||||
from dagster.utils import find_free_port, safe_tempfile_path_unmanaged | from dagster.utils import find_free_port, safe_tempfile_path_unmanaged | ||||
from dagster.utils.error import SerializableErrorInfo, serializable_error_info_from_exc_info | from dagster.utils.error import SerializableErrorInfo, serializable_error_info_from_exc_info | ||||
from grpc_health.v1 import health, health_pb2, health_pb2_grpc | from grpc_health.v1 import health, health_pb2, health_pb2_grpc | ||||
from .__generated__ import api_pb2 | from .__generated__ import api_pb2 | ||||
from .__generated__.api_pb2_grpc import DagsterApiServicer, add_DagsterApiServicer_to_server | from .__generated__.api_pb2_grpc import DagsterApiServicer, add_DagsterApiServicer_to_server | ||||
from .impl import ( | from .impl import ( | ||||
Show All 36 Lines | |||||
STREAMING_CHUNK_SIZE = 4000000 | STREAMING_CHUNK_SIZE = 4000000 | ||||
class CouldNotBindGrpcServerToAddress(Exception): | class CouldNotBindGrpcServerToAddress(Exception): | ||||
pass | pass | ||||
class LazyRepositorySymbolsAndCodePointers: | class RepositorySymbolsAndCodePointers: | ||||
"""Enables lazily loading user code at RPC-time so that it doesn't interrupt startup and | |||||
we can gracefully handle user code errors.""" | |||||
def __init__(self, loadable_target_origin): | def __init__(self, loadable_target_origin): | ||||
self._loadable_target_origin = loadable_target_origin | self._loadable_target_origin = loadable_target_origin | ||||
self._loadable_repository_symbols = None | self._loadable_repository_symbols = None | ||||
self._code_pointers_by_repo_name = None | self._code_pointers_by_repo_name = None | ||||
def load(self): | def load(self): | ||||
self._loadable_repository_symbols = load_loadable_repository_symbols( | self._loadable_repository_symbols = load_loadable_repository_symbols( | ||||
self._loadable_target_origin | self._loadable_target_origin | ||||
) | ) | ||||
self._code_pointers_by_repo_name = build_code_pointers_by_repo_name( | self._code_pointers_by_repo_name = build_code_pointers_by_repo_name( | ||||
self._loadable_target_origin, self._loadable_repository_symbols | self._loadable_target_origin, self._loadable_repository_symbols | ||||
) | ) | ||||
@property | @property | ||||
def loadable_repository_symbols(self): | def loadable_repository_symbols(self): | ||||
if self._loadable_repository_symbols is None: | |||||
self.load() | |||||
return self._loadable_repository_symbols | return self._loadable_repository_symbols | ||||
@property | @property | ||||
def code_pointers_by_repo_name(self): | def code_pointers_by_repo_name(self): | ||||
if self._code_pointers_by_repo_name is None: | |||||
self.load() | |||||
return self._code_pointers_by_repo_name | return self._code_pointers_by_repo_name | ||||
def load_loadable_repository_symbols(loadable_target_origin): | def load_loadable_repository_symbols(loadable_target_origin): | ||||
if loadable_target_origin: | if loadable_target_origin: | ||||
loadable_targets = get_loadable_targets( | loadable_targets = get_loadable_targets( | ||||
loadable_target_origin.python_file, | loadable_target_origin.python_file, | ||||
loadable_target_origin.module_name, | loadable_target_origin.module_name, | ||||
▲ Show 20 Lines • Show All 80 Lines • ▼ Show 20 Lines | ): | ||||
# Dict[str, (multiprocessing.Process, DagsterInstance)] | # Dict[str, (multiprocessing.Process, DagsterInstance)] | ||||
self._executions = {} | self._executions = {} | ||||
# Dict[str, multiprocessing.Event] | # Dict[str, multiprocessing.Event] | ||||
self._termination_events = {} | self._termination_events = {} | ||||
self._termination_times = {} | self._termination_times = {} | ||||
self._execution_lock = threading.Lock() | self._execution_lock = threading.Lock() | ||||
self._repository_symbols_and_code_pointers = LazyRepositorySymbolsAndCodePointers( | self._serializable_load_error = None | ||||
self._repository_symbols_and_code_pointers = RepositorySymbolsAndCodePointers( | |||||
loadable_target_origin | loadable_target_origin | ||||
) | ) | ||||
if not lazy_load_user_code: | try: | ||||
self._repository_symbols_and_code_pointers.load() | self._repository_symbols_and_code_pointers.load() | ||||
except Exception: # pylint:disable=broad-except | |||||
if not lazy_load_user_code: | |||||
raise | |||||
self._serializable_load_error = serializable_error_info_from_exc_info(sys.exc_info()) | |||||
self.__last_heartbeat_time = time.time() | self.__last_heartbeat_time = time.time() | ||||
if heartbeat: | if heartbeat: | ||||
self.__heartbeat_thread = threading.Thread( | self.__heartbeat_thread = threading.Thread( | ||||
target=self._heartbeat_thread, | target=self._heartbeat_thread, | ||||
args=(heartbeat_timeout,), | args=(heartbeat_timeout,), | ||||
name="grpc-server-heartbeat", | name="grpc-server-heartbeat", | ||||
) | ) | ||||
▲ Show 20 Lines • Show All 121 Lines • ▼ Show 20 Lines | def ExecutionPlanSnapshot(self, request, _context): | ||||
) | ) | ||||
return api_pb2.ExecutionPlanSnapshotReply( | return api_pb2.ExecutionPlanSnapshotReply( | ||||
serialized_execution_plan_snapshot=serialize_dagster_namedtuple( | serialized_execution_plan_snapshot=serialize_dagster_namedtuple( | ||||
execution_plan_snapshot_or_error | execution_plan_snapshot_or_error | ||||
) | ) | ||||
) | ) | ||||
def ListRepositories(self, request, _context): | def ListRepositories(self, request, _context): | ||||
try: | if self._serializable_load_error: | ||||
return api_pb2.ListRepositoriesReply( | |||||
serialized_list_repositories_response_or_error=serialize_dagster_namedtuple( | |||||
self._serializable_load_error | |||||
) | |||||
) | |||||
response = ListRepositoriesResponse( | response = ListRepositoriesResponse( | ||||
self._repository_symbols_and_code_pointers.loadable_repository_symbols, | self._repository_symbols_and_code_pointers.loadable_repository_symbols, | ||||
executable_path=self._loadable_target_origin.executable_path | executable_path=self._loadable_target_origin.executable_path | ||||
if self._loadable_target_origin | if self._loadable_target_origin | ||||
else None, | else None, | ||||
repository_code_pointer_dict=( | repository_code_pointer_dict=( | ||||
self._repository_symbols_and_code_pointers.code_pointers_by_repo_name | self._repository_symbols_and_code_pointers.code_pointers_by_repo_name | ||||
), | ), | ||||
) | ) | ||||
except Exception: # pylint: disable=broad-except | |||||
response = serializable_error_info_from_exc_info(sys.exc_info()) | |||||
return api_pb2.ListRepositoriesReply( | return api_pb2.ListRepositoriesReply( | ||||
serialized_list_repositories_response_or_error=serialize_dagster_namedtuple(response) | serialized_list_repositories_response_or_error=serialize_dagster_namedtuple(response) | ||||
) | ) | ||||
def ExternalPartitionNames(self, request, _context): | def ExternalPartitionNames(self, request, _context): | ||||
partition_names_args = deserialize_json_to_dagster_namedtuple( | partition_names_args = deserialize_json_to_dagster_namedtuple( | ||||
request.serialized_partition_names_args | request.serialized_partition_names_args | ||||
▲ Show 20 Lines • Show All 297 Lines • ▼ Show 20 Lines | def StartRun(self, request, _context): | ||||
) | ) | ||||
self._termination_events[run_id] = termination_event | self._termination_events[run_id] = termination_event | ||||
success = None | success = None | ||||
message = None | message = None | ||||
serializable_error_info = None | serializable_error_info = None | ||||
while success is None: | while success is None: | ||||
time.sleep(EVENT_QUEUE_POLL_INTERVAL) | sleep(EVENT_QUEUE_POLL_INTERVAL) | ||||
# We use `get_nowait()` instead of `get()` so that we can handle the case where the | # We use `get_nowait()` instead of `get()` so that we can handle the case where the | ||||
# execution process has died unexpectedly -- `get()` would hang forever in that case | # execution process has died unexpectedly -- `get()` would hang forever in that case | ||||
try: | try: | ||||
dagster_event_or_ipc_error_message_or_done = event_queue.get_nowait() | dagster_event_or_ipc_error_message_or_done = event_queue.get_nowait() | ||||
except queue.Empty: | except queue.Empty: | ||||
if not execution_process.is_alive(): | if not execution_process.is_alive(): | ||||
# subprocess died unexpectedly | # subprocess died unexpectedly | ||||
success = False | success = False | ||||
▲ Show 20 Lines • Show All 224 Lines • ▼ Show 20 Lines | def __init__(self, port=None, socket=None): | ||||
+ ( | + ( | ||||
"port {port}".format(port=port) | "port {port}".format(port=port) | ||||
if port is not None | if port is not None | ||||
else "socket {socket}".format(socket=socket) | else "socket {socket}".format(socket=socket) | ||||
) | ) | ||||
) | ) | ||||
def wait_for_grpc_server(server_process, ipc_output_file, timeout=60): | def wait_for_grpc_server(server_process, client, subprocess_args, timeout=60): | ||||
event = read_unary_response(ipc_output_file, timeout=timeout, ipc_process=server_process) | start_time = time.time() | ||||
while True: | |||||
try: | |||||
client.ping("") | |||||
return | |||||
except grpc._channel._InactiveRpcError: # pylint: disable=protected-access | |||||
pass | |||||
if isinstance(event, GrpcServerFailedToBindEvent): | if time.time() - start_time > timeout: | ||||
raise CouldNotBindGrpcServerToAddress() | raise Exception( | ||||
elif isinstance(event, GrpcServerLoadErrorEvent): | f"Timed out waiting for gRPC server to start with arguments: \"{' '.join(subprocess_args)}\"" | ||||
raise DagsterUserCodeProcessError( | |||||
event.error_info.to_string(), user_code_process_error_infos=[event.error_info] | |||||
) | ) | ||||
elif isinstance(event, GrpcServerStartedEvent): | |||||
return True | if server_process.poll() != None: | ||||
else: | |||||
raise Exception( | raise Exception( | ||||
"Received unexpected IPC event from gRPC Server: {event}".format(event=event) | f"gRPC server exited with return code {server_process.returncode} while starting up with the command: \"{' '.join(subprocess_args)}\"" | ||||
) | ) | ||||
sleep(0.1) | |||||
def open_server_process( | def open_server_process( | ||||
port, | port, | ||||
socket, | socket, | ||||
loadable_target_origin=None, | loadable_target_origin=None, | ||||
max_workers=None, | max_workers=None, | ||||
heartbeat=False, | heartbeat=False, | ||||
heartbeat_timeout=30, | heartbeat_timeout=30, | ||||
lazy_load_user_code=False, | |||||
fixed_server_id=None, | fixed_server_id=None, | ||||
): | ): | ||||
check.invariant((port or socket) and not (port and socket), "Set only port or socket") | check.invariant((port or socket) and not (port and socket), "Set only port or socket") | ||||
check.opt_inst_param(loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin) | check.opt_inst_param(loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin) | ||||
check.opt_int_param(max_workers, "max_workers") | check.opt_int_param(max_workers, "max_workers") | ||||
from dagster.core.test_utils import get_mocked_system_timezone | from dagster.core.test_utils import get_mocked_system_timezone | ||||
with tempfile.TemporaryDirectory() as temp_dir: | |||||
output_file = os.path.join( | |||||
temp_dir, "grpc-server-startup-{uuid}".format(uuid=uuid.uuid4().hex) | |||||
) | |||||
mocked_system_timezone = get_mocked_system_timezone() | mocked_system_timezone = get_mocked_system_timezone() | ||||
subprocess_args = ( | subprocess_args = ( | ||||
[ | [ | ||||
loadable_target_origin.executable_path | loadable_target_origin.executable_path | ||||
if loadable_target_origin and loadable_target_origin.executable_path | if loadable_target_origin and loadable_target_origin.executable_path | ||||
else sys.executable, | else sys.executable, | ||||
"-m", | "-m", | ||||
"dagster.grpc", | "dagster.grpc", | ||||
] | ] | ||||
+ ["--lazy-load-user-code"] | |||||
+ (["--port", str(port)] if port else []) | + (["--port", str(port)] if port else []) | ||||
+ (["--socket", socket] if socket else []) | + (["--socket", socket] if socket else []) | ||||
+ (["-n", str(max_workers)] if max_workers else []) | + (["-n", str(max_workers)] if max_workers else []) | ||||
+ (["--heartbeat"] if heartbeat else []) | + (["--heartbeat"] if heartbeat else []) | ||||
+ (["--heartbeat-timeout", str(heartbeat_timeout)] if heartbeat_timeout else []) | + (["--heartbeat-timeout", str(heartbeat_timeout)] if heartbeat_timeout else []) | ||||
+ (["--lazy-load-user-code"] if lazy_load_user_code else []) | |||||
+ (["--ipc-output-file", output_file]) | |||||
+ (["--fixed-server-id", fixed_server_id] if fixed_server_id else []) | + (["--fixed-server-id", fixed_server_id] if fixed_server_id else []) | ||||
+ ( | + (["--override-system-timezone", mocked_system_timezone] if mocked_system_timezone else []) | ||||
["--override-system-timezone", mocked_system_timezone] | |||||
if mocked_system_timezone | |||||
else [] | |||||
) | |||||
) | ) | ||||
if loadable_target_origin: | if loadable_target_origin: | ||||
subprocess_args += loadable_target_origin.get_cli_args() | subprocess_args += loadable_target_origin.get_cli_args() | ||||
server_process = open_ipc_subprocess(subprocess_args) | server_process = open_ipc_subprocess(subprocess_args) | ||||
from dagster.grpc.client import DagsterGrpcClient | |||||
client = DagsterGrpcClient( | |||||
port=port, | |||||
socket=socket, | |||||
host="localhost", | |||||
) | |||||
try: | try: | ||||
wait_for_grpc_server(server_process, output_file) | wait_for_grpc_server(server_process, client, subprocess_args) | ||||
except: | except: | ||||
if server_process.poll() is None: | if server_process.poll() is None: | ||||
server_process.terminate() | server_process.terminate() | ||||
raise | raise | ||||
return server_process | return server_process | ||||
def open_server_process_on_dynamic_port( | def open_server_process_on_dynamic_port( | ||||
max_retries=10, | max_retries=10, | ||||
loadable_target_origin=None, | loadable_target_origin=None, | ||||
max_workers=None, | max_workers=None, | ||||
heartbeat=False, | heartbeat=False, | ||||
heartbeat_timeout=30, | heartbeat_timeout=30, | ||||
lazy_load_user_code=False, | |||||
fixed_server_id=None, | fixed_server_id=None, | ||||
): | ): | ||||
server_process = None | server_process = None | ||||
retries = 0 | retries = 0 | ||||
while server_process is None and retries < max_retries: | while server_process is None and retries < max_retries: | ||||
port = find_free_port() | port = find_free_port() | ||||
try: | try: | ||||
server_process = open_server_process( | server_process = open_server_process( | ||||
port=port, | port=port, | ||||
socket=None, | socket=None, | ||||
loadable_target_origin=loadable_target_origin, | loadable_target_origin=loadable_target_origin, | ||||
max_workers=max_workers, | max_workers=max_workers, | ||||
heartbeat=heartbeat, | heartbeat=heartbeat, | ||||
heartbeat_timeout=heartbeat_timeout, | heartbeat_timeout=heartbeat_timeout, | ||||
lazy_load_user_code=lazy_load_user_code, | |||||
fixed_server_id=fixed_server_id, | fixed_server_id=fixed_server_id, | ||||
) | ) | ||||
except CouldNotBindGrpcServerToAddress: | except CouldNotBindGrpcServerToAddress: | ||||
pass | pass | ||||
retries += 1 | retries += 1 | ||||
return server_process, port | return server_process, port | ||||
class GrpcServerProcess: | class GrpcServerProcess: | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
loadable_target_origin=None, | loadable_target_origin=None, | ||||
force_port=False, | force_port=False, | ||||
max_retries=10, | max_retries=10, | ||||
max_workers=None, | max_workers=None, | ||||
heartbeat=False, | heartbeat=False, | ||||
heartbeat_timeout=30, | heartbeat_timeout=30, | ||||
lazy_load_user_code=False, | |||||
fixed_server_id=None, | fixed_server_id=None, | ||||
): | ): | ||||
self.port = None | self.port = None | ||||
self.socket = None | self.socket = None | ||||
self.server_process = None | self.server_process = None | ||||
self.loadable_target_origin = check.opt_inst_param( | self.loadable_target_origin = check.opt_inst_param( | ||||
loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin | loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin | ||||
) | ) | ||||
check.bool_param(force_port, "force_port") | check.bool_param(force_port, "force_port") | ||||
check.int_param(max_retries, "max_retries") | check.int_param(max_retries, "max_retries") | ||||
check.opt_int_param(max_workers, "max_workers") | check.opt_int_param(max_workers, "max_workers") | ||||
check.bool_param(heartbeat, "heartbeat") | check.bool_param(heartbeat, "heartbeat") | ||||
check.int_param(heartbeat_timeout, "heartbeat_timeout") | check.int_param(heartbeat_timeout, "heartbeat_timeout") | ||||
check.invariant(heartbeat_timeout > 0, "heartbeat_timeout must be greater than 0") | check.invariant(heartbeat_timeout > 0, "heartbeat_timeout must be greater than 0") | ||||
check.bool_param(lazy_load_user_code, "lazy_load_user_code") | |||||
check.opt_str_param(fixed_server_id, "fixed_server_id") | check.opt_str_param(fixed_server_id, "fixed_server_id") | ||||
check.invariant( | check.invariant( | ||||
max_workers is None or max_workers > 1 if heartbeat else True, | max_workers is None or max_workers > 1 if heartbeat else True, | ||||
"max_workers must be greater than 1 or set to None if heartbeat is True. " | "max_workers must be greater than 1 or set to None if heartbeat is True. " | ||||
"If set to None, the server will use the gRPC default.", | "If set to None, the server will use the gRPC default.", | ||||
) | ) | ||||
if seven.IS_WINDOWS or force_port: | if seven.IS_WINDOWS or force_port: | ||||
self.server_process, self.port = open_server_process_on_dynamic_port( | self.server_process, self.port = open_server_process_on_dynamic_port( | ||||
max_retries=max_retries, | max_retries=max_retries, | ||||
loadable_target_origin=loadable_target_origin, | loadable_target_origin=loadable_target_origin, | ||||
max_workers=max_workers, | max_workers=max_workers, | ||||
heartbeat=heartbeat, | heartbeat=heartbeat, | ||||
heartbeat_timeout=heartbeat_timeout, | heartbeat_timeout=heartbeat_timeout, | ||||
lazy_load_user_code=lazy_load_user_code, | |||||
fixed_server_id=fixed_server_id, | fixed_server_id=fixed_server_id, | ||||
) | ) | ||||
else: | else: | ||||
self.socket = safe_tempfile_path_unmanaged() | self.socket = safe_tempfile_path_unmanaged() | ||||
self.server_process = open_server_process( | self.server_process = open_server_process( | ||||
port=None, | port=None, | ||||
socket=self.socket, | socket=self.socket, | ||||
loadable_target_origin=loadable_target_origin, | loadable_target_origin=loadable_target_origin, | ||||
max_workers=max_workers, | max_workers=max_workers, | ||||
heartbeat=heartbeat, | heartbeat=heartbeat, | ||||
heartbeat_timeout=heartbeat_timeout, | heartbeat_timeout=heartbeat_timeout, | ||||
lazy_load_user_code=lazy_load_user_code, | |||||
fixed_server_id=fixed_server_id, | fixed_server_id=fixed_server_id, | ||||
) | ) | ||||
if self.server_process is None: | if self.server_process is None: | ||||
raise CouldNotStartServerProcess(port=self.port, socket=self.socket) | raise CouldNotStartServerProcess(port=self.port, socket=self.socket) | ||||
@property | @property | ||||
def pid(self): | def pid(self): | ||||
Show All 12 Lines |