diff --git a/data-processing-lib/spark/src/data_processing_spark/runtime/spark/transform_orchestrator.py b/data-processing-lib/spark/src/data_processing_spark/runtime/spark/transform_orchestrator.py index 20d12f78e..c279f2b73 100644 --- a/data-processing-lib/spark/src/data_processing_spark/runtime/spark/transform_orchestrator.py +++ b/data-processing-lib/spark/src/data_processing_spark/runtime/spark/transform_orchestrator.py @@ -10,10 +10,13 @@ # limitations under the License. ################################################################################ +import os +import socket import time import traceback from datetime import datetime +import yaml from data_processing.data_access import DataAccessFactoryBase from data_processing.transform import TransformStatistics from data_processing.utils import GB, get_logger @@ -23,11 +26,53 @@ SparkTransformRuntimeConfiguration, ) from pyspark import SparkConf, SparkContext +from pyspark.sql import SparkSession logger = get_logger(__name__) +def _init_spark(runtime_config: SparkTransformRuntimeConfiguration) -> SparkSession: + server_port_https = int(os.getenv("KUBERNETES_SERVICE_PORT_HTTPS", "-1")) + if server_port_https == -1: + # running locally + spark_config = {"spark.driver.host": "127.0.0.1"} + return SparkSession.builder.appName(runtime_config.get_name()).config(map=spark_config).getOrCreate() + else: + # running in Kubernetes, use spark_profile.yml and + # environment variables for configuration + server_port = os.environ["KUBERNETES_SERVICE_PORT"] + master_url = f"k8s://https://kubernetes.default:{server_port}" + + # Read Spark configuration profile + config_filepath = os.path.abspath( + os.path.join(os.getenv("SPARK_HOME"), "work-dir", "config", "spark_profile.yml") + ) + with open(config_filepath, "r") as config_fp: + spark_config = yaml.safe_load(os.path.expandvars(config_fp.read())) + spark_config["spark.submit.deployMode"] = "client" + + # configure the executor pods from template + executor_pod_template_file = os.path.join( + os.getenv("SPARK_HOME"), + "work-dir", + "src", + "templates", + "spark-executor-pod-template.yml", + ) + spark_config["spark.kubernetes.executor.podTemplateFile"] = executor_pod_template_file + spark_config["spark.kubernetes.container.image.pullPolicy"] = "Always" + + # Pass the driver IP address to the workers for callback + myservice_url = socket.gethostbyname(socket.gethostname()) + spark_config["spark.driver.host"] = myservice_url + spark_config["spark.driver.bindAddress"] = "0.0.0.0" + spark_config["spark.decommission.enabled"] = True + logger.info(f"Launching Spark Session with configuration\n" f"{yaml.dump(spark_config, indent=2)}") + app_name = spark_config.get("spark.app.name", "my-spark-app") + return SparkSession.builder.master(master_url).appName(app_name).config(map=spark_config).getOrCreate() + + def orchestrate( runtime_config: SparkTransformRuntimeConfiguration, execution_configuration: SparkTransformExecutionConfiguration, @@ -50,8 +95,9 @@ def orchestrate( logger.error("No DataAccess instance provided - exiting") return 1 # initialize Spark - conf = SparkConf().setAppName(runtime_config.get_name()).set("spark.driver.host", "127.0.0.1") - sc = SparkContext(conf=conf) + spark_session = _init_spark(runtime_config) + sc = spark_session.sparkContext + # broadcast spark_runtime_config = sc.broadcast(runtime_config) daf = sc.broadcast(data_access_factory) spark_bcast_params = sc.broadcast(bcast_params)