Skip to content

Commit

Permalink
feat: deprecate TRANSFORMERS_CACHE, use HF_HUB_CACHE everywhere (IBM#89)
Browse files Browse the repository at this point in the history
#### Motivation

`TRANSFORMERS_CACHE` is deprecated (slated for removal with Transformers
v5) and `HUGGINGFACE_HUB_CACHE` is legacy. This PR standardizes on
`HF_HUB_CACHE` to configure the cache. Also, not all operations/CLI
commands were correctly pulling from `TRANSFORMERS_CACHE` so we have
been setting both env vars anyways. After this change, everything should
work with only `HF_HUB_CACHE`.

#### Modifications

- Launcher inspects HF_HUB_CACHE to determine the model cache path
- TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE are still checked as
well, but a deprecation warning is printed
    - if multiple values are present and do not match, raise an error
- Launcher can resolve the default HF_HUB_CACHE so it does not need to
be set (HF_HOME or its default can be used instead)
- Server CLI checks TRANSFORMERS_CACHE and prints a warning if it is set
- Server CLI returns an error if both TRANSFORMERS_CACHE and
HF_HUB_CACHE are set with different values

---------

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
  • Loading branch information
tjohnson31415 authored May 10, 2024
1 parent ddc56ee commit 2358566
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 47 deletions.
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,15 @@ check-test-image:
integration-tests: check-test-image ## Run integration tests
mkdir -p /tmp/transformers_cache
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
-e TRANSFORMERS_CACHE=/transformers_cache \
-e HF_HUB_CACHE=/transformers_cache \
-w /usr/src/integration_tests \
$(TEST_IMAGE_NAME) make test

.PHONY: python-tests
python-tests: check-test-image ## Run Python tests
mkdir -p /tmp/transformers_cache
docker run --rm -v /tmp/transformers_cache:/transformers_cache \
-e HUGGINGFACE_HUB_CACHE=/transformers_cache \
-e TRANSFORMERS_CACHE=/transformers_cache \
-e HF_HUB_CACHE=/transformers_cache \
$(TEST_IMAGE_NAME) pytest -sv --ignore=server/tests/test_utils.py server/tests

.PHONY: clean
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ cd deployment

### Model configuration

When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.
When deploying TGIS, the `MODEL_NAME` environment variable can contain either the full name of a model on the Hugging Face hub (such as `google/flan-ul2`) or an absolute path to a (mounted) model directory inside the container. In the former case, the `HF_HUB_CACHE` environment variable should be set to the path of a mounted directory containing a local HF hub model cache, see [this](deployment/base/patches/pvcs/pvc.yaml) kustomize patch as an example.

### Downloading model weights

TGIS will not download model data at runtime. To populate the local HF hub cache with models so that it can be used per above, the image can be run with the following command:
```shell
text-generation-server download-weights model_name
```
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and `TRANSFORMERS_CACHE` and `HUGGINGFACE_HUB_CACHE` environment variables, and that it has write access to this mounted filesystem.
where `model_name` is the name of the model on the HF hub. Ensure that it's run with the same mounted directory and the `HF_HUB_CACHE` environment variable, and that it has write access to this mounted filesystem.

This will attempt to download weights in `.safetensors` format, and if those aren't in the HF hub will download pytorch `.bin` weights and then convert them to `.safetensors`.

Expand Down
47 changes: 38 additions & 9 deletions integration_tests/text_generation_tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def start_server(
master_port: int,
timeout=30,
model_path=None,
include_cache_env_vars=True,
env=None,
output_special_tokens=False,
):
# Download weights to the cache first
Expand Down Expand Up @@ -66,13 +66,12 @@ def start_server(
if output_special_tokens:
args.append("--output-special-tokens")

env = os.environ.copy()
if env is None:
env = os.environ.copy()

env["RUST_BACKTRACE"] = "full"
env["ESTIMATE_MEMORY"] = "manual"
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
if not include_cache_env_vars:
env.pop("TRANSFORMERS_CACHE", None)
env.pop("HUGGING_FACE_HUB_CACHE", None)

# Start the process
process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
Expand Down Expand Up @@ -455,17 +454,21 @@ async def test_time_limit_stopping(server_fixture):

# Test loading when an explicit local path is provided
def test_explicit_path():
# Test with and without providing TRANSFORMERS_CACHE env var
path = glob.glob(f'{os.environ["TRANSFORMERS_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]
for include_env_vars in [False, True]:
path = glob.glob(f'{os.environ["HF_HUB_CACHE"]}/models--bigscience--mt0-small/snapshots/*')[0]

# Test with and without providing HF_HUB_CACHE
env_with = os.environ.copy()
env_without = os.environ.copy()
env_without.pop("HF_HUB_CACHE", None)
for env in [env_with, env_without]:
p = start_server(
"bigscience/mt0-small",
".bin,.json,.model",
1,
3000,
29502,
model_path=path,
include_cache_env_vars=include_env_vars,
env=env,
)
try:
async def test_model_info() -> pb2.ModelInfoResponse:
Expand All @@ -481,6 +484,32 @@ async def test_model_info() -> pb2.ModelInfoResponse:

assert p.wait(8.0) == 0

# Test loading with only TRANSFORMERS_CACHE set
def test_transformers_cache():
env = os.environ.copy()
env["TRANSFORMERS_CACHE"] = env.pop("HF_HUB_CACHE")
p = start_server(
"bigscience/mt0-small",
".bin,.json,.model",
1,
3000,
29502,
env=env,
)
try:
async def test_model_info() -> pb2.ModelInfoResponse:
async with grpc.aio.insecure_channel('localhost:8033') as channel:
return await gpb2.GenerationServiceStub(channel).ModelInfo(pb2.ModelInfoRequest(model_id="unused"))

result = asyncio.get_event_loop().run_until_complete(test_model_info())
assert result.max_sequence_length == 200
assert result.max_new_tokens == 169
assert result.model_kind == pb2.ModelInfoResponse.ModelKind.ENCODER_DECODER
finally:
p.terminate()

assert p.wait(8.0) == 0


# To avoid errors related to event loop shutdown timing
@pytest.fixture(scope="session")
Expand Down
94 changes: 63 additions & 31 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,54 @@ fn main() -> ExitCode {
// Determine number of shards based on command line arg and env vars
let num_shard = find_num_shards(args.num_shard);

let config_path: PathBuf = resolve_config_path(&args.model_name, args.revision.as_deref())
// Determine the model cache path and resolve from possible env vars:
// - HF_HUB_CACHE
// - TRANSFORMERS_CACHE (deprecated)
// - HUGGINGFACE_HUB_CACHE (deprecated)
//
// We allow multiple to be set for compatibility, but then the values must match.

let mut cache_env_var: String = "".to_string();
let mut cache_env_value: String = "".to_string();

if let Ok(t) = env::var("HF_HUB_CACHE") {
cache_env_var = "HF_HUB_CACHE".into();
cache_env_value = t.into();
}

for deprecated_env_var in vec!["TRANSFORMERS_CACHE", "HUGGINGFACE_HUB_CACHE"] {
match (
env::var(deprecated_env_var),
!cache_env_var.is_empty(),
) {
(Ok(t), false) => {
cache_env_var = deprecated_env_var.into();
cache_env_value = t.into();
},
(Ok(t), true) if t != cache_env_value => panic!(
"{deprecated_env_var} and {cache_env_var} env vars can't be set to different values"
),
(Ok(_), true) => warn!(
"{deprecated_env_var} is deprecated and should not be used. Use HF_HUB_CACHE instead."
),
_ => (),
}
}

// ensure HF_HUB_CACHE is set for downstream usage
// default value to match huggingface_hub
// REF: https://github.com/huggingface/huggingface_hub/blob/5ff2d150d121d04799b78bc08f2343c21b8f07a9/docs/source/en/package_reference/environment_variables.md?plain=1#L32
let cache_path = if !cache_env_value.is_empty() {
PathBuf::from(cache_env_value)
} else if let Ok(hf_home) = env::var("HF_HOME") {
PathBuf::from(hf_home).join("hub")
} else if let Ok(home) = env::var("HOME") {
PathBuf::from(home).join(".cache").join("huggingface").join("hub")
} else {
PathBuf::new()
};

let config_path: PathBuf = resolve_config_path(cache_path.clone(), &args.model_name, args.revision.as_deref())
.expect("Failed to resolve config path")
.into();

Expand Down Expand Up @@ -223,15 +270,18 @@ fn main() -> ExitCode {
let (status_sender, status_receiver) = mpsc::channel();

// Start shard processes
let cache_path_string = cache_path.into_os_string();
for rank in 0..num_shard {
let args = args.clone();
let cache_path = cache_path_string.clone();
let deployment_framework = deployment_framework.to_string();
let status_sender = status_sender.clone();
let shutdown = shutdown.clone();
let shutdown_sender = shutdown_sender.clone();
thread::spawn(move || {
shard_manager(
args.model_name,
cache_path,
args.revision,
deployment_framework,
args.dtype.or(args.dtype_str),
Expand Down Expand Up @@ -548,6 +598,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn shard_manager(
model_name: String,
cache_path: OsString,
revision: Option<String>,
deployment_framework: String,
dtype: Option<String>,
Expand Down Expand Up @@ -620,19 +671,6 @@ fn shard_manager(
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();

// Fix up TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars
match (
env::var("TRANSFORMERS_CACHE"),
env::var("HUGGINGFACE_HUB_CACHE"),
) {
(Ok(t), Err(_)) => env.push(("HUGGINGFACE_HUB_CACHE".into(), t.into())),
(Err(_), Ok(h)) => env.push(("TRANSFORMERS_CACHE".into(), h.into())),
(Ok(t), Ok(h)) if t != h => panic!(
"TRANSFORMERS_CACHE and HUGGINGFACE_HUB_CACHE env vars can't be set to different values"
),
_ => (),
}

if let Some(alloc_conf) = cuda_alloc_conf {
if alloc_conf.is_empty() {
// Remove it from env
Expand Down Expand Up @@ -665,6 +703,9 @@ fn shard_manager(
// Ensure offline-only
env.push(("HF_HUB_OFFLINE".into(), "1".into()));

// Ensure that we set the standard cache variable
env.push(("HF_HUB_CACHE".into(), cache_path.into()));

// Start process
info!("Starting shard {rank}");
let mut p = match Command::new("text-generation-server")
Expand Down Expand Up @@ -776,18 +817,13 @@ fn write_termination_log(msg: &str) -> Result<(), io::Error> {
Ok(())
}

fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
let cache = env::var("TRANSFORMERS_CACHE")
.or_else(|_| env::var("HUGGINGFACE_HUB_CACHE"))
.ok();
let mut model_dir = cache
.as_ref()
.map(|c| Path::new(&c).join(format!("models--{}", model_name.replace('/', "--"))));
if let Some(ref d) = model_dir {
if !d.try_exists()? {
model_dir = None;
}
}
fn resolve_config_path(cache_path: PathBuf, model_name: &str, revision: Option<&str>) -> Result<String, io::Error> {
let model_hf_cache_dir = cache_path.join(format!("models--{}", model_name.replace('/', "--")));
let model_dir = if model_hf_cache_dir.try_exists()? {
Some(model_hf_cache_dir)
} else {
None
};
if let Some(dir) = model_dir {
let revision = revision.unwrap_or("main");
let ref_path = dir.join("refs").join(revision);
Expand All @@ -811,11 +847,7 @@ fn resolve_config_path(model_name: &str, revision: Option<&str>) -> Result<Strin
if try_path.try_exists()? {
Ok(try_path.to_string_lossy().into())
} else {
let message = if cache.is_none() {
format!("Model path {model_name} not found (TRANSFORMERS_CACHE env var not set)")
} else {
format!("Model {model_name} not found in local cache")
};
let message = format!("Model {model_name} not found");
error!(message);
Err(io::Error::new(ErrorKind::NotFound, message))
}
Expand Down
10 changes: 10 additions & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,14 @@ def convert_to_fast_tokenizer(


if __name__ == "__main__":

# Use of TRANSFORMERS_CACHE is deprecated
if (tc := os.getenv("TRANSFORMERS_CACHE")) is not None:
print("WARNING: Using TRANSFORMERS_CACHE is deprecated. Use HF_HUB_CACHE instead.")
hc = os.getenv("HF_HUB_CACHE")
if tc != hc:
raise ValueError("Conflicting model cache values between TRANSFORMERS_CACHE and HF_HUB_CACHE")
if hc is None:
os.putenv("HF_HUB_CACHE", tc)

app()
1 change: 0 additions & 1 deletion server/text_generation_server/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def get_model_path(model_name: str, revision: Optional[str] = None):
try:
config_path = try_to_load_from_cache(
model_name, config_file,
cache_dir=os.getenv("TRANSFORMERS_CACHE"), # will fall back to HUGGINGFACE_HUB_CACHE
revision=revision,
)
if config_path is not None:
Expand Down

0 comments on commit 2358566

Please sign in to comment.