简体   繁体   中英

jax.experimental importing error Python 3.9

I'd seen previous errors importing form JAX from several years ago ( https://github.com/google/jax/issues/372 ), but the post implied an update would fix it. I just installed JAX and am trying to get set up on a jupyter notebook. Could you let me know what might be going wrong?

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Input In [1], in <cell line: 4>()
      1 ########## JAX ON MNIST #####################
      2 # Import some additional JAX and dataloader helpers
      3 from jax.scipy.special import logsumexp
----> 4 from jax.experimental import optimizers
      6 import torch
      7 from torchvision import datasets, transforms

ImportError: cannot import name 'optimizers' from 'jax.experimental' (/Users/XXX/opt/anaconda3/lib/python3.9/site-packages/jax/experimental/__init__.py)

I saw that the similar previous error was in 2019 and implied a version difference would fix it. I did not know where to go from there.

According to the CHANGELOG

jax 0.3.16

  • Deprecations:
    • Removed jax.experimental.optimizers ; it has long been a deprecated alias of jax.example_libraries.optimizers .

So it sounds like if you're using JAX version 0.3.16 or newer, you should do

from jax.example_libraries import optimizers

But as noted in the jax.example_libraries.optimizers documentation, this is not well-supported code and you'll probably have a better experience with something like Optax or JAXopt .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM