Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 192 additions & 29 deletions sdks/python/apache_beam/yaml/examples/testing/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,39 @@ def _fn(row):
@beam.ptransform.ptransform_fn
def test_kafka_read(
pcoll,
format,
topic,
bootstrap_servers,
auto_offset_reset_config,
consumer_config):
topic: Optional[str] = None,
format: Optional[str] = None,
schema: Optional[Any] = None,
bootstrap_servers: Optional[str] = None,
auto_offset_reset_config: Optional[str] = None,
consumer_config: Optional[Any] = None):
"""
This PTransform simulates the behavior of the ReadFromKafka transform
with the RAW format by simply using some fixed sample text data and
encode it to raw bytes.
Mocks the ReadFromKafka transform for testing purposes.

This PTransform simulates the behavior of the ReadFromKafka transform by
reading from predefined in-memory data based on the Kafka topic argument.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs schema arg doc string

pcoll: The input PCollection.
format: The format of the Kafka messages (e.g., 'RAW').
topic: The name of Kafka topic to read from.
format: The format of the Kafka messages (e.g., 'RAW').
schema: The schema of the Kafka messages.
bootstrap_servers: A list of Kafka bootstrap servers to connect to.
auto_offset_reset_config: A configuration for the auto offset reset
consumer_config: A dictionary containing additional consumer configurations
auto_offset_reset_config: A configuration for the auto offset reset.
consumer_config: A map for additional consumer configuration parameters.

Returns:
A PCollection containing the sample text data in bytes.
A PCollection containing the sample data.
"""

return (
pcoll | beam.Create(input_data.text_data().split('\n'))
| beam.Map(lambda element: beam.Row(payload=element.encode('utf-8'))))
if topic == 'test-topic':
kafka_byte_messages = KAFKA_TOPICS['test-topic']
return (
pcoll
| beam.Create([msg.decode('utf-8') for msg in kafka_byte_messages])
| beam.Map(lambda element: beam.Row(payload=element.encode('utf-8'))))

return None


@beam.ptransform.ptransform_fn
Expand All @@ -155,17 +163,70 @@ def test_pubsub_read(
attributes_map: Optional[str] = None,
id_attribute: Optional[str] = None,
timestamp_attribute: Optional[str] = None):
"""
Mocks the ReadFromPubSub transform for testing purposes.

This PTransform simulates the behavior of the ReadFromPubSub transform by
reading from predefined in-memory data based on the Pub/Sub topic argument.
Args:
pcoll: The input PCollection.
topic: The name of Pub/Sub topic to read from.
subscription: The name of Pub/Sub subscription to read from.
format: The format of the Pub/Sub messages (e.g., 'JSON').
schema: The schema of the Pub/Sub messages.
attributes: A list of attributes to include in the output.
attributes_map: A string representing a mapping of attributes.
id_attribute: The attribute to use as the ID for the message.
timestamp_attribute: The attribute to use as the timestamp for the message.

Returns:
A PCollection containing the sample data.
"""

if topic == 'test-topic':
pubsub_messages = PUBSUB_TOPICS['test-topic']
return (
pcoll
| beam.Create([json.loads(msg.data) for msg in pubsub_messages])
| beam.Map(lambda element: beam.Row(**element)))
elif topic == 'taxi-ride-topic':
pubsub_messages = PUBSUB_TOPICS['taxi-ride-topic']
schema = input_data.TaxiRideEventSchema
return (
pcoll
| beam.Create([json.loads(msg.data) for msg in pubsub_messages])
|
beam.Map(lambda element: beam.Row(**element)).with_output_types(schema))

return None


@beam.ptransform.ptransform_fn
def test_run_inference_taxi_fare(pcoll, inference_tag, model_handler):
"""
This PTransform simulates the behavior of the RunInference transform.

Args:
pcoll: The input PCollection.
inference_tag: The tag to use for the returned inference.
model_handler: A configuration for the respective ML model handler

Returns:
A PCollection containing the enriched data.
"""
def _fn(row):
input = row._asdict()

pubsub_messages = input_data.pubsub_messages_data()
row = {inference_tag: PredictionResult(input, 10.0), **input}

return (
pcoll
| beam.Create([json.loads(msg.data) for msg in pubsub_messages])
| beam.Map(lambda element: beam.Row(**element)))
return beam.Row(**row)

schema = _format_predicition_result_ouput(pcoll, inference_tag)
return pcoll | beam.Map(_fn).with_output_types(schema)


