Skip to content

Commit dda17cb

Browse files
Utils to read unigraph-formatted data into sampler v2 input format and a driver program used like graph_sampler.py
PiperOrigin-RevId: 535353135
1 parent fb7c122 commit dda17cb

File tree

6 files changed

+1035
-1
lines changed

6 files changed

+1035
-1
lines changed

tensorflow_gnn/experimental/sampler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# Export.
3030
create_program = eval_dag.create_program
3131
save_model = eval_dag.save_model
32+
Artifacts = eval_dag.Artifacts
3233

3334
# Sampling layers.
3435
InMemUniformEdgesSampler = core.InMemUniformEdgesSampler

tensorflow_gnn/experimental/sampler/beam/BUILD

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,57 @@ pytype_strict_contrib_test(
127127
"//:expect_tensorflow_installed",
128128
],
129129
)
130+
131+
pytype_strict_library(
132+
name = "unigraph_utils",
133+
srcs = ["unigraph_utils.py"],
134+
srcs_version = "PY3ONLY",
135+
deps = [
136+
"//third_party/py/apache_beam",
137+
"//:expect_numpy_installed",
138+
"//:expect_tensorflow_installed",
139+
"//tensorflow_gnn",
140+
"//tensorflow_gnn/data:unigraph",
141+
"//tensorflow_gnn/sampler:sampling_spec_py_proto",
142+
],
143+
)
144+
145+
py_binary(
146+
name = "sampler",
147+
srcs = ["sampler.py"],
148+
deps = [
149+
":accessors",
150+
":edge_samplers",
151+
":executor_lib",
152+
":unigraph_utils",
153+
"//third_party/py/absl:app",
154+
"//third_party/py/absl/flags",
155+
"//third_party/py/absl/logging",
156+
"//third_party/py/apache_beam",
157+
"//:expect_tensorflow_installed",
158+
"//tensorflow_gnn",
159+
"//tensorflow_gnn/data:unigraph",
160+
"//tensorflow_gnn/experimental/sampler",
161+
"//tensorflow_gnn/experimental/sampler:subgraph_pipeline",
162+
"//tensorflow_gnn/proto:graph_schema_py_proto",
163+
"//tensorflow_gnn/sampler:sampling_spec_py_proto",
164+
],
165+
)
166+
167+
pytype_strict_contrib_test(
168+
name = "unigraph_utils_test",
169+
srcs = ["unigraph_utils_test.py"],
170+
data = ["@tensorflow_gnn//testdata/heterogeneous"],
171+
python_version = "PY3",
172+
srcs_version = "PY3ONLY",
173+
deps = [
174+
":unigraph_utils",
175+
"//testing/pybase",
176+
"//third_party/py/apache_beam",
177+
"//:expect_numpy_installed",
178+
"//:expect_tensorflow_installed",
179+
"//tensorflow_gnn",
180+
"//tensorflow_gnn/data:unigraph",
181+
"//tensorflow_gnn/utils:test_utils",
182+
],
183+
)
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Runs sampling pipeline defined by the GraphSchema and SamplingSpec.
16+
17+
Closely follows V1.
18+
"""
19+
20+
import os
21+
from typing import Optional
22+
23+
from absl import app
24+
from absl import flags
25+
from absl import logging
26+
import apache_beam as beam
27+
from apache_beam.options.pipeline_options import PipelineOptions
28+
from apache_beam.options.pipeline_options import SetupOptions
29+
import tensorflow as tf
30+
import tensorflow_gnn as tfgnn
31+
from tensorflow_gnn.data import unigraph
32+
from tensorflow_gnn.experimental import sampler
33+
from tensorflow_gnn.experimental.sampler import subgraph_pipeline
34+
from tensorflow_gnn.experimental.sampler.beam import accessors # pylint: disable=unused-import
35+
from tensorflow_gnn.experimental.sampler.beam import edge_samplers # pylint: disable=unused-import
36+
from tensorflow_gnn.experimental.sampler.beam import executor_lib
37+
from tensorflow_gnn.experimental.sampler.beam import unigraph_utils
38+
from tensorflow_gnn.proto import graph_schema_pb2
39+
from tensorflow_gnn.sampler import sampling_spec_pb2
40+
41+
from google.protobuf import text_format
42+
43+
44+
_DIRECT_RUNNER = 'DirectRunner'
45+
_DATAFLOW_RUNNER = 'DataflowRunner'
46+
47+
48+
def _get_shape(feature: graph_schema_pb2.Feature) -> tf.TensorShape:
49+
dim_fn = lambda dim: (None if dim.size == -1 else dim.size)
50+
dims = [dim_fn(dim) for dim in feature.shape.dim]
51+
return tf.TensorShape(dims)
52+
53+
54+
def get_sampling_model(
55+
graph_schema: tfgnn.GraphSchema,
56+
sampling_spec: sampling_spec_pb2.SamplingSpec,
57+
) -> tf.keras.Model:
58+
"""Constructs sampling model from schema and sampling spec.
59+
60+
Args:
61+
graph_schema: Attribute `edge_sets` identifies end-point node set names.
62+
sampling_spec: The number of nodes sampled from edge set. The spec defines
63+
the structure of the sampled subgraphs, that look like rooted trees,
64+
possibly densified adding all pairwise edges between sampled nodes.
65+
66+
Returns:
67+
A Keras model for sampling.
68+
"""
69+
70+
def edge_sampler_factory(
71+
op: sampling_spec_pb2.SamplingOp,
72+
) -> sampler.UniformEdgesSampler:
73+
accessor = sampler.KeyToTfExampleAccessor(
74+
sampler.InMemStringKeyToBytesAccessor(
75+
keys_to_values={'b': b'b'}),
76+
features_spec={
77+
'#target': tf.TensorSpec([None], tf.string),
78+
},
79+
)
80+
81+
sample_size = op.sample_size
82+
edge_target_feature_name = '#target'
83+
return sampler.UniformEdgesSampler(
84+
outgoing_edges_accessor=accessor,
85+
sample_size=sample_size,
86+
edge_target_feature_name=edge_target_feature_name,
87+
name=f'edges/{op.edge_set_name}'
88+
)
89+
90+
def node_features_accessor_factory(
91+
node_set_name: tfgnn.NodeSetName,
92+
) -> sampler.KeyToTfExampleAccessor:
93+
node_features = graph_schema.node_sets[node_set_name].features
94+
features_spec = {}
95+
for name, feature in node_features.items():
96+
shape = _get_shape(feature)
97+
dtype = tf.dtypes.as_dtype(feature.dtype)
98+
features_spec[name] = tf.TensorSpec(shape, dtype)
99+
accessor = sampler.KeyToTfExampleAccessor(
100+
sampler.InMemStringKeyToBytesAccessor(
101+
keys_to_values={'b': b'b'},
102+
name=f'nodes/{node_set_name}'),
103+
features_spec=features_spec,
104+
)
105+
return accessor
106+
107+
return subgraph_pipeline.create_sampling_model_from_spec(
108+
graph_schema,
109+
sampling_spec,
110+
edge_sampler_factory=edge_sampler_factory,
111+
node_features_accessor_factory=node_features_accessor_factory,
112+
)
113+
114+
115+
def _create_beam_runner(
116+
runner_name: Optional[str],
117+
) -> beam.runners.PipelineRunner:
118+
"""Creates appropriate runner."""
119+
if runner_name == _DIRECT_RUNNER:
120+
runner = beam.runners.DirectRunner()
121+
elif runner_name == _DATAFLOW_RUNNER:
122+
runner = beam.runners.DataflowRunner()
123+
else:
124+
runner = None
125+
return runner
126+
127+
128+
def save_artifacts(artifacts: sampler.Artifacts, artifacts_path: str) -> None:
129+
for layer_id, model in artifacts.models.items():
130+
path = os.path.join(artifacts_path, layer_id)
131+
tf.io.gfile.makedirs(path)
132+
sampler.save_model(model, path)
133+
134+
135+
def define_flags():
136+
"""Creates commandline flags."""
137+
138+
flags.DEFINE_string(
139+
'graph_schema',
140+
None,
141+
'Path to a text-formatted GraphSchema proto file or directory '
142+
'containing one for a graph in Universal Graph Format. This '
143+
'defines the input graph to be sampled.',
144+
)
145+
146+
flags.DEFINE_string(
147+
'data_path',
148+
None,
149+
'Path to data files for node and edge sets. Defaults to the directory '
150+
'containing graph_schema.',
151+
)
152+
153+
flags.DEFINE_string(
154+
'input_seeds',
155+
None,
156+
'Path to an input file with the seed node ids to restrict sampling over. '
157+
'The file can be in any of the supported unigraph table formats, and as '
158+
"for node sets, the 'id' column will be used. If the seeds aren't "
159+
'specified, the full set of nodes from the graph will be used '
160+
'(optional).',
161+
)
162+
163+
flags.DEFINE_string(
164+
'sampling_spec',
165+
None,
166+
'An input file with a text-formatted SamplingSpec proto to use. This is '
167+
"a required input and to some extent may mirror some of the schema's "
168+
'structure. See `sampling_spec.proto` for details on the configuration.',
169+
)
170+
171+
flags.DEFINE_string(
172+
'output_samples',
173+
None,
174+
'Output file with serialized graph tensor Example protos.',
175+
)
176+
177+
runner_choices = [_DIRECT_RUNNER, _DATAFLOW_RUNNER]
178+
runner_choices.append('flume')
179+
flags.DEFINE_enum(
180+
'runner',
181+
None,
182+
runner_choices,
183+
'The underlying runner; if not specified, use the default runner.',
184+
)
185+
186+
flags.mark_flags_as_required(
187+
['graph_schema', 'sampling_spec', 'output_samples']
188+
)
189+
190+
191+
def app_main(argv) -> None:
192+
"""Main sampler entrypoint.
193+
194+
Args:
195+
argv: List of arguments passed by flags parser.
196+
"""
197+
FLAGS = flags.FLAGS # pylint: disable=invalid-name
198+
pipeline_args = argv[1:]
199+
graph_schema: tfgnn.GraphSchema = unigraph.read_schema(FLAGS.graph_schema)
200+
201+
data_path = os.path.dirname(FLAGS.graph_schema)
202+
with tf.io.gfile.GFile(FLAGS.sampling_spec, 'r') as f:
203+
sampling_spec = text_format.Parse(
204+
f.read(), sampling_spec_pb2.SamplingSpec()
205+
)
206+
# we have graph schema which defines Graph...
207+
# and sampling spec which defines how to sample in V1 format.
208+
# 1. Let's define sampling model as TF keras model.
209+
# Example:
210+
# model = get_sampling_model(mag_graph_schema, mag_sampling_spec)
211+
# model(tf.ragged.constant([[0], [1]]))
212+
# # returns GraphTensor for seed papers 0 and 1.
213+
214+
model = get_sampling_model(graph_schema, sampling_spec)
215+
# Export sampling model as a "sampling program".
216+
program_pb, artifacts = sampler.create_program(model)
217+
# here `eval_dag` defines Beam stages to run, artifacts are TF models
218+
# for some Beam stages.
219+
220+
if not FLAGS.data_path:
221+
data_path = os.path.dirname(FLAGS.graph_schema)
222+
else:
223+
data_path = FLAGS.data_path
224+
225+
output_dir = os.path.dirname(FLAGS.output_samples)
226+
artifacts_path = os.path.join(output_dir, 'artifacts')
227+
if tf.io.gfile.exists(artifacts_path):
228+
raise ValueError(f'{artifacts_path} already exists.')
229+
230+
tf.io.gfile.makedirs(artifacts_path)
231+
save_artifacts(artifacts, artifacts_path)
232+
233+
pipeline_options = PipelineOptions(pipeline_args)
234+
pipeline_options.view_as(SetupOptions).save_main_session = True
235+
236+
with beam.Pipeline(
237+
runner=_create_beam_runner(FLAGS.runner), options=pipeline_options
238+
) as root:
239+
feeds = (root
240+
| unigraph_utils.ReadAndConvertUnigraph(graph_schema, data_path))
241+
if FLAGS.input_seeds:
242+
seeds = unigraph_utils.read_seeds(root, FLAGS.input_seeds)
243+
else:
244+
seeds = unigraph_utils.seeds_from_graph_dict(feeds)
245+
inputs = {
246+
'Input': seeds,
247+
}
248+
examples = executor_lib.execute(
249+
program_pb,
250+
inputs,
251+
feeds=feeds,
252+
artifacts_path=artifacts_path
253+
)
254+
# results are tuple: example_id to tf.Example with graph tensors.
255+
coder = beam.coders.ProtoCoder(tf.train.Example)
256+
_ = (
257+
examples
258+
| 'DropExampleId' >> beam.Values()
259+
| 'WriteToTFRecord'
260+
>> beam.io.WriteToTFRecord(
261+
os.path.join(output_dir, 'examples.tfrecord'), coder=coder
262+
)
263+
)
264+
logging.info('Pipeline complete')
265+
266+
267+
def main():
268+
define_flags()
269+
app.run(
270+
app_main, flags_parser=lambda argv: flags.FLAGS(argv, known_only=True)
271+
)
272+
273+
if __name__ == '__main__':
274+
main()

0 commit comments

Comments
 (0)