AttributeError: module 'jax.core' has no attribute 'NamedShape'
I Duplicated this space and chose ZeroGPU in setup but I get the following error.
Any solutions will be appreciated.
Traceback (most recent call last):
File "/home/user/app/app.py", line 15, in
from whisper_jax import FlaxWhisperPipline
File "/home/user/app/whisper_jax/init.py", line 18, in
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
File "/home/user/app/whisper_jax/modeling_flax_whisper.py", line 57, in
from whisper_jax import layers
File "/home/user/app/whisper_jax/layers.py", line 63, in
def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
File "/usr/local/lib/python3.10/site-packages/jax/_src/deprecations.py", line 55, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'NamedShape'