diff --git a/docs/content/docs/libs/state_processor_api.md b/docs/content/docs/libs/state_processor_api.md index 57f35a72f13b3..82c34c11e2344 100644 --- a/docs/content/docs/libs/state_processor_api.md +++ b/docs/content/docs/libs/state_processor_api.md @@ -586,13 +586,6 @@ public class StatefulFunction extends KeyedProcessFunction getTypeInformation() { - return new AvroTypeInfo<>(AvroRecord.class); - } -} ``` Then it can read by querying a table created using the following SQL statement: @@ -609,8 +602,7 @@ CREATE TABLE state_table ( 'connector' = 'savepoint', 'state.backend.type' = 'rocksdb', 'state.path' = '/root/dir/of/checkpoint-data/chk-1', - 'operator.uid' = 'my-uid', - 'fields.MyAvroState.value-type-factory' = 'org.apache.flink.state.table.AvroSavepointTypeInformationFactory' + 'operator.uid' = 'my-uid' ); ``` diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointLoader.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointLoader.java index 1b2cdedefaf56..e4ea285c0e876 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointLoader.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointLoader.java @@ -19,15 +19,27 @@ package org.apache.flink.state.api.runtime; import org.apache.flink.annotation.Internal; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.runtime.checkpoint.Checkpoints; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.metadata.CheckpointMetadata; import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.KeyedBackendSerializationProxy; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess; +import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot; +import org.apache.flink.state.api.OperatorIdentifier; import java.io.DataInputStream; import java.io.IOException; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; -/** Utility class for loading {@link CheckpointMetadata} metadata. */ +/** Utility class for loading savepoint metadata and operator state information. */ @Internal public final class SavepointLoader { private SavepointLoader() {} @@ -55,4 +67,70 @@ public static CheckpointMetadata loadSavepointMetadata(String savepointPath) stream, Thread.currentThread().getContextClassLoader(), savepointPath); } } + + /** + * Loads all state metadata for an operator in a single I/O operation. + * + * @param savepointPath Path to the savepoint directory + * @param operatorIdentifier Operator UID or hash + * @return Map from state name to StateMetaInfoSnapshot + * @throws IOException If reading fails + */ + public static Map loadOperatorStateMetadata( + String savepointPath, OperatorIdentifier operatorIdentifier) throws IOException { + + CheckpointMetadata checkpointMetadata = loadSavepointMetadata(savepointPath); + + OperatorState operatorState = + checkpointMetadata.getOperatorStates().stream() + .filter( + state -> + operatorIdentifier + .getOperatorId() + .equals(state.getOperatorID())) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + "Operator " + + operatorIdentifier + + " not found in savepoint")); + + KeyedStateHandle keyedStateHandle = + operatorState.getStates().stream() + .flatMap(s -> s.getManagedKeyedState().stream()) + .findFirst() + .orElseThrow( + () -> + new IllegalArgumentException( + "No keyed state found for operator " + + operatorIdentifier)); + + KeyedBackendSerializationProxy proxy = readSerializationProxy(keyedStateHandle); + return proxy.getStateMetaInfoSnapshots().stream() + .collect(Collectors.toMap(StateMetaInfoSnapshot::getName, Function.identity())); + } + + private static KeyedBackendSerializationProxy readSerializationProxy( + KeyedStateHandle stateHandle) throws IOException { + + StreamStateHandle streamStateHandle; + if (stateHandle instanceof KeyGroupsStateHandle) { + streamStateHandle = ((KeyGroupsStateHandle) stateHandle).getDelegateStateHandle(); + } else { + throw new IllegalArgumentException( + "Unsupported KeyedStateHandle type: " + stateHandle.getClass()); + } + + try (FSDataInputStream inputStream = streamStateHandle.openInputStream()) { + DataInputViewStreamWrapper inputView = new DataInputViewStreamWrapper(inputStream); + + KeyedBackendSerializationProxy proxy = + new KeyedBackendSerializationProxy<>( + Thread.currentThread().getContextClassLoader()); + proxy.read(inputView); + + return proxy; + } + } } diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDataStreamScanProvider.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDataStreamScanProvider.java index 5393e4fa01a34..fd24b0f2448a2 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDataStreamScanProvider.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDataStreamScanProvider.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.StateBackendOptions; @@ -91,32 +92,33 @@ public DataStream produceDataStream( // Get value state descriptors for (StateValueColumnConfiguration columnConfig : keyValueProjections.f1) { - TypeInformation valueTypeInfo = columnConfig.getValueTypeInfo(); + TypeSerializer valueTypeSerializer = columnConfig.getValueTypeSerializer(); switch (columnConfig.getStateType()) { case VALUE: columnConfig.setStateDescriptor( new ValueStateDescriptor<>( - columnConfig.getStateName(), valueTypeInfo)); + columnConfig.getStateName(), valueTypeSerializer)); break; case LIST: columnConfig.setStateDescriptor( new ListStateDescriptor<>( - columnConfig.getStateName(), valueTypeInfo)); + columnConfig.getStateName(), valueTypeSerializer)); break; case MAP: - TypeInformation mapKeyTypeInfo = columnConfig.getMapKeyTypeInfo(); - if (mapKeyTypeInfo == null) { + TypeSerializer mapKeyTypeSerializer = + columnConfig.getMapKeyTypeSerializer(); + if (mapKeyTypeSerializer == null) { throw new ConfigurationException( - "Map key type information is required for map state"); + "Map key type serializer is required for map state"); } columnConfig.setStateDescriptor( new MapStateDescriptor<>( columnConfig.getStateName(), - mapKeyTypeInfo, - valueTypeInfo)); + mapKeyTypeSerializer, + valueTypeSerializer)); break; default: diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDynamicTableSourceFactory.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDynamicTableSourceFactory.java index 7ac206d008110..9c507109c24ea 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDynamicTableSourceFactory.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointDynamicTableSourceFactory.java @@ -18,12 +18,16 @@ package org.apache.flink.state.table; +import org.apache.flink.api.common.serialization.SerializerConfig; +import org.apache.flink.api.common.serialization.SerializerConfigImpl; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeinfo.utils.TypeUtils; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.ConfigOption; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot; import org.apache.flink.state.api.OperatorIdentifier; +import org.apache.flink.state.api.runtime.SavepointLoader; import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.catalog.ResolvedCatalogTable; import org.apache.flink.table.catalog.ResolvedSchema; @@ -31,21 +35,19 @@ import org.apache.flink.table.factories.DynamicTableSourceFactory; import org.apache.flink.table.factories.FactoryUtil; import org.apache.flink.table.types.DataType; -import org.apache.flink.table.types.logical.ArrayType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeRoot; -import org.apache.flink.table.types.logical.MapType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; import org.apache.flink.util.Preconditions; -import javax.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.math.BigDecimal; import java.util.Arrays; import java.util.HashSet; import java.util.List; -import java.util.Optional; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -73,15 +75,27 @@ /** Dynamic source factory for {@link SavepointDynamicTableSource}. */ public class SavepointDynamicTableSourceFactory implements DynamicTableSourceFactory { + + private static final Logger LOG = + LoggerFactory.getLogger(SavepointDynamicTableSourceFactory.class); + @Override public DynamicTableSource createDynamicTableSource(Context context) { Configuration options = new Configuration(); context.getCatalogTable().getOptions().forEach(options::setString); + SerializerConfig serializerConfig = new SerializerConfigImpl(options); final String stateBackendType = options.getOptional(STATE_BACKEND_TYPE).orElse(null); final String statePath = options.get(STATE_PATH); final OperatorIdentifier operatorIdentifier = getOperatorIdentifier(options); + final Map preloadedStateMetadata = + preloadStateMetadata(statePath, operatorIdentifier); + + // Create resolver with preloaded metadata + SavepointTypeInfoResolver typeResolver = + new SavepointTypeInfoResolver(preloadedStateMetadata, serializerConfig); + final Tuple2 keyValueProjections = createKeyValueProjections(context.getCatalogTable()); @@ -106,115 +120,21 @@ public DynamicTableSource createDynamicTableSource(Context context) { optionalOptions.add(keyTypeInfoFactoryOption); TypeInformation keyTypeInfo = - getTypeInfo(options, keyFormatOption, keyTypeInfoFactoryOption, keyRowField, true); + typeResolver.resolveKeyType( + options, keyFormatOption, keyTypeInfoFactoryOption, keyRowField); final Tuple2> keyValueConfigProjections = Tuple2.of( keyValueProjections.f0, Arrays.stream(keyValueProjections.f1) .mapToObj( - columnIndex -> { - RowType.RowField valueRowField = - rowType.getFields().get(columnIndex); - - ConfigOption stateNameOption = - key(String.format( - "%s.%s.%s", - FIELDS, - valueRowField.getName(), - STATE_NAME)) - .stringType() - .noDefaultValue(); - optionalOptions.add(stateNameOption); - - ConfigOption - stateTypeOption = - key(String.format( - "%s.%s.%s", - FIELDS, - valueRowField.getName(), - STATE_TYPE)) - .enumType( - SavepointConnectorOptions - .StateType - .class) - .noDefaultValue(); - optionalOptions.add(stateTypeOption); - - ConfigOption mapKeyFormatOption = - key(String.format( - "%s.%s.%s", - FIELDS, - valueRowField.getName(), - KEY_CLASS)) - .stringType() - .noDefaultValue(); - optionalOptions.add(mapKeyFormatOption); - - ConfigOption mapKeyTypeInfoFactoryOption = - key(String.format( - "%s.%s.%s", - FIELDS, - valueRowField.getName(), - KEY_TYPE_FACTORY)) - .stringType() - .noDefaultValue(); - optionalOptions.add(mapKeyTypeInfoFactoryOption); - - ConfigOption valueFormatOption = - key(String.format( - "%s.%s.%s", - FIELDS, - valueRowField.getName(), - VALUE_CLASS)) - .stringType() - .noDefaultValue(); - optionalOptions.add(valueFormatOption); - - ConfigOption valueTypeInfoFactoryOption = - key(String.format( - "%s.%s.%s", - FIELDS, - valueRowField.getName(), - VALUE_TYPE_FACTORY)) - .stringType() - .noDefaultValue(); - optionalOptions.add(valueTypeInfoFactoryOption); - - LogicalType valueLogicalType = valueRowField.getType(); - - SavepointConnectorOptions.StateType stateType = - options.getOptional(stateTypeOption) - .orElseGet( - () -> - inferStateType( - valueLogicalType)); - - TypeInformation mapKeyTypeInfo = - getTypeInfo( - options, - keyFormatOption, - mapKeyTypeInfoFactoryOption, - valueRowField, - stateType.equals( - SavepointConnectorOptions - .StateType.MAP)); - - TypeInformation valueTypeInfo = - getTypeInfo( - options, - valueFormatOption, - valueTypeInfoFactoryOption, - valueRowField, - true); - return new StateValueColumnConfiguration( - columnIndex, - options.getOptional(stateNameOption) - .orElse(valueRowField.getName()), - stateType, - mapKeyTypeInfo, - valueTypeInfo); - }) + columnIndex -> + createStateColumnConfiguration( + columnIndex, + rowType, + options, + optionalOptions, + typeResolver)) .collect(Collectors.toList())); FactoryUtil.validateFactoryOptions(requiredOptions, optionalOptions, options); @@ -234,6 +154,83 @@ public DynamicTableSource createDynamicTableSource(Context context) { rowType); } + private StateValueColumnConfiguration createStateColumnConfiguration( + int columnIndex, + RowType rowType, + Configuration options, + Set> optionalOptions, + SavepointTypeInfoResolver typeResolver) { + + RowType.RowField valueRowField = rowType.getFields().get(columnIndex); + + ConfigOption stateNameOption = + key(String.format("%s.%s.%s", FIELDS, valueRowField.getName(), STATE_NAME)) + .stringType() + .noDefaultValue(); + optionalOptions.add(stateNameOption); + + ConfigOption stateTypeOption = + key(String.format("%s.%s.%s", FIELDS, valueRowField.getName(), STATE_TYPE)) + .enumType(SavepointConnectorOptions.StateType.class) + .noDefaultValue(); + optionalOptions.add(stateTypeOption); + + ConfigOption mapKeyFormatOption = + key(String.format("%s.%s.%s", FIELDS, valueRowField.getName(), KEY_CLASS)) + .stringType() + .noDefaultValue(); + optionalOptions.add(mapKeyFormatOption); + + ConfigOption mapKeyTypeInfoFactoryOption = + key(String.format("%s.%s.%s", FIELDS, valueRowField.getName(), KEY_TYPE_FACTORY)) + .stringType() + .noDefaultValue(); + optionalOptions.add(mapKeyTypeInfoFactoryOption); + + ConfigOption valueFormatOption = + key(String.format("%s.%s.%s", FIELDS, valueRowField.getName(), VALUE_CLASS)) + .stringType() + .noDefaultValue(); + optionalOptions.add(valueFormatOption); + + ConfigOption valueTypeInfoFactoryOption = + key(String.format("%s.%s.%s", FIELDS, valueRowField.getName(), VALUE_TYPE_FACTORY)) + .stringType() + .noDefaultValue(); + optionalOptions.add(valueTypeInfoFactoryOption); + + LogicalType valueLogicalType = valueRowField.getType(); + + SavepointConnectorOptions.StateType stateType = + options.getOptional(stateTypeOption) + .orElseGet(() -> inferStateType(valueLogicalType)); + + TypeSerializer mapKeyTypeSerializer = + typeResolver.resolveSerializer( + options, + mapKeyFormatOption, + mapKeyTypeInfoFactoryOption, + valueRowField, + stateType.equals(SavepointConnectorOptions.StateType.MAP), + SavepointTypeInfoResolver.InferenceContext.MAP_KEY); + + TypeSerializer valueTypeSerializer = + typeResolver.resolveSerializer( + options, + valueFormatOption, + valueTypeInfoFactoryOption, + valueRowField, + true, + SavepointTypeInfoResolver.InferenceContext.VALUE); + + return new StateValueColumnConfiguration( + columnIndex, + options.getOptional(stateNameOption).orElse(valueRowField.getName()), + stateType, + mapKeyTypeSerializer, + valueTypeSerializer); + } + private Tuple2 createKeyValueProjections(ResolvedCatalogTable catalogTable) { ResolvedSchema schema = catalogTable.getResolvedSchema(); if (schema.getPrimaryKey().isEmpty()) { @@ -271,46 +268,6 @@ private int[] createValueFormatProjection(DataType physicalDataType, int keyProj return physicalFields.filter(pos -> keyProjection != pos).toArray(); } - private TypeInformation getTypeInfo( - Configuration options, - ConfigOption classOption, - ConfigOption typeInfoFactoryOption, - RowType.RowField rowField, - boolean inferStateType) { - Optional clazz = options.getOptional(classOption); - Optional typeInfoFactory = options.getOptional(typeInfoFactoryOption); - if (clazz.isPresent() && typeInfoFactory.isPresent()) { - throw new IllegalArgumentException( - "Either " - + classOption.key() - + " or " - + typeInfoFactoryOption.key() - + " can be specified for column " - + rowField.getName() - + "."); - } - try { - if (clazz.isPresent()) { - return TypeInformation.of(Class.forName(clazz.get())); - } else if (typeInfoFactory.isPresent()) { - SavepointTypeInformationFactory savepointTypeInformationFactory = - (SavepointTypeInformationFactory) - TypeUtils.getInstance(typeInfoFactory.get(), new Object[0]); - return savepointTypeInformationFactory.getTypeInformation(); - } else { - if (inferStateType) { - String inferredValueFormat = - inferStateValueFormat(rowField.getName(), rowField.getType()); - return TypeInformation.of(Class.forName(inferredValueFormat)); - } else { - return null; - } - } - } catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - } - private SavepointConnectorOptions.StateType inferStateType(LogicalType logicalType) { switch (logicalType.getTypeRoot()) { case ARRAY: @@ -324,82 +281,25 @@ private SavepointConnectorOptions.StateType inferStateType(LogicalType logicalTy } } - @Nullable - private String inferStateMapKeyFormat(String columnName, LogicalType logicalType) { - return logicalType.is(LogicalTypeRoot.MAP) - ? inferStateValueFormat(columnName, ((MapType) logicalType).getKeyType()) - : null; - } - - private String inferStateValueFormat(String columnName, LogicalType logicalType) { - switch (logicalType.getTypeRoot()) { - case CHAR: - case VARCHAR: - return String.class.getName(); - - case BOOLEAN: - return Boolean.class.getName(); - - case BINARY: - case VARBINARY: - return byte[].class.getName(); - - case DECIMAL: - return BigDecimal.class.getName(); - - case TINYINT: - return Byte.class.getName(); - - case SMALLINT: - return Short.class.getName(); - - case INTEGER: - return Integer.class.getName(); - - case BIGINT: - return Long.class.getName(); - - case FLOAT: - return Float.class.getName(); - - case DOUBLE: - return Double.class.getName(); - - case DATE: - return Integer.class.getName(); - - case INTERVAL_YEAR_MONTH: - case INTERVAL_DAY_TIME: - return Long.class.getName(); - - case ARRAY: - return inferStateValueFormat( - columnName, ((ArrayType) logicalType).getElementType()); - - case MAP: - return inferStateValueFormat(columnName, ((MapType) logicalType).getValueType()); - - case NULL: - return null; - - case ROW: - case MULTISET: - case TIME_WITHOUT_TIME_ZONE: - case TIMESTAMP_WITHOUT_TIME_ZONE: - case TIMESTAMP_WITH_TIME_ZONE: - case TIMESTAMP_WITH_LOCAL_TIME_ZONE: - case DISTINCT_TYPE: - case STRUCTURED_TYPE: - case RAW: - case SYMBOL: - case UNRESOLVED: - case DESCRIPTOR: - default: - throw new UnsupportedOperationException( - String.format( - "Unable to infer state format for SQL type: %s in column: %s. " - + "Please override the type with the following config parameter: %s.%s.%s", - logicalType, columnName, FIELDS, columnName, VALUE_CLASS)); + /** + * Preloads all state metadata for an operator in a single I/O operation. + * + * @param savepointPath Path to the savepoint + * @param operatorIdentifier Operator UID or hash + * @return Map from state name to StateMetaInfoSnapshot + */ + private Map preloadStateMetadata( + String savepointPath, OperatorIdentifier operatorIdentifier) { + try { + return SavepointLoader.loadOperatorStateMetadata(savepointPath, operatorIdentifier); + } catch (Exception e) { + throw new RuntimeException( + String.format( + "Failed to load state metadata from savepoint '%s' for operator '%s'. " + + "Ensure the savepoint path is valid and the operator exists in the savepoint. " + + "Original error: %s", + savepointPath, operatorIdentifier, e.getMessage()), + e); } } diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointTypeInfoResolver.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointTypeInfoResolver.java new file mode 100644 index 0000000000000..19832c36d9e65 --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/SavepointTypeInfoResolver.java @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.table; + +import org.apache.flink.api.common.serialization.SerializerConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.utils.TypeUtils; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot; +import org.apache.flink.table.types.logical.ArrayType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.RowType; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigDecimal; +import java.util.Map; +import java.util.Optional; + +import static org.apache.flink.state.table.SavepointConnectorOptions.FIELDS; +import static org.apache.flink.state.table.SavepointConnectorOptions.VALUE_CLASS; + +/** Resolver for TypeInformation from savepoint metadata and configuration. */ +public class SavepointTypeInfoResolver { + + private static final Logger LOG = LoggerFactory.getLogger(SavepointTypeInfoResolver.class); + + /** Context for type inference to determine what aspect of the type we need. */ + public enum InferenceContext { + /** Inferring the key type of keyed state (always primitive). */ + KEY, + /** Inferring the key type of a MAP state. */ + MAP_KEY, + /** Inferring the value type (behavior depends on logical type). */ + VALUE + } + + private final Map preloadedStateMetadata; + private final SerializerConfig serializerConfig; + + public SavepointTypeInfoResolver( + Map preloadedStateMetadata, + SerializerConfig serializerConfig) { + this.preloadedStateMetadata = preloadedStateMetadata; + this.serializerConfig = serializerConfig; + } + + /** + * Resolves TypeInformation for keyed state keys (primitive types only). + * + *

This is a simplified version of type resolution specifically for key types, which are + * always primitive and don't require complex metadata inference. + * + * @param options Configuration containing table options + * @param classOption Config option for explicit class specification + * @param typeInfoFactoryOption Config option for type factory specification + * @param rowField The row field containing name and LogicalType + * @return The resolved TypeInformation for the key + * @throws IllegalArgumentException If both class and factory options are specified + * @throws RuntimeException If type instantiation fails + */ + public TypeInformation resolveKeyType( + Configuration options, + ConfigOption classOption, + ConfigOption typeInfoFactoryOption, + RowType.RowField rowField) { + try { + // Priority 1: Explicit configuration (backward compatibility) + TypeInformation explicitTypeInfo = + getExplicitTypeInfo(options, classOption, typeInfoFactoryOption); + if (explicitTypeInfo != null) { + return explicitTypeInfo; + } + + // Priority 2: Simple primitive type inference from LogicalType + LogicalType logicalType = rowField.getType(); + String columnName = rowField.getName(); + return TypeInformation.of(getPrimitiveClass(logicalType, columnName)); + } catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + /** + * Resolves TypeSerializer for a table field using a three-tier priority system with direct + * serializer extraction for metadata inference. + * + *

Three-Tier Priority System (Serializer-First)

+ * + *
    + *
  1. Priority 1: Explicit Configuration (Highest priority)
    + * Uses user-specified class name or type factory from table options, then converts to + * serializer. + *
  2. Priority 2: Metadata Inference
    + * Directly extracts serializers from preloaded savepoint metadata (NO TypeInformation + * conversion). + *
  3. Priority 3: LogicalType Fallback (Lowest priority)
    + * Infers TypeInformation from table schema's LogicalType, then converts to serializer. + *
+ * + *

This approach eliminates TypeInformation extraction complexity for metadata inference, + * making it work with ANY serializer type (Avro, custom types, etc.). + * + * @param options Configuration containing table options + * @param classOption Config option for explicit class specification + * @param typeInfoFactoryOption Config option for type factory specification + * @param rowField The table field containing name and LogicalType + * @param inferStateType Whether to enable automatic type inference. If false, returns null when + * no explicit configuration is provided. + * @param context The inference context determining what type aspect to extract. + * @return The resolved TypeSerializer, or null if inferStateType is false and no explicit + * configuration is provided. + * @throws IllegalArgumentException If both class and factory options are specified + * @throws RuntimeException If serializer creation fails + */ + public TypeSerializer resolveSerializer( + Configuration options, + ConfigOption classOption, + ConfigOption typeInfoFactoryOption, + RowType.RowField rowField, + boolean inferStateType, + InferenceContext context) { + try { + // Priority 1: Explicit configuration (backward compatibility) + TypeInformation explicitTypeInfo = + getExplicitTypeInfo(options, classOption, typeInfoFactoryOption); + if (explicitTypeInfo != null) { + return explicitTypeInfo.createSerializer(serializerConfig); + } + + if (inferStateType) { + // Priority 2: Direct serializer extraction from metadata + Optional> metadataSerializer = + getSerializerFromMetadata(rowField, context); + if (metadataSerializer.isPresent()) { + LOG.info( + "Using serializer directly from metadata for state '{}' with context {}: {}", + rowField.getName(), + context, + metadataSerializer.get().getClass().getSimpleName()); + return metadataSerializer.get(); + } + + // Priority 3: Fallback to LogicalType-based inference + TypeInformation fallbackTypeInfo = inferTypeFromLogicalType(rowField, context); + return fallbackTypeInfo.createSerializer(serializerConfig); + } else { + return null; + } + } catch (Exception e) { + throw new RuntimeException( + "Failed to resolve serializer for field " + rowField.getName(), e); + } + } + + /** + * Extracts explicit TypeInformation from user configuration (Priority 1). + * + * @param options Configuration containing table options + * @param classOption Config option for explicit class specification + * @param typeInfoFactoryOption Config option for type factory specification + * @return The explicit TypeInformation if specified, null otherwise + * @throws IllegalArgumentException If both class and factory options are specified + * @throws ReflectiveOperationException If type instantiation fails + */ + private TypeInformation getExplicitTypeInfo( + Configuration options, + ConfigOption classOption, + ConfigOption typeInfoFactoryOption) + throws ReflectiveOperationException { + + Optional clazz = options.getOptional(classOption); + Optional typeInfoFactory = options.getOptional(typeInfoFactoryOption); + + if (clazz.isPresent() && typeInfoFactory.isPresent()) { + throw new IllegalArgumentException( + "Either " + + classOption.key() + + " or " + + typeInfoFactoryOption.key() + + " can be specified, not both."); + } + + if (clazz.isPresent()) { + return TypeInformation.of(Class.forName(clazz.get())); + } else if (typeInfoFactory.isPresent()) { + SavepointTypeInformationFactory savepointTypeInformationFactory = + (SavepointTypeInformationFactory) + TypeUtils.getInstance(typeInfoFactory.get(), new Object[0]); + return savepointTypeInformationFactory.getTypeInformation(); + } + + return null; + } + + /** + * Directly extracts TypeSerializer from preloaded metadata (Priority 2). + * + *

This method performs NO I/O and NO TypeInformation conversion. It directly extracts the + * serializer that was used to write the state data. + * + * @param rowField The row field to extract serializer for + * @param context The inference context determining what serializer to extract + * @return The serializer if found in metadata, empty otherwise + */ + private Optional> getSerializerFromMetadata( + RowType.RowField rowField, InferenceContext context) { + try { + // Get state name for this field (defaults to field name) + String stateName = rowField.getName(); + + // Look up from preloaded metadata (NO I/O) + StateMetaInfoSnapshot stateMetaInfo = preloadedStateMetadata.get(stateName); + + if (stateMetaInfo == null) { + LOG.debug("State '{}' not found in preloaded metadata", stateName); + return Optional.empty(); + } + + // Extract appropriate serializer based on context + TypeSerializerSnapshot serializerSnapshot = null; + switch (context) { + case KEY: + serializerSnapshot = + stateMetaInfo.getTypeSerializerSnapshot( + StateMetaInfoSnapshot.CommonSerializerKeys.KEY_SERIALIZER); + break; + case MAP_KEY: + // For MAP_KEY, we need the key serializer from the value serializer + // (which is MapSerializer) + TypeSerializerSnapshot valueSnapshot = + stateMetaInfo.getTypeSerializerSnapshot( + StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER); + if (valueSnapshot != null) { + TypeSerializer valueSerializer = valueSnapshot.restoreSerializer(); + if (valueSerializer instanceof MapSerializer) { + serializerSnapshot = + ((MapSerializer) valueSerializer) + .getKeySerializer() + .snapshotConfiguration(); + } + } + break; + case VALUE: + serializerSnapshot = + stateMetaInfo.getTypeSerializerSnapshot( + StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER); + break; + } + + if (serializerSnapshot == null) { + LOG.debug( + "No serializer snapshot found for state '{}' with context {}", + stateName, + context); + return Optional.empty(); + } + + // Restore serializer from snapshot + TypeSerializer serializer = serializerSnapshot.restoreSerializer(); + + // For VALUE context with complex types, extract the appropriate sub-serializer + if (context == InferenceContext.VALUE) { + return extractValueSerializerForLogicalType(serializer, rowField.getType()); + } + + return Optional.of(serializer); + + } catch (Exception e) { + LOG.warn( + "Failed to extract serializer from metadata for field '{}': {}", + rowField.getName(), + e.getMessage()); + return Optional.empty(); + } + } + + /** + * Extracts the appropriate value serializer based on LogicalType for VALUE context. + * + * @param fullSerializer The complete serializer from metadata + * @param logicalType The LogicalType from the table schema + * @return The appropriate value serializer + */ + private Optional> extractValueSerializerForLogicalType( + TypeSerializer fullSerializer, LogicalType logicalType) { + + switch (logicalType.getTypeRoot()) { + case ARRAY: + // ARRAY logical type → LIST state → extract element serializer + if (fullSerializer + instanceof org.apache.flink.api.common.typeutils.base.ListSerializer) { + org.apache.flink.api.common.typeutils.base.ListSerializer listSerializer = + (org.apache.flink.api.common.typeutils.base.ListSerializer) + fullSerializer; + return Optional.of(listSerializer.getElementSerializer()); + } + LOG.debug( + "Expected ListSerializer for ARRAY logical type but got: {}", + fullSerializer.getClass()); + return Optional.empty(); + + case MAP: + // MAP logical type → MAP state → extract value serializer + if (fullSerializer instanceof MapSerializer) { + return Optional.of(((MapSerializer) fullSerializer).getValueSerializer()); + } + LOG.debug( + "Expected MapSerializer for MAP logical type but got: {}", + fullSerializer.getClass()); + return Optional.empty(); + + default: + // Primitive logical type → VALUE state → use serializer as-is + return Optional.of(fullSerializer); + } + } + + /** + * Fallback inference using LogicalType when metadata extraction fails. + * + * @param rowField The row field to infer type for + * @param context The inference context + * @return The inferred TypeInformation + */ + private TypeInformation inferTypeFromLogicalType( + RowType.RowField rowField, InferenceContext context) { + + LogicalType logicalType = rowField.getType(); + String columnName = rowField.getName(); + + try { + switch (context) { + case KEY: + // Keys are always primitive + return TypeInformation.of(getPrimitiveClass(logicalType, columnName)); + + case MAP_KEY: + // Extract key type from MAP logical type + if (logicalType instanceof MapType) { + LogicalType keyType = ((MapType) logicalType).getKeyType(); + return TypeInformation.of(getPrimitiveClass(keyType, columnName)); + } + throw new UnsupportedOperationException( + "MAP_KEY context requires MAP logical type, but got: " + logicalType); + + case VALUE: + return inferValueTypeFromLogicalType(logicalType, columnName); + + default: + throw new UnsupportedOperationException("Unknown context: " + context); + } + } catch (ClassNotFoundException e) { + throw new RuntimeException("Failed to infer type for context " + context, e); + } + } + + /** + * Infers value type from LogicalType for VALUE context fallback. + * + * @param logicalType The LogicalType + * @param columnName The column name for error messages + * @return The inferred TypeInformation + */ + private TypeInformation inferValueTypeFromLogicalType( + LogicalType logicalType, String columnName) throws ClassNotFoundException { + + switch (logicalType.getTypeRoot()) { + case ARRAY: + // ARRAY logical type → LIST state → return element type + ArrayType arrayType = (ArrayType) logicalType; + return TypeInformation.of( + getPrimitiveClass(arrayType.getElementType(), columnName)); + + case MAP: + // MAP logical type → MAP state → return value type + MapType mapType = (MapType) logicalType; + return TypeInformation.of(getPrimitiveClass(mapType.getValueType(), columnName)); + + default: + // Primitive logical type → VALUE state → return primitive type + return TypeInformation.of(getPrimitiveClass(logicalType, columnName)); + } + } + + /** + * Maps LogicalType to primitive Java class. + * + * @param logicalType The LogicalType to map + * @param columnName The column name for error messages + * @return The corresponding Java class + */ + private Class getPrimitiveClass(LogicalType logicalType, String columnName) + throws ClassNotFoundException { + String className = inferTypeInfoClassFromLogicalType(columnName, logicalType); + return Class.forName(className); + } + + private String inferTypeInfoClassFromLogicalType(String columnName, LogicalType logicalType) { + switch (logicalType.getTypeRoot()) { + case CHAR: + case VARCHAR: + return String.class.getName(); + + case BOOLEAN: + return Boolean.class.getName(); + + case BINARY: + case VARBINARY: + return byte[].class.getName(); + + case DECIMAL: + return BigDecimal.class.getName(); + + case TINYINT: + return Byte.class.getName(); + + case SMALLINT: + return Short.class.getName(); + + case INTEGER: + return Integer.class.getName(); + + case BIGINT: + return Long.class.getName(); + + case FLOAT: + return Float.class.getName(); + + case DOUBLE: + return Double.class.getName(); + + case DATE: + return Integer.class.getName(); + + case INTERVAL_YEAR_MONTH: + case INTERVAL_DAY_TIME: + return Long.class.getName(); + + case ARRAY: + return inferTypeInfoClassFromLogicalType( + columnName, ((ArrayType) logicalType).getElementType()); + + case MAP: + return inferTypeInfoClassFromLogicalType( + columnName, ((MapType) logicalType).getValueType()); + + case NULL: + return null; + + case ROW: + case MULTISET: + case TIME_WITHOUT_TIME_ZONE: + case TIMESTAMP_WITHOUT_TIME_ZONE: + case TIMESTAMP_WITH_TIME_ZONE: + case TIMESTAMP_WITH_LOCAL_TIME_ZONE: + case DISTINCT_TYPE: + case STRUCTURED_TYPE: + case RAW: + case SYMBOL: + case UNRESOLVED: + case DESCRIPTOR: + default: + throw new UnsupportedOperationException( + String.format( + "Unable to infer state format for SQL type: %s in column: %s. " + + "Please override the type with the following config parameter: %s.%s.%s", + logicalType, columnName, FIELDS, columnName, VALUE_CLASS)); + } + } +} diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/StateValueColumnConfiguration.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/StateValueColumnConfiguration.java index fa622bb0a10a9..865077717fc7c 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/StateValueColumnConfiguration.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/table/StateValueColumnConfiguration.java @@ -19,7 +19,7 @@ package org.apache.flink.state.table; import org.apache.flink.api.common.state.StateDescriptor; -import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; import javax.annotation.Nullable; @@ -31,21 +31,21 @@ public class StateValueColumnConfiguration implements Serializable { private final int columnIndex; private final String stateName; private final SavepointConnectorOptions.StateType stateType; - @Nullable private final TypeInformation mapKeyTypeInfo; - @Nullable private final TypeInformation valueTypeInfo; + @Nullable private final TypeSerializer mapKeyTypeSerializer; + @Nullable private final TypeSerializer valueTypeSerializer; @Nullable private StateDescriptor stateDescriptor; public StateValueColumnConfiguration( int columnIndex, final String stateName, final SavepointConnectorOptions.StateType stateType, - @Nullable final TypeInformation mapKeyTypeInfo, - final TypeInformation valueTypeInfo) { + @Nullable final TypeSerializer mapKeyTypeSerializer, + final TypeSerializer valueTypeSerializer) { this.columnIndex = columnIndex; this.stateName = stateName; this.stateType = stateType; - this.mapKeyTypeInfo = mapKeyTypeInfo; - this.valueTypeInfo = valueTypeInfo; + this.mapKeyTypeSerializer = mapKeyTypeSerializer; + this.valueTypeSerializer = valueTypeSerializer; } public int getColumnIndex() { @@ -61,12 +61,12 @@ public SavepointConnectorOptions.StateType getStateType() { } @Nullable - public TypeInformation getMapKeyTypeInfo() { - return mapKeyTypeInfo; + public TypeSerializer getMapKeyTypeSerializer() { + return mapKeyTypeSerializer; } - public TypeInformation getValueTypeInfo() { - return valueTypeInfo; + public TypeSerializer getValueTypeSerializer() { + return valueTypeSerializer; } public void setStateDescriptor(StateDescriptor stateDescriptor) { diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/GenericAvroSavepointTypeInformationFactory.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/GenericAvroSavepointTypeInformationFactory.java deleted file mode 100644 index b73e3273e1f97..0000000000000 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/GenericAvroSavepointTypeInformationFactory.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.state.table; - -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.formats.avro.typeutils.GenericRecordAvroTypeInfo; - -import com.example.state.writer.job.schema.avro.AvroRecord; - -/** {@link SavepointTypeInformationFactory} for generic avro record. */ -public class GenericAvroSavepointTypeInformationFactory implements SavepointTypeInformationFactory { - @Override - public TypeInformation getTypeInformation() { - return new GenericRecordAvroTypeInfo(AvroRecord.getClassSchema()); - } -} diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointDynamicTableSourceTest.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointDynamicTableSourceTest.java index 9afeb85a66040..abbc7653c6f07 100644 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointDynamicTableSourceTest.java +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointDynamicTableSourceTest.java @@ -53,14 +53,13 @@ public void testReadKeyedState() throws Exception { + " KeyedPrimitiveValue bigint,\n" + " KeyedPojoValue ROW,\n" + " KeyedPrimitiveValueList ARRAY,\n" - + " KeyedPrimitiveValueMap MAP,\n" + + " KeyedPrimitiveValueMap MAP,\n" + " PRIMARY KEY (k) NOT ENFORCED\n" + ")\n" + "with (\n" + " 'connector' = 'savepoint',\n" + " 'state.path' = 'src/test/resources/table-state',\n" - + " 'operator.uid' = 'keyed-state-process-uid',\n" - + " 'fields.KeyedPojoValue.value-class' = 'com.example.state.writer.job.schema.PojoData'\n" + + " 'operator.uid' = 'keyed-state-process-uid'\n" + ")"; tEnv.executeSql(sql); Table table = tEnv.sqlQuery("SELECT * FROM state_table"); @@ -108,20 +107,21 @@ public void testReadKeyedState() throws Exception { } // Check map state - Set>> mapValues = + Set>> mapValues = result.stream() .map( r -> Tuple2.of( (Long) r.getField("k"), - (Map) + (Map) r.getField("KeyedPrimitiveValueMap"))) .flatMap(l -> Set.of(l).stream()) .collect(Collectors.toSet()); assertThat(mapValues.size()).isEqualTo(10); - for (Tuple2> tuple2 : mapValues) { + for (Tuple2> tuple2 : mapValues) { assertThat(tuple2.f1.size()).isEqualTo(1); - assertThat(tuple2.f0).isEqualTo(tuple2.f1.get(tuple2.f0)); + String expectedKey = String.valueOf(tuple2.f0); + assertThat(tuple2.f1.get(expectedKey)).isEqualTo(tuple2.f0); } } @@ -142,8 +142,7 @@ public void testReadKeyedStateWithNullValues() throws Exception { + "with (\n" + " 'connector' = 'savepoint',\n" + " 'state.path' = 'src/test/resources/table-state-nulls',\n" - + " 'operator.uid' = 'keyed-state-process-uid-null',\n" - + " 'fields.total.value-class' = 'com.example.state.writer.job.schema.PojoData'\n" + + " 'operator.uid' = 'keyed-state-process-uid-null'\n" + ")"; tEnv.executeSql(sql); Table table = tEnv.sqlQuery("SELECT * FROM state_table"); @@ -185,9 +184,7 @@ public void testReadAvroKeyedState() throws Exception { + "with (\n" + " 'connector' = 'savepoint',\n" + " 'state.path' = 'src/test/resources/table-state-avro',\n" - + " 'operator.uid' = 'keyed-state-process-uid',\n" - + " 'fields.KeyedSpecificAvroValue.value-type-factory' = 'org.apache.flink.state.table.SpecificAvroSavepointTypeInformationFactory',\n" - + " 'fields.KeyedGenericAvroValue.value-type-factory' = 'org.apache.flink.state.table.GenericAvroSavepointTypeInformationFactory'\n" + + " 'operator.uid' = 'keyed-state-process-uid'\n" + ")"; tEnv.executeSql(sql); Table table = tEnv.sqlQuery("SELECT * FROM state_table"); diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointMetadataDynamicTableSourceTest.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointMetadataDynamicTableSourceTest.java index 128166d8f57cc..6a009820ba4df 100644 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointMetadataDynamicTableSourceTest.java +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointMetadataDynamicTableSourceTest.java @@ -53,9 +53,9 @@ public void testReadMetadata() throws Exception { Iterator it = result.iterator(); assertThat(it.next().toString()) .isEqualTo( - "+I[2, Source: broadcast-source, broadcast-source-uid, 3a6f51704798c4f418be51bfb6813b77, 1, 128, 0, 0, 0]"); + "+I[10, Source: broadcast-source, broadcast-source-uid, 3a6f51704798c4f418be51bfb6813b77, 1, 128, 0, 0, 0]"); assertThat(it.next().toString()) .isEqualTo( - "+I[2, keyed-broadcast-process, keyed-broadcast-process-uid, 413c1d6f88ee8627fe4b8bc533b4cf1b, 2, 128, 2, 0, 4548]"); + "+I[10, keyed-broadcast-process, keyed-broadcast-process-uid, 413c1d6f88ee8627fe4b8bc533b4cf1b, 2, 128, 2, 0, 4548]"); } } diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointTypeInformationFactoryTest.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointTypeInformationFactoryTest.java new file mode 100644 index 0000000000000..280b0b4c86f4b --- /dev/null +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SavepointTypeInformationFactoryTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.table; + +import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.configuration.ExecutionOptions.RUNTIME_MODE; +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit tests for the SavepointTypeInformationFactory. */ +public class SavepointTypeInformationFactoryTest { + + public static class TestLongTypeInformationFactory implements SavepointTypeInformationFactory { + private static volatile boolean wasCalled = false; + + public static boolean wasFactoryCalled() { + return wasCalled; + } + + public static void resetCallTracker() { + wasCalled = false; + } + + @Override + public TypeInformation getTypeInformation() { + wasCalled = true; + return TypeInformation.of(Long.class); + } + } + + public static class TestStringTypeInformationFactory + implements SavepointTypeInformationFactory { + @Override + public TypeInformation getTypeInformation() { + return TypeInformation.of(String.class); + } + } + + @Test + public void testSavepointTypeInformationFactoryEndToEnd() throws Exception { + TestLongTypeInformationFactory.resetCallTracker(); + + Configuration config = new Configuration(); + config.set(RUNTIME_MODE, RuntimeExecutionMode.BATCH); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + final String sql = + "CREATE TABLE state_table (\n" + + " k bigint,\n" + + " KeyedPrimitiveValue bigint,\n" + + " PRIMARY KEY (k) NOT ENFORCED\n" + + ")\n" + + "with (\n" + + " 'connector' = 'savepoint',\n" + + " 'state.path' = 'src/test/resources/table-state',\n" + + " 'operator.uid' = 'keyed-state-process-uid',\n" + + " 'fields.KeyedPrimitiveValue.value-type-factory' = '" + + TestLongTypeInformationFactory.class.getName() + + "'\n" + + ")"; + + tEnv.executeSql(sql); + Table table = tEnv.sqlQuery("SELECT k, KeyedPrimitiveValue FROM state_table"); + List result = tEnv.toDataStream(table).executeAndCollect(100); + + assertThat(TestLongTypeInformationFactory.wasFactoryCalled()) + .as( + "Factory getTypeInformation() method must be called - this proves factory is used instead of metadata inference") + .isTrue(); + + assertThat(result.size()).isEqualTo(10); + + Set keys = + result.stream().map(r -> (Long) r.getField("k")).collect(Collectors.toSet()); + assertThat(keys).hasSize(10); + assertThat(keys).containsExactlyInAnyOrder(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L); + + Set primitiveValues = + result.stream() + .map(r -> (Long) r.getField("KeyedPrimitiveValue")) + .collect(Collectors.toSet()); + assertThat(primitiveValues).hasSize(1); + assertThat(primitiveValues.iterator().next()).isEqualTo(1L); + } + + @Test + public void testBasicFactoryFunctionality() { + TestLongTypeInformationFactory.resetCallTracker(); + + TestLongTypeInformationFactory longFactory = new TestLongTypeInformationFactory(); + TypeInformation longTypeInfo = longFactory.getTypeInformation(); + + assertThat(longTypeInfo).isEqualTo(TypeInformation.of(Long.class)); + assertThat(TestLongTypeInformationFactory.wasFactoryCalled()).isTrue(); + + TestStringTypeInformationFactory stringFactory = new TestStringTypeInformationFactory(); + TypeInformation stringTypeInfo = stringFactory.getTypeInformation(); + + assertThat(stringTypeInfo).isEqualTo(TypeInformation.of(String.class)); + } +} diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SpecificAvroSavepointTypeInformationFactory.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SpecificAvroSavepointTypeInformationFactory.java deleted file mode 100644 index 8e9e459aff0a9..0000000000000 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/table/SpecificAvroSavepointTypeInformationFactory.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.state.table; - -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.formats.avro.typeutils.AvroTypeInfo; - -import com.example.state.writer.job.schema.avro.AvroRecord; - -/** {@link SavepointTypeInformationFactory} for specific avro record. */ -public class SpecificAvroSavepointTypeInformationFactory - implements SavepointTypeInformationFactory { - @Override - public TypeInformation getTypeInformation() { - return new AvroTypeInfo<>(AvroRecord.class); - } -} diff --git a/flink-libraries/flink-state-processing-api/src/test/resources/table-state/_metadata b/flink-libraries/flink-state-processing-api/src/test/resources/table-state/_metadata index dc9d5acbbd2ec..1bff4e03d5f59 100644 Binary files a/flink-libraries/flink-state-processing-api/src/test/resources/table-state/_metadata and b/flink-libraries/flink-state-processing-api/src/test/resources/table-state/_metadata differ