Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError when using alphageometry with JAX bfloat16 dtype #2

Open
jiwei08 opened this issue May 8, 2024 · 3 comments
Open

TypeError when using alphageometry with JAX bfloat16 dtype #2

jiwei08 opened this issue May 8, 2024 · 3 comments

Comments

@jiwei08
Copy link

jiwei08 commented May 8, 2024

Hi,

I wanted to report an issue I encountered while working with the alphageometry library. I appreciate the improvements made to the original alphageometry project and have been trying out the new code.

However, when I attempted to use it, I ran into a TypeError. Specifically, the error message I received was:

TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

Steps to Reproduce:

  • Run run.sh with PROB=napoleon2 and MODEL=alphageometry

Anaconda Environment:

  • Operating System: Ubuntu 22.04
  • Python Version: 3.10.9
  • JAX Version: 0.4.6

Thank you for your time in addressing this issue. Any help is appreciated!

@jiwei08
Copy link
Author

jiwei08 commented May 8, 2024

The full error message is as follows:

++ HOME_DIR=/home1/newhome/lijiwei/ag4masses
++ TESTDIR=/home1/newhome/lijiwei/ag4masses/ag4mtest
++ AG4MDIR=/home1/newhome/lijiwei/ag4masses/ag4masses
++ AGLIB=/home1/newhome/lijiwei/ag4masses/aglib
++ AGDIR=/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry
++ export PYTHONPATH=:/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry:/home1/newhome/lijiwei/ag4masses/aglib
++ PYTHONPATH=:/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry:/home1/newhome/lijiwei/ag4masses/aglib
++ OUTFILE=/home1/newhome/lijiwei/ag4masses/ag4mtest/ag.out
++ ERRFILE=/home1/newhome/lijiwei/ag4masses/ag4mtest/ag.err
++ exec
+++ tee /home1/newhome/lijiwei/ag4masses/ag4mtest/ag.err
++ BATCH_SIZE=8
++ BEAM_SIZE=32
++ DEPTH=8
++ NWORKERS=1
++ PROB_FILE=/home1/newhome/lijiwei/ag4masses/ag4masses/data/ag4m_problems.txt
++ PROB=napoleon2
++ MODEL=alphageometry
++ DATA=/home1/newhome/lijiwei/ag4masses/aglib/ag_ckpt_vocab
++ MELIAD_PATH=/home1/newhome/lijiwei/ag4masses/aglib/meliad
++ export PYTHONPATH=:/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry:/home1/newhome/lijiwei/ag4masses/aglib:/home1/newhome/lijiwei/ag4masses/aglib/meliad
++ PYTHONPATH=:/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry:/home1/newhome/lijiwei/ag4masses/aglib:/home1/newhome/lijiwei/ag4masses/aglib/meliad
++ DDAR_ARGS=(--defs_file=$AGDIR/defs.txt --rules_file=$AGDIR/rules.txt)
++ SEARCH_ARGS=(--beam_size=$BEAM_SIZE --search_depth=$DEPTH)
++ LM_ARGS=(--ckpt_path=$DATA --vocab_path=$DATA/geometry.757.model --gin_search_paths=$MELIAD_PATH/transformer/configs,$AGDIR --gin_file=base_htrans.gin --gin_file=size/medium_150M.gin --gin_file=options/positions_t5.gin --gin_file=options/lr_cosine_decay.gin --gin_file=options/seq_1024_nocache.gin --gin_file=geometry_150M_generate.gin --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE --gin_param=TransformerTaskConfig.sequence_length=128 --gin_param=Trainer.restore_state_variables=False)
++ true ==========================================
++ python -m alphageometry --alsologtostderr --problems_file=/home1/newhome/lijiwei/ag4masses/ag4masses/data/ag4m_problems.txt --problem_name=napoleon2 --mode=alphageometry --defs_file=/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/defs.txt --rules_file=/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/rules.txt --beam_size=32 --search_depth=8 --ckpt_path=/home1/newhome/lijiwei/ag4masses/aglib/ag_ckpt_vocab --vocab_path=/home1/newhome/lijiwei/ag4masses/aglib/ag_ckpt_vocab/geometry.757.model --gin_search_paths=/home1/newhome/lijiwei/ag4masses/aglib/meliad/transformer/configs,/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry --gin_file=base_htrans.gin --gin_file=size/medium_150M.gin --gin_file=options/positions_t5.gin --gin_file=options/lr_cosine_decay.gin --gin_file=options/seq_1024_nocache.gin --gin_file=geometry_150M_generate.gin --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True --gin_param=TransformerTaskConfig.batch_size=8 --gin_param=TransformerTaskConfig.sequence_length=128 --gin_param=Trainer.restore_state_variables=False --out_file=/home1/newhome/lijiwei/ag4masses/ag4mtest/ag.out --n_workers=1
/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
2024-05-08 16:59:48.237274: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/struct.py:132: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(data_clz, keypaths)
I0508 16:59:50.363534 140267185493824 graph.py:498] napoleon2
I0508 16:59:50.363727 140267185493824 graph.py:499] a b c = triangle a b c; d = s_angle a b d 30, s_angle b a d 150; e = s_angle b c e 30, s_angle c b e 150; f = s_angle c a f 30, s_angle a c f 150 ? cong e f e d
I0508 16:59:50.422343 140267185493824 ddar.py:60] Depth 1/1000 time = 0.03681206703186035
I0508 16:59:50.481287 140267185493824 ddar.py:60] Depth 2/1000 time = 0.05869913101196289
I0508 16:59:50.552018 140267185493824 ddar.py:60] Depth 3/1000 time = 0.07051420211791992
I0508 16:59:50.624904 140267185493824 ddar.py:60] Depth 4/1000 time = 0.07238030433654785
I0508 16:59:50.628416 140267185493824 ddar.py:130] Nothing added, breaking
I0508 16:59:50.628540 140267185493824 alphageometry.py:231] DD+AR failed to solve the problem.
I0508 16:59:50.628618 140267185493824 alphageometry.py:528] Worker initializing. PID=725283
I0508 16:59:50.628970 140267185493824 inference_utils.py:69] Parsing gin configuration.
I0508 16:59:50.629003 140267185493824 inference_utils.py:71] Added Gin search path /home1/newhome/lijiwei/ag4masses/aglib/meliad/transformer/configs
I0508 16:59:50.629115 140267185493824 inference_utils.py:71] Added Gin search path /home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry
I0508 16:59:50.629145 140267185493824 inference_utils.py:74] Loading Gin config file base_htrans.gin
I0508 16:59:50.629173 140267185493824 inference_utils.py:74] Loading Gin config file size/medium_150M.gin
I0508 16:59:50.629199 140267185493824 inference_utils.py:74] Loading Gin config file options/positions_t5.gin
I0508 16:59:50.629225 140267185493824 inference_utils.py:74] Loading Gin config file options/lr_cosine_decay.gin
I0508 16:59:50.629251 140267185493824 inference_utils.py:74] Loading Gin config file options/seq_1024_nocache.gin
I0508 16:59:50.629276 140267185493824 inference_utils.py:74] Loading Gin config file geometry_150M_generate.gin
I0508 16:59:50.629301 140267185493824 inference_utils.py:76] Overriding Gin param DecoderOnlyLanguageModelGenerate.output_token_losses=True
I0508 16:59:50.629328 140267185493824 inference_utils.py:76] Overriding Gin param TransformerTaskConfig.batch_size=8
I0508 16:59:50.629353 140267185493824 inference_utils.py:76] Overriding Gin param TransformerTaskConfig.sequence_length=128
I0508 16:59:50.629378 140267185493824 inference_utils.py:76] Overriding Gin param Trainer.restore_state_variables=False
I0508 16:59:50.629465 140267185493824 resource_reader.py:50] system_path_file_exists:base_htrans.gin
E0508 16:59:50.629693 140267185493824 resource_reader.py:55] Path not found: base_htrans.gin
I0508 16:59:50.629993 140267185493824 resource_reader.py:50] system_path_file_exists:trainer_configuration.gin
E0508 16:59:50.630172 140267185493824 resource_reader.py:55] Path not found: trainer_configuration.gin
I0508 16:59:50.636854 140267185493824 resource_reader.py:50] system_path_file_exists:size/medium_150M.gin
E0508 16:59:50.637141 140267185493824 resource_reader.py:55] Path not found: size/medium_150M.gin
I0508 16:59:50.637547 140267185493824 resource_reader.py:50] system_path_file_exists:options/positions_t5.gin
E0508 16:59:50.637714 140267185493824 resource_reader.py:55] Path not found: options/positions_t5.gin
I0508 16:59:50.638046 140267185493824 resource_reader.py:50] system_path_file_exists:options/lr_cosine_decay.gin
E0508 16:59:50.638195 140267185493824 resource_reader.py:55] Path not found: options/lr_cosine_decay.gin
I0508 16:59:50.638675 140267185493824 resource_reader.py:50] system_path_file_exists:options/seq_1024_nocache.gin
E0508 16:59:50.638825 140267185493824 resource_reader.py:55] Path not found: options/seq_1024_nocache.gin
I0508 16:59:50.639230 140267185493824 resource_reader.py:50] system_path_file_exists:geometry_150M_generate.gin
E0508 16:59:50.639382 140267185493824 resource_reader.py:55] Path not found: geometry_150M_generate.gin
I0508 16:59:50.639428 140267185493824 resource_reader.py:50] system_path_file_exists:/home1/newhome/lijiwei/ag4masses/aglib/meliad/transformer/configs/geometry_150M_generate.gin
E0508 16:59:50.639469 140267185493824 resource_reader.py:55] Path not found: /home1/newhome/lijiwei/ag4masses/aglib/meliad/transformer/configs/geometry_150M_generate.gin
I0508 16:59:50.643476 140267185493824 training_loop.py:334] ==== Training loop: initializing model ====
I0508 16:59:50.654237 140267185493824 xla_bridge.py:166] Remote TPU is not linked into jax; skipping remote TPU.
I0508 16:59:50.654428 140267185493824 xla_bridge.py:413] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0508 16:59:50.654519 140267185493824 xla_bridge.py:413] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0508 16:59:50.654574 140267185493824 xla_bridge.py:413] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0508 16:59:50.655371 140267185493824 xla_bridge.py:413] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
I0508 16:59:50.655545 140267185493824 xla_bridge.py:413] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W0508 16:59:50.655667 140267185493824 xla_bridge.py:420] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0508 16:59:50.655745 140267185493824 training_loop.py:335] Process 0 of 1
I0508 16:59:50.655793 140267185493824 training_loop.py:336] Local device count = 1
I0508 16:59:50.655842 140267185493824 training_loop.py:337] Number of replicas = 1
I0508 16:59:50.655878 140267185493824 training_loop.py:339] Using random number seed 42
I0508 16:59:50.945881 140267185493824 training_loop.py:359] Initializing the model.
Traceback (most recent call last):
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 752, in <module>
    app.run(main)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 739, in main
    run_alphageometry(
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 652, in run_alphageometry
    bqsearch_init()
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 529, in bqsearch_init
    model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 213, in get_lm
    return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/lm_inference.py", line 62, in __init__
    (tstate, _, imodel, prngs) = trainer.initialize_model()
  File "/home1/newhome/lijiwei/ag4masses/aglib/meliad/training_loop.py", line 367, in initialize_model
    variables = model_init_fn(init_rngs, imodel.get_fake_input())
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/api.py", line 442, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/pjit.py", line 515, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/pjit.py", line 967, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/pjit.py", line 925, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2029, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2046, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/linear_util.py", line 165, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 1266, in init
    _, v_out = self.init_with_output(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 1222, in init_with_output
    return init_with_output(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/core/scope.py", line 896, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/core/scope.py", line 864, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 1640, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/transforms.py", line 1331, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 353, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 652, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/models.py", line 68, in __call__
    self.decoder(
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/transforms.py", line 1331, in wrapped_fn
    return prewrapped_fn(self, *args, **kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 353, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/flax/linen/module.py", line 652, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "/home1/newhome/lijiwei/ag4masses/aglib/meliad/transformer/decoder_stack.py", line 274, in __call__
    embeddings = embeddings.astype(self.dtype)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype
    dtypes.check_user_dtype_supported(dtype, "astype")
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/dtypes.py", line 706, in check_user_dtype_supported
    raise TypeError(msg)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 752, in <module>
    app.run(main)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 739, in main
    run_alphageometry(
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 652, in run_alphageometry
    bqsearch_init()
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 529, in bqsearch_init
    model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/alphageometry.py", line 213, in get_lm
    return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/lm_inference.py", line 62, in __init__
    (tstate, _, imodel, prngs) = trainer.initialize_model()
  File "/home1/newhome/lijiwei/ag4masses/aglib/meliad/training_loop.py", line 367, in initialize_model
    variables = model_init_fn(init_rngs, imodel.get_fake_input())
  File "/home1/newhome/lijiwei/ag4masses/ag4masses/alphageometry/models.py", line 68, in __call__
    self.decoder(
  File "/home1/newhome/lijiwei/ag4masses/aglib/meliad/transformer/decoder_stack.py", line 274, in __call__
    embeddings = embeddings.astype(self.dtype)
  File "/home1/newhome/lijiwei/anaconda3/envs/alphageometry/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4952, in _astype
    dtypes.check_user_dtype_supported(dtype, "astype")
TypeError: JAX only supports number and bool dtypes, got dtype bfloat16 in astype
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

@tpgh24
Copy link
Owner

tpgh24 commented May 18, 2024

Were you able to run ag4masses successfully with MODEL=alphageometry for any problem?

@tpgh24
Copy link
Owner

tpgh24 commented May 18, 2024

This seems to be caused by a Python package version issue. You are using Anaconda, try to use virtualenv as mentioned in the AG4Masses set up instruction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants