diff --git a/tsl/profiler/lib/BUILD b/tsl/profiler/lib/BUILD index a3c07948d..8aa45db7b 100644 --- a/tsl/profiler/lib/BUILD +++ b/tsl/profiler/lib/BUILD @@ -207,6 +207,8 @@ cc_library( "@xla//xla/tsl/profiler:internal", ]), deps = [ + "//tensorflow/core/profiler/utils:xplane_builder", + "//tensorflow/core/profiler/utils:xplane_utils", "//tsl/platform:thread_annotations", "//tsl/profiler/protobuf:profiler_options_proto_cc", "//tsl/profiler/protobuf:xplane_proto_cc", @@ -216,6 +218,7 @@ cc_library( "@xla//xla/tsl/platform:errors", "@xla//xla/tsl/platform:logging", "@xla//xla/tsl/platform:types", + "@xla//xla/tsl/profiler/utils:xplane_schema", ] + if_not_android([ ":profiler_collection", ":profiler_factory", diff --git a/tsl/profiler/lib/profiler_session.cc b/tsl/profiler/lib/profiler_session.cc index ea677e5fa..bcec25fb1 100644 --- a/tsl/profiler/lib/profiler_session.cc +++ b/tsl/profiler/lib/profiler_session.cc @@ -23,6 +23,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "third_party/tensorflow/core/profiler/utils/xplane_builder.h" +#include "third_party/tensorflow/core/profiler/utils/xplane_utils.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -41,6 +44,7 @@ namespace { using tensorflow::ProfileOptions; using tensorflow::profiler::XSpace; +using ::tsl::profiler::kProfileOptionsName; ProfileOptions GetOptions(const ProfileOptions& opts) { if (opts.version()) return opts; @@ -49,6 +53,14 @@ ProfileOptions GetOptions(const ProfileOptions& opts) { return options; } +void SetProfileOptionsIntoSpace(const ProfileOptions& options, XSpace* space) { + tensorflow::profiler::XPlaneBuilder xplane( + tensorflow::profiler::FindOrAddMutablePlaneWithName( + space, tsl::profiler::kMetadataPlaneName)); + xplane.AddStatValue(*xplane.GetOrCreateStatMetadata(kProfileOptionsName), + options.DebugString()); +} + }; // namespace /*static*/ std::unique_ptr ProfilerSession::Create( @@ -84,6 +96,7 @@ absl::Status ProfilerSession::CollectData(XSpace* space) { TF_RETURN_IF_ERROR(CollectDataInternal(space)); profiler::PostProcessSingleHostXSpace(space, start_time_ns_, stop_time_ns_); #endif + SetProfileOptionsIntoSpace(options_, space); return absl::OkStatus(); }