Skip to content

Commit f78ff0d

Browse files
committed
Add basic type inference for tuple channel operators
1 parent 81dc980 commit f78ff0d

File tree

4 files changed

+320
-10
lines changed

4 files changed

+320
-10
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright 2024-2025, Seqera Labs
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package nextflow.script.control;
17+
18+
import java.util.Collections;
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
import nextflow.script.types.Bag;
23+
import nextflow.script.types.Channel;
24+
import nextflow.script.types.Tuple;
25+
import nextflow.script.types.Value;
26+
import org.codehaus.groovy.ast.ClassHelper;
27+
import org.codehaus.groovy.ast.ClassNode;
28+
import org.codehaus.groovy.ast.GenericsType;
29+
import org.codehaus.groovy.ast.MethodNode;
30+
import org.codehaus.groovy.ast.expr.Expression;
31+
import org.codehaus.groovy.ast.expr.NamedArgumentListExpression;
32+
33+
import static nextflow.script.ast.ASTUtils.*;
34+
import static nextflow.script.types.TypeCheckingUtils.*;
35+
36+
/**
37+
*
38+
* @author Ben Sherman <bentshermann@gmail.com>
39+
*/
40+
class TupleOpResolver {
41+
42+
private static final ClassNode BAG_TYPE = ClassHelper.makeCached(Bag.class);
43+
private static final ClassNode CHANNEL_TYPE = ClassHelper.makeCached(Channel.class);
44+
private static final ClassNode TUPLE_TYPE = ClassHelper.makeCached(Tuple.class);
45+
private static final ClassNode VALUE_TYPE = ClassHelper.makeCached(Value.class);
46+
47+
/**
48+
* Resolve the return type of dataflow operators that tranform
49+
* tuples, such as `combine`, `groupTuple`, and `join`.
50+
*
51+
* @param lhsType
52+
* @param method
53+
* @param arguments
54+
*/
55+
public ClassNode apply(ClassNode lhsType, MethodNode method, List<Expression> arguments) {
56+
var name = method.getName();
57+
58+
if( "combine".equals(name) )
59+
return applyCombine(lhsType, arguments);
60+
61+
if( "groupTuple".equals(name) )
62+
return applyGroupBy(lhsType, arguments);
63+
64+
if( "join".equals(name) )
65+
return applyJoin(lhsType, arguments);
66+
67+
return ClassHelper.dynamicType();
68+
}
69+
70+
/**
71+
* Resolve the result type of a `combine` operation in terms of the left
72+
* and right operands.
73+
*
74+
* Given arguments of type `(L1, L2, ..., Lm)` and `R`, `combine`
75+
* produces a tuple of type `(L1, L2, ..., Lm, R).
76+
*
77+
* When the `by` option is specified, `combine` produces the same result
78+
* type as `join`.
79+
*
80+
* @param lhsType
81+
* @param arguments
82+
*/
83+
private ClassNode applyCombine(ClassNode lhsType, List<Expression> arguments) {
84+
if( !TUPLE_TYPE.equals(lhsType) )
85+
return ClassHelper.dynamicType();
86+
87+
var namedArgs = namedArgs(arguments);
88+
if( namedArgs.containsKey("by") )
89+
return applyJoin(lhsType, arguments);
90+
91+
var argType = getType(arguments.get(arguments.size() - 1));
92+
var rhsType = dataflowElementType(argType);
93+
94+
var lgts = lhsType.getGenericsTypes();
95+
if( lgts == null || lgts.length == 0 )
96+
return ClassHelper.dynamicType();
97+
98+
var gts = new GenericsType[lgts.length + 1];
99+
for( int i = 0; i < lgts.length; i++ )
100+
gts[i] = lgts[i];
101+
gts[lgts.length] = new GenericsType(rhsType);
102+
103+
return channelTupleType(gts);
104+
}
105+
106+
/**
107+
* Resolve the result type of a `groupTuple` operation.
108+
*
109+
* Given source tuples of type `(K, V1, V2, ..., Vn)`,
110+
* `groupTuple` produces a tuple of type `(K, Bag<V1>, Bag<V2>, ..., Bag<Vn>)`.
111+
*
112+
* @param lhsType
113+
* @param arguments
114+
*/
115+
private ClassNode applyGroupBy(ClassNode lhsType, List<Expression> arguments) {
116+
if( !TUPLE_TYPE.equals(lhsType) )
117+
return ClassHelper.dynamicType();
118+
119+
var namedArgs = namedArgs(arguments);
120+
if( namedArgs.containsKey("by") )
121+
return ClassHelper.dynamicType();
122+
123+
var lgts = lhsType.getGenericsTypes();
124+
if( lgts == null || lgts.length == 0 )
125+
return ClassHelper.dynamicType();
126+
127+
// TODO: group on index specified by `by` option
128+
// TODO: skip if `by` option isn't a single integer
129+
var gts = new GenericsType[lgts.length];
130+
gts[0] = lgts[0];
131+
for( int i = 1; i < lgts.length; i++ ) {
132+
var groupType = makeType(BAG_TYPE, lgts[i].getType());
133+
gts[i] = new GenericsType(groupType);
134+
}
135+
136+
return channelTupleType(gts);
137+
}
138+
139+
/**
140+
* Resolve the result type of a `join` operation in terms of the left
141+
* and right operands.
142+
*
143+
* Given tuples of type `(K, L1, L2, ..., Lm)` and `(K, R1, R2, ..., Rn)`,
144+
* `join` produces a tuple of type `(K, L1, L2, ..., Lm, R1, R2, ..., Rn).
145+
*
146+
* @param lhsType
147+
* @param arguments
148+
*/
149+
private ClassNode applyJoin(ClassNode lhsType, List<Expression> arguments) {
150+
if( !TUPLE_TYPE.equals(lhsType) )
151+
return ClassHelper.dynamicType();
152+
153+
var namedArgs = namedArgs(arguments);
154+
if( namedArgs.containsKey("by") )
155+
return ClassHelper.dynamicType();
156+
157+
var argType = getType(arguments.get(arguments.size() - 1));
158+
var rhsType = dataflowElementType(argType);
159+
if( !TUPLE_TYPE.equals(rhsType) )
160+
return ClassHelper.dynamicType();
161+
162+
var lgts = lhsType.getGenericsTypes();
163+
var rgts = rhsType.getGenericsTypes();
164+
if( lgts == null || lgts.length == 0 || rgts == null || rgts.length == 0 )
165+
return ClassHelper.dynamicType();
166+
167+
// TODO: join on index specified by `by` option
168+
// TODO: skip if `by` option isn't a single integer
169+
var gts = new GenericsType[lgts.length + rgts.length - 1];
170+
for( int i = 0; i < lgts.length; i++ )
171+
gts[i] = lgts[i];
172+
for( int i = 1; i < rgts.length; i++ )
173+
gts[lgts.length + i - 1] = rgts[i];
174+
175+
return channelTupleType(gts);
176+
}
177+
178+
private static Map<String,Expression> namedArgs(List<Expression> args) {
179+
return args.size() > 0 && args.get(0) instanceof NamedArgumentListExpression nale
180+
? Map.ofEntries(
181+
nale.getMapEntryExpressions().stream()
182+
.map((entry) -> {
183+
var name = entry.getKeyExpression().getText();
184+
var value = entry.getValueExpression();
185+
return Map.entry(name, value);
186+
})
187+
.toArray(Map.Entry[]::new)
188+
)
189+
: Collections.emptyMap();
190+
}
191+
192+
private static ClassNode dataflowElementType(ClassNode type) {
193+
if( CHANNEL_TYPE.equals(type) || VALUE_TYPE.equals(type) )
194+
return elementType(type);
195+
return ClassHelper.dynamicType();
196+
}
197+
198+
private static ClassNode channelTupleType(GenericsType[] gts) {
199+
var tupleType = TUPLE_TYPE.getPlainNodeReference();
200+
tupleType.setGenericsTypes(gts);
201+
return makeType(CHANNEL_TYPE, tupleType);
202+
}
203+
204+
}

