diff --git a/FirebaseAI/Sources/AILog.swift b/FirebaseAI/Sources/AILog.swift index 52b44bf7c01..345451bf07f 100644 --- a/FirebaseAI/Sources/AILog.swift +++ b/FirebaseAI/Sources/AILog.swift @@ -87,6 +87,7 @@ enum AILog { case generateContentResponseEmptyCandidates = 4003 case invalidWebsocketURL = 4004 case duplicateLiveSessionSetupComplete = 4005 + case malformedURL = 4006 // SDK Debugging case loadRequestStreamResponseLine = 5000 @@ -138,6 +139,17 @@ enum AILog { log(level: .debug, code: code, message) } + @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) + static func makeInternalError(message: String, code: MessageCode) -> GenerateContentError { + let error = GenerateContentError.internalError(underlying: NSError( + domain: "\(Constants.baseErrorDomain).Internal", + code: code.rawValue, + userInfo: [NSLocalizedDescriptionKey: message] + )) + AILog.error(code: code, message) + return error + } + /// Returns `true` if additional logging has been enabled via a launch argument. static func additionalLoggingEnabled() -> Bool { return ProcessInfo.processInfo.arguments.contains(enableArgumentKey) diff --git a/FirebaseAI/Sources/Chat.swift b/FirebaseAI/Sources/Chat.swift index 80e908a8f57..99c6fb13367 100644 --- a/FirebaseAI/Sources/Chat.swift +++ b/FirebaseAI/Sources/Chat.swift @@ -19,35 +19,21 @@ import Foundation @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public final class Chat: Sendable { private let model: GenerativeModel + private let _history: History - /// Initializes a new chat representing a 1:1 conversation between model and user. init(model: GenerativeModel, history: [ModelContent]) { self.model = model - self.history = history + _history = History(history: history) } - private let historyLock = NSLock() - private nonisolated(unsafe) var _history: [ModelContent] = [] /// The previous content from the chat that has been successfully sent and received from the /// model. This will be provided to the model for each message sent as context for the discussion. public var history: [ModelContent] { get { - historyLock.withLock { _history } + return _history.history } set { - historyLock.withLock { _history = newValue } - } - } - - private func appendHistory(contentsOf: [ModelContent]) { - historyLock.withLock { - _history.append(contentsOf: contentsOf) - } - } - - private func appendHistory(_ newElement: ModelContent) { - historyLock.withLock { - _history.append(newElement) + _history.history = newValue } } @@ -87,8 +73,8 @@ public final class Chat: Sendable { let toAdd = ModelContent(role: "model", parts: reply.parts) // Append the request and successful result to history, then return the value. - appendHistory(contentsOf: newContent) - appendHistory(toAdd) + _history.append(contentsOf: newContent) + _history.append(toAdd) return result } @@ -136,63 +122,16 @@ public final class Chat: Sendable { } // Save the request. - appendHistory(contentsOf: newContent) + _history.append(contentsOf: newContent) // Aggregate the content to add it to the history before we finish. - let aggregated = self.aggregatedChunks(aggregatedContent) - self.appendHistory(aggregated) + let aggregated = self._history.aggregatedChunks(aggregatedContent) + self._history.append(aggregated) continuation.finish() } } } - private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { - var parts: [InternalPart] = [] - var combinedText = "" - var combinedThoughts = "" - - func flush() { - if !combinedThoughts.isEmpty { - parts.append(InternalPart(.text(combinedThoughts), isThought: true, thoughtSignature: nil)) - combinedThoughts = "" - } - if !combinedText.isEmpty { - parts.append(InternalPart(.text(combinedText), isThought: nil, thoughtSignature: nil)) - combinedText = "" - } - } - - // Loop through all the parts, aggregating the text. - for part in chunks.flatMap({ $0.internalParts }) { - // Only text parts may be combined. - if case let .text(text) = part.data, part.thoughtSignature == nil { - // Thought summaries must not be combined with regular text. - if part.isThought ?? false { - // If we were combining regular text, flush it before handling "thoughts". - if !combinedText.isEmpty { - flush() - } - combinedThoughts += text - } else { - // If we were combining "thoughts", flush it before handling regular text. - if !combinedThoughts.isEmpty { - flush() - } - combinedText += text - } - } else { - // This is a non-combinable part (not text), flush any pending text. - flush() - parts.append(part) - } - } - - // Flush any remaining text. - flush() - - return ModelContent(role: "model", parts: parts) - } - /// Populates the `role` field with `user` if it doesn't exist. Required in chat sessions. private func populateContentRole(_ content: ModelContent) -> ModelContent { if content.role != nil { diff --git a/FirebaseAI/Sources/FirebaseAI.swift b/FirebaseAI/Sources/FirebaseAI.swift index 354c16b79ab..40cf38590cf 100644 --- a/FirebaseAI/Sources/FirebaseAI.swift +++ b/FirebaseAI/Sources/FirebaseAI.swift @@ -135,6 +135,28 @@ public final class FirebaseAI: Sendable { ) } + /// Initializes a new `TemplateGenerativeModel`. + /// + /// - Returns: A new `TemplateGenerativeModel` instance. + public func templateGenerativeModel() -> TemplateGenerativeModel { + return TemplateGenerativeModel( + generativeAIService: GenerativeAIService(firebaseInfo: firebaseInfo, + urlSession: GenAIURLSession.default), + apiConfig: apiConfig + ) + } + + /// Initializes a new `TemplateImagenModel`. + /// + /// - Returns: A new `TemplateImagenModel` instance. + public func templateImagenModel() -> TemplateImagenModel { + return TemplateImagenModel( + generativeAIService: GenerativeAIService(firebaseInfo: firebaseInfo, + urlSession: GenAIURLSession.default), + apiConfig: apiConfig + ) + } + /// **[Public Preview]** Initializes a ``LiveGenerativeModel`` with the given parameters. /// /// - Note: Refer to [the Firebase docs on the Live diff --git a/FirebaseAI/Sources/GenerateContentRequest.swift b/FirebaseAI/Sources/GenerateContentRequest.swift index 21acd502a75..bc4e9797760 100644 --- a/FirebaseAI/Sources/GenerateContentRequest.swift +++ b/FirebaseAI/Sources/GenerateContentRequest.swift @@ -73,15 +73,23 @@ extension GenerateContentRequest { extension GenerateContentRequest: GenerativeAIRequest { typealias Response = GenerateContentResponse - var url: URL { + func getURL() throws -> URL { let modelURL = "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)" + let urlString: String switch apiMethod { case .generateContent: - return URL(string: "\(modelURL):\(apiMethod.rawValue)")! + urlString = "\(modelURL):\(apiMethod.rawValue)" case .streamGenerateContent: - return URL(string: "\(modelURL):\(apiMethod.rawValue)?alt=sse")! + urlString = "\(modelURL):\(apiMethod.rawValue)?alt=sse" case .countTokens: - fatalError("\(Self.self) should be a property of \(CountTokensRequest.self).") + throw AILog.makeInternalError( + message: "\(Self.self) should be a property of \(CountTokensRequest.self).", + code: .malformedURL + ) } + guard let url = URL(string: urlString) else { + throw AILog.makeInternalError(message: "Malformed URL: \(urlString)", code: .malformedURL) + } + return url } } diff --git a/FirebaseAI/Sources/GenerativeAIRequest.swift b/FirebaseAI/Sources/GenerativeAIRequest.swift index 148e989db40..192de607137 100644 --- a/FirebaseAI/Sources/GenerativeAIRequest.swift +++ b/FirebaseAI/Sources/GenerativeAIRequest.swift @@ -18,7 +18,7 @@ import Foundation protocol GenerativeAIRequest: Sendable, Encodable { associatedtype Response: Sendable, Decodable - var url: URL { get } + func getURL() throws -> URL var options: RequestOptions { get } } diff --git a/FirebaseAI/Sources/GenerativeAIService.swift b/FirebaseAI/Sources/GenerativeAIService.swift index a17364f8cb6..ed385f942a0 100644 --- a/FirebaseAI/Sources/GenerativeAIService.swift +++ b/FirebaseAI/Sources/GenerativeAIService.swift @@ -26,7 +26,7 @@ struct GenerativeAIService { /// The Firebase SDK version in the format `fire/`. static let firebaseVersionTag = "fire/\(FirebaseVersion())" - private let firebaseInfo: FirebaseInfo + let firebaseInfo: FirebaseInfo private let urlSession: URLSession @@ -167,7 +167,7 @@ struct GenerativeAIService { // MARK: - Private Helpers private func urlRequest(request: T) async throws -> URLRequest { - var urlRequest = URLRequest(url: request.url) + var urlRequest = try URLRequest(url: request.getURL()) urlRequest.httpMethod = "POST" urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key") urlRequest.setValue( diff --git a/FirebaseAI/Sources/History.swift b/FirebaseAI/Sources/History.swift new file mode 100644 index 00000000000..827f7df5b46 --- /dev/null +++ b/FirebaseAI/Sources/History.swift @@ -0,0 +1,94 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class History: Sendable { + private let historyLock = NSLock() + private nonisolated(unsafe) var _history: [ModelContent] = [] + /// The previous content from the chat that has been successfully sent and received from the + /// model. This will be provided to the model for each message sent as context for the discussion. + public var history: [ModelContent] { + get { + historyLock.withLock { _history } + } + set { + historyLock.withLock { _history = newValue } + } + } + + init(history: [ModelContent]) { + self.history = history + } + + func append(contentsOf: [ModelContent]) { + historyLock.withLock { + _history.append(contentsOf: contentsOf) + } + } + + func append(_ newElement: ModelContent) { + historyLock.withLock { + _history.append(newElement) + } + } + + func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { + var parts: [InternalPart] = [] + var combinedText = "" + var combinedThoughts = "" + + func flush() { + if !combinedThoughts.isEmpty { + parts.append(InternalPart(.text(combinedThoughts), isThought: true, thoughtSignature: nil)) + combinedThoughts = "" + } + if !combinedText.isEmpty { + parts.append(InternalPart(.text(combinedText), isThought: nil, thoughtSignature: nil)) + combinedText = "" + } + } + + // Loop through all the parts, aggregating the text. + for part in chunks.flatMap({ $0.internalParts }) { + // Only text parts may be combined. + if case let .text(text) = part.data, part.thoughtSignature == nil { + // Thought summaries must not be combined with regular text. + if part.isThought ?? false { + // If we were combining regular text, flush it before handling "thoughts". + if !combinedText.isEmpty { + flush() + } + combinedThoughts += text + } else { + // If we were combining "thoughts", flush it before handling regular text. + if !combinedThoughts.isEmpty { + flush() + } + combinedText += text + } + } else { + // This is a non-combinable part (not text), flush any pending text. + flush() + parts.append(part) + } + } + + // Flush any remaining text. + flush() + + return ModelContent(role: "model", parts: parts) + } +} diff --git a/FirebaseAI/Sources/TemplateChatSession.swift b/FirebaseAI/Sources/TemplateChatSession.swift new file mode 100644 index 00000000000..abba669a1dd --- /dev/null +++ b/FirebaseAI/Sources/TemplateChatSession.swift @@ -0,0 +1,176 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +// TODO: Restore `public` to class and methods when determined to be releaseable. + +/// A chat session that allows for conversation with a model. +/// +/// **Public Preview**: This API is a public preview and may be subject to change. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateChatSession: Sendable { + private let model: TemplateGenerativeModel + private let templateID: String + private let _history: History + + init(model: TemplateGenerativeModel, templateID: String, history: [ModelContent]) { + self.model = model + self.templateID = templateID + _history = History(history: history) + } + + public var history: [ModelContent] { + get { + return _history.history + } + set { + _history.history = newValue + } + } + + /// Sends a message to the model and returns the response. + /// + /// **Public Preview**: This API is a public preview and may be subject to change. + /// + /// - Parameters: + /// - content: The message to send to the model. + /// - inputs: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + func sendMessage(_ content: [ModelContent], + inputs: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } + let newContent = content.map(populateContentRole) + let response = try await model.generateContentWithHistory( + history: _history.history + newContent, + template: templateID, + inputs: templateInputs, + options: options + ) + _history.append(contentsOf: newContent) + if let modelResponse = response.candidates.first { + _history.append(modelResponse.content) + } + return response + } + + /// Sends a message to the model and returns the response. + /// + /// **Public Preview**: This API is a public preview and may be subject to change. + /// + /// - Parameters: + /// - message: The message to send to the model. + /// - inputs: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + func sendMessage(_ message: any PartsRepresentable, + inputs: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + return try await sendMessage([ModelContent(parts: message.partsValue)], + inputs: inputs, + options: options) + } + + /// Sends a message to the model and returns the response as a stream of + /// `GenerateContentResponse`s. + /// + /// **Public Preview**: This API is a public preview and may be subject to change. + /// + /// - Parameters: + /// - content: The message to send to the model. + /// - inputs: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects. + /// - Throws: A ``GenerateContentError`` if the request failed. + func sendMessageStream(_ content: [ModelContent], + inputs: [String: Any], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } + let newContent = content.map(populateContentRole) + let stream = try model.generateContentStreamWithHistory( + history: _history.history + newContent, + template: templateID, + inputs: templateInputs, + options: options + ) + return AsyncThrowingStream { continuation in + Task { + var aggregatedContent: [ModelContent] = [] + + do { + for try await chunk in stream { + // Capture any content that's streaming. This should be populated if there's no error. + if let chunkContent = chunk.candidates.first?.content { + aggregatedContent.append(chunkContent) + } + + // Pass along the chunk. + continuation.yield(chunk) + } + } catch { + // Rethrow the error that the underlying stream threw. Don't add anything to history. + continuation.finish(throwing: error) + return + } + + // Save the request. + _history.append(contentsOf: newContent) + + // Aggregate the content to add it to the history before we finish. + let aggregated = _history.aggregatedChunks(aggregatedContent) + _history.append(aggregated) + continuation.finish() + } + } + } + + /// Sends a message to the model and returns the response as a stream of + /// `GenerateContentResponse`s. + /// + /// **Public Preview**: This API is a public preview and may be subject to change. + /// + /// - Parameters: + /// - message: The message to send to the model. + /// - inputs: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects. + /// - Throws: A ``GenerateContentError`` if the request failed. + func sendMessageStream(_ message: any PartsRepresentable, + inputs: [String: Any], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + return try sendMessageStream([ModelContent(parts: message.partsValue)], + inputs: inputs, + options: options) + } + + private func populateContentRole(_ content: ModelContent) -> ModelContent { + if content.role != nil { + return content + } else { + return ModelContent(role: "user", parts: content.parts) + } + } +} diff --git a/FirebaseAI/Sources/TemplateGenerateContentRequest.swift b/FirebaseAI/Sources/TemplateGenerateContentRequest.swift new file mode 100644 index 00000000000..20ba84b3571 --- /dev/null +++ b/FirebaseAI/Sources/TemplateGenerateContentRequest.swift @@ -0,0 +1,63 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct TemplateGenerateContentRequest: Sendable { + let template: String + let inputs: [String: TemplateInput] + let history: [ModelContent] + let projectID: String + let stream: Bool + let apiConfig: APIConfig + let options: RequestOptions +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateGenerateContentRequest: Encodable { + enum CodingKeys: String, CodingKey { + case inputs + case history + } + + func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(inputs, forKey: .inputs) + try container.encode(history, forKey: .history) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateGenerateContentRequest: GenerativeAIRequest { + typealias Response = GenerateContentResponse + + func getURL() throws -> URL { + var urlString = + "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/projects/\(projectID)" + if case let .vertexAI(_, location) = apiConfig.service { + urlString += "/locations/\(location)" + } + + if stream { + urlString += "/templates/\(template):templateStreamGenerateContent?alt=sse" + } else { + urlString += "/templates/\(template):templateGenerateContent" + } + guard let url = URL(string: urlString) else { + throw AILog.makeInternalError(message: "Malformed URL: \(urlString)", code: .malformedURL) + } + return url + } +} diff --git a/FirebaseAI/Sources/TemplateGenerativeModel.swift b/FirebaseAI/Sources/TemplateGenerativeModel.swift new file mode 100644 index 00000000000..bf727021c0f --- /dev/null +++ b/FirebaseAI/Sources/TemplateGenerativeModel.swift @@ -0,0 +1,141 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type that represents a remote multimodal model (like Gemini), with the ability to generate +/// content based on various input types. +/// +/// **Public Preview**: This API is a public preview and may be subject to change. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public final class TemplateGenerativeModel: Sendable { + let generativeAIService: GenerativeAIService + let apiConfig: APIConfig + + init(generativeAIService: GenerativeAIService, apiConfig: APIConfig) { + self.generativeAIService = generativeAIService + self.apiConfig = apiConfig + } + + /// Generates content from a prompt template and inputs. + /// + /// **Public Preview**: This API is a public preview and may be subject to change. + /// + /// - Parameters: + /// - templateID: The ID of the prompt template to use. + /// - inputs: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + public func generateContent(templateID: String, + inputs: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } + return try await generateContentWithHistory( + history: [], + template: templateID, + inputs: templateInputs, + options: options + ) + } + + /// Generates content from a prompt template, inputs, and history. + /// + /// - Parameters: + /// - history: The conversation history to use. + /// - template: The prompt template to use. + /// - inputs: A dictionary of variables to substitute into the template. + /// - Returns: The content generated by the model. + /// - Throws: A ``GenerateContentError`` if the request failed. + func generateContentWithHistory(history: [ModelContent], template: String, + inputs: [String: TemplateInput], + options: RequestOptions = RequestOptions()) async throws + -> GenerateContentResponse { + let request = TemplateGenerateContentRequest( + template: template, + inputs: inputs, + history: history, + projectID: generativeAIService.firebaseInfo.projectID, + stream: false, + apiConfig: apiConfig, + options: options + ) + let response: GenerateContentResponse = try await generativeAIService + .loadRequest(request: request) + return response + } + + /// Generates content from a prompt template and inputs, with streaming responses. + /// + /// **Public Preview**: This API is a public preview and may be subject to change. + /// + /// - Parameters: + /// - templateID: The ID of the prompt template to use. + /// - inputs: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: An `AsyncThrowingStream` that yields `GenerateContentResponse` objects. + /// - Throws: A ``GenerateContentError`` if the request failed. + public func generateContentStream(templateID: String, + inputs: [String: Any], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } + let request = TemplateGenerateContentRequest( + template: templateID, + inputs: templateInputs, + history: [], + projectID: generativeAIService.firebaseInfo.projectID, + stream: true, + apiConfig: apiConfig, + options: options + ) + return generativeAIService.loadRequestStream(request: request) + } + + func generateContentStreamWithHistory(history: [ModelContent], template: String, + inputs: [String: TemplateInput], + options: RequestOptions = RequestOptions()) throws + -> AsyncThrowingStream { + let request = TemplateGenerateContentRequest( + template: template, + inputs: inputs, + history: history, + projectID: generativeAIService.firebaseInfo.projectID, + stream: true, + apiConfig: apiConfig, + options: options + ) + return generativeAIService.loadRequestStream(request: request) + } + + // TODO: Restore `public` determined to be releaseable along with the contents of TemplateChatSession. + + /// Creates a new chat conversation using this model with the provided history and template. + /// + /// - Parameters: + /// - templateID: The ID of the prompt template to use. + /// - history: The conversation history to use. + /// - Returns: A new ``TemplateChatSession`` instance. + func startChat(templateID: String, + history: [ModelContent] = []) -> TemplateChatSession { + return TemplateChatSession( + model: self, + templateID: templateID, + history: history + ) + } +} diff --git a/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift b/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift new file mode 100644 index 00000000000..c155b66fe55 --- /dev/null +++ b/FirebaseAI/Sources/TemplateImagenGenerationRequest.swift @@ -0,0 +1,67 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +enum ImageAPIMethod: String { + case generateImages = "templatePredict" +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +struct TemplateImagenGenerationRequest: Sendable { + typealias Response = ImagenGenerationResponse + + let template: String + let inputs: [String: TemplateInput] + let projectID: String + let apiConfig: APIConfig + let options: RequestOptions + + init(template: String, inputs: [String: TemplateInput], projectID: String, + apiConfig: APIConfig, options: RequestOptions) { + self.template = template + self.inputs = inputs + self.projectID = projectID + self.apiConfig = apiConfig + self.options = options + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable { + func getURL() throws -> URL { + var urlString = + "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/projects/\(projectID)" + if case let .vertexAI(_, location) = apiConfig.service { + urlString += "/locations/\(location)" + } + urlString += "/templates/\(template):\(ImageAPIMethod.generateImages.rawValue)" + guard let url = URL(string: urlString) else { + throw AILog.makeInternalError(message: "Malformed URL: \(urlString)", code: .malformedURL) + } + return url + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension TemplateImagenGenerationRequest: Encodable { + enum CodingKeys: String, CodingKey { + case inputs + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(inputs, forKey: .inputs) + } +} diff --git a/FirebaseAI/Sources/TemplateImagenModel.swift b/FirebaseAI/Sources/TemplateImagenModel.swift new file mode 100644 index 00000000000..794965364bd --- /dev/null +++ b/FirebaseAI/Sources/TemplateImagenModel.swift @@ -0,0 +1,56 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +/// A type that represents a remote image generation model (like Imagen), with the ability to +/// generate +/// images based on various input types. +/// +/// **Public Preview**: This API is a public preview and may be subject to change. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public final class TemplateImagenModel: Sendable { + let generativeAIService: GenerativeAIService + let apiConfig: APIConfig + + init(generativeAIService: GenerativeAIService, apiConfig: APIConfig) { + self.generativeAIService = generativeAIService + self.apiConfig = apiConfig + } + + /// Generates images from a prompt template and variables. + /// + /// - Parameters: + /// - template: The prompt template to use. + /// - variables: A dictionary of variables to substitute into the template. + /// - options: The ``RequestOptions`` for the request, currently used to override default + /// request timeout. + /// - Returns: The images generated by the model. + /// - Throws: An error if the request failed. + public func generateImages(templateID: String, + inputs: [String: Any], + options: RequestOptions = RequestOptions()) async throws + -> ImagenGenerationResponse { + let templateInputs = try inputs.mapValues { try TemplateInput(value: $0) } + let projectID = generativeAIService.firebaseInfo.projectID + let request = TemplateImagenGenerationRequest( + template: templateID, + inputs: templateInputs, + projectID: projectID, + apiConfig: apiConfig, + options: options + ) + return try await generativeAIService.loadRequest(request: request) + } +} diff --git a/FirebaseAI/Sources/TemplateInput.swift b/FirebaseAI/Sources/TemplateInput.swift new file mode 100644 index 00000000000..606150ed824 --- /dev/null +++ b/FirebaseAI/Sources/TemplateInput.swift @@ -0,0 +1,66 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +enum TemplateInput: Encodable, Sendable { + case string(String) + case int(Int) + case double(Double) + case bool(Bool) + case array([TemplateInput]) + case dictionary([String: TemplateInput]) + + init(value: Any) throws { + switch value { + case let value as String: + self = .string(value) + case let value as Int: + self = .int(value) + case let value as Double: + self = .double(value) + case let value as Float: + self = .double(Double(value)) + case let value as Bool: + self = .bool(value) + case let value as [Any]: + self = try .array(value.map { try TemplateInput(value: $0) }) + case let value as [String: Any]: + self = try .dictionary(value.mapValues { try TemplateInput(value: $0) }) + default: + throw EncodingError.invalidValue( + value, + EncodingError.Context(codingPath: [], debugDescription: "Invalid value") + ) + } + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .string(value): + try container.encode(value) + case let .int(value): + try container.encode(value) + case let .double(value): + try container.encode(value) + case let .bool(value): + try container.encode(value) + case let .array(value): + try container.encode(value) + case let .dictionary(value): + try container.encode(value) + } + } +} diff --git a/FirebaseAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift b/FirebaseAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift index ffb0e8bcf57..9f5a76137d3 100644 --- a/FirebaseAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift +++ b/FirebaseAI/Sources/Types/Internal/Imagen/ImagenGenerationRequest.swift @@ -39,9 +39,13 @@ struct ImagenGenerationRequest: Sendable { extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable { typealias Response = ImagenGenerationResponse - var url: URL { - return URL(string: - "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict")! + func getURL() throws -> URL { + let urlString = + "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict" + guard let url = URL(string: urlString) else { + throw AILog.makeInternalError(message: "Malformed URL: \(urlString)", code: .malformedURL) + } + return url } } diff --git a/FirebaseAI/Sources/Types/Internal/Requests/CountTokensRequest.swift b/FirebaseAI/Sources/Types/Internal/Requests/CountTokensRequest.swift index f282b02096b..be3e09c3060 100644 --- a/FirebaseAI/Sources/Types/Internal/Requests/CountTokensRequest.swift +++ b/FirebaseAI/Sources/Types/Internal/Requests/CountTokensRequest.swift @@ -29,10 +29,14 @@ extension CountTokensRequest: GenerativeAIRequest { var apiConfig: APIConfig { generateContentRequest.apiConfig } - var url: URL { + func getURL() throws -> URL { let version = apiConfig.version.rawValue let endpoint = apiConfig.service.endpoint.rawValue - return URL(string: "\(endpoint)/\(version)/\(modelResourceName):countTokens")! + let urlString = "\(endpoint)/\(version)/\(modelResourceName):countTokens" + guard let url = URL(string: urlString) else { + throw AILog.makeInternalError(message: "Malformed URL: \(urlString)", code: .malformedURL) + } + return url } } diff --git a/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj b/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj index 8b1b80e54d8..2ce772d1fc6 100644 --- a/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj +++ b/FirebaseAI/Tests/TestApp/FirebaseAITestApp.xcodeproj/project.pbxproj @@ -31,6 +31,7 @@ 86E850612DBAFBC3002E8D94 /* FirebaseStorage in Frameworks */ = {isa = PBXBuildFile; productRef = 86E850602DBAFBC3002E8D94 /* FirebaseStorage */; }; DEF0BB4F2DA74F680093E9F4 /* TestHelpers.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEF0BB4E2DA74F460093E9F4 /* TestHelpers.swift */; }; DEF0BB512DA9B7450093E9F4 /* SchemaTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEF0BB502DA9B7400093E9F4 /* SchemaTests.swift */; }; + DEF4634B2EA1AA77004E79B1 /* ServerPromptTemplateIntegrationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = DEF4634A2EA1AA77004E79B1 /* ServerPromptTemplateIntegrationTests.swift */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -66,6 +67,7 @@ 86D77E032D7B6C95003D155D /* InstanceConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InstanceConfig.swift; sourceTree = ""; }; DEF0BB4E2DA74F460093E9F4 /* TestHelpers.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TestHelpers.swift; sourceTree = ""; }; DEF0BB502DA9B7400093E9F4 /* SchemaTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SchemaTests.swift; sourceTree = ""; }; + DEF4634A2EA1AA77004E79B1 /* ServerPromptTemplateIntegrationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ServerPromptTemplateIntegrationTests.swift; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -144,6 +146,7 @@ 868A7C572CCC27AF00E449DD /* Integration */ = { isa = PBXGroup; children = ( + DEF4634A2EA1AA77004E79B1 /* ServerPromptTemplateIntegrationTests.swift */, 0E460FAA2E9858E4007E26A6 /* LiveSessionTests.swift */, DEF0BB502DA9B7400093E9F4 /* SchemaTests.swift */, DEF0BB4E2DA74F460093E9F4 /* TestHelpers.swift */, @@ -307,6 +310,7 @@ 864F8F712D4980DD0002EA7E /* ImagenIntegrationTests.swift in Sources */, 862218812D04E098007ED2D4 /* IntegrationTestUtils.swift in Sources */, 86D77DFC2D7A5340003D155D /* GenerateContentIntegrationTests.swift in Sources */, + DEF4634B2EA1AA77004E79B1 /* ServerPromptTemplateIntegrationTests.swift in Sources */, 8661386E2CC943DE00F4B78E /* IntegrationTests.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; diff --git a/FirebaseAI/Tests/TestApp/Resources/TestApp.entitlements b/FirebaseAI/Tests/TestApp/Resources/TestApp.entitlements index ee95ab7e582..225aa48bc8c 100644 --- a/FirebaseAI/Tests/TestApp/Resources/TestApp.entitlements +++ b/FirebaseAI/Tests/TestApp/Resources/TestApp.entitlements @@ -6,5 +6,7 @@ com.apple.security.network.client + keychain-access-groups + diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/ServerPromptTemplateIntegrationTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/ServerPromptTemplateIntegrationTests.swift new file mode 100644 index 00000000000..d3b5a8c96e1 --- /dev/null +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/ServerPromptTemplateIntegrationTests.swift @@ -0,0 +1,205 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO: remove @testable when Template Chat is restored to the public API. +@testable import FirebaseAILogic +import Testing +#if canImport(UIKit) + import UIKit +#endif + +struct ServerPromptTemplateIntegrationTests { + private static let testConfigs: [InstanceConfig] = [ + .googleAI_v1beta, + .vertexAI_v1beta, + .vertexAI_v1beta_global, + ] + private static let imageGenerationTestConfigs: [InstanceConfig] = [.vertexAI_v1beta] + + @Test(arguments: testConfigs) + func generateContentWithText(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).templateGenerativeModel() + let userName = "paul" + let response = try await model.generateContent( + templateID: "greeting-5", + inputs: [ + "name": userName, + "language": "Spanish", + ] + ) + let text = try #require(response.text) + #expect(text.contains("Paul")) + } + + @Test(arguments: testConfigs) + func generateContentStream(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).templateGenerativeModel() + let userName = "paul" + let stream = try model.generateContentStream( + templateID: "greeting-5", + inputs: [ + "name": userName, + "language": "English", + ] + ) + var resultText = "" + for try await response in stream { + if let text = response.text { + resultText += text + } + } + #expect(resultText.contains("Paul")) + } + + @Test(arguments: [ + InstanceConfig.googleAI_v1beta, + InstanceConfig.vertexAI_v1beta, + ]) + func generateImages(_ config: InstanceConfig) async throws { + let imagenModel = FirebaseAI.componentInstance(config).templateImagenModel() + let imagenPrompt = "firefly" + let response = try await imagenModel.generateImages( + templateID: "image-generation-basic", + inputs: [ + "prompt": imagenPrompt, + ] + ) + #expect(response.images.count == 4) + } + + @Test(arguments: testConfigs) + func generateContentWithMedia(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).templateGenerativeModel() + #if canImport(UIKit) + let image = UIImage(systemName: "photo")! + #elseif canImport(AppKit) + let image = NSImage(systemSymbolName: "photo", accessibilityDescription: nil)! + #endif + let imageBytes = try #require( + image.jpegData(compressionQuality: 0.8), "Could not get image data." + ) + let base64Image = imageBytes.base64EncodedString() + + let response = try await model.generateContent( + templateID: "media", + inputs: [ + "imageData": [ + "isInline": true, + "mimeType": "image/jpeg", + "contents": base64Image, + ], + ] + ) + let text = try #require(response.text) + #expect(!text.isEmpty) + } + + @Test(arguments: testConfigs) + func generateContentStreamWithMedia(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).templateGenerativeModel() + #if canImport(UIKit) + let image = UIImage(systemName: "photo")! + #elseif canImport(AppKit) + let image = NSImage(systemSymbolName: "photo", accessibilityDescription: nil)! + #endif + let imageBytes = try #require( + image.jpegData(compressionQuality: 0.8), "Could not get image data." + ) + let base64Image = imageBytes.base64EncodedString() + + let stream = try model.generateContentStream( + templateID: "media", + inputs: [ + "imageData": [ + "isInline": true, + "mimeType": "image/jpeg", + "contents": base64Image, + ], + ] + ) + var resultText = "" + for try await response in stream { + if let text = response.text { + resultText += text + } + } + #expect(!resultText.isEmpty) + } + + @Test(arguments: testConfigs) + func chat(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).templateGenerativeModel() + let initialHistory = [ + ModelContent(role: "user", parts: "Hello!"), + ModelContent(role: "model", parts: "Hi there! How can I help?"), + ] + let chatSession = model.startChat(templateID: "chat-history", history: initialHistory) + + let userMessage = "What's the weather like?" + + let response = try await chatSession.sendMessage( + userMessage, + inputs: ["message": userMessage] + ) + let text = try #require(response.text) + #expect(!text.isEmpty) + #expect(chatSession.history.count == 4) + let textPart = try #require(chatSession.history[2].parts.first as? TextPart) + #expect(textPart.text == userMessage) + } + + @Test(arguments: testConfigs) + func chatStream(_ config: InstanceConfig) async throws { + let model = FirebaseAI.componentInstance(config).templateGenerativeModel() + let initialHistory = [ + ModelContent(role: "user", parts: "Hello!"), + ModelContent(role: "model", parts: "Hi there! How can I help?"), + ] + let chatSession = model.startChat(templateID: "chat-history", history: initialHistory) + + let userMessage = "What's the weather like?" + + let stream = try chatSession.sendMessageStream( + userMessage, + inputs: ["message": userMessage] + ) + var resultText = "" + for try await response in stream { + if let text = response.text { + resultText += text + } + } + #expect(!resultText.isEmpty) + #expect(chatSession.history.count == 4) + let textPart = try #require(chatSession.history[2].parts.first as? TextPart) + #expect(textPart.text == userMessage) + } +} + +#if canImport(AppKit) + import AppKit + + extension NSImage { + func jpegData(compressionQuality: CGFloat) -> Data? { + guard let tiffRepresentation = tiffRepresentation, + let bitmapImage = NSBitmapImageRep(data: tiffRepresentation) else { + return nil + } + return bitmapImage.representation( + using: .jpeg, + properties: [.compressionFactor: compressionQuality] + ) + } + } +#endif diff --git a/FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift b/FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift new file mode 100644 index 00000000000..3ff5ad14ff0 --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateChatSessionTests.swift @@ -0,0 +1,121 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAILogic +import FirebaseCore +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateChatSessionTests: XCTestCase { + var model: TemplateGenerativeModel! + var urlSession: URLSession! + + override func setUp() { + super.setUp() + let configuration = URLSessionConfiguration.default + configuration.protocolClasses = [MockURLProtocol.self] + urlSession = URLSession(configuration: configuration) + let firebaseInfo = GenerativeModelTestUtil.testFirebaseInfo() + let generativeAIService = GenerativeAIService( + firebaseInfo: firebaseInfo, + urlSession: urlSession + ) + let apiConfig = APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) + model = TemplateGenerativeModel(generativeAIService: generativeAIService, apiConfig: apiConfig) + } + + func testSendMessage() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + let chat = model.startChat(templateID: "test-template") + let response = try await chat.sendMessage("Hello", inputs: ["name": "test"]) + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].role, "user") + XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello") + XCTAssertEqual(chat.history[1].role, "model") + XCTAssertEqual( + (chat.history[1].parts.first as? TextPart)?.text, + "Google's headquarters, also known as the Googleplex, is located in **Mountain View, California**.\n" + ) + XCTAssertEqual(response.candidates.count, 1) + } + + func testSendMessageStream() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + let chat = model.startChat(templateID: "test-template") + let stream = try chat.sendMessageStream("Hello", inputs: ["name": "test"]) + + let content = try await GenerativeModelTestUtil.collectTextFromStream(stream) + + XCTAssertEqual(content, "The capital of Wyoming is **Cheyenne**.\n") + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].role, "user") + XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello") + XCTAssertEqual(chat.history[1].role, "model") + } + + func testSendMessageWithModelContent() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + let chat = model.startChat(templateID: "test-template") + let response = try await chat.sendMessage( + [ModelContent(parts: [TextPart("Hello")])], + inputs: ["name": "test"] + ) + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].role, "user") + XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello") + XCTAssertEqual(chat.history[1].role, "model") + XCTAssertEqual( + (chat.history[1].parts.first as? TextPart)?.text, + "Google's headquarters, also known as the Googleplex, is located in **Mountain View, California**.\n" + ) + XCTAssertEqual(response.candidates.count, 1) + } + + func testSendMessageStreamWithModelContent() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + let chat = model.startChat(templateID: "test-template") + let stream = try chat.sendMessageStream( + [ModelContent(parts: [TextPart("Hello")])], + inputs: ["name": "test"] + ) + + let content = try await GenerativeModelTestUtil.collectTextFromStream(stream) + + XCTAssertEqual(content, "The capital of Wyoming is **Cheyenne**.\n") + XCTAssertEqual(chat.history.count, 2) + XCTAssertEqual(chat.history[0].role, "user") + XCTAssertEqual((chat.history[0].parts.first as? TextPart)?.text, "Hello") + XCTAssertEqual(chat.history[1].role, "model") + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateGenerativeModelTests.swift b/FirebaseAI/Tests/Unit/TemplateGenerativeModelTests.swift new file mode 100644 index 00000000000..a9994b8cf7a --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateGenerativeModelTests.swift @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAILogic +import FirebaseCore +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateGenerativeModelTests: XCTestCase { + var urlSession: URLSession! + var model: TemplateGenerativeModel! + + override func setUp() { + super.setUp() + let configuration = URLSessionConfiguration.default + configuration.protocolClasses = [MockURLProtocol.self] + urlSession = URLSession(configuration: configuration) + let firebaseInfo = GenerativeModelTestUtil.testFirebaseInfo() + let generativeAIService = GenerativeAIService( + firebaseInfo: firebaseInfo, + urlSession: urlSession + ) + let apiConfig = APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) + model = TemplateGenerativeModel(generativeAIService: generativeAIService, apiConfig: apiConfig) + } + + func testGenerateContent() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + + let response = try await model.generateContent( + templateID: "test-template", + inputs: ["name": "test"] + ) + XCTAssertEqual( + response.text, + "Google's headquarters, also known as the Googleplex, is located in **Mountain View, California**.\n" + ) + } + + func testGenerateContentStream() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + subdirectory: "mock-responses/googleai", + isTemplateRequest: true + ) + + let stream = try model.generateContentStream( + templateID: "test-template", + inputs: ["name": "test"] + ) + + let content = try await GenerativeModelTestUtil.collectTextFromStream(stream) + XCTAssertEqual(content, "The capital of Wyoming is **Cheyenne**.\n") + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateImagenModelTests.swift b/FirebaseAI/Tests/Unit/TemplateImagenModelTests.swift new file mode 100644 index 00000000000..04712b377b8 --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateImagenModelTests.swift @@ -0,0 +1,52 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law of or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAILogic +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateImagenModelTests: XCTestCase { + var urlSession: URLSession! + var model: TemplateImagenModel! + + override func setUp() { + super.setUp() + let configuration = URLSessionConfiguration.default + configuration.protocolClasses = [MockURLProtocol.self] + urlSession = URLSession(configuration: configuration) + let firebaseInfo = GenerativeModelTestUtil.testFirebaseInfo() + let generativeAIService = GenerativeAIService( + firebaseInfo: firebaseInfo, + urlSession: urlSession + ) + let apiConfig = APIConfig(service: .googleAI(endpoint: .firebaseProxyProd), version: .v1beta) + model = TemplateImagenModel(generativeAIService: generativeAIService, apiConfig: apiConfig) + } + + func testGenerateImages() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-generate-images-base64", + withExtension: "json", + subdirectory: "mock-responses/vertexai", + isTemplateRequest: true + ) + + let response = try await model.generateImages( + templateID: "test-template", + inputs: ["prompt": "a cat picture"] + ) + XCTAssertEqual(response.images.count, 4) + XCTAssertNotNil(response.images.first?.data) + } +} diff --git a/FirebaseAI/Tests/Unit/TemplateInputTests.swift b/FirebaseAI/Tests/Unit/TemplateInputTests.swift new file mode 100644 index 00000000000..2ed428be12b --- /dev/null +++ b/FirebaseAI/Tests/Unit/TemplateInputTests.swift @@ -0,0 +1,29 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAILogic +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class TemplateInputTests: XCTestCase { + func testInitWithFloat() throws { + let floatValue: Float = 3.14 + let templateInput = try TemplateInput(value: floatValue) + guard case let .double(doubleValue) = templateInput else { + XCTFail("Expected a .double case, but got \(templateInput)") + return + } + XCTAssertEqual(doubleValue, Double(floatValue), accuracy: 1e-6) + } +} diff --git a/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift b/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift index 7f9a8724363..84062c58a2a 100644 --- a/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift +++ b/FirebaseAI/Tests/Unit/TestUtilities/GenerativeModelTestUtil.swift @@ -30,10 +30,12 @@ enum GenerativeModelTestUtil { timeout: TimeInterval = RequestOptions().timeout, appCheckToken: String? = nil, authToken: String? = nil, - dataCollection: Bool = true) throws -> ((URLRequest) throws -> ( - URLResponse, - AsyncLineSequence? - )) { + dataCollection: Bool = true, + isTemplateRequest: Bool = false) throws + -> ((URLRequest) throws -> ( + URLResponse, + AsyncLineSequence? + )) { // Skip tests using MockURLProtocol on watchOS; unsupported in watchOS 2 and later, see // https://developer.apple.com/documentation/foundation/urlprotocol for details. #if os(watchOS) @@ -45,7 +47,14 @@ enum GenerativeModelTestUtil { ) return { request in let requestURL = try XCTUnwrap(request.url) - XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) + if isTemplateRequest { + XCTAssertEqual( + requestURL.path.occurrenceCount(of: "templates/test-template:template"), + 1 + ) + } else { + XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) + } XCTAssertEqual(request.timeoutInterval, timeout) let apiClientTags = try XCTUnwrap(request.value(forHTTPHeaderField: "x-goog-api-client")) .components(separatedBy: " ") @@ -79,6 +88,19 @@ enum GenerativeModelTestUtil { #endif // os(watchOS) } + static func collectTextFromStream(_ stream: AsyncThrowingStream< + GenerateContentResponse, + Error + >) async throws -> String { + var content = "" + for try await response in stream { + if let text = response.text { + content += text + } + } + return content + } + static func nonHTTPRequestHandler() throws -> ((URLRequest) -> ( URLResponse, AsyncLineSequence? diff --git a/FirebaseAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift b/FirebaseAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift index f36376061d7..70a98a54321 100644 --- a/FirebaseAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift +++ b/FirebaseAI/Tests/Unit/Types/Imagen/ImagenGenerationRequestTests.swift @@ -60,7 +60,7 @@ final class ImagenGenerationRequestTests: XCTestCase { XCTAssertEqual(request.instances, [instance]) XCTAssertEqual(request.parameters, parameters) XCTAssertEqual( - request.url, + try request.getURL(), URL(string: "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict") ) @@ -80,7 +80,7 @@ final class ImagenGenerationRequestTests: XCTestCase { XCTAssertEqual(request.instances, [instance]) XCTAssertEqual(request.parameters, parameters) XCTAssertEqual( - request.url, + try request.getURL(), URL(string: "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict") )