diff --git a/Sources/OpenAI/Public/Models/ChatQuery.swift b/Sources/OpenAI/Public/Models/ChatQuery.swift index 7e555f0c08c063bcaa7ea06776646599b4df3acc..7d1ef2e74a7fbe0403a3075aa1e07311dcdcd228 100644 --- a/Sources/OpenAI/Public/Models/ChatQuery.swift +++ b/Sources/OpenAI/Public/Models/ChatQuery.swift @@ -1,6 +1,6 @@ // // ChatQuery.swift -// +// // // Created by Sergii Kryvoblotskyi on 02/04/2023. // @@ -819,10 +819,12 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable { case text case jsonObject case jsonSchema(name: String, type: StructuredOutput.Type) + case dynamicJsonSchema(DynamicJSONSchema) enum CodingKeys: String, CodingKey { case type case jsonSchema = "json_schema" + case dynamicJsonSchema } public func encode(to encoder: any Encoder) throws { @@ -836,6 +838,9 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable { try container.encode("json_schema", forKey: .type) let schema = JSONSchema(name: name, schema: type.example) try container.encode(schema, forKey: .jsonSchema) + case .dynamicJsonSchema(let dynamicJSONSchema): + try container.encode("json_schema", forKey: .type) + try container.encode(dynamicJSONSchema, forKey: .jsonSchema) } } @@ -845,6 +850,8 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable { case (.jsonObject, .jsonObject): return true case (.jsonSchema(let lhsName, let lhsType), .jsonSchema(let rhsName, let rhsType)): return lhsName == rhsName && lhsType == rhsType + case (.dynamicJsonSchema(let lhsSchema), .dynamicJsonSchema(let rhsSchema)): + return lhsSchema == rhsSchema default: return false } @@ -1072,6 +1079,53 @@ public struct ChatQuery: Equatable, Codable, Streamable, Sendable { } } } + + public struct DynamicJSONSchema: Encodable, Sendable, Equatable { + let name: String + let description: String? + let schema: Encodable & Sendable + let strict: Bool? + + enum CodingKeys: String, CodingKey { + case name + case description + case schema + case strict + } + + public init( + name: String, + description: String? = nil, + schema: Encodable & Sendable, + strict: Bool? = nil + ) { + self.name = name + self.description = description + self.schema = schema + self.strict = strict + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + if let description { + try container.encode(description, forKey: .description) + } + try container.encode(schema, forKey: .schema) + if let strict { + try container.encode(strict, forKey: .strict) + } + } + + public static func == (lhs: DynamicJSONSchema, rhs: DynamicJSONSchema) -> Bool { + guard lhs.name == rhs.name else { return false } + guard lhs.description == rhs.description else { return false } + guard lhs.strict == rhs.strict else { return false } + let lhsData = try? JSONEncoder().encode(lhs.schema) + let rhsData = try? JSONEncoder().encode(rhs.schema) + return lhsData == rhsData + } + } public enum ChatCompletionFunctionCallOptionParam: Codable, Equatable, Sendable { case none diff --git a/Tests/OpenAITests/OpenAITests.swift b/Tests/OpenAITests/OpenAITests.swift index dd4e138551606b70a2bb620dd804890f69c6f3fb..133607efb66b8165e963eda24026ad0cc8357b36 100644 --- a/Tests/OpenAITests/OpenAITests.swift +++ b/Tests/OpenAITests/OpenAITests.swift @@ -128,6 +128,71 @@ class OpenAITests: XCTestCase { let result = try await openAI.chats(query: query) XCTAssertEqual(result, chatResult) } + + func testChatQueryWithDynamicStructuredOutput() async throws { + + let chatResult = ChatResult( + id: "id-12312", created: 100, model: .gpt3_5Turbo, object: "foo", serviceTier: nil, systemFingerprint: "fing", + choices: [], + usage: .init(completionTokens: 200, promptTokens: 100, totalTokens: 300), + citations: nil + ) + try self.stub(result: chatResult) + + struct AnyEncodable: Encodable { + + private let _encode: (Encoder) throws -> Void + public init<T: Encodable>(_ wrapped: T) { + _encode = wrapped.encode + } + + func encode(to encoder: Encoder) throws { + try _encode(encoder) + } + } + + let schema = [ + "type": AnyEncodable("object"), + "properties": AnyEncodable([ + "title": AnyEncodable([ + "type": "string" + ]), + "director": AnyEncodable([ + "type": "string" + ]), + "release": AnyEncodable([ + "type": "string" + ]), + "genres": AnyEncodable([ + "type": AnyEncodable("array"), + "items": AnyEncodable([ + "type": AnyEncodable("string"), + "enum": AnyEncodable(["action", "drama", "comedy", "scifi"]) + ]) + ]), + "cast": AnyEncodable([ + "type": AnyEncodable("array"), + "items": AnyEncodable([ + "type": "string" + ]) + ]) + ]), + "additionalProperties": AnyEncodable(false) + ] + let query = ChatQuery( + messages: [.system(.init(content: "Return a structured response."))], + model: .gpt4_o, + responseFormat: .dynamicJsonSchema( + .init( + name: "movie-info", + schema: schema + ) + ) + ) + + let result = try await openAI.chats(query: query) + XCTAssertEqual(result, chatResult) + } func testChatsFunction() async throws { let query = ChatQuery(messages: [