src/main/java/nextflow/script/control/TypeCheckingVisitorEx.java

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.util.Map;
2424
import java.util.stream.IntStream;
2525

26-
import nextflow.lsp.ast.ASTNodeCache;
2726
import nextflow.script.ast.ASTNodeMarker;
2827
import nextflow.script.ast.AssignmentExpression;
2928
import nextflow.script.ast.FeatureFlagNode;
@@ -117,8 +116,18 @@ protected SourceUnit getSourceUnit() {
117116

118117
public void visit() {
119118
var moduleNode = sourceUnit.getAST();
120-
if( moduleNode instanceof ScriptNode sn )
121-
visit(sn);
119+
if( moduleNode instanceof ScriptNode sn ) {
120+
for( var featureFlag : sn.getFeatureFlags() )
121+
visitFeatureFlag(featureFlag);
122+
if( sn.getParams() != null )
123+
visitParams(sn.getParams());
124+
for( var functionNode : sn.getFunctions() )
125+
visitFunction(functionNode);
126+
for( var processNode : sn.getProcesses() )
127+
visitProcess(processNode);
128+
for( var workflowNode : sn.getWorkflows() )
129+
visitWorkflow(workflowNode);
130+
}
122131
}
123132

124133
// script declarations
@@ -478,6 +487,8 @@ public void visitMethodCallExpression(MethodCallExpression node) {
478487
var dummyMethod = resolveGenericReturnType(receiverType, target, arguments);
479488
node.putNodeMetaData(ASTNodeMarker.METHOD_TARGET, dummyMethod);
480489
node.putNodeMetaData(ASTNodeMarker.INFERRED_TYPE, dummyMethod.getReturnType());
490+
491+
checkOperatorCall(node);
481492
}
482493
else if( node.getNodeMetaData(ASTNodeMarker.METHOD_TARGET) instanceof MethodNode mn ) {
483494
var parameters = mn.getParameters();
@@ -550,6 +561,30 @@ private void checkSpreadMethodCall(MethodCallExpression node) {
550561
}
551562
}
552563

564+
/**
565+
* Resolve the return type of operators that transform tuples.
566+
* such as `combine`, `groupTuple`, and `join`.
567+
*
568+
* @param node
569+
*/
570+
private void checkOperatorCall(MethodCallExpression node) {
571+
if( node.isImplicitThis() )
572+
return;
573+
574+
var receiverType = getType(node.getObjectExpression());
575+
if( !CHANNEL_TYPE.equals(receiverType) )
576+
return;
577+
578+
var lhsType = elementType(receiverType);
579+
var method = (MethodNode) node.getNodeMetaData(ASTNodeMarker.METHOD_TARGET);
580+
var arguments = asMethodCallArguments(node);
581+
var resultType = new TupleOpResolver().apply(lhsType, method, arguments);
582+
if( ClassHelper.isDynamicTyped(resultType) )
583+
return;
584+
585+
node.putNodeMetaData(ASTNodeMarker.INFERRED_TYPE, resultType);
586+
}
587+
553588
/**
554589
* Check the arguments of an invalid method call and report appropriate
555590
* errors for each invalid argument.
@@ -691,13 +726,6 @@ private static ClassNode dataflowElementType(ClassNode type) {
691726
return type;
692727
}
693728

694-
private static ClassNode elementType(ClassNode type) {
695-
var gts = type.getGenericsTypes();
696-
if( gts == null || gts.length != 1 )
697-
return ClassHelper.dynamicType();
698-
return gts[0].getType();
699-
}
700-
701729
private static ClassNode processOutputType(ClassNode dataflowType, Statement block) {
702730
var outputs = asBlockStatements(block);
703731
if( outputs.size() == 1 ) {

src/main/java/nextflow/script/types/TypeCheckingUtils.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,4 +859,16 @@ public static ClassNode makeType(ClassNode type, ClassNode... typeArguments) {
859859
cn.setGenericsTypes(gts);
860860
return cn;
861861
}
862+
863+
/**
864+
* Return the element type of a type with one type parameter.
865+
*
866+
* @param type
867+
*/
868+
public static ClassNode elementType(ClassNode type) {
869+
var gts = type.getGenericsTypes();
870+
if( gts == null || gts.length != 1 )
871+
return ClassHelper.dynamicType();
872+
return gts[0].getType();
873+
}
862874
}

