diff --git a/tensorflow_lite_support/ios/task/text/BUILD b/tensorflow_lite_support/ios/task/text/BUILD new file mode 100644 index 000000000..101824060 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/BUILD @@ -0,0 +1,28 @@ +package( + default_visibility = ["//tensorflow_lite_support:internal"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLTextSearcher", + srcs = [ + "sources/TFLTextSearcher.mm", + ], + hdrs = [ + "sources/TFLTextSearcher.h", + ], + copts = [ + "-ObjC++", + "-std=c++17", + ], + features = ["-layering_check"], + module_name = "TFLTextSearcher", + deps = [ + "//tensorflow_lite_support/cc/task/text:text_searcher", + "//tensorflow_lite_support/ios:TFLCommonUtils", + "//tensorflow_lite_support/ios/task/core:TFLBaseOptionsCppHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLEmbeddingOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLSearchOptionsHelpers", + "//tensorflow_lite_support/ios/task/processor:TFLSearchResultHelpers", + ], +) diff --git a/tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h b/tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h new file mode 100644 index 000000000..267b322f7 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h @@ -0,0 +1,104 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLEmbeddingOptions.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchOptions.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * Options to configure TFLTextSearcher. + */ +NS_SWIFT_NAME(TextSearcherOptions) +@interface TFLTextSearcherOptions : NSObject + +/** + * Base options for configuring the TextSearcher. This specifies the TFLite + * model to use for embedding extraction, as well as hardware acceleration + * options to use as inference time. + */ +@property(nonatomic, copy) TFLBaseOptions *baseOptions; + +/** + * Options controlling the behavior of the embedding model specified in the + * base options. + */ +@property(nonatomic, copy) TFLEmbeddingOptions *embeddingOptions; + +/** + * Options specifying the index to search into and controlling the search behavior. + */ +@property(nonatomic, copy) TFLSearchOptions *searchOptions; + +/** + * Initializes a new `TFLTextSearcherOptions` with the absolute path to the model file + * stored locally on the device, set to the given the model path. + * + * @discussion The external model file must be a single standalone TFLite file. It could be packed + * with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the + * necessary metadata and associated files might result in errors. Check the [documentation] + * (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * + * @return An instance of `TFLTextSearcherOptions` initialized to the given model path. + */ +- (instancetype)initWithModelPath:(NSString *)modelPath; + +@end + +/** + * A TensorFlow Lite Task Text Searcher. + */ +NS_SWIFT_NAME(TextSearcher) +@interface TFLTextSearcher : NSObject + +/** + * Creates a new instance of `TFLTextSearcher` from the given `TFLTextSearcherOptions`. + * + * @param options The options to use for configuring the `TFLTextSearcher`. + * @param error An optional error parameter populated when there is an error in initializing + * the text searcher. + * + * @return A new instance of `TextSearcher` with the given options. `nil` if there is an error + * in initializing the text searcher. + */ ++ (nullable instancetype)textSearcherWithOptions:(TFLTextSearcherOptions *)options + error:(NSError **)error + NS_SWIFT_NAME(searcher(options:)); + ++ (instancetype)new NS_UNAVAILABLE; + +/** + * Performs embedding extraction on the given text, followed by nearest-neighbor search in the + * index. + * + * @param text An string on which embedding extraction is to be performed, followed by + * nearest-neighbor search in the index. + * + * @return A `TFLSearchResult`. `nil` if there is an error encountered during embedding extraction + * and nearest neighbor search. Please see `TFLSearchResult` for more details. + */ +- (nullable TFLSearchResult *)searchWithText:(NSString *)text + error:(NSError **)error NS_SWIFT_NAME(search(text:)); + +- (instancetype)init NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.mm b/tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.mm new file mode 100644 index 000000000..d1cc50286 --- /dev/null +++ b/tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.mm @@ -0,0 +1,114 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import "tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h" +#import "tensorflow_lite_support/ios/sources/TFLCommon.h" +#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h" +#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions+CppHelpers.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLEmbeddingOptions+Helpers.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchOptions+Helpers.h" +#import "tensorflow_lite_support/ios/task/processor/sources/TFLSearchResult+Helpers.h" + +#include "tensorflow_lite_support/cc/task/text/text_searcher.h" + +namespace { +using TextSearcherCpp = ::tflite::task::text::TextSearcher; +using TextSearcherOptionsCpp = ::tflite::task::text::TextSearcherOptions; +using SearchResultCpp = ::tflite::task::processor::SearchResult; +using ::tflite::support::StatusOr; +} // namespace + +@interface TFLTextSearcher () { + /** TextSearcher backed by C++ API */ + std::unique_ptr _cppTextSearcher; +} +@end + +@implementation TFLTextSearcherOptions + +- (instancetype)init { + self = [super init]; + if (self) { + _baseOptions = [[TFLBaseOptions alloc] init]; + _embeddingOptions = [[TFLEmbeddingOptions alloc] init]; + _searchOptions = [[TFLSearchOptions alloc] init]; + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath { + self = [self init]; + if (self) { + _baseOptions.modelFile.filePath = modelPath; + } + return self; +} + +- (TextSearcherOptionsCpp)cppOptions { + TextSearcherOptionsCpp cppOptions = {}; + [self.baseOptions copyToCppOptions:cppOptions.mutable_base_options()]; + [self.embeddingOptions copyToCppOptions:cppOptions.mutable_embedding_options()]; + [self.searchOptions copyToCppOptions:cppOptions.mutable_search_options()]; + + return cppOptions; +} + +@end + +@implementation TFLTextSearcher + +- (nullable instancetype)initWithCppTextSearcherOptions:(TextSearcherOptionsCpp)cppOptions { + self = [super init]; + if (self) { + StatusOr> cppTextSearcher = + TextSearcherCpp::CreateFromOptions(cppOptions); + if (cppTextSearcher.ok()) { + _cppTextSearcher = std::move(cppTextSearcher.value()); + } else { + return nil; + } + } + return self; +} + ++ (nullable instancetype)textSearcherWithOptions:(TFLTextSearcherOptions *)options + error:(NSError **)error { + if (!options) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"TFLTextSearcherOptions argument cannot be nil."]; + return nil; + } + + TextSearcherOptionsCpp cppOptions = [options cppOptions]; + + return [[TFLTextSearcher alloc] initWithCppTextSearcherOptions:cppOptions]; +} + +- (nullable TFLSearchResult *)searchWithText:(NSString *)text error:(NSError **)error { + if (!text) { + [TFLCommonUtils createCustomError:error + withCode:TFLSupportErrorCodeInvalidArgumentError + description:@"GMLImage argument cannot be nil."]; + return nil; + } + + std::string cppTextToBeSearched = std::string(text.UTF8String, [text lengthOfBytesUsingEncoding:NSUTF8StringEncoding]); + StatusOr cppSearchResultStatus = _cppTextSearcher->Search( + cppTextToBeSearched); + + return [TFLSearchResult searchResultWithCppResult:cppSearchResultStatus error:error]; +} + +@end diff --git a/tensorflow_lite_support/ios/test/task/text/text_searcher/BUILD b/tensorflow_lite_support/ios/test/task/text/text_searcher/BUILD new file mode 100644 index 000000000..3b4197d77 --- /dev/null +++ b/tensorflow_lite_support/ios/test/task/text/text_searcher/BUILD @@ -0,0 +1,31 @@ +load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION") +load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") +load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner") + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], # Apache 2.0 +) + +objc_library( + name = "TFLTextSearcherObjcTestLibrary", + testonly = 1, + srcs = ["TFLTextSearcherTests.m"], + data = [ + "//tensorflow_lite_support/cc/test/testdata/task/text:test_searchers", + ], + tags = TFL_DEFAULT_TAGS, + deps = [ + "//tensorflow_lite_support/ios/task/text:TFLTextSearcher", + ], +) + +ios_unit_test( + name = "TFLTextSearcherObjcTest", + minimum_os_version = TFL_MINIMUM_OS_VERSION, + runner = tflite_ios_lab_runner("IOS_LATEST"), + tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS, + deps = [ + ":TFLTextSearcherObjcTestLibrary", + ], +) diff --git a/tensorflow_lite_support/ios/test/task/text/text_searcher/TFLTextSearcherTests.m b/tensorflow_lite_support/ios/test/task/text/text_searcher/TFLTextSearcherTests.m new file mode 100644 index 000000000..282fd2789 --- /dev/null +++ b/tensorflow_lite_support/ios/test/task/text/text_searcher/TFLTextSearcherTests.m @@ -0,0 +1,90 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#import + +#import "tensorflow_lite_support/ios/task/text/sources/TFLTextSearcher.h" + +NS_ASSUME_NONNULL_BEGIN + +#define ValidateSearchResultCount(searchResult, expectedNearestNeighborsCount) \ + XCTAssertEqual(searchResult.nearestNeighbors.count, expectedNearestNeighborsCount); + +#define ValidateNearestNeighbor(nearestNeighbor, expectedMetadata, expectedDistance) \ + XCTAssertEqualObjects(nearestNeighbor.metadata, expectedMetadata); \ + XCTAssertEqualWithAccuracy(nearestNeighbor.distance, expectedDistance, 1e-6); + +@interface TFLTextSearcherTests : XCTestCase +@property(nonatomic, nullable) NSString *modelPath; +@end + +@implementation TFLTextSearcherTests + +- (void)setUp { + [super setUp]; + self.modelPath = + [[NSBundle bundleForClass:self.class] pathForResource:@"regex_searcher" + ofType:@"tflite"]; + XCTAssertNotNil(self.modelPath); +} + +- (TFLTextSearcher *)textSearcherWithSearcherModelPath:(NSString *)modelPath { + TFLTextSearcherOptions *textSearcherOptions = + [[TFLTextSearcherOptions alloc] initWithModelPath:self.modelPath]; + + TFLTextSearcher *textSearcher = [TFLTextSearcher textSearcherWithOptions:textSearcherOptions + error:nil]; + XCTAssertNotNil(textSearcher); + + return textSearcher; +} + +- (void)validateSearchResultForInferenceWithSearchContent:(TFLSearchResult *)searchResult { + ValidateSearchResultCount(searchResult, + 5 // expectedNearestNeighborsCount + ); + + ValidateNearestNeighbor(searchResult.nearestNeighbors[0], + @"The weather was excellent.", // expectedMetadata + 0.889664649963 // expectedDistance + ); + ValidateNearestNeighbor(searchResult.nearestNeighbors[1], + @"The sun was shining on that day.", // expectedMetadata + 0.889667928219 // expectedDistance + ); + ValidateNearestNeighbor(searchResult.nearestNeighbors[2], + @"The cat is chasing after the mouse.", // expectedMetadata + 0.889669716358 // expectedDistance + ); + ValidateNearestNeighbor(searchResult.nearestNeighbors[3], + @"It was a sunny day.", // expectedMetadata + 0.889671087265 // expectedDistance + ); + ValidateNearestNeighbor(searchResult.nearestNeighbors[4], + @"He was very happy with his newly bought car.", // expectedMetadata + 0.889671683311 // expectedDistance + ); +} + +- (void)testSuccessfullInferenceWithSearchContentOnText { + TFLTextSearcher *textSearcher = + [self textSearcherWithSearcherModelPath:self.modelPath]; + + TFLSearchResult *searchResult = [textSearcher searchWithText:@"The weather was excellent." error:nil]; + [self validateSearchResultForInferenceWithSearchContent:searchResult]; +} + +@end + +NS_ASSUME_NONNULL_END diff --git a/tensorflow_lite_support/ios/test/task/vision/image_searcher/TFLImageSearcherTests.m b/tensorflow_lite_support/ios/test/task/vision/image_searcher/TFLImageSearcherTests.m index ff2818faa..f3679ebb3 100644 --- a/tensorflow_lite_support/ios/test/task/vision/image_searcher/TFLImageSearcherTests.m +++ b/tensorflow_lite_support/ios/test/task/vision/image_searcher/TFLImageSearcherTests.m @@ -294,4 +294,4 @@ - (void)testImageSearcherWithEmbedderModelAndInvalidIndexFileFails { @end -NS_ASSUME_NONNULL_END +NS_ASSUME_NONNULL_END \ No newline at end of file