Skip to content

Commit 794dcda

Browse files
committed
[FLINK-37902] batch support for ml_predict
1 parent 321b17f commit 794dcda

File tree

26 files changed

+3096
-1247
lines changed

26 files changed

+3096
-1247
lines changed

flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/MLPredictTypeStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ private static Optional<List<DataType>> inferMLPredictInputTypes(
157157
return Optional.empty();
158158
}
159159

160-
// Config map validation is done in StreamPhysicalMLPredictTableFunctionRule since
160+
// Config map validation is done in PhysicalMLPredictTableFunctionRule since
161161
// we are not able to get map literal here.
162162
return Optional.of(callContext.getArgumentDataTypes());
163163
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.nodes.exec.batch;
20+
21+
import org.apache.flink.FlinkVersion;
22+
import org.apache.flink.configuration.ReadableConfig;
23+
import org.apache.flink.table.data.RowData;
24+
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
25+
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata;
26+
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
27+
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
28+
import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecMLPredictTableFunction;
29+
import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
30+
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
31+
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
32+
import org.apache.flink.table.types.logical.RowType;
33+
34+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
35+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
36+
37+
import javax.annotation.Nullable;
38+
39+
import java.util.Collections;
40+
import java.util.List;
41+
42+
/** Batch {@link ExecNode} for {@code ML_PREDICT}. */
43+
@ExecNodeMetadata(
44+
name = "batch-exec-ml-predict-table-function",
45+
version = 1,
46+
consumedOptions = {
47+
"table.exec.async-ml-predict.max-concurrent-operations",
48+
"table.exec.async-ml-predict.timeout",
49+
"table.exec.async-ml-predict.output-mode"
50+
},
51+
producedTransformations = CommonExecMLPredictTableFunction.ML_PREDICT_TRANSFORMATION,
52+
minPlanVersion = FlinkVersion.v2_3,
53+
minStateVersion = FlinkVersion.v2_3)
54+
public class BatchExecMLPredictTableFunction extends CommonExecMLPredictTableFunction
55+
implements SingleTransformationTranslator<RowData>, BatchExecNode<RowData> {
56+
57+
public BatchExecMLPredictTableFunction(
58+
ReadableConfig persistedConfig,
59+
MLPredictSpec mlPredictSpec,
60+
ModelSpec modelSpec,
61+
@Nullable FunctionCallUtil.AsyncOptions asyncOptions,
62+
InputProperty inputProperty,
63+
RowType outputType,
64+
String description) {
65+
this(
66+
ExecNodeContext.newNodeId(),
67+
ExecNodeContext.newContext(BatchExecMLPredictTableFunction.class),
68+
ExecNodeContext.newPersistedConfig(
69+
BatchExecMLPredictTableFunction.class, persistedConfig),
70+
mlPredictSpec,
71+
modelSpec,
72+
asyncOptions,
73+
Collections.singletonList(inputProperty),
74+
outputType,
75+
description);
76+
}
77+
78+
@JsonCreator
79+
public BatchExecMLPredictTableFunction(
80+
@JsonProperty(FIELD_NAME_ID) int id,
81+
@JsonProperty(FIELD_NAME_TYPE) ExecNodeContext context,
82+
@JsonProperty(FIELD_NAME_CONFIGURATION) ReadableConfig persistedConfig,
83+
@JsonProperty(FIELD_NAME_ML_PREDICT_SPEC) MLPredictSpec mlPredictSpec,
84+
@JsonProperty(FIELD_NAME_MODEL_SPEC) ModelSpec modelSpec,
85+
@JsonProperty(FIELD_NAME_ASYNC_OPTIONS) @Nullable
86+
FunctionCallUtil.AsyncOptions asyncOptions,
87+
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
88+
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
89+
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
90+
super(
91+
id,
92+
context,
93+
persistedConfig,
94+
mlPredictSpec,
95+
modelSpec,
96+
asyncOptions,
97+
inputProperties,
98+
outputType,
99+
description);
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.nodes.exec.common;
20+
21+
import org.apache.flink.api.common.functions.FlatMapFunction;
22+
import org.apache.flink.api.dag.Transformation;
23+
import org.apache.flink.configuration.Configuration;
24+
import org.apache.flink.configuration.PipelineOptions;
25+
import org.apache.flink.configuration.ReadableConfig;
26+
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
27+
import org.apache.flink.streaming.api.operators.ProcessOperator;
28+
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
29+
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
30+
import org.apache.flink.table.api.TableException;
31+
import org.apache.flink.table.catalog.DataTypeFactory;
32+
import org.apache.flink.table.data.RowData;
33+
import org.apache.flink.table.functions.AsyncPredictFunction;
34+
import org.apache.flink.table.functions.PredictFunction;
35+
import org.apache.flink.table.functions.UserDefinedFunction;
36+
import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
37+
import org.apache.flink.table.ml.ModelProvider;
38+
import org.apache.flink.table.ml.PredictRuntimeProvider;
39+
import org.apache.flink.table.planner.calcite.FlinkContext;
40+
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
41+
import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
42+
import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator;
43+
import org.apache.flink.table.planner.delegation.PlannerBase;
44+
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
45+
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
46+
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
47+
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
48+
import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
49+
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
50+
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
51+
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
52+
import org.apache.flink.table.runtime.collector.ListenableCollector;
53+
import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
54+
import org.apache.flink.table.runtime.generated.GeneratedCollector;
55+
import org.apache.flink.table.runtime.generated.GeneratedFunction;
56+
import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner;
57+
import org.apache.flink.table.runtime.operators.ml.MLPredictRunner;
58+
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
59+
import org.apache.flink.table.types.logical.RowType;
60+
import org.apache.flink.util.Preconditions;
61+
62+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
63+
64+
import javax.annotation.Nullable;
65+
66+
import java.util.List;
67+
68+
/** Common ExecNode for {@code ML_PREDICT}. */
69+
public abstract class CommonExecMLPredictTableFunction extends ExecNodeBase<RowData> {
70+
71+
public static final String ML_PREDICT_TRANSFORMATION = "ml-predict-table-function";
72+
73+
protected static final String FIELD_NAME_ML_PREDICT_SPEC = "mlPredictSpec";
74+
protected static final String FIELD_NAME_MODEL_SPEC = "modelSpec";
75+
protected static final String FIELD_NAME_ASYNC_OPTIONS = "asyncOptions";
76+
77+
@JsonProperty(FIELD_NAME_ML_PREDICT_SPEC)
78+
protected final MLPredictSpec mlPredictSpec;
79+
80+
@JsonProperty(FIELD_NAME_MODEL_SPEC)
81+
protected final ModelSpec modelSpec;
82+
83+
@JsonProperty(FIELD_NAME_ASYNC_OPTIONS)
84+
protected final @Nullable FunctionCallUtil.AsyncOptions asyncOptions;
85+
86+
protected CommonExecMLPredictTableFunction(
87+
int id,
88+
ExecNodeContext context,
89+
ReadableConfig persistedConfig,
90+
MLPredictSpec mlPredictSpec,
91+
ModelSpec modelSpec,
92+
@Nullable FunctionCallUtil.AsyncOptions asyncOptions,
93+
List<InputProperty> inputProperties,
94+
RowType outputType,
95+
String description) {
96+
super(id, context, persistedConfig, inputProperties, outputType, description);
97+
this.mlPredictSpec = mlPredictSpec;
98+
this.modelSpec = modelSpec;
99+
this.asyncOptions = asyncOptions;
100+
}
101+
102+
@Override
103+
protected Transformation<RowData> translateToPlanInternal(
104+
PlannerBase planner, ExecNodeConfig config) {
105+
Transformation<RowData> inputTransformation =
106+
(Transformation<RowData>) getInputEdges().get(0).translateToPlan(planner);
107+
108+
ModelProvider provider = modelSpec.getModelProvider(planner.getFlinkContext());
109+
boolean async = asyncOptions != null;
110+
UserDefinedFunction predictFunction = findModelFunction(provider, async);
111+
FlinkContext context = planner.getFlinkContext();
112+
DataTypeFactory dataTypeFactory = context.getCatalogManager().getDataTypeFactory();
113+
114+
RowType inputType = (RowType) getInputEdges().get(0).getOutputType();
115+
RowType modelOutputType =
116+
(RowType)
117+
modelSpec
118+
.getContextResolvedModel()
119+
.getResolvedModel()
120+
.getResolvedOutputSchema()
121+
.toPhysicalRowDataType()
122+
.getLogicalType();
123+
return async
124+
? createAsyncModelPredict(
125+
inputTransformation,
126+
config,
127+
planner.getFlinkContext().getClassLoader(),
128+
dataTypeFactory,
129+
inputType,
130+
modelOutputType,
131+
(RowType) getOutputType(),
132+
(AsyncPredictFunction) predictFunction)
133+
: createModelPredict(
134+
inputTransformation,
135+
config,
136+
planner.getFlinkContext().getClassLoader(),
137+
dataTypeFactory,
138+
inputType,
139+
modelOutputType,
140+
(RowType) getOutputType(),
141+
(PredictFunction) predictFunction);
142+
}
143+
144+
private Transformation<RowData> createModelPredict(
145+
Transformation<RowData> inputTransformation,
146+
ExecNodeConfig config,
147+
ClassLoader classLoader,
148+
DataTypeFactory dataTypeFactory,
149+
RowType inputRowType,
150+
RowType modelOutputType,
151+
RowType resultRowType,
152+
PredictFunction predictFunction) {
153+
GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
154+
MLPredictCodeGenerator.generateSyncPredictFunction(
155+
config,
156+
classLoader,
157+
dataTypeFactory,
158+
inputRowType,
159+
modelOutputType,
160+
resultRowType,
161+
mlPredictSpec.getFeatures(),
162+
predictFunction,
163+
modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(),
164+
config.get(PipelineOptions.OBJECT_REUSE));
165+
GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
166+
MLPredictCodeGenerator.generateCollector(
167+
new CodeGeneratorContext(config, classLoader),
168+
inputRowType,
169+
modelOutputType,
170+
(RowType) getOutputType());
171+
MLPredictRunner mlPredictRunner = new MLPredictRunner(generatedFetcher, generatedCollector);
172+
SimpleOperatorFactory<RowData> operatorFactory =
173+
SimpleOperatorFactory.of(new ProcessOperator<>(mlPredictRunner));
174+
return ExecNodeUtil.createOneInputTransformation(
175+
inputTransformation,
176+
createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
177+
operatorFactory,
178+
InternalTypeInfo.of(getOutputType()),
179+
inputTransformation.getParallelism(),
180+
false);
181+
}
182+
183+
@SuppressWarnings("unchecked")
184+
private Transformation<RowData> createAsyncModelPredict(
185+
Transformation<RowData> inputTransformation,
186+
ExecNodeConfig config,
187+
ClassLoader classLoader,
188+
DataTypeFactory dataTypeFactory,
189+
RowType inputRowType,
190+
RowType modelOutputType,
191+
RowType resultRowType,
192+
AsyncPredictFunction asyncPredictFunction) {
193+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>>
194+
generatedFuncWithType =
195+
MLPredictCodeGenerator.generateAsyncPredictFunction(
196+
config,
197+
classLoader,
198+
dataTypeFactory,
199+
inputRowType,
200+
modelOutputType,
201+
resultRowType,
202+
mlPredictSpec.getFeatures(),
203+
asyncPredictFunction,
204+
modelSpec
205+
.getContextResolvedModel()
206+
.getIdentifier()
207+
.asSummaryString());
208+
AsyncFunction<RowData, RowData> asyncFunc =
209+
new AsyncMLPredictRunner(
210+
(GeneratedFunction) generatedFuncWithType.tableFunc(),
211+
Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity);
212+
return ExecNodeUtil.createOneInputTransformation(
213+
inputTransformation,
214+
createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
215+
new AsyncWaitOperatorFactory<>(
216+
asyncFunc,
217+
asyncOptions.asyncTimeout,
218+
asyncOptions.asyncBufferCapacity,
219+
asyncOptions.asyncOutputMode),
220+
InternalTypeInfo.of(getOutputType()),
221+
inputTransformation.getParallelism(),
222+
false);
223+
}
224+
225+
private UserDefinedFunction findModelFunction(ModelProvider provider, boolean async) {
226+
ModelPredictRuntimeProviderContext context =
227+
new ModelPredictRuntimeProviderContext(
228+
modelSpec.getContextResolvedModel().getResolvedModel(),
229+
Configuration.fromMap(mlPredictSpec.getRuntimeConfig()));
230+
if (async) {
231+
if (provider instanceof AsyncPredictRuntimeProvider) {
232+
return ((AsyncPredictRuntimeProvider) provider).createAsyncPredictFunction(context);
233+
}
234+
} else {
235+
if (provider instanceof PredictRuntimeProvider) {
236+
return ((PredictRuntimeProvider) provider).createPredictFunction(context);
237+
}
238+
}
239+
240+
throw new TableException(
241+
"Required "
242+
+ (async ? "async" : "sync")
243+
+ " model function by planner, but ModelProvider "
244+
+ "does not offer a valid model function.");
245+
}
246+
}

0 commit comments

Comments
 (0)