-
Notifications
You must be signed in to change notification settings - Fork 53
Description
Thank you for your great work. However, when I ran the current code under python 3.11 and numpy 1.26, I got a very bad result. The magic is that when I run in another lower version (i.e. python 3.7 and numpy 1.21), the results are as expected.
I checked the code carefully and found that there was a problem on line 150 of data_loader.py:
"result[list(np.indices(arr.shape)) + [arr]] = 1"
This is because when doing one-hot encoding, the current code uses inappropriate slicing, which causes all one-hot encoding values to be set to 1 in the newer environment.
I recommend making the following changes to line 150 of data_loader.py:
"result[np.indices(arr.shape), np.array(arr)] = 1"
Unfortunately, I was not aware of this problem before, which resulted in an error in a previous job :(