Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#define __INFINIOP_API_H__

#include "infiniop/handle.h"
// Unified headers for elementwise operators
#include "infiniop/ops/unary_ops_api.h"
#include "infiniop/ops/binary_ops_api.h"
// Other operators
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h"
Expand Down
22 changes: 2 additions & 20 deletions include/infiniop/ops/add.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
#ifndef __INFINIOP_ADD_API_H__
#define __INFINIOP_ADD_API_H__

#include "../operator_descriptor.h"
#include "binary_op_api.h"

typedef struct InfiniopDescriptor *infiniopAddDescriptor_t;

__C __export infiniStatus_t infiniopCreateAddDescriptor(infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);

__C __export infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopAdd(infiniopAddDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream);

__C __export infiniStatus_t infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc);
BINARY_OP_API_DECLARE(add, Add)

#endif
50 changes: 50 additions & 0 deletions include/infiniop/ops/binary_op_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef __INFINIOP_BINARY_OP_API_H__
#define __INFINIOP_BINARY_OP_API_H__

#include "../operator_descriptor.h"

/**
* @brief Macro to generate the C API header for a binary operator.
*
* This macro generates all the necessary declarations for a binary operator:
* - Descriptor type definition
* - Create descriptor function
* - Get workspace size function
* - Execute operator function
* - Destroy descriptor function
*
* Usage:
* BINARY_OP_API_DECLARE(div, Div)
* BINARY_OP_API_DECLARE(pow, Pow)
*
* @param OP_NAME Lowercase operator name (e.g., div, pow, mod)
* @param OP_NAME_UPPER Uppercase operator name (e.g., Div, Pow, Mod)
*/
#define BINARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \
\
typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \
\
__C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \
infiniopHandle_t handle, \
infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \
infiniopTensorDescriptor_t c, \
infiniopTensorDescriptor_t a, \
infiniopTensorDescriptor_t b); \
\
__C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
size_t *size); \
\
__C __export infiniStatus_t infiniop##OP_NAME_UPPER( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
void *workspace, \
size_t workspace_size, \
void *c, \
const void *a, \
const void *b, \
void *stream); \
\
__C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \
infiniop##OP_NAME_UPPER##Descriptor_t desc);

#endif // __INFINIOP_BINARY_OP_API_H__
23 changes: 23 additions & 0 deletions include/infiniop/ops/binary_ops_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef __INFINIOP_BINARY_OPS_API_H__
#define __INFINIOP_BINARY_OPS_API_H__

#include "binary_op_api.h"

/**
* @brief Unified API declarations for all binary operators.
*
* This header contains API declarations for all binary operators in a single file,
* eliminating the need for individual header files for each operator.
*
* All binary operator APIs are declared here:
* - div, pow, mod, max, min
*/

// Declare all binary operator APIs
BINARY_OP_API_DECLARE(div, Div)
BINARY_OP_API_DECLARE(pow, Pow)
BINARY_OP_API_DECLARE(mod, Mod)
BINARY_OP_API_DECLARE(max, Max)
BINARY_OP_API_DECLARE(min, Min)

#endif // __INFINIOP_BINARY_OPS_API_H__
22 changes: 2 additions & 20 deletions include/infiniop/ops/mul.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
#ifndef __INFINIOP_MUL_API_H__
#define __INFINIOP_MUL_API_H__

#include "../operator_descriptor.h"
#include "binary_op_api.h"

typedef struct InfiniopDescriptor *infiniopMulDescriptor_t;

__C __export infiniStatus_t infiniopCreateMulDescriptor(infiniopHandle_t handle,
infiniopMulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);

__C __export infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopMul(infiniopMulDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream);

__C __export infiniStatus_t infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc);
BINARY_OP_API_DECLARE(mul, Mul)

#endif
22 changes: 2 additions & 20 deletions include/infiniop/ops/sub.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
#ifndef __INFINIOP_SUB_API_H__
#define __INFINIOP_SUB_API_H__

