I'd seen previous errors importing form JAX from several years ago ( https://github.com/google/jax/issues/372 ), but the post implied an update would fix it. I just installed JAX and am trying to get set up on a jupyter notebook. Could you let me know what might be going wrong?
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Input In [1], in <cell line: 4>()
1 ########## JAX ON MNIST #####################
2 # Import some additional JAX and dataloader helpers
3 from jax.scipy.special import logsumexp
----> 4 from jax.experimental import optimizers
6 import torch
7 from torchvision import datasets, transforms
ImportError: cannot import name 'optimizers' from 'jax.experimental' (/Users/XXX/opt/anaconda3/lib/python3.9/site-packages/jax/experimental/__init__.py)
I saw that the similar previous error was in 2019 and implied a version difference would fix it. I did not know where to go from there.
According to the CHANGELOG
jax 0.3.16
- Deprecations:
- Removed
jax.experimental.optimizers
; it has long been a deprecated alias ofjax.example_libraries.optimizers
.
So it sounds like if you're using JAX version 0.3.16 or newer, you should do
from jax.example_libraries import optimizers
But as noted in the jax.example_libraries.optimizers
documentation, this is not well-supported code and you'll probably have a better experience with something like Optax or JAXopt .
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.