Source code for dagster._core.definitions.dependency

from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from typing import (
    TYPE_CHECKING,
    AbstractSet,
    Any,
    DefaultDict,
    Dict,
    Iterable,
    Iterator,
    List,
    Mapping,
    NamedTuple,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import dagster._check as check
from dagster._annotations import PublicAttr, public
from dagster._core.definitions.policy import RetryPolicy
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._serdes.serdes import (
    DefaultNamedTupleSerializer,
    WhitelistMap,
    register_serdes_tuple_fallbacks,
    whitelist_for_serdes,
)
from dagster._utils import frozentags

from .hook_definition import HookDefinition
from .input import FanInInputPointer, InputDefinition, InputMapping, InputPointer
from .output import OutputDefinition
from .utils import DEFAULT_OUTPUT, struct_to_string, validate_tags

if TYPE_CHECKING:
    from dagster._core.definitions.op_definition import OpDefinition

    from .asset_layer import AssetLayer
    from .composition import MappedInputPlaceholder
    from .graph_definition import GraphDefinition
    from .node_definition import NodeDefinition
    from .resource_requirement import ResourceRequirement


[docs]class NodeInvocation( NamedTuple( "Node", [ ("name", PublicAttr[str]), ("alias", PublicAttr[Optional[str]]), ("tags", PublicAttr[Mapping[str, Any]]), ("hook_defs", PublicAttr[AbstractSet[HookDefinition]]), ("retry_policy", PublicAttr[Optional[RetryPolicy]]), ], ) ): """Identifies an instance of a node in a graph dependency structure. Args: name (str): Name of the node of which this is an instance. alias (Optional[str]): Name specific to this instance of the node. Necessary when there are multiple instances of the same node. tags (Optional[Dict[str, Any]]): Optional tags values to extend or override those set on the node definition. hook_defs (Optional[AbstractSet[HookDefinition]]): A set of hook definitions applied to the node instance. Examples: In general, users should prefer not to construct this class directly or use the :py:class:`JobDefinition` API that requires instances of this class. Instead, use the :py:func:`@job <job>` API: .. code-block:: python from dagster import job @job def my_job(): other_name = some_op.alias('other_name') some_graph(other_name(some_op)) """ def __new__( cls, name: str, alias: Optional[str] = None, tags: Optional[Mapping[str, str]] = None, hook_defs: Optional[AbstractSet[HookDefinition]] = None, retry_policy: Optional[RetryPolicy] = None, ): return super().__new__( cls, name=check.str_param(name, "name"), alias=check.opt_str_param(alias, "alias"), tags=frozentags(check.opt_mapping_param(tags, "tags", value_type=str, key_type=str)), hook_defs=frozenset( check.opt_set_param(hook_defs, "hook_defs", of_type=HookDefinition) ), retry_policy=check.opt_inst_param(retry_policy, "retry_policy", RetryPolicy), )
class Node(ABC): """ Node invocation within a graph. Identified by its name inside the graph. """ name: str definition: "NodeDefinition" graph_definition: "GraphDefinition" _additional_tags: Mapping[str, str] _hook_defs: AbstractSet[HookDefinition] _retry_policy: Optional[RetryPolicy] _inputs: Mapping[str, "NodeInput"] _outputs: Mapping[str, "NodeOutput"] def __init__( self, name: str, definition: "NodeDefinition", graph_definition: "GraphDefinition", tags: Optional[Mapping[str, str]] = None, hook_defs: Optional[AbstractSet[HookDefinition]] = None, retry_policy: Optional[RetryPolicy] = None, ): from .graph_definition import GraphDefinition from .node_definition import NodeDefinition self.name = check.str_param(name, "name") self.definition = check.inst_param(definition, "definition", NodeDefinition) self.graph_definition = check.inst_param( graph_definition, "graph_definition", GraphDefinition, ) self._additional_tags = validate_tags(tags) self._hook_defs = check.opt_set_param(hook_defs, "hook_defs", of_type=HookDefinition) self._retry_policy = check.opt_inst_param(retry_policy, "retry_policy", RetryPolicy) self._inputs = { name: NodeInput(self, input_def) for name, input_def in self.definition.input_dict.items() } self._outputs = { name: NodeOutput(self, output_def) for name, output_def in self.definition.output_dict.items() } def inputs(self) -> Iterable["NodeInput"]: return self._inputs.values() def outputs(self) -> Iterable["NodeOutput"]: return self._outputs.values() def get_input(self, name: str) -> "NodeInput": check.str_param(name, "name") return self._inputs[name] def get_output(self, name: str) -> "NodeOutput": check.str_param(name, "name") return self._outputs[name] def has_input(self, name: str) -> bool: return self.definition.has_input(name) def input_def_named(self, name: str) -> InputDefinition: return self.definition.input_def_named(name) def has_output(self, name: str) -> bool: return self.definition.has_output(name) def output_def_named(self, name: str) -> OutputDefinition: return self.definition.output_def_named(name) @property def input_dict(self) -> Mapping[str, InputDefinition]: return self.definition.input_dict @property def output_dict(self) -> Mapping[str, OutputDefinition]: return self.definition.output_dict @property def tags(self) -> frozentags: # Type-ignore temporarily pending assessment of right data structure for `tags` return self.definition.tags.updated_with(self._additional_tags) # type: ignore def container_maps_input(self, input_name: str) -> bool: return ( self.graph_definition.input_mapping_for_pointer(InputPointer(self.name, input_name)) is not None ) def container_mapped_input(self, input_name: str) -> InputMapping: mapping = self.graph_definition.input_mapping_for_pointer( InputPointer(self.name, input_name) ) if mapping is None: check.failed( f"container does not map input {input_name}, check container_maps_input first" ) return mapping def container_maps_fan_in_input(self, input_name: str, fan_in_index: int) -> bool: return ( self.graph_definition.input_mapping_for_pointer( FanInInputPointer(self.name, input_name, fan_in_index) ) is not None ) def container_mapped_fan_in_input(self, input_name: str, fan_in_index: int) -> InputMapping: mapping = self.graph_definition.input_mapping_for_pointer( FanInInputPointer(self.name, input_name, fan_in_index) ) if mapping is None: check.failed( f"container does not map fan-in {input_name} idx {fan_in_index}, check " "container_maps_fan_in_input first" ) return mapping @property def hook_defs(self) -> AbstractSet[HookDefinition]: return self._hook_defs @property def retry_policy(self) -> Optional[RetryPolicy]: return self._retry_policy @abstractmethod def describe_node(self) -> str: ... @abstractmethod def get_resource_requirements( self, outer_container: "GraphDefinition", parent_handle: Optional["NodeHandle"] = None, asset_layer: Optional["AssetLayer"] = None, ) -> Iterator["ResourceRequirement"]: ... class GraphNode(Node): definition: "GraphDefinition" def __init__( self, name: str, definition: "GraphDefinition", graph_definition: "GraphDefinition", tags: Optional[Mapping[str, str]] = None, hook_defs: Optional[AbstractSet[HookDefinition]] = None, retry_policy: Optional[RetryPolicy] = None, ): from .graph_definition import GraphDefinition check.inst_param(definition, "definition", GraphDefinition) super().__init__(name, definition, graph_definition, tags, hook_defs, retry_policy) def get_resource_requirements( self, outer_container: "GraphDefinition", parent_handle: Optional["NodeHandle"] = None, asset_layer: Optional["AssetLayer"] = None, ) -> Iterator["ResourceRequirement"]: cur_node_handle = NodeHandle(self.name, parent_handle) for node in self.definition.node_dict.values(): yield from node.get_resource_requirements( asset_layer=asset_layer, outer_container=self.definition, parent_handle=cur_node_handle, ) def describe_node(self) -> str: return f"graph '{self.name}'" class OpNode(Node): definition: "OpDefinition" def __init__( self, name: str, definition: "OpDefinition", graph_definition: "GraphDefinition", tags: Optional[Mapping[str, str]] = None, hook_defs: Optional[AbstractSet[HookDefinition]] = None, retry_policy: Optional[RetryPolicy] = None, ): from .op_definition import OpDefinition check.inst_param(definition, "definition", OpDefinition) super().__init__(name, definition, graph_definition, tags, hook_defs, retry_policy) def get_resource_requirements( self, outer_container: "GraphDefinition", parent_handle: Optional["NodeHandle"] = None, asset_layer: Optional["AssetLayer"] = None, ) -> Iterator["ResourceRequirement"]: from .resource_requirement import InputManagerRequirement cur_node_handle = NodeHandle(self.name, parent_handle) for requirement in self.definition.get_resource_requirements( (cur_node_handle, asset_layer) ): # If requirement is a root input manager requirement, but the corresponding node has an upstream output, then ignore the requirement. if ( isinstance(requirement, InputManagerRequirement) and outer_container.dependency_structure.has_deps( NodeInput(self, self.definition.input_def_named(requirement.input_name)) ) and requirement.root_input ): continue yield requirement for hook_def in self.hook_defs: yield from hook_def.get_resource_requirements(self.describe_node()) def describe_node(self) -> str: return f"op '{self.name}'" class NodeHandleSerializer(DefaultNamedTupleSerializer): @classmethod def value_to_storage_dict( cls, value: NamedTuple, whitelist_map: WhitelistMap, descent_path: str, ) -> Dict[str, Any]: storage = super().value_to_storage_dict( value, whitelist_map, descent_path, ) # persist using legacy name SolidHandle storage["__class__"] = "SolidHandle" return storage @whitelist_for_serdes(serializer=NodeHandleSerializer) class NodeHandle( # mypy does not yet support recursive types # NamedTuple("_NodeHandle", [("name", str), ("parent", Optional["NodeHandle"])]) NamedTuple("_NodeHandle", [("name", str), ("parent", Any)]) ): """ A structured object to identify nodes in the potentially recursive graph structure. """ def __new__(cls, name: str, parent: Optional["NodeHandle"]): return super(NodeHandle, cls).__new__( cls, check.str_param(name, "name"), check.opt_inst_param(parent, "parent", NodeHandle), ) def __str__(self): return self.to_string() @property def root(self): if self.parent: return self.parent.root else: return self @property def path(self) -> Sequence[str]: """Return a list representation of the handle. Inverse of NodeHandle.from_path. Returns: List[str]: """ path = [] cur = self while cur: path.append(cur.name) cur = cur.parent path.reverse() return path def to_string(self) -> str: """Return a unique string representation of the handle. Inverse of NodeHandle.from_string. """ return self.parent.to_string() + "." + self.name if self.parent else self.name def is_or_descends_from(self, handle: "NodeHandle") -> bool: """Check if the handle is or descends from another handle. Args: handle (NodeHandle): The handle to check against. Returns: bool: """ check.inst_param(handle, "handle", NodeHandle) for idx in range(len(handle.path)): if idx >= len(self.path): return False if self.path[idx] != handle.path[idx]: return False return True def pop(self, ancestor: "NodeHandle") -> Optional["NodeHandle"]: """Return a copy of the handle with some of its ancestors pruned. Args: ancestor (NodeHandle): Handle to an ancestor of the current handle. Returns: NodeHandle: Example: .. code-block:: python handle = NodeHandle('baz', NodeHandle('bar', NodeHandle('foo', None))) ancestor = NodeHandle('bar', NodeHandle('foo', None)) assert handle.pop(ancestor) == NodeHandle('baz', None) """ check.inst_param(ancestor, "ancestor", NodeHandle) check.invariant( self.is_or_descends_from(ancestor), "Handle {handle} does not descend from {ancestor}".format( handle=self.to_string(), ancestor=ancestor.to_string() ), ) return NodeHandle.from_path(self.path[len(ancestor.path) :]) def with_ancestor(self, ancestor: Optional["NodeHandle"]) -> "NodeHandle": """Returns a copy of the handle with an ancestor grafted on. Args: ancestor (NodeHandle): Handle to the new ancestor. Returns: NodeHandle: Example: .. code-block:: python handle = NodeHandle('baz', NodeHandle('bar', NodeHandle('foo', None))) ancestor = NodeHandle('quux' None) assert handle.with_ancestor(ancestor) == NodeHandle( 'baz', NodeHandle('bar', NodeHandle('foo', NodeHandle('quux', None))) ) """ check.opt_inst_param(ancestor, "ancestor", NodeHandle) return NodeHandle.from_path([*(ancestor.path if ancestor else []), *self.path]) @staticmethod def from_path(path: Sequence[str]) -> "NodeHandle": check.sequence_param(path, "path", of_type=str) cur: Optional["NodeHandle"] = None _path = list(path) while len(_path) > 0: cur = NodeHandle(name=_path.pop(0), parent=cur) if cur is None: check.failed(f"Invalid handle path {path}") return cur @staticmethod def from_string(handle_str: str) -> "NodeHandle": check.str_param(handle_str, "handle_str") path = handle_str.split(".") return NodeHandle.from_path(path) @classmethod def from_dict(cls, dict_repr: Dict[str, Any]) -> Optional["NodeHandle"]: """This method makes it possible to load a potentially nested NodeHandle after a roundtrip through json.loads(json.dumps(NodeHandle._asdict())). """ check.dict_param(dict_repr, "dict_repr", key_type=str) check.invariant( "name" in dict_repr, "Dict representation of NodeHandle must have a 'name' key" ) check.invariant( "parent" in dict_repr, "Dict representation of NodeHandle must have a 'parent' key" ) if isinstance(dict_repr["parent"], (list, tuple)): dict_repr["parent"] = NodeHandle.from_dict( { "name": dict_repr["parent"][0], "parent": dict_repr["parent"][1], } ) return NodeHandle(**{k: dict_repr[k] for k in ["name", "parent"]}) class NodeInputHandle( NamedTuple("_NodeInputHandle", [("node_handle", NodeHandle), ("input_name", str)]) ): """ A structured object to uniquely identify inputs in the potentially recursive graph structure. """ class NodeOutputHandle( NamedTuple("_NodeOutputHandle", [("node_handle", NodeHandle), ("output_name", str)]) ): """ A structured object to uniquely identify outputs in the potentially recursive graph structure. """ # previous name for NodeHandle was SolidHandle register_serdes_tuple_fallbacks({"SolidHandle": NodeHandle}) class NodeInput(NamedTuple("_NodeInput", [("node", Node), ("input_def", InputDefinition)])): def __new__(cls, node: Node, input_def: InputDefinition): return super(NodeInput, cls).__new__( cls, check.inst_param(node, "node", Node), check.inst_param(input_def, "input_def", InputDefinition), ) def _inner_str(self) -> str: return struct_to_string( "NodeInput", node_name=self.node.name, input_name=self.input_def.name, ) def __str__(self): return self._inner_str() def __repr__(self): return self._inner_str() def __hash__(self): return hash((self.node.name, self.input_def.name)) def __eq__(self, other: object) -> bool: return ( isinstance(other, NodeInput) and self.node.name == other.node.name and self.input_def.name == other.input_def.name ) @property def node_name(self) -> str: return self.node.name @property def solid_name(self) -> str: return self.node.name @property def input_name(self) -> str: return self.input_def.name class NodeOutput(NamedTuple("_NodeOutput", [("node", Node), ("output_def", OutputDefinition)])): def __new__(cls, node: Node, output_def: OutputDefinition): return super(NodeOutput, cls).__new__( cls, check.inst_param(node, "node", Node), check.inst_param(output_def, "output_def", OutputDefinition), ) def _inner_str(self) -> str: return struct_to_string( "NodeOutput", node_name=self.node.name, output_name=self.output_def.name, ) def __str__(self): return self._inner_str() def __repr__(self): return self._inner_str() def __hash__(self): return hash((self.node.name, self.output_def.name)) def __eq__(self, other: Any): return self.node.name == other.node.name and self.output_def.name == other.output_def.name def describe(self) -> str: return f"{self.node_name}:{self.output_def.name}" @property def node_name(self) -> str: return self.node.name @property def is_dynamic(self) -> bool: return self.output_def.is_dynamic class DependencyType(Enum): DIRECT = "DIRECT" FAN_IN = "FAN_IN" DYNAMIC_COLLECT = "DYNAMIC_COLLECT" class IDependencyDefinition(ABC): # pylint: disable=no-init @abstractmethod def get_node_dependencies(self) -> Sequence["DependencyDefinition"]: pass @abstractmethod def is_fan_in(self) -> bool: """The result passed to the corresponding input will be a List made from different node outputs. """
[docs]class DependencyDefinition( NamedTuple( "_DependencyDefinition", [("node", str), ("output", str), ("description", Optional[str])] ), IDependencyDefinition, ): """Represents an edge in the DAG of nodes (ops or graphs) forming a job. This object is used at the leaves of a dictionary structure that represents the complete dependency structure of a job whose keys represent the dependent node and dependent input, so this object only contains information about the dependee. Concretely, if the input named 'input' of op_b depends on the output named 'result' of op_a, and the output named 'other_result' of graph_a, the structure will look as follows: .. code-block:: python dependency_structure = { 'my_downstream_op': { 'input': DependencyDefinition('my_upstream_op', 'result') } 'my_downstream_op': { 'input': DependencyDefinition('my_upstream_graph', 'result') } } In general, users should prefer not to construct this class directly or use the :py:class:`JobDefinition` API that requires instances of this class. Instead, use the :py:func:`@job <job>` API: .. code-block:: python @job def the_job(): node_b(node_a()) Args: solid (str): (legacy) The name of the solid that is depended on, that is, from which the value passed between the two nodes originates. output (Optional[str]): The name of the output that is depended on. (default: "result") description (Optional[str]): Human-readable description of this dependency. node (str): The name of the node (op or graph) that is depended on, that is, from which the value passed between the two nodes originates. """ def __new__( cls, solid: Optional[str] = None, output: str = DEFAULT_OUTPUT, description: Optional[str] = None, node: Optional[str] = None, ): if solid and node: raise DagsterInvalidDefinitionError( "Both ``node`` and legacy ``solid`` arguments provided to DependencyDefinition." " Please use one or the other." ) if not solid and not node: raise DagsterInvalidDefinitionError( "Expected node parameter to be str for DependencyDefinition" ) node = node or solid return super(DependencyDefinition, cls).__new__( cls, check.str_param(node, "node"), check.str_param(output, "output"), check.opt_str_param(description, "description"), ) def get_node_dependencies(self) -> Sequence["DependencyDefinition"]: return [self] def is_fan_in(self) -> bool: return False def get_op_dependencies(self) -> Sequence["DependencyDefinition"]: return [self]
[docs]class MultiDependencyDefinition( NamedTuple( "_MultiDependencyDefinition", [ ( "dependencies", PublicAttr[Sequence[Union[DependencyDefinition, Type["MappedInputPlaceholder"]]]], ) ], ), IDependencyDefinition, ): """Represents a fan-in edge in the DAG of op instances forming a job. This object is used only when an input of type ``List[T]`` is assembled by fanning-in multiple upstream outputs of type ``T``. This object is used at the leaves of a dictionary structure that represents the complete dependency structure of a job or pipeline whose keys represent the dependent ops or graphs and dependent input, so this object only contains information about the dependee. Concretely, if the input named 'input' of op_c depends on the outputs named 'result' of op_a and op_b, this structure will look as follows: .. code-block:: python dependency_structure = { 'op_c': { 'input': MultiDependencyDefinition( [ DependencyDefinition('op_a', 'result'), DependencyDefinition('op_b', 'result') ] ) } } In general, users should prefer not to construct this class directly or use the :py:class:`JobDefinition` API that requires instances of this class. Instead, use the :py:func:`@job <job>` API: .. code-block:: python @job def the_job(): op_c(op_a(), op_b()) Args: dependencies (List[Union[DependencyDefinition, Type[MappedInputPlaceHolder]]]): List of upstream dependencies fanned in to this input. """ def __new__( cls, dependencies: Sequence[Union[DependencyDefinition, Type["MappedInputPlaceholder"]]], ): from .composition import MappedInputPlaceholder deps = check.sequence_param(dependencies, "dependencies") seen = {} for dep in deps: if isinstance(dep, DependencyDefinition): key = dep.node + ":" + dep.output if key in seen: raise DagsterInvalidDefinitionError( f'Duplicate dependencies on node "{dep.node}" output "{dep.output}" ' "used in the same MultiDependencyDefinition." ) seen[key] = True elif dep is MappedInputPlaceholder: pass else: check.failed("Unexpected dependencies entry {}".format(dep)) return super(MultiDependencyDefinition, cls).__new__(cls, deps) @public def get_node_dependencies(self) -> Sequence[DependencyDefinition]: return [dep for dep in self.dependencies if isinstance(dep, DependencyDefinition)]
[docs] @public def is_fan_in(self) -> bool: return True
@public def get_dependencies_and_mappings( self, ) -> Sequence[Union[DependencyDefinition, Type["MappedInputPlaceholder"]]]: return self.dependencies
class DynamicCollectDependencyDefinition( NamedTuple("_DynamicCollectDependencyDefinition", [("solid_name", str), ("output_name", str)]), IDependencyDefinition, ): def get_node_dependencies(self) -> Sequence[DependencyDefinition]: return [DependencyDefinition(self.solid_name, self.output_name)] def is_fan_in(self) -> bool: return True DepTypeAndOutputs = Tuple[ DependencyType, Union[NodeOutput, List[Union[NodeOutput, Type["MappedInputPlaceholder"]]]], ] InputToOutputMap = Dict[NodeInput, DepTypeAndOutputs] def _create_handle_dict( solid_dict: Mapping[str, Node], dep_dict: Mapping[str, Mapping[str, IDependencyDefinition]], ) -> InputToOutputMap: from .composition import MappedInputPlaceholder check.mapping_param(solid_dict, "solid_dict", key_type=str, value_type=Node) check.two_dim_mapping_param(dep_dict, "dep_dict", value_type=IDependencyDefinition) handle_dict: InputToOutputMap = {} for solid_name, input_dict in dep_dict.items(): from_solid = solid_dict[solid_name] for input_name, dep_def in input_dict.items(): if isinstance(dep_def, MultiDependencyDefinition): handles: List[Union[NodeOutput, Type[MappedInputPlaceholder]]] = [] for inner_dep in dep_def.get_dependencies_and_mappings(): if isinstance(inner_dep, DependencyDefinition): handles.append(solid_dict[inner_dep.node].get_output(inner_dep.output)) elif inner_dep is MappedInputPlaceholder: handles.append(inner_dep) else: check.failed( "Unexpected MultiDependencyDefinition dependencies type {}".format( inner_dep ) ) handle_dict[from_solid.get_input(input_name)] = (DependencyType.FAN_IN, handles) elif isinstance(dep_def, DependencyDefinition): handle_dict[from_solid.get_input(input_name)] = ( DependencyType.DIRECT, solid_dict[dep_def.node].get_output(dep_def.output), ) elif isinstance(dep_def, DynamicCollectDependencyDefinition): handle_dict[from_solid.get_input(input_name)] = ( DependencyType.DYNAMIC_COLLECT, solid_dict[dep_def.solid_name].get_output(dep_def.output_name), ) else: check.failed(f"Unknown dependency type {dep_def}") return handle_dict class DependencyStructure: @staticmethod def from_definitions(solids: Mapping[str, Node], dep_dict: Mapping[str, Any]): return DependencyStructure(list(dep_dict.keys()), _create_handle_dict(solids, dep_dict)) _node_input_index: DefaultDict[str, Dict[NodeInput, List[NodeOutput]]] _node_output_index: Dict[str, DefaultDict[NodeOutput, List[NodeInput]]] _dynamic_fan_out_index: Dict[str, NodeOutput] _collect_index: Dict[str, Set[NodeOutput]] def __init__(self, node_names: Sequence[str], input_to_output_map: InputToOutputMap): self._node_names = node_names self._input_to_output_map = input_to_output_map # Building up a couple indexes here so that one can look up all the upstream output handles # or downstream input handles in O(1). Without this, this can become O(N^2) where N is solid # count during the GraphQL query in particular # solid_name => input_handle => list[output_handle] self._node_input_index = defaultdict(dict) # solid_name => output_handle => list[input_handle] self._node_output_index = defaultdict(lambda: defaultdict(list)) # solid_name => dynamic output_handle that this solid will dupe for self._dynamic_fan_out_index = {} # solid_name => set of dynamic output_handle this collects over self._collect_index = defaultdict(set) for node_input, (dep_type, node_output_or_list) in self._input_to_output_map.items(): if dep_type == DependencyType.FAN_IN: node_output_list = [] for node_output in node_output_or_list: if not isinstance(node_output, NodeOutput): continue if node_output.is_dynamic: raise DagsterInvalidDefinitionError( "Currently, items in a fan-in dependency cannot be downstream of" " dynamic outputs. Problematic dependency on dynamic output" f' "{node_output.describe()}".' ) if self._dynamic_fan_out_index.get(node_output.node_name): raise DagsterInvalidDefinitionError( "Currently, items in a fan-in dependency cannot be downstream of" " dynamic outputs. Problematic dependency on output" f' "{node_output.describe()}", downstream of' f' "{self._dynamic_fan_out_index[node_output.node_name].describe()}".' ) node_output_list.append(node_output) elif dep_type == DependencyType.DIRECT: node_output = cast(NodeOutput, node_output_or_list) if node_output.is_dynamic: self._validate_and_set_fan_out(node_input, node_output) if self._dynamic_fan_out_index.get(node_output.node_name): self._validate_and_set_fan_out( node_input, self._dynamic_fan_out_index[node_output.node_name] ) node_output_list = [node_output] elif dep_type == DependencyType.DYNAMIC_COLLECT: node_output = cast(NodeOutput, node_output_or_list) if node_output.is_dynamic: self._validate_and_set_collect(node_input, node_output) elif self._dynamic_fan_out_index.get(node_output.node_name): self._validate_and_set_collect( node_input, self._dynamic_fan_out_index[node_output.node_name], ) else: check.failed( f"Unexpected dynamic fan in dep created {node_output} -> {node_input}" ) node_output_list = [node_output] else: check.failed(f"Unexpected dep type {dep_type}") self._node_input_index[node_input.node.name][node_input] = node_output_list for node_output in node_output_list: self._node_output_index[node_output.node.name][node_output].append(node_input) def _validate_and_set_fan_out(self, node_input: NodeInput, node_output: NodeOutput) -> None: """Helper function for populating _dynamic_fan_out_index.""" if not node_input.node.definition.input_supports_dynamic_output_dep(node_input.input_name): raise DagsterInvalidDefinitionError( f"{node_input.node.describe_node()} cannot be downstream of dynamic output" f' "{node_output.describe()}" since input "{node_input.input_name}" maps to a node' " that is already downstream of another dynamic output. Nodes cannot be downstream" " of more than one dynamic output" ) if self._collect_index.get(node_input.node_name): raise DagsterInvalidDefinitionError( f"{node_input.node.describe_node()} cannot be both downstream of dynamic output " f"{node_output.describe()} and collect over dynamic output " f"{list(self._collect_index[node_input.node_name])[0].describe()}." ) if self._dynamic_fan_out_index.get(node_input.node_name) is None: self._dynamic_fan_out_index[node_input.node_name] = node_output return if self._dynamic_fan_out_index[node_input.node_name] != node_output: raise DagsterInvalidDefinitionError( f"{node_input.node.describe_node()} cannot be downstream of more than one dynamic" f' output. It is downstream of both "{node_output.describe()}" and' f' "{self._dynamic_fan_out_index[node_input.node_name].describe()}"' ) def _validate_and_set_collect( self, node_input: NodeInput, node_output: NodeOutput, ) -> None: if self._dynamic_fan_out_index.get(node_input.node_name): raise DagsterInvalidDefinitionError( f"{node_input.node.describe_node()} cannot both collect over dynamic output " f"{node_output.describe()} and be downstream of the dynamic output " f"{self._dynamic_fan_out_index[node_input.node_name].describe()}." ) self._collect_index[node_input.node_name].add(node_output) # if the output is already fanned out if self._dynamic_fan_out_index.get(node_output.node_name): raise DagsterInvalidDefinitionError( f"{node_input.node.describe_node()} cannot be downstream of more than one dynamic" f' output. It is downstream of both "{node_output.describe()}" and' f' "{self._dynamic_fan_out_index[node_output.node_name].describe()}"' ) def all_upstream_outputs_from_node(self, node_name: str) -> Sequence[NodeOutput]: check.str_param(node_name, "node_name") # flatten out all outputs that feed into the inputs of this solid return [ output_handle for output_handle_list in self._node_input_index[node_name].values() for output_handle in output_handle_list ] def input_to_upstream_outputs_for_node( self, node_name: str ) -> Mapping[NodeInput, Sequence[NodeOutput]]: """ Returns a Dict[NodeInput, List[NodeOutput]] that encodes where all the the inputs are sourced from upstream. Usually the List[NodeOutput] will be a list of one, except for the multi-dependency case. """ check.str_param(node_name, "node_name") return self._node_input_index[node_name] def output_to_downstream_inputs_for_node( self, node_name: str ) -> Mapping[NodeOutput, Sequence[NodeInput]]: """ Returns a Dict[NodeOutput, List[NodeInput]] that represents all the downstream inputs for each output in the dictionary. """ check.str_param(node_name, "node_name") return self._node_output_index[node_name] def has_direct_dep(self, node_input: NodeInput) -> bool: check.inst_param(node_input, "node_input", NodeInput) if node_input not in self._input_to_output_map: return False dep_type, _ = self._input_to_output_map[node_input] return dep_type == DependencyType.DIRECT def get_direct_dep(self, node_input: NodeInput) -> NodeOutput: check.inst_param(node_input, "node_input", NodeInput) dep_type, dep = self._input_to_output_map[node_input] check.invariant( dep_type == DependencyType.DIRECT, f"Cannot call get_direct_dep when dep is not singular, got {dep_type}", ) return cast(NodeOutput, dep) def has_fan_in_deps(self, node_input: NodeInput) -> bool: check.inst_param(node_input, "node_input", NodeInput) if node_input not in self._input_to_output_map: return False dep_type, _ = self._input_to_output_map[node_input] return dep_type == DependencyType.FAN_IN def get_fan_in_deps( self, node_input: NodeInput ) -> Sequence[Union[NodeOutput, Type["MappedInputPlaceholder"]]]: check.inst_param(node_input, "node_input", NodeInput) dep_type, deps = self._input_to_output_map[node_input] check.invariant( dep_type == DependencyType.FAN_IN, f"Cannot call get_multi_dep when dep is not fan in, got {dep_type}", ) return cast(List[Union[NodeOutput, Type["MappedInputPlaceholder"]]], deps) def has_dynamic_fan_in_dep(self, node_input: NodeInput) -> bool: check.inst_param(node_input, "node_input", NodeInput) if node_input not in self._input_to_output_map: return False dep_type, _ = self._input_to_output_map[node_input] return dep_type == DependencyType.DYNAMIC_COLLECT def get_dynamic_fan_in_dep(self, node_input: NodeInput) -> NodeOutput: check.inst_param(node_input, "node_input", NodeInput) dep_type, dep = self._input_to_output_map[node_input] check.invariant( dep_type == DependencyType.DYNAMIC_COLLECT, f"Cannot call get_dynamic_fan_in_dep when dep is not, got {dep_type}", ) return cast(NodeOutput, dep) def has_deps(self, node_input: NodeInput) -> bool: check.inst_param(node_input, "node_input", NodeInput) return node_input in self._input_to_output_map def get_deps_list(self, node_input: NodeInput) -> Sequence[NodeOutput]: check.inst_param(node_input, "node_input", NodeInput) check.invariant(self.has_deps(node_input)) dep_type, handle_or_list = self._input_to_output_map[node_input] if dep_type == DependencyType.DIRECT: return [cast(NodeOutput, handle_or_list)] elif dep_type == DependencyType.DYNAMIC_COLLECT: return [cast(NodeOutput, handle_or_list)] elif dep_type == DependencyType.FAN_IN: return [handle for handle in handle_or_list if isinstance(handle, NodeOutput)] else: check.failed(f"Unexpected dep type {dep_type}") def inputs(self) -> Sequence[NodeInput]: return list(self._input_to_output_map.keys()) def get_upstream_dynamic_output_for_node(self, node_name: str) -> Optional[NodeOutput]: return self._dynamic_fan_out_index.get(node_name) def get_dependency_type(self, node_input: NodeInput) -> Optional[DependencyType]: result = self._input_to_output_map.get(node_input) if result is None: return None dep_type, _ = result return dep_type def is_dynamic_mapped(self, node_name: str) -> bool: return node_name in self._dynamic_fan_out_index def has_dynamic_downstreams(self, node_name: str) -> bool: for node_output in self._dynamic_fan_out_index.values(): if node_output.node_name == node_name: return True return False