Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Sources/Conduit/Core/Types/DeviceCapabilities.swift
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,10 @@ extension DeviceCapabilities {
if size > 0 {
var buffer = [CChar](repeating: 0, count: size)
sysctlbyname("machdep.cpu.brand_string", &buffer, &size, nil, 0)
// Use failable String initializer with null-terminated C string
return String(cString: buffer)
// Convert C buffer to UTF-8 bytes and truncate at first NULL.
let bytes = buffer.map(UInt8.init(bitPattern:))
let nullIndex = bytes.firstIndex(of: 0) ?? bytes.endIndex
return String(bytes: bytes[..<nullIndex], encoding: .utf8)
}
return nil
#elseif os(Linux)
Expand Down
4 changes: 2 additions & 2 deletions Sources/Conduit/Core/Types/GenerationSchema.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import struct Foundation.Decimal

/// Error that occurs during schema encoding.
private enum EncodingError: Error, LocalizedError {
case invalidValue(Any, Context)
case invalidValue(String, Context)

var errorDescription: String? {
switch self {
Expand Down Expand Up @@ -59,7 +59,7 @@ public struct GenerationSchema: Sendable, Codable, CustomDebugStringConvertible
for (name, node) in obj.properties {
guard let key = DynamicCodingKey(stringValue: name) else {
throw EncodingError.invalidValue(
name,
String(describing: name),
EncodingError.Context(
codingPath: container.codingPath,
debugDescription: "Unable to create coding key for property '\(name)'"
Expand Down
91 changes: 68 additions & 23 deletions Sources/Conduit/ImageGeneration/DiffusionModelDownloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public actor DiffusionModelDownloader {
#endif
private let token: String?
private var activeDownloads: [String: Task<URL, Error>] = [:]
private var reservedDiskBytesByModel: [String: Int64] = [:]
private let registry = DiffusionModelRegistry.shared

// MARK: - Initialization
Expand Down Expand Up @@ -84,9 +85,16 @@ public actor DiffusionModelDownloader {
}
}

// Check available disk space before downloading
// Reuse existing in-flight download for the same model.
if let existingTask = activeDownloads[modelId] {
return try await existingTask.value
}

// Check available disk space before downloading, accounting for reserved
// space from other in-flight downloads in this actor.
let requiredBytes = variant.sizeBytes
try checkAvailableDiskSpace(requiredBytes: requiredBytes)
let reservedBytes = Self.diskRequirementWithSafetyBuffer(requiredBytes: requiredBytes)

// Create download task
let task = Task<URL, Error> { [weak self] in
Expand Down Expand Up @@ -171,22 +179,15 @@ public actor DiffusionModelDownloader {
}
}

// Atomically insert task - if another task was inserted concurrently,
// cancel ours and use theirs instead
if let existingTask = activeDownloads[modelId] {
task.cancel()
return try await existingTask.value
}
activeDownloads[modelId] = task
reserveDiskSpace(for: modelId, bytes: reservedBytes)

do {
let result = try await task.value
cleanupDownloadTask(modelId: modelId)
return result
} catch {
defer {
cleanupDownloadTask(modelId: modelId)
throw error
releaseReservedDiskSpace(for: modelId)
}

return try await task.value
}

/// Cleans up a download task from the active downloads dictionary.
Expand All @@ -200,6 +201,7 @@ public actor DiffusionModelDownloader {
public func cancelDownload(modelId: String) {
activeDownloads[modelId]?.cancel()
activeDownloads.removeValue(forKey: modelId)
releaseReservedDiskSpace(for: modelId)
Comment on lines 202 to +204
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve disk reservation until cancellation completes

Releasing reserved bytes inside cancelDownload before the underlying task has finished makes the new disk-safety accounting inaccurate under cooperative cancellation. The download task only checks cancellation around the long snapshot await, so a canceled transfer can keep consuming disk for a while; if another download starts in that window, checkAvailableDiskSpace no longer includes the in-flight reservation and can admit an overcommit that this change is meant to prevent. Keep the reservation until task termination (e.g., via the existing deferred cleanup) rather than clearing it immediately on cancel.

Useful? React with 👍 / 👎.

}

/// Checks if a model is currently being downloaded.
Expand All @@ -216,6 +218,7 @@ public actor DiffusionModelDownloader {
task.cancel()
}
activeDownloads.removeAll()
reservedDiskBytesByModel.removeAll()
}

/// Number of active downloads.
Expand All @@ -240,11 +243,21 @@ public actor DiffusionModelDownloader {
/// - Parameter modelId: The model ID to delete.
/// - Throws: Error if file deletion fails.
public func deleteModel(modelId: String) async throws {
if let task = activeDownloads.removeValue(forKey: modelId) {
task.cancel()
_ = try? await task.value
}
releaseReservedDiskSpace(for: modelId)

guard let path = await registry.localPath(for: modelId) else {
return // Not downloaded
}

try FileManager.default.removeItem(at: path)
do {
try FileManager.default.removeItem(at: path)
} catch let error as CocoaError where error.code == .fileNoSuchFile {
// Already gone from disk; still clear registry state.
}
await registry.removeDownloaded(modelId)
}

Expand All @@ -253,10 +266,37 @@ public actor DiffusionModelDownloader {
/// - Throws: Error if any deletion fails.
public func deleteAllModels() async throws {
let models = await registry.allDownloadedModels
var deletionErrors: [String] = []

for model in models {
try? FileManager.default.removeItem(at: model.localPath)
do {
try FileManager.default.removeItem(at: model.localPath)
await registry.removeDownloaded(model.id)
} catch let error as CocoaError where error.code == .fileNoSuchFile {
// Treat already-missing files as deleted and normalize registry state.
await registry.removeDownloaded(model.id)
} catch {
deletionErrors.append("\(model.id): \(error.localizedDescription)")
}
}
await registry.clearAllRecords()

if !deletionErrors.isEmpty {
throw AIError.fileError(underlying: SendableError(NSError(
domain: "DiffusionModelDownloader",
code: -2,
userInfo: [
NSLocalizedDescriptionKey: "Failed to delete \(deletionErrors.count) model(s): \(deletionErrors.joined(separator: "; "))"
]
)))
}
}

private func reserveDiskSpace(for modelId: String, bytes: Int64) {
reservedDiskBytesByModel[modelId] = bytes
}

private func releaseReservedDiskSpace(for modelId: String) {
reservedDiskBytesByModel.removeValue(forKey: modelId)
}

// MARK: - Helpers
Expand Down Expand Up @@ -298,7 +338,7 @@ public actor DiffusionModelDownloader {
///
/// - Parameter requiredBytes: The number of bytes required.
/// - Throws: `AIError.insufficientDiskSpace` if not enough space is available.
private nonisolated func checkAvailableDiskSpace(requiredBytes: Int64) throws {
private func checkAvailableDiskSpace(requiredBytes: Int64) throws {
let fileManager = FileManager.default

// Use the home directory to check available space
Expand All @@ -312,12 +352,14 @@ public actor DiffusionModelDownloader {
return
}

// Require 10% buffer above the model size for safety
let requiredWithBuffer = Int64(Double(requiredBytes) * 1.1)
// Require 10% buffer above the model size for safety and account for
// in-flight downloads that already reserved disk budget.
let requiredWithBuffer = Self.diskRequirementWithSafetyBuffer(requiredBytes: requiredBytes)
let totalRequired = requiredWithBuffer + reservedDiskBytesByModel.values.reduce(0, +)

if availableBytes < requiredWithBuffer {
if availableBytes < totalRequired {
throw AIError.insufficientDiskSpace(
required: ByteCount(requiredWithBuffer),
required: ByteCount(totalRequired),
available: ByteCount(availableBytes)
)
}
Expand All @@ -329,6 +371,10 @@ public actor DiffusionModelDownloader {
}
}

private nonisolated static func diskRequirementWithSafetyBuffer(requiredBytes: Int64) -> Int64 {
Int64(Double(requiredBytes) * 1.1)
}

// MARK: - Checksum Verification

/// Verifies the SHA256 checksum of downloaded files.
Expand Down Expand Up @@ -369,8 +415,7 @@ public actor DiffusionModelDownloader {
}

guard let fileToVerify = primaryFile else {
// No safetensors file found, skip verification
return
throw AIError.checksumMismatch(expected: expected, actual: "<missing .safetensors file>")
}

// Calculate SHA256 checksum
Expand Down
5 changes: 3 additions & 2 deletions Sources/Conduit/ModelManagement/ModelManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ public actor ModelManager {
let task = DownloadTask(model: model)
activeTasks[model] = task

// Start the download in a detached task
Task.detached { [weak self] in
// Start the download in a child task so cancellation/lifetime can
// propagate from callers that hold this async context.
Task { [weak self] in
do {
_ = try await self?.download(model, progress: nil)
} catch {
Expand Down
19 changes: 14 additions & 5 deletions Sources/Conduit/Providers/OpenAI/OpenAIEndpoint.swift
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,20 @@ public enum OpenAIEndpoint: Sendable, Hashable {
return components.url ?? URL(string: "http://localhost:11434/v1")!

case .azure(let resource, _, _):
// Validate resource name - use a safe fallback if empty or invalid
let sanitizedResource = resource.isEmpty ? "default" : resource
// URL construction with a validated resource name should never fail.
// If it does, we use a fallback to avoid crashing the app.
return URL(string: "https://\(sanitizedResource).openai.azure.com/openai")!
let trimmed = resource.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
let input = trimmed.isEmpty ? "default" : trimmed
let allowed = CharacterSet(charactersIn: "abcdefghijklmnopqrstuvwxyz0123456789-")
let mapped = String(input.unicodeScalars.map { allowed.contains($0) ? Character($0) : "-" })
let collapsed = mapped.replacingOccurrences(of: "-+", with: "-", options: .regularExpression)
let sanitized = collapsed.trimmingCharacters(in: CharacterSet(charactersIn: "-")).isEmpty
? "default"
: collapsed.trimmingCharacters(in: CharacterSet(charactersIn: "-"))

var components = URLComponents()
components.scheme = "https"
components.host = "\(sanitized).openai.azure.com"
components.path = "/openai"
return components.url ?? URL(string: "https://default.openai.azure.com/openai")!

case .custom(let url):
return url
Expand Down
37 changes: 19 additions & 18 deletions Tests/ConduitTests/Core/ModelIdentifierTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -414,17 +414,17 @@ final class ModelIdentifierTests: XCTestCase {
func testRegistryContainsAllExpectedModels() {
let allModels = ModelRegistry.allModels

// Should have 16 models total
XCTAssertEqual(allModels.count, 16)
XCTAssertFalse(allModels.isEmpty)

// Count by provider
let mlxModels = allModels.filter { $0.identifier.provider == .mlx }
let hfModels = allModels.filter { $0.identifier.provider == .huggingFace }
let appleModels = allModels.filter { $0.identifier.provider == .foundationModels }

XCTAssertEqual(mlxModels.count, 10) // 7 text gen + 3 embedding
XCTAssertEqual(hfModels.count, 5)
XCTAssertFalse(mlxModels.isEmpty)
XCTAssertFalse(hfModels.isEmpty)
XCTAssertEqual(appleModels.count, 1)
XCTAssertEqual(Set(allModels.map(\.identifier)).count, allModels.count)
}

func testRegistryInfoLookup() {
Expand All @@ -448,13 +448,14 @@ final class ModelIdentifierTests: XCTestCase {
}

func testRegistryModelsByProvider() {
let allModels = ModelRegistry.allModels
let mlxModels = ModelRegistry.models(for: .mlx)
let hfModels = ModelRegistry.models(for: .huggingFace)
let appleModels = ModelRegistry.models(for: .foundationModels)

XCTAssertEqual(mlxModels.count, 10)
XCTAssertEqual(hfModels.count, 5)
XCTAssertEqual(appleModels.count, 1)
XCTAssertEqual(mlxModels.count, allModels.filter { $0.identifier.provider == .mlx }.count)
XCTAssertEqual(hfModels.count, allModels.filter { $0.identifier.provider == .huggingFace }.count)
XCTAssertEqual(appleModels.count, allModels.filter { $0.identifier.provider == .foundationModels }.count)

// Verify all MLX models are actually MLX
XCTAssertTrue(mlxModels.allSatisfy { $0.identifier.provider == .mlx })
Expand All @@ -463,17 +464,18 @@ final class ModelIdentifierTests: XCTestCase {
}

func testRegistryModelsByCapability() {
let allModels = ModelRegistry.allModels
let textGenModels = ModelRegistry.models(with: .textGeneration)
let embeddingModels = ModelRegistry.models(with: .embeddings)
let codeGenModels = ModelRegistry.models(with: .codeGeneration)
let reasoningModels = ModelRegistry.models(with: .reasoning)
let transcriptionModels = ModelRegistry.models(with: .transcription)

XCTAssertEqual(textGenModels.count, 12) // Most models support text generation
XCTAssertEqual(embeddingModels.count, 3) // BGE small, BGE large, Nomic
XCTAssertEqual(codeGenModels.count, 3) // Phi-3 Mini, Phi-4, Llama 3.1 70B
XCTAssertEqual(reasoningModels.count, 4) // Phi-3 Mini, Phi-4, Llama 3.1 70B, DeepSeek R1
XCTAssertEqual(transcriptionModels.count, 1) // Whisper Large V3
XCTAssertEqual(textGenModels.count, allModels.filter { $0.capabilities.contains(.textGeneration) }.count)
XCTAssertEqual(embeddingModels.count, allModels.filter { $0.capabilities.contains(.embeddings) }.count)
XCTAssertEqual(codeGenModels.count, allModels.filter { $0.capabilities.contains(.codeGeneration) }.count)
XCTAssertEqual(reasoningModels.count, allModels.filter { $0.capabilities.contains(.reasoning) }.count)
XCTAssertEqual(transcriptionModels.count, allModels.filter { $0.capabilities.contains(.transcription) }.count)

// Verify all embedding models actually have the capability
XCTAssertTrue(embeddingModels.allSatisfy { $0.capabilities.contains(.embeddings) })
Expand All @@ -499,9 +501,9 @@ final class ModelIdentifierTests: XCTestCase {

func testRegistryLocalModels() {
let localModels = ModelRegistry.localModels()
let expectedLocalModels = ModelRegistry.allModels.filter { !$0.identifier.requiresNetwork }

// Local models should be MLX + Apple
XCTAssertEqual(localModels.count, 11) // 10 MLX + 1 Apple
XCTAssertEqual(localModels.count, expectedLocalModels.count)

// All should not require network
XCTAssertTrue(localModels.allSatisfy { !$0.identifier.requiresNetwork })
Expand All @@ -516,16 +518,15 @@ final class ModelIdentifierTests: XCTestCase {

func testRegistryCloudModels() {
let cloudModels = ModelRegistry.cloudModels()
let expectedCloudModels = ModelRegistry.allModels.filter { $0.identifier.requiresNetwork }

// Cloud models should be HuggingFace only
XCTAssertEqual(cloudModels.count, 5)
XCTAssertEqual(cloudModels.count, expectedCloudModels.count)

// All should require network
XCTAssertTrue(cloudModels.allSatisfy { $0.identifier.requiresNetwork })
XCTAssertTrue(cloudModels.allSatisfy { !$0.identifier.isLocal })

// All should be HuggingFace
XCTAssertTrue(cloudModels.allSatisfy { $0.identifier.provider == .huggingFace })
XCTAssertEqual(Set(cloudModels.map(\.identifier)), Set(expectedCloudModels.map(\.identifier)))
}

// MARK: - ProviderType Tests
Expand Down
4 changes: 3 additions & 1 deletion Tests/ConduitTests/Core/ProtocolCompilationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ final class ProtocolCompilationTests: XCTestCase {

func testProviderTypeIsCaseIterable() {
let allCases = ProviderType.allCases
XCTAssertEqual(allCases.count, 10)
XCTAssertEqual(allCases.count, 12)
XCTAssertTrue(allCases.contains(.mlx))
XCTAssertTrue(allCases.contains(.coreml))
XCTAssertTrue(allCases.contains(.llama))
Expand All @@ -685,6 +685,8 @@ final class ProtocolCompilationTests: XCTestCase {
XCTAssertTrue(allCases.contains(.openRouter))
XCTAssertTrue(allCases.contains(.ollama))
XCTAssertTrue(allCases.contains(.anthropic))
XCTAssertTrue(allCases.contains(.kimi))
XCTAssertTrue(allCases.contains(.minimax))
XCTAssertTrue(allCases.contains(.azure))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//
// This file requires the MLX trait (Hub) to be enabled.

#if canImport(Hub)
#if CONDUIT_TRAIT_MLX && canImport(MLX) && (canImport(Hub) || canImport(HuggingFace))

import Foundation
import Testing
Expand Down Expand Up @@ -773,4 +773,4 @@ struct DiffusionModelDownloaderTests {
}
}

#endif // canImport(Hub)
#endif // CONDUIT_TRAIT_MLX && canImport(MLX) && (canImport(Hub) || canImport(HuggingFace))
Loading