14
14
* limitations under the License.
15
15
*/
16
16
17
- #include < legate.h>
18
-
19
- #include < cudf/binaryop.hpp>
20
- #include < cudf/types.hpp>
21
- #include < cudf/unary.hpp>
22
-
23
17
#include < arrow/compute/api.h>
24
18
#include < legate_dataframe/binaryop.hpp>
25
19
#include < legate_dataframe/core/column.hpp>
30
24
31
25
namespace legate ::dataframe::task {
32
26
33
- cudf::binary_operator arrow_to_cudf_binary_op (std::string op, legate::Type output_type )
27
+ /* static */ void BinaryOpColColTask::cpu_variant ( legate::TaskContext context )
34
28
{
35
- // Arrow binary operators taken from the below list,
36
- // where an equivalent cudf binary operator exists.
37
- // https://arrow.apache.org/docs/cpp/compute.html#element-wise-scalar-functions
38
- // https://docs.rapids.ai/api/libcudf/stable/group__transformation__binaryops
39
- std::unordered_map<std::string, cudf::binary_operator> arrow_to_cudf_ops = {
40
- {" add" , cudf::binary_operator::ADD},
41
- {" divide" , cudf::binary_operator::DIV},
42
- {" multiply" , cudf::binary_operator::MUL},
43
- {" power" , cudf::binary_operator::POW},
44
- {" subtract" , cudf::binary_operator::SUB},
45
- {" bit_wise_and" , cudf::binary_operator::BITWISE_AND},
46
- {" bit_wise_or" , cudf::binary_operator::BITWISE_OR},
47
- {" bit_wise_xor" , cudf::binary_operator::BITWISE_XOR},
48
- {" shift_left" , cudf::binary_operator::SHIFT_LEFT},
49
- {" shift_right" , cudf::binary_operator::SHIFT_RIGHT},
50
- {" logb" , cudf::binary_operator::LOG_BASE},
51
- {" atan2" , cudf::binary_operator::ATAN2},
52
- {" equal" , cudf::binary_operator::EQUAL},
53
- {" greater" , cudf::binary_operator::GREATER},
54
- {" greater_equal" , cudf::binary_operator::GREATER_EQUAL},
55
- {" less" , cudf::binary_operator::LESS},
56
- {" less_equal" , cudf::binary_operator::LESS_EQUAL},
57
- {" not_equal" , cudf::binary_operator::NOT_EQUAL},
58
- // logical operators:
59
- {" and" , cudf::binary_operator::LOGICAL_AND},
60
- {" or" , cudf::binary_operator::LOGICAL_OR},
61
- {" and_kleene" , cudf::binary_operator::NULL_LOGICAL_AND},
62
- {" or_kleene" , cudf::binary_operator::NULL_LOGICAL_OR},
63
- };
64
-
65
- // Cudf has a special case for powers with integers
66
- // https://github.com/rapidsai/cudf/issues/10178#issuecomment-3004143727
67
- if (op == " power" && output_type.to_string ().find (" int" ) != std::string::npos) {
68
- return cudf::binary_operator::INT_POW;
29
+ TaskContext ctx{context};
30
+ auto op = argument::get_next_scalar<std::string>(ctx);
31
+ const auto lhs = argument::get_next_input<PhysicalColumn>(ctx);
32
+ const auto rhs = argument::get_next_input<PhysicalColumn>(ctx);
33
+ auto output = argument::get_next_output<PhysicalColumn>(ctx);
34
+
35
+ std::vector<arrow::Datum> args (2 );
36
+ if (lhs.num_rows () == 1 ) {
37
+ auto scalar = ARROW_RESULT (lhs.arrow_array_view ()->GetScalar (0 ));
38
+ args[0 ] = scalar;
39
+ } else {
40
+ args[0 ] = lhs.arrow_array_view ();
41
+ }
42
+ if (rhs.num_rows () == 1 ) {
43
+ auto scalar = ARROW_RESULT (rhs.arrow_array_view ()->GetScalar (0 ));
44
+ args[1 ] = scalar;
45
+ } else {
46
+ args[1 ] = rhs.arrow_array_view ();
69
47
}
70
48
71
- if (arrow_to_cudf_ops.find (op) != arrow_to_cudf_ops.end ()) { return arrow_to_cudf_ops[op]; }
72
- throw std::invalid_argument (" Could not find cudf binary operator matching: " + op);
73
- return cudf::binary_operator::INVALID_BINARY;
74
- }
75
-
76
- class BinaryOpColColTask : public Task <BinaryOpColColTask, OpCode::BinaryOpColCol> {
77
- public:
78
- static void cpu_variant (legate::TaskContext context)
79
- {
80
- TaskContext ctx{context};
81
- auto op = argument::get_next_scalar<std::string>(ctx);
82
- const auto lhs = argument::get_next_input<PhysicalColumn>(ctx);
83
- const auto rhs = argument::get_next_input<PhysicalColumn>(ctx);
84
- auto output = argument::get_next_output<PhysicalColumn>(ctx);
85
-
86
- std::vector<arrow::Datum> args (2 );
87
- if (lhs.num_rows () == 1 ) {
88
- auto scalar = ARROW_RESULT (lhs.arrow_array_view ()->GetScalar (0 ));
89
- args[0 ] = scalar;
90
- } else {
91
- args[0 ] = lhs.arrow_array_view ();
92
- }
93
- if (rhs.num_rows () == 1 ) {
94
- auto scalar = ARROW_RESULT (rhs.arrow_array_view ()->GetScalar (0 ));
95
- args[1 ] = scalar;
96
- } else {
97
- args[1 ] = rhs.arrow_array_view ();
98
- }
99
-
100
- if (output.cudf_type ().id () == cudf::type_id::BOOL8 &&
101
- (op == " and" || op == " or" || op == " and_kleene" || op == " or_kleene" )) {
102
- // arrow doesn't seem to cast for the user for logical ops.
103
- args[0 ] = ARROW_RESULT (arrow::compute::Cast (args[0 ], arrow::boolean ()));
104
- args[1 ] = ARROW_RESULT (arrow::compute::Cast (args[1 ], arrow::boolean ()));
105
- }
106
-
107
- // Result may be scalar or array
108
- auto datum_result = ARROW_RESULT (arrow::compute::CallFunction (op, args));
49
+ if (output.cudf_type ().id () == cudf::type_id::BOOL8 &&
50
+ (op == " and" || op == " or" || op == " and_kleene" || op == " or_kleene" )) {
51
+ // arrow doesn't seem to cast for the user for logical ops.
52
+ args[0 ] = ARROW_RESULT (arrow::compute::Cast (args[0 ], arrow::boolean ()));
53
+ args[1 ] = ARROW_RESULT (arrow::compute::Cast (args[1 ], arrow::boolean ()));
54
+ }
109
55
110
- // Coerce the output type if necessary
111
- auto arrow_result_type = to_arrow_type (output.cudf_type ().id ());
112
- if (datum_result.type () != arrow_result_type) {
113
- auto coerced_result = ARROW_RESULT (arrow::compute::Cast (
114
- datum_result, arrow_result_type, arrow::compute::CastOptions::Unsafe ()));
115
- datum_result = std::move (coerced_result);
116
- }
56
+ // Result may be scalar or array
57
+ auto datum_result = ARROW_RESULT (arrow::compute::CallFunction (op, args));
117
58
118
- if (datum_result.is_scalar ()) {
119
- auto as_array = ARROW_RESULT (arrow::MakeArrayFromScalar (*datum_result.scalar (), 1 ));
120
- if (get_prefer_eager_allocations ()) {
121
- output.copy_into (std::move (as_array));
122
- } else {
123
- output.move_into (std::move (as_array));
124
- }
125
- } else {
126
- if (get_prefer_eager_allocations ()) {
127
- output.copy_into (std::move (datum_result.make_array ()));
128
- } else {
129
- output.move_into (std::move (datum_result.make_array ()));
130
- }
131
- }
59
+ // Coerce the output type if necessary
60
+ auto arrow_result_type = to_arrow_type (output.cudf_type ().id ());
61
+ if (datum_result.type () != arrow_result_type) {
62
+ auto coerced_result = ARROW_RESULT (
63
+ arrow::compute::Cast (datum_result, arrow_result_type, arrow::compute::CastOptions::Unsafe ()));
64
+ datum_result = std::move (coerced_result);
132
65
}
133
66
134
- static void gpu_variant (legate::TaskContext context)
135
- {
136
- TaskContext ctx{context};
137
- auto arrow_op = argument::get_next_scalar<std::string>(ctx);
138
- const auto lhs = argument::get_next_input<PhysicalColumn>(ctx);
139
- const auto rhs = argument::get_next_input<PhysicalColumn>(ctx);
140
- auto output = argument::get_next_output<PhysicalColumn>(ctx);
141
- auto op = arrow_to_cudf_binary_op (arrow_op, output.type ());
142
-
143
- std::unique_ptr<cudf::column> ret;
144
- /*
145
- * If one (not both) are length 1, use scalars as cudf doesn't allow
146
- * broadcast binary operations.
147
- */
148
- if (lhs.num_rows () == 1 && rhs.num_rows () != 1 ) {
149
- auto lhs_scalar = lhs.cudf_scalar ();
150
- ret = cudf::binary_operation (
151
- *lhs_scalar, rhs.column_view (), op, output.cudf_type (), ctx.stream (), ctx.mr ());
152
- } else if (rhs.num_rows () == 1 && lhs.num_rows () != 1 ) {
153
- auto rhs_scalar = rhs.cudf_scalar ();
154
- ret = cudf::binary_operation (
155
- lhs.column_view (), *rhs_scalar, op, output.cudf_type (), ctx.stream (), ctx.mr ());
67
+ if (datum_result.is_scalar ()) {
68
+ auto as_array = ARROW_RESULT (arrow::MakeArrayFromScalar (*datum_result.scalar (), 1 ));
69
+ if (get_prefer_eager_allocations ()) {
70
+ output.copy_into (std::move (as_array));
156
71
} else {
157
- ret = cudf::binary_operation (
158
- lhs.column_view (), rhs.column_view (), op, output.cudf_type (), ctx.stream (), ctx.mr ());
72
+ output.move_into (std::move (as_array));
159
73
}
74
+ } else {
160
75
if (get_prefer_eager_allocations ()) {
161
- output.copy_into (std::move (ret ));
76
+ output.copy_into (std::move (datum_result. make_array () ));
162
77
} else {
163
- output.move_into (std::move (ret ));
78
+ output.move_into (std::move (datum_result. make_array () ));
164
79
}
165
80
}
166
- };
81
+ }
167
82
168
83
} // namespace legate::dataframe::task
169
84
@@ -181,15 +96,16 @@ namespace legate::dataframe {
181
96
LogicalColumn binary_operation (const LogicalColumn& lhs,
182
97
const LogicalColumn& rhs,
183
98
std::string op,
184
- cudf::data_type output_type)
99
+ std::shared_ptr<arrow::DataType> output_type)
185
100
{
186
101
auto runtime = legate::Runtime::get_runtime ();
187
102
188
103
// Check if the op is valid before we enter the task
189
104
// This allows us to to throw nicely
190
105
if (runtime->get_machine ().count (legate::mapping::TaskTarget::GPU) > 0 ) {
191
- // Throws if op doesn't exist
192
- task::arrow_to_cudf_binary_op (op, to_legate_type (output_type.id ()));
106
+ if (task::cudf_supported_binary_ops.count (op) == 0 ) {
107
+ throw std::invalid_argument (" Unsupported binary operator: " + op);
108
+ }
193
109
} else {
194
110
auto result = arrow::compute::GetFunctionRegistry ()->GetFunction (op);
195
111
if (!result.ok ()) {
@@ -201,7 +117,7 @@ LogicalColumn binary_operation(const LogicalColumn& lhs,
201
117
auto scalar_result = lhs.is_scalar () && rhs.is_scalar ();
202
118
std::optional<size_t > size{};
203
119
if (get_prefer_eager_allocations ()) { size = lhs.is_scalar () ? rhs.num_rows () : lhs.num_rows (); }
204
- auto ret = LogicalColumn::empty_like (std::move ( output_type) , nullable, scalar_result, size);
120
+ auto ret = LogicalColumn::empty_like (output_type, nullable, scalar_result, size);
205
121
legate::AutoTask task =
206
122
runtime->create_task (get_library (), task::BinaryOpColColTask::TASK_CONFIG.task_id ());
207
123
argument::add_next_scalar (task, op);
0 commit comments