import itertools
import json
from datetime import datetime
from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Tuple
import dagster._check as check
from dagster._annotations import experimental
from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvalidInvocationError
from dagster._core.storage.tags import (
MULTIDIMENSIONAL_PARTITION_PREFIX,
get_multidimensional_partition_tag,
)
from .partition import (
DefaultPartitionsSubset,
Partition,
PartitionsDefinition,
StaticPartitionsDefinition,
)
INVALID_STATIC_PARTITIONS_KEY_CHARACTERS = set(["|", ",", "[", "]"])
class PartitionDimensionKey(
NamedTuple("_PartitionDimensionKey", [("dimension_name", str), ("partition_key", str)])
):
"""
Representation of a single dimension of a multi-dimensional partition key.
"""
def __new__(cls, dimension_name: str, partition_key: str):
return super(PartitionDimensionKey, cls).__new__(
cls,
dimension_name=check.str_param(dimension_name, "dimension_name"),
partition_key=check.str_param(partition_key, "partition_key"),
)
[docs]class MultiPartitionKey(str):
"""
A multi-dimensional partition key stores the partition key for each dimension.
Subclasses the string class to keep partition key type as a string.
Contains additional methods to access the partition key for each dimension.
Creates a string representation of the partition key for each dimension, separated by a pipe (|).
Orders the dimensions by name, to ensure consistent string representation.
"""
dimension_keys: List[PartitionDimensionKey] = []
def __new__(cls, keys_by_dimension: Mapping[str, str]):
check.mapping_param(
keys_by_dimension, "partitions_by_dimension", key_type=str, value_type=str
)
dimension_keys: List[PartitionDimensionKey] = [
PartitionDimensionKey(dimension, keys_by_dimension[dimension])
for dimension in sorted(list(keys_by_dimension.keys()))
]
str_key = super(MultiPartitionKey, cls).__new__(
cls, "|".join([dim_key.partition_key for dim_key in dimension_keys])
)
str_key.dimension_keys = dimension_keys
return str_key
def __getnewargs__(self):
# When this instance is pickled, replace the argument to __new__ with the
# dimension key mapping instead of the string representation.
return ({dim_key.dimension_name: dim_key.partition_key for dim_key in self.dimension_keys},)
@property
def keys_by_dimension(self) -> Mapping[str, str]:
return {dim_key.dimension_name: dim_key.partition_key for dim_key in self.dimension_keys}
class PartitionDimensionDefinition(
NamedTuple(
"_PartitionDimensionDefinition",
[
("name", str),
("partitions_def", PartitionsDefinition),
],
)
):
def __new__(
cls,
name: str,
partitions_def: PartitionsDefinition,
):
return super().__new__(
cls,
name=check.str_param(name, "name"),
partitions_def=check.inst_param(partitions_def, "partitions_def", PartitionsDefinition),
)
def __eq__(self, other):
return (
isinstance(other, PartitionDimensionDefinition)
and self.name == other.name
and self.partitions_def == other.partitions_def
)
[docs]@experimental
class MultiPartitionsDefinition(PartitionsDefinition):
"""
Takes the cross-product of partitions from two partitions definitions.
For example, with a static partitions definition where the partitions are ["a", "b", "c"]
and a daily partitions definition, this partitions definition will have the following
partitions:
2020-01-01|a
2020-01-01|b
2020-01-01|c
2020-01-02|a
2020-01-02|b
...
Args:
partitions_defs (Mapping[str, PartitionsDefinition]):
A mapping of dimension name to partitions definition. The total set of partitions will
be the cross-product of the partitions from each PartitionsDefinition.
Attributes:
partitions_defs (Sequence[PartitionDimensionDefinition]):
A sequence of PartitionDimensionDefinition objects, each of which contains a dimension
name and a PartitionsDefinition. The total set of partitions will be the cross-product
of the partitions from each PartitionsDefinition. This sequence is ordered by
dimension name, to ensure consistent ordering of the partitions.
"""
def __init__(self, partitions_defs: Mapping[str, PartitionsDefinition]):
if not len(partitions_defs.keys()) == 2:
raise DagsterInvalidInvocationError(
"Dagster currently only supports multi-partitions definitions with 2 partitions"
" definitions. Your multi-partitions definition has"
f" {len(partitions_defs.keys())} partitions definitions."
)
check.mapping_param(
partitions_defs, "partitions_defs", key_type=str, value_type=PartitionsDefinition
)
for dim_name, partitions_def in partitions_defs.items():
if isinstance(partitions_def, StaticPartitionsDefinition):
if any(
[
INVALID_STATIC_PARTITIONS_KEY_CHARACTERS & set(key)
for key in partitions_def.get_partition_keys()
]
):
raise DagsterInvalidDefinitionError(
f"Invalid character in partition key for dimension {dim_name}. "
"A multi-partitions definition cannot contain partition keys with "
"the following characters: |, [, ], ,"
)
self._partitions_defs: List[PartitionDimensionDefinition] = sorted(
[
PartitionDimensionDefinition(name, partitions_def)
for name, partitions_def in partitions_defs.items()
],
key=lambda x: x.name,
)
@property
def partition_dimension_names(self) -> List[str]:
return [dim_def.name for dim_def in self._partitions_defs]
@property
def partitions_defs(self) -> Sequence[PartitionDimensionDefinition]:
return self._partitions_defs
def get_partitions(self, current_time: Optional[datetime] = None) -> Sequence[Partition]:
partition_sequences = [
partition_dim.partitions_def.get_partitions(current_time=current_time)
for partition_dim in self._partitions_defs
]
def get_multi_dimensional_partition(partitions_tuple: Tuple[Partition]) -> Partition:
check.invariant(len(partitions_tuple) == len(self._partitions_defs))
partitions_by_dimension: Dict[str, Partition] = {
self._partitions_defs[i].name: partitions_tuple[i]
for i in range(len(partitions_tuple))
}
return Partition(
value=partitions_by_dimension,
name=MultiPartitionKey(
{
dimension_key: partition.name
for dimension_key, partition in partitions_by_dimension.items()
}
),
)
return [
get_multi_dimensional_partition(partitions_tuple)
for partitions_tuple in itertools.product(*partition_sequences)
]
def __eq__(self, other):
return (
isinstance(other, MultiPartitionsDefinition)
and self.partitions_defs == other.partitions_defs
)
def __hash__(self):
return hash(
tuple(
[
(partitions_def.name, partitions_def.__repr__())
for partitions_def in self.partitions_defs
]
)
)
def __str__(self) -> str:
dimension_1 = self._partitions_defs[0]
dimension_2 = self._partitions_defs[1]
partition_str = (
"Multi-partitioned, with dimensions: \n"
f"{dimension_1.name.capitalize()}: {str(dimension_1.partitions_def)} \n"
f"{dimension_2.name.capitalize()}: {str(dimension_2.partitions_def)}"
)
return partition_str
def __repr__(self) -> str:
return f"{type(self).__name__}(dimensions={[str(dim) for dim in self.partitions_defs]}"
def get_multi_partition_key_from_str(self, partition_key_str: str) -> MultiPartitionKey:
"""
Given a string representation of a partition key, returns a MultiPartitionKey object.
"""
check.str_param(partition_key_str, "partition_key_str")
partition_key_strs = partition_key_str.split("|")
check.invariant(
len(partition_key_strs) == len(self.partitions_defs),
(
f"Expected {len(self.partitions_defs)} partition keys in partition key string"
f" {partition_key_str}, but got {len(partition_key_strs)}"
),
)
keys_per_dimension = [
(dim.name, dim.partitions_def.get_partition_keys()) for dim in self._partitions_defs
]
partition_key_dims_by_idx = dict(enumerate([dim.name for dim in self._partitions_defs]))
for idx, key in enumerate(partition_key_strs):
check.invariant(
key in keys_per_dimension[idx][1],
f"Partition key {key} not found in dimension {partition_key_dims_by_idx[idx][0]}",
)
multi_partition_key = MultiPartitionKey(
{partition_key_dims_by_idx[idx]: key for idx, key in enumerate(partition_key_strs)}
)
return multi_partition_key
def deserialize_subset(self, serialized: str) -> "MultiPartitionsSubset":
return MultiPartitionsSubset.from_serialized(self, serialized)
class MultiPartitionsSubset(DefaultPartitionsSubset):
@staticmethod
def from_serialized(
partitions_def: PartitionsDefinition, serialized: str
) -> "MultiPartitionsSubset":
if not isinstance(partitions_def, MultiPartitionsDefinition):
check.failed(
"Must pass a MultiPartitionsDefinition object to deserialize MultiPartitionsSubset."
)
return MultiPartitionsSubset(
subset=set(
[
partitions_def.get_multi_partition_key_from_str(key)
for key in json.loads(serialized)
]
),
partitions_def=partitions_def,
)
def get_tags_from_multi_partition_key(multi_partition_key: MultiPartitionKey) -> Mapping[str, str]:
check.inst_param(multi_partition_key, "multi_partition_key", MultiPartitionKey)
return {
get_multidimensional_partition_tag(dimension.dimension_name): dimension.partition_key
for dimension in multi_partition_key.dimension_keys
}
def get_multipartition_key_from_tags(tags: Mapping[str, str]) -> str:
partitions_by_dimension: Dict[str, str] = {}
for tag in tags:
if tag.startswith(MULTIDIMENSIONAL_PARTITION_PREFIX):
dimension = tag[len(MULTIDIMENSIONAL_PARTITION_PREFIX) :]
partitions_by_dimension[dimension] = tags[tag]
return MultiPartitionKey(partitions_by_dimension)