64
64
finfo = get_xp (cp )(_aliases .finfo )
65
65
iinfo = get_xp (cp )(_aliases .iinfo )
66
66
67
- _copy_default = object ()
68
-
69
67
70
68
# asarray also adds the copy keyword, which is not present in numpy 1.0.
71
69
def asarray (
@@ -79,7 +77,7 @@ def asarray(
79
77
* ,
80
78
dtype : Optional [DType ] = None ,
81
79
device : Optional [Device ] = None ,
82
- copy : Optional [bool ] = _copy_default ,
80
+ copy : Optional [bool ] = None ,
83
81
** kwargs ,
84
82
) -> Array :
85
83
"""
@@ -89,25 +87,13 @@ def asarray(
89
87
specification for more details.
90
88
"""
91
89
with cp .cuda .Device (device ):
92
- # cupy is like NumPy 1.26 (except without _CopyMode). See the comments
93
- # in asarray in numpy/_aliases.py.
94
- if copy is not _copy_default :
95
- # A future version of CuPy will change the meaning of copy=False
96
- # to mean no-copy. We don't know for certain what version it will
97
- # be yet, so to avoid breaking that version, we use a different
98
- # default value for copy so asarray(obj) with no copy kwarg will
99
- # always do the copy-if-needed behavior.
100
-
101
- # This will still need to be updated to remove the
102
- # NotImplementedError for copy=False, but at least this won't
103
- # break the default or existing behavior.
104
- if copy is None :
105
- copy = False
106
- elif copy is False :
107
- raise NotImplementedError ("asarray(copy=False) is not yet supported in cupy" )
108
- kwargs ['copy' ] = copy
109
-
110
- return cp .array (obj , dtype = dtype , ** kwargs )
90
+ if copy is None :
91
+ return cp .asarray (obj , dtype = dtype , ** kwargs )
92
+ else :
93
+ res = cp .array (obj , dtype = dtype , copy = copy , ** kwargs )
94
+ if not copy and res is not obj :
95
+ raise ValueError ("Unable to avoid copy while creating an array as requested" )
96
+ return res
111
97
112
98
113
99
def astype (
@@ -138,6 +124,11 @@ def count_nonzero(
138
124
return result
139
125
140
126
127
+ # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg
128
+ def take_along_axis (x : Array , indices : Array , / , * , axis : int = - 1 ):
129
+ return cp .take_along_axis (x , indices , axis = axis )
130
+
131
+
141
132
# These functions are completely new here. If the library already has them
142
133
# (i.e., numpy 2.0), use the library version instead of our wrapper.
143
134
if hasattr (cp , 'vecdot' ):
@@ -159,6 +150,7 @@ def count_nonzero(
159
150
'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
160
151
'atan2' , 'atanh' , 'bitwise_left_shift' ,
161
152
'bitwise_invert' , 'bitwise_right_shift' ,
162
- 'bool' , 'concat' , 'count_nonzero' , 'pow' , 'sign' ]
153
+ 'bool' , 'concat' , 'count_nonzero' , 'pow' , 'sign' ,
154
+ 'take_along_axis' ]
163
155
164
156
_all_ignore = ['cp' , 'get_xp' ]
0 commit comments