@beam.ptransform.ptransform_fn
def test_run_inference(pcoll, inference_tag, model_handler):
def test_run_inference_youtube_comments(pcoll, inference_tag, model_handler):
"""
This PTransform simulates the behavior of the RunInference transform.

Expand Down Expand Up @@ -193,24 +254,28 @@ def _fn(row):

return beam.Row(**row)

schema = _format_predicition_result_ouput(pcoll, inference_tag)
return pcoll | beam.Map(_fn).with_output_types(schema)


def _format_predicition_result_ouput(pcoll, inference_tag):
user_type = RowTypeConstraint.from_user_type(pcoll.element_type.user_type)
user_schema_fields = [(name, type(typ) if not isinstance(typ, type) else typ)
for (name,
typ) in user_type._fields] if user_type else []
inference_output_type = RowTypeConstraint.from_fields([
('example', Any), ('inference', Any), ('model_id', Optional[str])
])
schema = RowTypeConstraint.from_fields(
return RowTypeConstraint.from_fields(
user_schema_fields + [(str(inference_tag), inference_output_type)])

return pcoll | beam.Map(_fn).with_output_types(schema)


TEST_PROVIDERS = {
'TestEnrichment': test_enrichment,
'TestReadFromKafka': test_kafka_read,
'TestReadFromPubSub': test_pubsub_read,
'TestRunInference': test_run_inference
'TestRunInferenceYouTubeComments': test_run_inference_youtube_comments,
'TestRunInferenceTaxiFare': test_run_inference_taxi_fare,
}
"""
Transforms not requiring inputs.
Expand Down Expand Up @@ -305,6 +370,11 @@ def _python_deps_involved(spec_filename):
substr in spec_filename
for substr in ['deps', 'streaming_sentiment_analysis'])

def _java_deps_involved(spec_filename):
return any(
substr in spec_filename
for substr in ['java_deps', 'streaming_taxifare_prediction'])

if _python_deps_involved(pipeline_spec_file):
test_yaml_example = pytest.mark.no_xdist(test_yaml_example)
test_yaml_example = unittest.skipIf(
Expand All @@ -319,7 +389,7 @@ def _python_deps_involved(spec_filename):
'Github actions environment issue.')(
test_yaml_example)

