Skip to content

Commit 3e4eda0

Browse files
committed
Merge remote-tracking branch 'remotes/origin/wlt' into ryt_tune
2 parents 7af61da + d9f942a commit 3e4eda0

File tree

3 files changed

+169
-10
lines changed

3 files changed

+169
-10
lines changed

kernel/libcvm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package kernel
22

33
/*
44
#cgo LDFLAGS: -ldl -lstdc++
5-
#cgo CFLAGS: -I../../../infernet/include -O2
5+
#cgo CFLAGS: -I../../../cvm-runtime/include -O2
66
#cgo CFLAGS: -Wall -Wno-unused-result -Wno-unknown-pragmas -Wno-unused-variable
77
88
#include "dlopen.h"

kernel/model.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ import (
55
"unsafe"
66
)
77

8+
const (
9+
CVM_VERSION_ONE = 0
10+
CVM_VERSION_TWO = 1
11+
)
12+
813
type Model struct {
914
model unsafe.Pointer
1015
lib *LibCVM
@@ -56,20 +61,33 @@ func (m *Model) GetInputLength() uint64 {
5661
return m.input_size
5762
}
5863

59-
func (m *Model) Predict(data []byte) ([]byte, int) {
64+
func (m *Model) Predict(data []byte, cvmVersion int) ([]byte, int) {
6065
var (
6166
output []byte
6267
status int
6368
err error
6469
)
65-
if len(data) != int(m.input_size) {
66-
log.Warn("input length not matched",
67-
"input length", len(data), "expected", m.input_size)
68-
return nil, ERROR_LOGIC
69-
}
70-
if data, err = ToAlignedData(data, int(m.input_byte)); err != nil {
71-
log.Warn("input ToAlignedData invalid", "error", err)
72-
return nil, ERROR_LOGIC
70+
if cvmVersion == CVM_VERSION_ONE {
71+
if len(data) < int(m.input_size) {
72+
log.Warn("input length less than input size",
73+
"input length", len(data), "expected", m.input_size)
74+
return nil, ERROR_LOGIC
75+
}
76+
if data, err = ToAlignedData(data[:m.input_size], int(m.input_byte)); err != nil {
77+
log.Warn("input ToAlignedData invalid", "error", err)
78+
return nil, ERROR_LOGIC
79+
}
80+
} else {
81+
// TODO(ryt): test it in istanbuer version code
82+
if len(data) != int(m.input_size) {
83+
log.Warn("input length not matched",
84+
"input length", len(data), "expected", m.input_size)
85+
return nil, ERROR_LOGIC
86+
}
87+
if data, err = ToAlignedData(data, int(m.input_byte)); err != nil {
88+
log.Warn("input ToAlignedData invalid", "error", err)
89+
return nil, ERROR_LOGIC
90+
}
7391
}
7492
if output, status = m.lib.Inference(m.model, data); status != SUCCEED {
7593
return nil, status

tests/golang/test_infer_main.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package main
2+
3+
/*
4+
#include <stdio.h>
5+
#include <stdlib.h>
6+
7+
enum CVMStatus {
8+
SUCCEED = 0,
9+
ERROR_LOGIC = -1,
10+
ERROR_RUNTIME = -2
11+
};
12+
13+
void myprint_int(long long *num) {
14+
printf("%lld\n", *num);
15+
*num = 4;
16+
}
17+
void myprint(char *s) {
18+
printf("%s\n", s);
19+
}
20+
21+
void new_arr(void **arr) {
22+
void *mal = malloc(sizeof(int) * 1);
23+
int *nums = (int*)mal;
24+
*nums = 1000;
25+
printf("%p %d\n", mal, *nums);
26+
*arr = mal;
27+
}
28+
29+
void myprint_void(void *arr) {
30+
int *nums = (int*)arr;
31+
printf("%p %d\n", arr, *nums);
32+
}
33+
34+
*/
35+
import "C"
36+
import (
37+
"fmt"
38+
"io/ioutil"
39+
_ "io/ioutil"
40+
"os"
41+
_ "reflect"
42+
_ "runtime"
43+
"unsafe"
44+
45+
"github.com/CortexFoundation/CortexTheseus/cvm-runtime/kernel"
46+
// "github.com/CortexFoundation/CortexTheseus/inference/synapse"
47+
"github.com/CortexFoundation/CortexTheseus/log"
48+
)
49+
50+
func test() {
51+
cs := C.CString("Hello from stdio")
52+
C.myprint(cs)
53+
C.free(unsafe.Pointer(cs))
54+
55+
var num C.longlong
56+
num = 3
57+
C.myprint_int(&num)
58+
fmt.Println(int64(num))
59+
60+
var s1 C.enum_CVMStatus
61+
s1 = C.ERROR_LOGIC
62+
s2 := C.ERROR_LOGIC
63+
fmt.Println(int(s1) == int(s2))
64+
65+
var arr unsafe.Pointer
66+
C.new_arr(&arr)
67+
C.myprint_void(arr)
68+
}
69+
70+
func main() {
71+
// Set log
72+
log.Root().SetHandler(log.LvlFilterHandler(log.Lvl(5), log.StreamHandler(os.Stdout, log.TerminalFormat(true))))
73+
74+
var (
75+
lib *kernel.LibCVM
76+
net *kernel.Model
77+
res []byte
78+
status int
79+
)
80+
81+
device := "cpu"
82+
deviceType := 0
83+
if device == "cuda" {
84+
deviceType = 1
85+
}
86+
// lib, status = kernel.LibOpen("./libcvm_runtime_" + device + ".so")
87+
// lib, status = kernel.LibOpen("./libcvm_runtime_cpu.so")
88+
lib, status = kernel.LibOpen("./build/libcvm_runtime.so")
89+
if status != kernel.SUCCEED {
90+
fmt.Printf("open library error: %d\n", status)
91+
return
92+
}
93+
94+
root := "/data/std_out/log2"
95+
// root := "/home/serving/ctxc_data/cpu/3145ad19228c1cd2d051314e72f26c1ce77b7f02/data"
96+
modelCfg, sErr := ioutil.ReadFile(root + "/symbol")
97+
if sErr != nil {
98+
fmt.Println(sErr)
99+
return
100+
}
101+
modelBin, pErr := ioutil.ReadFile(root + "/params")
102+
if pErr != nil {
103+
fmt.Println(pErr)
104+
return
105+
}
106+
// modelCfg := []byte("{}")
107+
// modelBin := []byte("dkjflsiejflsdkj")
108+
net, status = kernel.New(lib, modelCfg, modelBin, deviceType, 0)
109+
if status != kernel.SUCCEED {
110+
fmt.Printf("Failed LoadModel: %d\n", status)
111+
return
112+
}
113+
input_size := net.GetInputLength()
114+
fmt.Printf("Succeed LoadModel: %p ops=%s size=%s input_size=%s\n",
115+
&net, net.Ops(), net.Size(), input_size)
116+
117+
var data []byte = make([]byte, input_size)
118+
// cvmVersion := synapse.CVMVersion(cvm.chainConfig, cvm.BlockNumber)
119+
var cvmVersion int = 1
120+
res, status = net.Predict(data, cvmVersion)
121+
if status != kernel.SUCCEED {
122+
fmt.Printf("Failed Predict: %d\n", status)
123+
return
124+
}
125+
fmt.Printf("Succeed Predict: %v\n", res)
126+
127+
status = net.Free()
128+
if status != kernel.SUCCEED {
129+
fmt.Printf("Failed Free model: %d\n", status)
130+
return
131+
}
132+
fmt.Printf("Succeed Free model\n")
133+
134+
var gas uint64
135+
gas, status = kernel.GetModelGasFromGraphFile(lib, modelCfg)
136+
if status != kernel.SUCCEED {
137+
fmt.Printf("Failed get model gas from file: %s\n", status)
138+
return
139+
}
140+
fmt.Printf("Succeed get model gas from file: %s\n", int(gas))
141+
}

0 commit comments

Comments
 (0)