skip to Main Content

I am switching from a much older version of PyTorch from 3 years ago to stable PyTorch 1.9 in CentOS 7 (GPU-based) and with no change in the original paper code, I get the following error. Is there a quick fix to this?

(fashcomp) [jalal@goku fashion-compatibility]$     python main.py --name test_baseline --learned --l2_embed --datadir ../../../data/fashion/
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
  + Number of params: 3191808
Traceback (most recent call last):
  File "main.py", line 322, in <module>
    main()    
  File "main.py", line 167, in main
    train(train_loader, tnet, criterion, optimizer, epoch)
  File "main.py", line 194, in train
    for batch_idx, (img1, desc1, has_text1, img2, desc2, has_text2, img3, desc3, has_text3, condition) in enumerate(train_loader):
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/_utils.py", line 425, in reraise
    raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "mtrand.pyx", line 905, in numpy.random.mtrand.RandomState.choice
TypeError: 'dict_keys' object cannot be interpreted as an integer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/scratch3/research/code/fashion/fashion-compatibility/polyvore_outfits.py", line 338, in __getitem__
    neg_im = self.sample_negative(outfit_id, pos_im, item_type)
  File "/scratch3/research/code/fashion/fashion-compatibility/polyvore_outfits.py", line 235, in sample_negative
    choice = np.random.choice(candidate_sets)
  File "mtrand.pyx", line 907, in numpy.random.mtrand.RandomState.choice
ValueError: a must be 1-dimensional or an integer

and

$ pip freeze
absl-py==0.13.0
argon2-cffi==20.1.0
attrs==21.2.0
backcall==0.2.0
bleach==4.1.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
defusedxml==0.7.1
entrypoints==0.3
google-auth==1.35.0
google-auth-oauthlib==0.4.5
grpcio==1.39.0
h5py==3.3.0
idna==3.2
importlib==1.0.4
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==7.0.1
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
mistune==0.8.4
nbclient==0.5.4
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
notebook==6.4.3
numpy==1.21.2
oauthlib==3.1.1
packaging==21.0
pandas==1.3.2
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prometheus-client==0.11.0
prompt-toolkit==3.0.20
protobuf==3.17.3
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.10.0
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
qtconsole==5.1.1
QtPy==1.10.0
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
Send2Trash==1.8.0
six==1.16.0
sklearn==0.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
terminado==0.11.1
testpath==0.5.0
threadpoolctl==2.2.0
torch==1.9.0
torch-tb-profiler==0.2.1
torchaudio==0.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
urllib3==1.26.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1
widgetsnbextension==3.5.1

Link to issue on the repo: https://github.com/mvasil/fashion-compatibility/issues/25

2

Answers


  1. The problem is in these lines in the file polyvore_outfits.py

            [...]
            candidate_sets = self.category2ims[item_type].keys()
            attempts = 0
            while item_out == item_id and attempts < 100:
                choice = np.random.choice(candidate_sets)
                [...]
    

    candidate_sets is the object returned by the dict.keys() method. In older versions of Python, this was a list, but now it is a dict_keys object. The choice method in the NumPy random module accepts a list, but not a dict_keys object.

    A simple fix is to explicitly convert candidate_sets into a list, either when it is created,

            candidate_sets = list(self.category2ims[item_type].keys())
    

    or before passing it to np.random.choice:

                choice = np.random.choice(list(candidate_sets))
    
    Login or Signup to reply.
  2. You should convert your dict_keys to a list as explained in the comments above:

    np.random.choice(list(candidate_sets))
    

    It might be because of the version change of NumPy.

    Login or Signup to reply.
Please signup or login to give your own answer.
Back To Top
Search