src/test/groovy/nextflow/script/types/TypeCheckingTest.groovy

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,31 @@ class TypeCheckingTest extends Specification {
643643
type = getType(exp)
644644
then:
645645
TypesEx.getName(type) == 'Value<String>'
646+
647+
when:
648+
exp = parseExpression(
649+
'''\
650+
nextflow.preview.types = true
651+
652+
process hello {
653+
input:
654+
target: String
655+
656+
output:
657+
tuple(target, "Hello, $target!")
658+
659+
exec:
660+
true
661+
}
662+
663+
workflow {
664+
hello( 'World' )
665+
}
666+
'''
667+
)
668+
type = getType(exp)
669+
then:
670+
TypesEx.getName(type) == 'Value<Tuple<String, String>>'
646671
}
647672

648673
def 'should resolve tuple type' () {
@@ -705,4 +730,45 @@ class TypeCheckingTest extends Specification {
705730
"files('*.txt')*.toUriString()" | null
706731
}
707732

733+
def 'should resolve a `combine` operation' () {
734+
when:
735+
def exp = parseExpression(
736+
'''\
737+
left = channel.of( tuple(42, 'hello') )
738+
right = channel.of( true )
739+
left.combine(right)
740+
'''
741+
)
742+
def type = getType(exp)
743+
then:
744+
TypesEx.getName(type) == 'Channel<Tuple<Integer, String, Boolean>>'
745+
}
746+
747+
def 'should resolve a `groupTuple` operation' () {
748+
when:
749+
def exp = parseExpression(
750+
'''\
751+
left = channel.of( tuple(42, 'hello'), tuple(42, 'goodbye') )
752+
left.groupTuple()
753+
'''
754+
)
755+
def type = getType(exp)
756+
then:
757+
TypesEx.getName(type) == 'Channel<Tuple<Integer, Bag<String>>>'
758+
}
759+
760+
def 'should resolve a `join` operation' () {
761+
when:
762+
def exp = parseExpression(
763+
'''\
764+
left = channel.of( tuple(42, 'hello') )
765+
right = channel.of( tuple(42, true) )
766+
left.join(right)
767+
'''
768+
)
769+
def type = getType(exp)
770+
then:
771+
TypesEx.getName(type) == 'Channel<Tuple<Integer, String, Boolean>>'
772+
}
773+
708774
}

0 commit comments

Comments
 (0)