summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_OTDA_classes.py1
-rw-r--r--ot/datasets.py6
2 files changed, 4 insertions, 3 deletions
diff --git a/examples/plot_OTDA_classes.py b/examples/plot_OTDA_classes.py
index 999be53..c00cef6 100644
--- a/examples/plot_OTDA_classes.py
+++ b/examples/plot_OTDA_classes.py
@@ -11,6 +11,7 @@ import ot
+
#%% parameters
n=150 # nb samples in source and target datasets
diff --git a/ot/datasets.py b/ot/datasets.py
index 5c1ef78..7816833 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -124,8 +124,8 @@ def get_data_classif(dataset,n,nz=.5,theta=0,**kwargs):
else:
- x=0
- y=0
+ x=np.array(0)
+ y=np.array(0)
print("unknown dataset")
- return x,y \ No newline at end of file
+ return x,y.astype(int) \ No newline at end of file