Skip to content

Commit 255ae64

Browse files
authored
Merge pull request #17 from CortexFoundation/zkh
Merge zkh
2 parents 7a6019d + 6b78365 commit 255ae64

File tree

15 files changed

+2042
-2178
lines changed

15 files changed

+2042
-2178
lines changed

CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ else(MSVC)
7070
endif(MSVC)
7171

7272
# add source group
73-
FILE(GLOB_RECURSE GROUP_INCLUDE "src/cvm/ops/*.h" "src/*.h" "include/*.h")
73+
FILE(GLOB_RECURSE GROUP_INCLUDE "src/*.h" "include/*.h")
7474
assign_source_group("Include" ${GROUP_INCLUDE})
7575

7676
# Source file lists
@@ -82,7 +82,7 @@ if(NOT USE_RTTI)
8282
endif()
8383

8484
message(STATUS "Build with CVM runtime support...")
85-
file(GLOB RUNTIME_CVM_SRCS src/cvm/ops/*.cc src/cvm/*.cc)
85+
file(GLOB RUNTIME_CVM_SRCS src/cvm/*.cc)
8686
if(${USE_CUDA} STREQUAL "ON")
8787
message("use cuda")
8888
project(cvm CUDA)
@@ -95,6 +95,7 @@ if(${USE_CUDA} STREQUAL "ON")
9595
list(APPEND RUNTIME_SRCS ${RUNTIME_CVM_CUDA_SRCS})
9696
list(APPEND CVM_RUNTIME_LINKER_LIBS "cudart")
9797
list(APPEND CVM_RUNTIME_LINKER_LIBS "cuda")
98+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -std=c++11 --expt-extended-lambda")
9899
else()
99100
file(GLOB CVM_OPS_CPU_SRCS src/cvm/ops/cpu/*.cc)
100101
list(APPEND RUNTIME_CVM_SRCS ${CVM_OPS_CPU_SRCS})

src/cvm/ops/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#ifndef COMMON_H
2+
#define COMMON_H
3+
4+
#define FORMAT_CORNER 1
5+
#define FORMAT_CENTER 2
6+
7+
#endif

src/cvm/ops/cpu/elemwise.cc

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#include "ops.h"
2+
3+
namespace cvm {
4+
namespace runtime {
5+
6+
extern double cvm_op_elemwise_cnt;
7+
double cvm_op_clip_cnt = 0;
8+
double cvm_op_cvm_shift_cnt = 0;
9+
10+
typedef std::function<int32_t(int32_t a, int32_t b)> elemwise_func;
11+
12+
inline void elemwise(DLTensor *args0, DLTensor *args1, DLTensor *args2, const elemwise_func& f){
13+
#ifdef CVM_PROFILING
14+
double start = omp_get_wtime();
15+
#endif
16+
17+
int32_t *a = static_cast<int32_t*>(args0->data);
18+
int32_t *b = static_cast<int32_t*>(args1->data);
19+
int32_t *c = static_cast<int32_t*>(args2->data);
20+
21+
#pragma omp parallel for
22+
for(uint64_t i = 0; i < getSize(args0); i++){
23+
c[i] = f(a[i], b[i]);
24+
}
25+
26+
#ifdef CVM_PROFILING
27+
cvm_op_elemwise_cnt += omp_get_wtime() - start;
28+
#endif
29+
}
30+
31+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.elemwise_add")
32+
.set_body([](CVMArgs args, CVMRetValue *ret)
33+
{
34+
DLTensor *args0 = args[0];
35+
DLTensor *args1 = args[1];
36+
DLTensor *args2 = args[2];
37+
38+
auto f = [](int32_t a, int32_t b) -> int32_t {
39+
return a + b;
40+
};
41+
elemwise(args0, args1, args2, f);
42+
print_to_file(args2, "elemwise_add.txt");
43+
});
44+
45+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.elemwise_sub")
46+
.set_body([](CVMArgs args, CVMRetValue *ret)
47+
{
48+
DLTensor *args0 = args[0];
49+
DLTensor *args1 = args[1];
50+
DLTensor *args2 = args[2];
51+
52+
auto f = [](int32_t a, int32_t b) -> int32_t {
53+
return a - b;
54+
};
55+
elemwise(args0, args1, args2, f);
56+
});
57+
58+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.clip")
59+
.set_body([](CVMArgs args, CVMRetValue* rv){
60+
#ifdef CVM_PROFILING
61+
double start = omp_get_wtime();
62+
#endif
63+
DLTensor *x = args[0];
64+
DLTensor *y = args[1];
65+
void *_attr = args[2];
66+
auto *attr = static_cast<cvm::NodeAttrs*>(_attr);
67+
auto& param = cvm::get<cvm::top::ClipParam>(attr->parsed);
68+
int32_t max = param.a_max;
69+
int32_t min = param.a_min;
70+
int32_t *x_data = static_cast<int32_t*>(x->data);
71+
int32_t *y_data = static_cast<int32_t*>(y->data);
72+
#pragma omp parallel for
73+
for (uint64_t i = 0; i < getSize(x); i++) {
74+
y_data[i] = std::max(std::min(max, x_data[i]), min);
75+
}
76+
#ifdef CVM_PROFILING
77+
cvm_op_elemwise_cnt += omp_get_wtime() - start;
78+
#endif
79+
});
80+
81+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.flatten")
82+
.set_body([](CVMArgs args, CVMRetValue* rv)
83+
{
84+
#ifdef CVM_PROFILING
85+
double start = omp_get_wtime();
86+
#endif
87+
DLTensor *x = args[0];
88+
DLTensor *y = args[1];
89+
int32_t* x_data = static_cast<int32_t*>(x->data);
90+
int32_t* y_data = static_cast<int32_t*>(y->data);
91+
if(x_data != y_data){
92+
memcpy(y_data, x_data, getSize(x)*sizeof(int32_t));
93+
}
94+
95+
#ifdef CVM_PROFILING
96+
cvm_op_elemwise_cnt += omp_get_wtime() - start;
97+
#endif
98+
99+
print_to_file(y, "flatten.txt");
100+
});
101+
102+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.reshape")
103+
.set_body([](CVMArgs args, CVMRetValue *ret)
104+
{
105+
DLTensor *x = args[0];
106+
DLTensor *y = args[1];
107+
if(x->data == y->data) return;
108+
std::memcpy(y->data, x->data, getSize(x) * sizeof(int32_t));
109+
print_to_file(y, "reshape.txt");
110+
});
111+
112+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.cvm_clip")
113+
.set_body([](CVMArgs args, CVMRetValue *ret)
114+
{
115+
#ifdef CVM_PROFILING
116+
double start = omp_get_wtime();
117+
#endif
118+
DLTensor *x = args[0];
119+
DLTensor *y = args[1];
120+
int32_t *x_data = static_cast<int32_t*>(x->data);
121+
int32_t *y_data = static_cast<int32_t*>(y->data);
122+
123+
void *_attr = args[2];
124+
auto *attr = static_cast<cvm::NodeAttrs*>(_attr);
125+
auto &param = cvm::get<cvm::top::CVMClipParam>(attr->parsed);
126+
int32_t precision = param.precision;
127+
int32_t min = -(((int64_t)1 << (precision-1))-1);
128+
int32_t max = -min;
129+
130+
#pragma omp parallel for
131+
for(uint64_t i = 0; i < getSize(x); i++){
132+
int32_t tmp = x_data[i];
133+
if (tmp > max) tmp = max;
134+
else if (tmp < min) tmp = min;
135+
y_data[i] = tmp;
136+
}
137+
#ifdef CVM_PROFILING
138+
cvm_op_clip_cnt += omp_get_wtime() - start;
139+
#endif
140+
print_to_file(y, "clip.txt");
141+
}
142+
);
143+
144+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.cvm_right_shift")
145+
.set_body([](CVMArgs args, CVMRetValue *ret){
146+
DLTensor *a = args[0];
147+
DLTensor *c = args[1];
148+
149+
#ifdef CVM_PROFILING
150+
double start = omp_get_wtime();
151+
#endif
152+
void *_attr = args[2];
153+
auto *attr = static_cast<cvm::NodeAttrs*>(_attr);
154+
auto &param = cvm::get<cvm::top::CVMRightShiftParam>(attr->parsed);
155+
int32_t precision = param.precision;
156+
int32_t b = param.shift_bit;
157+
int32_t* a_data = static_cast<int32_t*>(a->data);
158+
int32_t* c_data = static_cast<int32_t*>(c->data);
159+
int32_t min = -(((int64_t)1 << (precision-1)) - 1);
160+
int32_t max = -min;
161+
auto size = getSize(a);
162+
163+
if (b == 1) {
164+
#pragma omp parallel for
165+
for(uint64_t i = 0; i < size; i++){
166+
int32_t shift_a = (a_data[i] + 1) >> 1;
167+
if (shift_a > max) shift_a = max;
168+
else if (shift_a < min) shift_a = min;
169+
c_data[i] = shift_a;
170+
}
171+
} else {
172+
b -= 1;
173+
#pragma omp parallel
174+
{
175+
#pragma omp for
176+
for(uint64_t i = 0; i < size; i++){
177+
c_data[i] = a_data[i] >> b;
178+
++c_data[i];
179+
c_data[i] >>= 1;
180+
}
181+
#pragma omp for
182+
for(uint64_t i = 0; i < size; i++){
183+
auto& shift_a = c_data[i];
184+
if (shift_a > max) shift_a = max;
185+
else if (shift_a < min) shift_a = min;
186+
}
187+
}
188+
}
189+
190+
#ifdef CVM_PROFILING
191+
cvm_op_cvm_shift_cnt += omp_get_wtime() - start;
192+
#endif
193+
print_to_file(c, "cvm_right_shift.txt");
194+
});
195+
196+
CVM_REGISTER_GLOBAL("cvm.runtime.cvm.cvm_left_shift")
197+
.set_body([](CVMArgs args, CVMRetValue *ret){
198+
DLTensor *a = args[0];
199+
DLTensor *c = args[1];
200+
void *_attr = args[2];
201+
auto *attr = static_cast<cvm::NodeAttrs*>(_attr);
202+
auto &param = cvm::get<cvm::top::CVMLeftShiftParam>(attr->parsed);
203+
int32_t precision = param.precision;
204+
int32_t b = param.shift_bit;std::string str_precision = args[2];
205+
int32_t* a_data = static_cast<int32_t*>(a->data);
206+
int32_t* c_data = static_cast<int32_t*>(c->data);
207+
int32_t min = -(((int64_t)1 << (precision-1)) - 1);
208+
int32_t max = -min;
209+
210+
for(uint64_t i = 0; i < getSize(a); i++){
211+
int32_t shift_a = a_data[i] << b;
212+
c_data[i] = std::max(std::min(shift_a, max), min);
213+
}
214+
});
215+
}
216+
}

0 commit comments

Comments
 (0)