Skip to content

Commit 24d8eec

Browse files
authored
Merge pull request #510 from joker-star-l/tf-scala-like-api
Add scala like api for deep learning
2 parents d56e987 + 141bcac commit 24d8eec

File tree

4 files changed

+294
-2
lines changed

4 files changed

+294
-2
lines changed

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,18 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
588588
predictOperator
589589
}
590590

591+
def predictJava[ThatOut: ClassTag, Result: ClassTag](
592+
that: DataQuanta[ThatOut]
593+
): DataQuanta[Result] = {
594+
val predictOperator = new PredictOperator(
595+
implicitly[ClassTag[ThatOut]].runtimeClass,
596+
implicitly[ClassTag[Result]].runtimeClass
597+
)
598+
this.connectTo(predictOperator, 0)
599+
that.connectTo(predictOperator, 1)
600+
predictOperator
601+
}
602+
591603
def dlTraining[ThatOut: ClassTag](
592604
model: DLModel,
593605
option: DLTrainingOperator.Option,
@@ -616,6 +628,24 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
616628
dlTrainingOperator
617629
}
618630

631+
632+
def dlTrainingJava[ThatOut: ClassTag](
633+
model: DLModel,
634+
option: DLTrainingOperator.Option,
635+
that: DataQuanta[ThatOut]
636+
): DataQuanta[DLModel] = {
637+
val dlTrainingOperator = new DLTrainingOperator(
638+
model,
639+
option,
640+
implicitly[ClassTag[Out]].runtimeClass,
641+
implicitly[ClassTag[ThatOut]].runtimeClass
642+
)
643+
644+
this.connectTo(dlTrainingOperator, 0)
645+
that.connectTo(dlTrainingOperator, 1)
646+
dlTrainingOperator
647+
}
648+
619649
/**
620650
* Feeds this and a further instance into a [[CoGroupOperator]].
621651
*

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ import java.util.{Collection => JavaCollection}
2727
import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaBuilderDecorator}
2828
import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap}
2929
import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
30-
import org.apache.wayang.basic.operators.{GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
30+
import org.apache.wayang.basic.model.{DLModel, Model}
31+
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
3132
import org.apache.wayang.commons.util.profiledb.model.Experiment
3233
import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate}
3334
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval
@@ -273,6 +274,30 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
273274
thatKeyUdf: SerializableFunction[ThatOut, Key]) =
274275
new JoinDataQuantaBuilder(this, that, thisKeyUdf, thatKeyUdf)
275276

277+
/**
278+
* Feed the built [[DataQuanta]] of this and the given instance into a
279+
* [[org.apache.wayang.basic.operators.DLTrainingOperator]].
280+
*
281+
* @param that the other [[DataQuantaBuilder]] to join with
282+
* @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
283+
* @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
284+
* @return a [[DLTrainingDataQuantaBuilder]]
285+
*/
286+
def dlTraining[ThatOut](that: DataQuantaBuilder[_, ThatOut],
287+
model: DLModel,
288+
option: DLTrainingOperator.Option) =
289+
new DLTrainingDataQuantaBuilder(this, that, model, option)
290+
291+
/**
292+
* Feed the built [[DataQuanta]] of this and the given instance into a
293+
* [[org.apache.wayang.basic.operators.PredictOperator]].
294+
*
295+
* @param that the other [[DataQuantaBuilder]] to join with
296+
* @return a [[PredictDataQuantaBuilder]]
297+
*/
298+
def predict[ThatOut, Result](that: DataQuantaBuilder[_, ThatOut], resultType: Class[Result]) =
299+
new PredictDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Model]], that, resultType)
300+
276301
/**
277302
* Feed the built [[DataQuanta]] of this and the given instance into a
278303
* [[org.apache.wayang.basic.operators.CoGroupOperator]].
@@ -1336,6 +1361,53 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
13361361

13371362
}
13381363

1364+
/**
1365+
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.DLTrainingOperator]]s.
1366+
*
1367+
* @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1368+
* @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1369+
* @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
1370+
* @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
1371+
*/
1372+
class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_, In0],
1373+
inputDataQuanta1: DataQuantaBuilder[_, In1],
1374+
model: DLModel,
1375+
option: DLTrainingOperator.Option)
1376+
(implicit javaPlanBuilder: JavaPlanBuilder)
1377+
extends BasicDataQuantaBuilder[DLTrainingDataQuantaBuilder[In0, In1], DLModel] {
1378+
1379+
// Since we are currently not looking at type parameters, we can statically determine the output type.
1380+
locally {
1381+
this.outputTypeTrap.dataSetType = dataSetType[DLModel]
1382+
}
1383+
1384+
override protected def build =
1385+
inputDataQuanta0.dataQuanta()
1386+
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
1387+
}
1388+
1389+
/**
1390+
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.PredictOperator]]s.
1391+
*
1392+
* @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1393+
* @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1394+
*/
1395+
class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_, Model],
1396+
inputDataQuanta1: DataQuantaBuilder[_, In1],
1397+
outClass: Class[Out])
1398+
(implicit javaPlanBuilder: JavaPlanBuilder)
1399+
extends BasicDataQuantaBuilder[PredictDataQuantaBuilder[In1, Out], Out] {
1400+
1401+
// Since we are currently not looking at type parameters, we can statically determine the output type.
1402+
locally {
1403+
this.outputTypeTrap.dataSetType = dataSetType[Out]
1404+
}
1405+
1406+
override protected def build =
1407+
inputDataQuanta0.dataQuanta().
1408+
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass))
1409+
}
1410+
13391411
/**
13401412
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.CoGroupOperator]]s.
13411413
*

wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public class TensorflowIrisIT {
5757
"Iris-virginica", 2
5858
);
5959

60-
@Ignore
60+
@Test
6161
public void test() {
6262
final Tuple<Operator, Operator> trainSource = fileOperation(TRAIN_PATH, true);
6363
final Tuple<Operator, Operator> testSource = fileOperation(TEST_PATH, false);
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.tests;
20+
21+
import org.apache.wayang.api.*;
22+
import org.apache.wayang.basic.model.DLModel;
23+
import org.apache.wayang.basic.model.op.*;
24+
import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss;
25+
import org.apache.wayang.basic.model.op.nn.Linear;
26+
import org.apache.wayang.basic.model.op.nn.Sigmoid;
27+
import org.apache.wayang.basic.model.optimizer.Adam;
28+
import org.apache.wayang.basic.model.optimizer.Optimizer;
29+
import org.apache.wayang.basic.operators.DLTrainingOperator;
30+
import org.apache.wayang.core.api.WayangContext;
31+
import org.apache.wayang.core.util.Tuple;
32+
import org.apache.wayang.java.Java;
33+
import org.apache.wayang.tensorflow.Tensorflow;
34+
import org.junit.Test;
35+
36+
import java.net.URI;
37+
import java.net.URISyntaxException;
38+
import java.util.ArrayList;
39+
import java.util.List;
40+
import java.util.Map;
41+
import java.util.Random;
42+
43+
/**
44+
* Test the Tensorflow integration with Wayang.
45+
* Note: this test fails on M1 Macs because of Tensorflow-Java incompatibility.
46+
*/
47+
public class TensorflowIrisScalaLikeApiIT {
48+
49+
public static URI TRAIN_PATH = createUri("/iris_train.csv");
50+
public static URI TEST_PATH = createUri("/iris_test.csv");
51+
52+
public static Map<String, Integer> LABEL_MAP = Map.of(
53+
"Iris-setosa", 0,
54+
"Iris-versicolor", 1,
55+
"Iris-virginica", 2
56+
);
57+
58+
@Test
59+
public void test() {
60+
WayangContext wayangContext = new WayangContext()
61+
.with(Java.basicPlugin())
62+
.with(Tensorflow.plugin());
63+
64+
JavaPlanBuilder plan = new JavaPlanBuilder(wayangContext);
65+
66+
final Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>> trainSource =
67+
fileOperation(plan, TRAIN_PATH, true);
68+
final Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>> testSource =
69+
fileOperation(plan, TEST_PATH, false);
70+
71+
/* training features */
72+
DataQuantaBuilder<?, float[]> trainXSource = trainSource.field0;
73+
74+
/* training labels */
75+
DataQuantaBuilder<?, Integer> trainYSource = trainSource.field1;
76+
77+
/* test features */
78+
DataQuantaBuilder<?, float[]> testXSource = testSource.field0;
79+
80+
/* test labels */
81+
DataQuantaBuilder<?, Integer> testYSource = testSource.field1;
82+
83+
/* model */
84+
Op l1 = new Linear(4, 32, true);
85+
Op s1 = new Sigmoid();
86+
Op l2 = new Linear(32, 3, true);
87+
s1.with(l1.with(new Input(Input.Type.FEATURES)));
88+
l2.with(s1);
89+
90+
DLModel model = new DLModel(l2);
91+
92+
/* training options */
93+
// 1. loss function
94+
Op criterion = new CrossEntropyLoss(3);
95+
criterion.with(
96+
new Input(Input.Type.PREDICTED, Op.DType.FLOAT32),
97+
new Input(Input.Type.LABEL, Op.DType.INT32)
98+
);
99+
100+
// 2. accuracy calculation function
101+
Op acc = new Mean(0);
102+
acc.with(new Cast(Op.DType.FLOAT32).with(new Eq().with(
103+
new ArgMax(1).with(new Input(Input.Type.PREDICTED, Op.DType.FLOAT32)),
104+
new Input(Input.Type.LABEL, Op.DType.INT32)
105+
)));
106+
107+
// 3. optimizer with learning rate
108+
Optimizer optimizer = new Adam(0.1f);
109+
110+
// 4. batch size
111+
int batchSize = 45;
112+
113+
// 5. epoch
114+
int epoch = 10;
115+
116+
DLTrainingOperator.Option option = new DLTrainingOperator.Option(criterion, optimizer, batchSize, epoch);
117+
option.setAccuracyCalculation(acc);
118+
119+
/* training operator */
120+
DLTrainingDataQuantaBuilder<float[], Integer> trainingOperator =
121+
trainXSource.dlTraining(trainYSource, model, option);
122+
123+
/* predict operator */
124+
PredictDataQuantaBuilder<float[], float[]> predictOperator =
125+
trainingOperator.predict(testXSource, float[].class);
126+
127+
/* map to label */
128+
MapDataQuantaBuilder<float[], Integer> mapOperator = predictOperator.map(array -> {
129+
int maxIdx = 0;
130+
float maxVal = array[0];
131+
for (int i = 1; i < array.length; i++) {
132+
if (array[i] > maxVal) {
133+
maxIdx = i;
134+
maxVal = array[i];
135+
}
136+
}
137+
return maxIdx;
138+
});
139+
140+
/* sink */
141+
List<Integer> predicted = new ArrayList<>(mapOperator.collect());
142+
// fixme: Currently, wayang's scala-like api only supports a single collect,
143+
// so it is not possible to collect multiple result lists in a single plan.
144+
// List<Integer> groundTruth = new ArrayList<>(testYSource.collect());
145+
146+
System.out.println("predicted: " + predicted);
147+
// System.out.println("ground truth: " + groundTruth);
148+
149+
// float success = 0;
150+
// for (int i = 0; i < predicted.size(); i++) {
151+
// if (predicted.get(i).equals(groundTruth.get(i))) {
152+
// success += 1;
153+
// }
154+
// }
155+
// System.out.println("test accuracy: " + success / predicted.size());
156+
}
157+
158+
public static Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>>
159+
fileOperation(JavaPlanBuilder plan, URI uri, boolean random) {
160+
DataQuantaBuilder<?, String> textFileSource = plan.readTextFile(uri.toString());
161+
162+
if (random) {
163+
Random r = new Random();
164+
textFileSource = textFileSource.sort(e -> r.nextInt());
165+
}
166+
167+
MapDataQuantaBuilder<String, Tuple<float[], Integer>> mapXY = textFileSource.map(line -> {
168+
String[] parts = line.split(",");
169+
float[] x = new float[parts.length - 1];
170+
for (int i = 0; i < x.length; i++) {
171+
x[i] = Float.parseFloat(parts[i]);
172+
}
173+
int y = LABEL_MAP.get(parts[parts.length - 1]);
174+
return new Tuple<>(x, y);
175+
});
176+
177+
MapDataQuantaBuilder<Tuple<float[], Integer>, float[]> mapX = mapXY.map(tuple -> tuple.field0);
178+
MapDataQuantaBuilder<Tuple<float[], Integer>, Integer> mapY = mapXY.map(tuple -> tuple.field1);
179+
180+
return new Tuple<>(mapX, mapY);
181+
}
182+
183+
public static URI createUri(String resourcePath) {
184+
try {
185+
return TensorflowIrisScalaLikeApiIT.class.getResource(resourcePath).toURI();
186+
} catch (URISyntaxException e) {
187+
throw new IllegalArgumentException("Illegal URI.", e);
188+
}
189+
}
190+
}

0 commit comments

Comments
 (0)