Skip to content

Commit 277be74

Browse files
committed
fix bug #4
1 parent e014b5c commit 277be74

File tree

4 files changed

+59
-9
lines changed

4 files changed

+59
-9
lines changed

README-zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ mvn clean package -DskipTests
4040

4141
你也可以直接在发布页下载打包好了最新版本 [发布页](https://github.com/aaronshan/hive-third-functions/releases).
4242

43-
> 当前最新的版本是 `2.1.2`
43+
> 当前最新的版本是 `2.1.3`
4444
4545
## 函数
4646

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ It will generate hive-third-functions-${version}-shaded.jar in target directory.
4040

4141
You can also directly download file from [release page](https://github.com/aaronshan/hive-third-functions/releases).
4242

43-
> current latest version is `2.1.2`
43+
> current latest version is `2.1.3`
4444
4545
## Functions
4646

src/main/java/cc/shanruifeng/functions/array/UDFArrayIntersect.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import cc.shanruifeng.functions.fastuitl.ints.IntArrays;
44
import java.util.ArrayList;
5+
import java.util.Arrays;
6+
57
import org.apache.hadoop.hive.ql.exec.Description;
68
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
79
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
@@ -139,18 +141,14 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException {
139141
} else if (compareValue < 0) {
140142
leftCurrentPosition++;
141143
} else {
142-
result.add(converter.convert(leftArrayElement));
144+
result.add(converter.convert(leftArrayOI.getListElement(leftArray, leftPositions[leftCurrentPosition])));
143145
leftCurrentPosition++;
144146
rightCurrentPosition++;
145147

146-
Object leftArrayElementTmp1 = leftArrayOI.getListElement(leftArray, leftPositions[leftBasePosition]);
147-
Object leftArrayElementTmp2 = leftArrayOI.getListElement(leftArray, leftPositions[leftCurrentPosition]);
148-
Object rightArrayElementTmp1 = rightArrayOI.getListElement(rightArray, rightPositions[rightBasePosition]);
149-
Object rightArrayElementTmp2 = rightArrayOI.getListElement(rightArray, rightPositions[rightCurrentPosition]);
150-
while (leftCurrentPosition < leftArrayLength && ObjectInspectorUtils.compare(leftArrayElementTmp1, leftArrayElementOI, leftArrayElementTmp2, leftArrayElementOI) == 0) {
148+
while (leftCurrentPosition < leftArrayLength && compare(leftArrayOI, leftArray, leftBasePosition, leftCurrentPosition) == 0) {
151149
leftCurrentPosition++;
152150
}
153-
while (rightCurrentPosition < rightArrayLength && ObjectInspectorUtils.compare(rightArrayElementTmp1, rightArrayElementOI, rightArrayElementTmp2, rightArrayElementOI) == 0) {
151+
while (rightCurrentPosition < rightArrayLength && compare(rightArrayOI, rightArray, rightBasePosition, rightCurrentPosition) == 0) {
154152
rightCurrentPosition++;
155153
}
156154
}
@@ -159,6 +157,13 @@ public Object evaluate(DeferredObject[] arguments) throws HiveException {
159157
return result;
160158
}
161159

160+
private int compare(ListObjectInspector arrayOI, Object array, int position1, int position2) {
161+
ObjectInspector arrayElementOI = arrayOI.getListElementObjectInspector();
162+
Object arrayElementTmp1 = arrayOI.getListElement(array, leftPositions[position1]);
163+
Object arrayElementTmp2 = arrayOI.getListElement(array, leftPositions[position2]);
164+
return ObjectInspectorUtils.compare(arrayElementTmp1, arrayElementOI, arrayElementTmp2, arrayElementOI);
165+
}
166+
162167
@Override
163168
public String getDisplayString(String[] strings) {
164169
assert (strings.length == ARG_COUNT);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package cc.shanruifeng.functions.array;
2+
3+
import com.google.common.collect.ImmutableList;
4+
import com.google.common.collect.Iterables;
5+
import org.apache.hadoop.hive.ql.metadata.HiveException;
6+
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
7+
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
8+
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
9+
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
10+
import org.junit.Test;
11+
12+
import java.util.ArrayList;
13+
import java.util.List;
14+
15+
import static org.junit.Assert.*;
16+
17+
/**
18+
* @author ruifeng.shan
19+
* @date 2018-07-18
20+
* @time 13:00
21+
*/
22+
public class UDFArrayIntersectTest {
23+
@Test
24+
public void testArrayIntersect() throws HiveException {
25+
UDFArrayIntersect udf = new UDFArrayIntersect();
26+
27+
ObjectInspector leftArrayOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
28+
ObjectInspector rightArrayOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
29+
ObjectInspector[] arguments = {leftArrayOI, rightArrayOI};
30+
31+
udf.initialize(arguments);
32+
33+
assertTrue(Iterables.elementsEqual(ImmutableList.of(1,2,5), evaluate(ImmutableList.of(0,1,2,3,4,5), ImmutableList.of(1,1,2,2,5,5), udf)));
34+
assertTrue(Iterables.elementsEqual(ImmutableList.of(1,2,3,4), evaluate(ImmutableList.of(0,1,2,3,4,4), ImmutableList.of(1,1,2,2,3,4), udf)));
35+
assertTrue(Iterables.elementsEqual(ImmutableList.of(1,2,3,4), evaluate(ImmutableList.of(0,1,1,2,3,4,4), ImmutableList.of(1,1,2,2,3,4), udf)));
36+
}
37+
38+
private ArrayList<Object> evaluate(List<Integer> leftArray, List<Integer> rightArray, UDFArrayIntersect udf) throws HiveException {
39+
GenericUDF.DeferredObject leftArrayObj = new GenericUDF.DeferredJavaObject(leftArray);
40+
GenericUDF.DeferredObject rightArrayObj = new GenericUDF.DeferredJavaObject(rightArray);
41+
GenericUDF.DeferredObject[] args = {leftArrayObj, rightArrayObj};
42+
ArrayList<Object> output = (ArrayList<Object>) udf.evaluate(args);
43+
return output;
44+
}
45+
}

0 commit comments

Comments
 (0)