Skip to content

Commit

Permalink
Merge pull request #20 from saforem2/frontier/slurm-support
Browse files Browse the repository at this point in the history
Add support for SLURM scheduler on Frontier @ OLCF
  • Loading branch information
saforem2 authored Aug 17, 2024
2 parents 01d272a + 09030cb commit fadd835
Show file tree
Hide file tree
Showing 10 changed files with 2,295 additions and 1,257 deletions.
122 changes: 60 additions & 62 deletions src/ezpz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,30 @@
ezpz/__init__.py
"""

# from __future__ import absolute_import, annotations, division, print_function
from __future__ import absolute_import, annotations, division, print_function
from mpi4py import MPI
import logging
import logging.config
import os
import re

# import socket
from typing import Any, Optional
from typing import Union
from typing import Optional
# from typing import Union

import numpy as np
# import numpy as np
# import rich
# from rich.console import Console
# from rich.logging import RichHandler

from ezpz import dist
from ezpz import log

# from ezpz import plot
from ezpz.plot import tplot, tplot_dict
from ezpz import profile
from ezpz import configs
# from ezpz import utils
from ezpz.configs import (
BACKENDS,
BIN_DIR,
Expand Down Expand Up @@ -82,6 +85,8 @@
timeit,
timeitlogit,
)
from ezpz.utils import grab_tensor

# from ezpz.jobs import loadjobenv, savejobenv
# from ezpz import jobs
from ezpz.log import get_file_logger, get_logger
Expand All @@ -96,6 +101,13 @@
to_bool,
)
from ezpz.log.handler import FluidLogRender, RichHandler
from ezpz.history import (
format_pair,
summarize_dict,
StopWatch,
History,
BaseHistory,
)
from ezpz.log.style import (
BEAT_TIME,
COLORS,
Expand Down Expand Up @@ -126,50 +138,56 @@
# return module


try:
log_config = logging.config.dictConfig(get_logging_config())
except Exception:
pass
TERM = os.environ.get("TERM", None)
PLAIN = os.environ.get(
"NO_COLOR",
os.environ.get(
"NOCOLOR",
os.environ.get(
"COLOR", os.environ.get("COLORS", os.environ.get("DUMB", False))
),
),
)
if not PLAIN and TERM not in ["dumb", "unknown"]:
try:
log_config = logging.config.dictConfig(get_logging_config())
except Exception:
pass
else:
print("Disabling color from logs!")

logger = logging.getLogger(__name__)
logging.getLogger("sh").setLevel("WARNING")


ScalarLike = Union[int, float, bool, np.floating]
os.environ["PYTHONIOENCODING"] = "utf-8"
# noqa: E402
RANK = int(MPI.COMM_WORLD.Get_rank())
WORLD_SIZE = int(MPI.COMM_WORLD.Get_size())

LOG_LEVEL: str = os.environ.get("LOG_LEVEL", "INFO").upper()
LOG_FROM_ALL_RANKS = os.environ.get(
"LOG_FROM_ALL_RANKS",
os.environ.get(
"LOG_FROM_ALL_RANK",
False
)
"LOG_FROM_ALL_RANKS", os.environ.get("LOG_FROM_ALL_RANK", False)
)
if LOG_FROM_ALL_RANKS:
if RANK == 0:
logger.info("LOGGING FROM ALL RANKS! BE SURE YOU WANT TO DO THIS !!!")
logger.setLevel(LOG_LEVEL)
# logger.info("Setting logging level to 'INFO' on 'RANK == 0'")
# logger.info("Setting logging level to 'CRITICAL' on all others 'RANK != 0'")
# logger.info(
# " ".join(
# [
# "To disable this behavior,",
# "and log from ALL ranks (not recommended),",
# "set: 'export LOG_FROM_ALL_RANKS=1' ",
# "in your environment, and re-run.",
# ]
# )
# )
else:
if RANK == 0:
logger.info("Setting logging level to 'INFO' on 'RANK == 0'")
logger.info(
"Setting logging level to 'CRITICAL' on all others 'RANK != 0'"
)
logger.info(
' ' .join(
[
"To disable this behavior,",
"and log from ALL ranks (not recommended),",
"set: 'export LOG_FROM_ALL_RANKS=1' ",
"in your environment, and re-run."
]
)
)
logger.setLevel(LOG_LEVEL) if RANK == 0 else logger.setLevel("CRITICAL")


__all__ = [
"BACKENDS",
"BEAT_TIME",
Expand Down Expand Up @@ -198,14 +216,19 @@
"SCHEDULERS",
"STYLES",
"UTILS",
"BaseHistory",
"History",
"PyInstrumentProfiler",
"StopWatch",
"TrainConfig",
"add_columns",
"build_layout",
"check",
"cleanup",
"command_exists",
"configs",
"dist",
"format_pair",
"flatten_dict",
"get_console",
"get_context_manager",
Expand Down Expand Up @@ -258,11 +281,13 @@
"setup_torch_distributed",
"setup_wandb",
"should_do_markup",
"summarize_dict",
"timeit",
"tplot",
"tplot_dict",
"timeitlogit",
"to_bool",
"utils",
]


Expand All @@ -285,40 +310,12 @@ def get_console_from_logger(logger: logging.Logger) -> Console:

for handler in logger.handlers:
if isinstance(handler, (RichHandler, EnrichHandler)):
return handler.console
return handler.console # type: ignore
from ezpz.log.console import get_console

return get_console()


def grab_tensor(x: Any) -> Union[np.ndarray, ScalarLike, None]:
import torch
if x is None:
return None
if isinstance(x, (int, float, bool, np.floating)):
return x
if isinstance(x, list):
if isinstance(x[0], torch.Tensor):
return grab_tensor(torch.stack(x))
elif isinstance(x[0], np.ndarray):
return np.stack(x)
else:
try:
import tensorflow as tf # type:ignore
except (ImportError, ModuleNotFoundError) as exc:
raise exc
if isinstance(x[0], tf.Tensor):
return grab_tensor(tf.stack(x))
elif isinstance(x, np.ndarray):
return x
elif isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
elif callable(getattr(x, "numpy", None)):
assert callable(getattr(x, "numpy"))
return x.numpy()
raise ValueError


def get_rich_logger(name: Optional[str] = None, level: str = "INFO") -> logging.Logger:
from ezpz.log.handler import RichHandler

Expand Down Expand Up @@ -381,10 +378,11 @@ def get_enrich_logging_config_as_yaml(name: str = "enrich", level: str = "INFO")


def get_logger_new(
name: str,
level: str = "INFO",
name: str,
level: str = "INFO",
):
import yaml

config = yaml.safe_load(
get_enrich_logging_config_as_yaml(name=name, level=level),
)
Expand All @@ -394,5 +392,5 @@ def get_logger_new(
return log


if __name__ == '__main__':
if __name__ == "__main__":
pass
Loading

0 comments on commit fadd835

Please sign in to comment.