Skip to content

Commit

Permalink
Driver broadcast of large spark config variables
Browse files Browse the repository at this point in the history
Signed-off-by: Constantin M Adam <cmadam@us.ibm.com>
  • Loading branch information
cmadam committed Oct 9, 2024
1 parent e581989 commit dc7466c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
# limitations under the License.
################################################################################

from typing import Any

from data_processing.data_access import DataAccessFactoryBase
from data_processing.runtime import TransformRuntimeConfiguration
from data_processing.transform import TransformConfiguration
from data_processing_spark.runtime.spark import DefaultSparkTransformRuntime
Expand All @@ -29,6 +32,15 @@ def __init__(
super().__init__(transform_config=transform_config)
self.runtime_class = runtime_class

def get_bcast_params(self, data_access_factory: DataAccessFactoryBase) -> dict[str, Any]:
"""Allows retrieving and broadcasting to all the workers very large
configuration parameters, like the list of document IDs to remove for
fuzzy dedup, or the list of blocked web domains for block listing. This
function is called after spark initialization, and before spark_context.parallelize()
:param data_access_factory - creates data_access object to download the large config parameter
"""
return {}

def create_transform_runtime(self) -> DefaultSparkTransformRuntime:
"""
Create transform runtime with the parameters captured during apply_input_params()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from data_processing.transform import TransformStatistics
from data_processing.utils import GB, get_logger
from data_processing_spark.runtime.spark import (
SparkTransformExecutionConfiguration,
SparkTransformFileProcessor,
SparkTransformRuntimeConfiguration,
SparkTransformExecutionConfiguration,
)
from pyspark import SparkConf, SparkContext

Expand All @@ -45,6 +45,7 @@ def orchestrate(
logger.info(f"orchestrator started at {start_ts}")
# create data access
data_access = data_access_factory.create_data_access()
bcast_params = runtime_config.get_bcast_params(data_access_factory)
if data_access is None:
logger.error("No DataAccess instance provided - exiting")
return 1
Expand All @@ -53,6 +54,7 @@ def orchestrate(
sc = SparkContext(conf=conf)
spark_runtime_config = sc.broadcast(runtime_config)
daf = sc.broadcast(data_access_factory)
spark_bcast_params = sc.broadcast(bcast_params)

def process_partition(iterator):
"""
Expand All @@ -63,6 +65,7 @@ def process_partition(iterator):
# local statistics dictionary
statistics = TransformStatistics()
# create transformer runtime
bcast_params = spark_bcast_params.value
d_access_factory = daf.value
runtime_conf = spark_runtime_config.value
runtime = runtime_conf.create_transform_runtime()
Expand All @@ -77,8 +80,11 @@ def process_partition(iterator):
logger.debug(f"partition {f}")
# add additional parameters
transform_params = (
runtime.get_transform_config(partition=int(f[1]), data_access_factory=d_access_factory,
statistics=statistics))
runtime.get_transform_config(
partition=int(f[1]), data_access_factory=d_access_factory, statistics=statistics
)
| bcast_params
)
# create transform with partition number
file_processor.create_transform(transform_params)
first = False
Expand Down Expand Up @@ -128,7 +134,7 @@ def process_partition(iterator):
memory = 0.0
for i in range(executors.size()):
memory += executors.toList().apply(i)._2()._1()
resources = {"cpus": cpus, "gpus": 0, "memory": round(memory/GB, 2), "object_store": 0}
resources = {"cpus": cpus, "gpus": 0, "memory": round(memory / GB, 2), "object_store": 0}
input_params = runtime_config.get_transform_metadata() | execution_configuration.get_input_params()
metadata = {
"pipeline": execution_configuration.pipeline_id,
Expand All @@ -143,7 +149,8 @@ def process_partition(iterator):
"execution_stats": {
"num partitions": num_partitions,
"execution time, min": round((time.time() - start_time) / 60, 3),
} | resources,
}
| resources,
"job_output_stats": stats,
}
logger.debug(f"Saving job metadata: {metadata}.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,26 @@ def get_transform_config(
config/params provided to this instance's initializer. This may include the addition
of new configuration data such as ray shared memory, new actors, etc, that might be needed and
expected by the transform in its initializer and/or transform() methods.
:param partition - the partition assigned to this worker, needed by transforms like doc_id
:param data_access_factory - data access factory class being used by the RayOrchestrator.
:param statistics - reference to statistics actor
:return: dictionary of transform init params
"""
return self.params

def get_bcast_params(self, data_access_factory: DataAccessFactoryBase) -> dict[str, Any]:
"""Allows retrieving and broadcasting to all the workers very large
configuration parameters, like the list of document IDs to remove for
fuzzy dedup, or the list of blocked web domains for block listing. This
function is called after spark initialization, and before spark_context.parallelize().
:param data_access_factory - creates data_access object to download the large config parameter
"""
return {}

def compute_execution_stats(self, stats: TransformStatistics) -> None:
"""
Update/augment the given statistics object with runtime-specific additions/modifications.
:param stats: output of statistics as aggregated across all calls to all transforms.
:return: job execution statistics. These are generally reported as metadata by the Ray Orchestrator.
"""
pass
pass

0 comments on commit dc7466c

Please sign in to comment.