#include "../operator_descriptor.h"
#include "binary_op_api.h"

typedef struct InfiniopDescriptor *infiniopSubDescriptor_t;

__C __export infiniStatus_t infiniopCreateSubDescriptor(infiniopHandle_t handle,
infiniopSubDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c,
infiniopTensorDescriptor_t a,
infiniopTensorDescriptor_t b);

__C __export infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopSub(infiniopSubDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream);

__C __export infiniStatus_t infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc);
BINARY_OP_API_DECLARE(sub, Sub)

#endif
48 changes: 48 additions & 0 deletions include/infiniop/ops/unary_op_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef __INFINIOP_UNARY_OP_API_H__
#define __INFINIOP_UNARY_OP_API_H__

#include "../operator_descriptor.h"

/**
* @brief Macro to generate the C API header for a unary operator.
*
* This macro generates all the necessary declarations for a unary operator:
* - Descriptor type definition
* - Create descriptor function
* - Get workspace size function
* - Execute operator function
* - Destroy descriptor function
*
* Usage:
* UNARY_OP_API_DECLARE(abs, Abs)
* UNARY_OP_API_DECLARE(log, Log)
*
* @param OP_NAME Lowercase operator name (e.g., abs, log, sin)
* @param OP_NAME_UPPER Uppercase operator name (e.g., Abs, Log, Sin)
*/
#define UNARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \
\
typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \
\
__C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \
infiniopHandle_t handle, \
infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \
infiniopTensorDescriptor_t y, \
infiniopTensorDescriptor_t x); \
\
__C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
size_t *size); \
\
__C __export infiniStatus_t infiniop##OP_NAME_UPPER( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
void *stream); \
\
__C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \
infiniop##OP_NAME_UPPER##Descriptor_t desc);

#endif // __INFINIOP_UNARY_OP_API_H__
39 changes: 39 additions & 0 deletions include/infiniop/ops/unary_ops_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef __INFINIOP_UNARY_OPS_API_H__
#define __INFINIOP_UNARY_OPS_API_H__

#include "unary_op_api.h"

/**
* @brief Unified API declarations for all unary operators.
*
* This header contains API declarations for all unary operators in a single file,
* eliminating the need for individual header files for each operator.
*
* All unary operator APIs are declared here:
* - abs, log, sqrt, reciprocal, neg, round, sinh, sign, tan
* - acosh, asinh, cos, atanh, asin, floor, cosh, erf, atan, acos, ceil
*/

// Declare all unary operator APIs
UNARY_OP_API_DECLARE(abs, Abs)
UNARY_OP_API_DECLARE(log, Log)
UNARY_OP_API_DECLARE(sqrt, Sqrt)
UNARY_OP_API_DECLARE(reciprocal, Reciprocal)
UNARY_OP_API_DECLARE(neg, Neg)
UNARY_OP_API_DECLARE(round, Round)
UNARY_OP_API_DECLARE(sinh, Sinh)
UNARY_OP_API_DECLARE(sign, Sign)
UNARY_OP_API_DECLARE(tan, Tan)
UNARY_OP_API_DECLARE(acosh, Acosh)
UNARY_OP_API_DECLARE(asinh, Asinh)
UNARY_OP_API_DECLARE(cos, Cos)
UNARY_OP_API_DECLARE(atanh, Atanh)
UNARY_OP_API_DECLARE(asin, Asin)
UNARY_OP_API_DECLARE(floor, Floor)
UNARY_OP_API_DECLARE(cosh, Cosh)
UNARY_OP_API_DECLARE(erf, Erf)
UNARY_OP_API_DECLARE(atan, Atan)
UNARY_OP_API_DECLARE(acos, Acos)
UNARY_OP_API_DECLARE(ceil, Ceil)

#endif // __INFINIOP_UNARY_OPS_API_H__
Loading