if 'java_deps' in pipeline_spec_file:
if _java_deps_involved(pipeline_spec_file):
test_yaml_example = pytest.mark.xlang_sql_expansion_service(
test_yaml_example)
test_yaml_example = unittest.skipIf(
Expand Down Expand Up @@ -500,6 +570,7 @@ def _kafka_test_preprocessor(
for transform in pipeline.get('transforms', []):
if transform.get('type', '') == 'ReadFromKafka':
transform['type'] = 'TestReadFromKafka'
transform['config']['topic'] = 'test-topic'

return test_spec

Expand All @@ -525,8 +596,7 @@ def _kafka_test_preprocessor(
'test_oracle_to_bigquery_yaml',
'test_mysql_to_bigquery_yaml',
'test_spanner_to_bigquery_yaml',
'test_streaming_sentiment_analysis_yaml',
'test_enrich_spanner_with_bigquery_yaml'
'test_streaming_sentiment_analysis_yaml'
])
def _io_write_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
Expand Down Expand Up @@ -740,6 +810,7 @@ def _pubsub_io_read_test_preprocessor(
for transform in pipeline.get('transforms', []):
if transform.get('type', '') == 'ReadFromPubSub':
transform['type'] = 'TestReadFromPubSub'
transform['config']['topic'] = 'test-topic'

return test_spec

Expand Down Expand Up @@ -904,7 +975,90 @@ def _streaming_sentiment_analysis_test_preprocessor(
if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '') == 'RunInference':
transform['type'] = 'TestRunInference'
transform['type'] = 'TestRunInferenceYouTubeComments'

return test_spec


@YamlExamplesTestSuite.register_test_preprocessor(
'test_streaming_taxifare_prediction_yaml')
def _streaming_taxifare_prediction_test_preprocessor(
test_spec: dict, expected: List[str], env: TestEnvironment):
"""
Preprocessor for tests that involve the streaming taxi fare prediction
example.

This preprocessor replaces several IO transforms and the RunInference
transform. This allows the test to verify the pipeline's correctness
without relying on external data sources and the model hosted on VertexAI.
It also turns this non-linear pipeline into a linear pipeline by replacing
the ReadFromKafka and WriteToKafka transforms with MapToFields and linking
the two disconnected pipeline components together. The pipeline logic,
however, remains the same and is still being tested accordingly.

Args:
test_spec: The dictionary representation of the YAML pipeline specification.
expected: A list of strings representing the expected output of the
pipeline.
env: The TestEnvironment object providing utilities for creating temporary
files.

Returns:
The modified test_spec dictionary with several involved IO transforms and
the RunInference transform replaced.
"""

if pipeline := test_spec.get('pipeline', None):
for transform in pipeline.get('transforms', []):
if transform.get('type', '') == 'ReadFromPubSub':
transform['type'] = 'TestReadFromPubSub'
transform['config']['topic'] = 'taxi-ride-topic'

elif transform.get('type', '') == 'WriteToKafka':
transform['type'] = 'MapToFields'
transform['config'] = {
k: v
for (k, v) in transform.get('config', {}).items()
if k.startswith('__')
}
transform['config']['fields'] = {
'ride_id': 'ride_id',
'pickup_longitude': 'pickup_longitude',
'pickup_latitude': 'pickup_latitude',
'pickup_datetime': 'pickup_datetime',
'dropoff_longitude': 'dropoff_longitude',
'dropoff_latitude': 'dropoff_latitude',
'passenger_count': 'passenger_count',
}

elif transform.get('type', '') == 'ReadFromKafka':
transform['type'] = 'MapToFields'
transform['config'] = {
k: v
for (k, v) in transform.get('config', {}).items()
if k.startswith('__')
}
transform['input'] = 'WriteKafka'
transform['config']['fields'] = {
'ride_id': 'ride_id',
'pickup_longitude': 'pickup_longitude',
'pickup_latitude': 'pickup_latitude',
'pickup_datetime': 'pickup_datetime',
'dropoff_longitude': 'dropoff_longitude',
'dropoff_latitude': 'dropoff_latitude',
'passenger_count': 'passenger_count',
}

elif transform.get('type', '') == 'WriteToBigQuery':
transform['type'] = 'LogForTesting'
transform['config'] = {
k: v
for (k, v) in transform.get('config', {}).items()
if (k.startswith('__') or k == 'error_handling')
}

elif transform.get('type', '') == 'RunInference':
transform['type'] = 'TestRunInferenceTaxiFare'

return test_spec

Expand All @@ -915,6 +1069,15 @@ def _streaming_sentiment_analysis_test_preprocessor(
'youtube-comments.csv': input_data.youtube_comments_csv()
}

KAFKA_TOPICS = {
'test-topic': input_data.kafka_messages_data(),
}

PUBSUB_TOPICS = {
'test-topic': input_data.pubsub_messages_data(),
'taxi-ride-topic': input_data.pubsub_taxi_ride_events_data()
}

INPUT_TABLES = {
('shipment-test', 'shipment', 'shipments'): input_data.shipments_data(),
('orders-test', 'order-database', 'orders'): input_data.
Expand Down
67 changes: 67 additions & 0 deletions sdks/python/apache_beam/yaml/examples/testing/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# limitations under the License.
#

import typing

from apache_beam.io.gcp.pubsub import PubsubMessage

# This file contains the input data to be requested by the example tests, if
Expand Down Expand Up @@ -186,3 +188,68 @@ def pubsub_messages_data():
PubsubMessage(data=b"{\"label\": \"37c\", \"rank\": 3}", attributes={}),
PubsubMessage(data=b"{\"label\": \"37d\", \"rank\": 2}", attributes={}),
]


def pubsub_taxi_ride_events_data():
"""
Provides a list of PubsubMessage objects for testing taxi ride events.
"""
return [
PubsubMessage(
data=b"{\"ride_id\": \"1\", \"longitude\": 11.0, \"latitude\": -11.0,"
b"\"passenger_count\": 1, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:29:00.00000-04:00\", \"ride_status\": \"pickup\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"2\", \"longitude\": 22.0, \"latitude\": -22.0,"
b"\"passenger_count\": 2, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:30:00.00000-04:00\", \"ride_status\": \"pickup\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"1\", \"longitude\": 13.0, \"latitude\": -13.0,"
b"\"passenger_count\": 1, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:31:00.00000-04:00\", \"ride_status\": \"enroute\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"2\", \"longitude\": 24.0, \"latitude\": -24.0,"
b"\"passenger_count\": 2, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:32:00.00000-04:00\", \"ride_status\": \"enroute\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"3\", \"longitude\": 33.0, \"latitude\": -33.0,"
b"\"passenger_count\": 3, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:35:00.00000-04:00\", \"ride_status\": \"enroute\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"4\", \"longitude\": 44.0, \"latitude\": -44.0,"
b"\"passenger_count\": 4, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:35:00.00000-04:00\", \"ride_status\": \"dropoff\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"1\", \"longitude\": 15.0, \"latitude\": -15.0,"
b"\"passenger_count\": 1, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:33:00.00000-04:00\", \"ride_status\": \"dropoff\"}",
attributes={}),
PubsubMessage(
data=b"{\"ride_id\": \"2\", \"longitude\": 26.0, \"latitude\": -26.0,"
b"\"passenger_count\": 2, \"meter_reading\": 100.0, \"timestamp\": "
b"\"2025-01-01T00:34:00.00000-04:00\", \"ride_status\": \"dropoff\"}",
attributes={}),
]


def kafka_messages_data():
"""
Provides a list of Kafka messages for testing.
"""
return [data.encode('utf-8') for data in text_data().split('\n')]


class TaxiRideEventSchema(typing.NamedTuple):
ride_id: str
longitude: float
latitude: float
passenger_count: int
meter_reading: float
timestamp: str
ride_status: str
Loading
Loading