Skip to content

Commit 781881a

Browse files
authored
.Net: Deduplicate embedding generation management (#13615)
Closes #12508
1 parent 429dd1c commit 781881a

24 files changed

+334
-354
lines changed

dotnet/src/VectorData/AzureAISearch/AzureAISearchCollection.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ floatVector is null
476476
ReadOnlyMemory<float> r => r,
477477
float[] f => new ReadOnlyMemory<float>(f),
478478
Embedding<float> e => e.Vector,
479-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
480-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
479+
_ when vectorProperty.EmbeddingGenerationDispatcher is not null
480+
=> ((Embedding<float>)await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false)).Vector,
481481

482482
// A string was passed without an embedding generator being configured; send the string to Azure AI Search for backend embedding generation.
483483
string when vectorProperty.EmbeddingGenerator is null => (ReadOnlyMemory<float>?)null,
@@ -739,16 +739,8 @@ private static SearchOptions BuildSearchOptions(CollectionModel model, VectorSea
739739

740740
// TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties),
741741
// and generate embeddings for them in a single batch. That's some more complexity though.
742-
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
743-
{
744-
generatedEmbeddings ??= new IReadOnlyList<MEAI.Embedding>?[vectorPropertyCount];
745-
generatedEmbeddings[i] = await floatTask.ConfigureAwait(false);
746-
}
747-
else
748-
{
749-
throw new InvalidOperationException(
750-
$"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding<float>).Name}' for the given input type.");
751-
}
742+
generatedEmbeddings ??= new IReadOnlyList<MEAI.Embedding>?[vectorPropertyCount];
743+
generatedEmbeddings[i] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false);
752744
}
753745

754746
return (records, generatedEmbeddings);

dotnet/src/VectorData/CosmosMongoDB/CosmosMongoCollection.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,8 @@ private static TKey GetStorageKey(BsonDocument document)
332332

333333
// TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties),
334334
// and generate embeddings for them in a single batch. That's some more complexity though.
335-
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
336-
{
337-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
338-
generatedEmbeddings[i] = await floatTask.ConfigureAwait(false);
339-
}
340-
else
341-
{
342-
throw new InvalidOperationException(
343-
$"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding<float>).Name}' for the given input type.");
344-
}
335+
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
336+
generatedEmbeddings[i] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false);
345337
}
346338

347339
return (records, generatedEmbeddings);
@@ -373,8 +365,8 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
373365
float[] f => f,
374366
Embedding<float> e => Unwrap(e.Vector),
375367

376-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
377-
=> Unwrap(await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false)),
368+
_ when vectorProperty.EmbeddingGenerationDispatcher is not null
369+
=> Unwrap(((Embedding<float>)await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false)).Vector),
378370

379371
_ => vectorProperty.EmbeddingGenerator is null
380372
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), MongoModelBuilder.SupportedVectorTypes))

dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlCollection.cs

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -510,26 +510,8 @@ await this.RunOperationAsync(OperationName, () =>
510510

511511
// TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties),
512512
// and generate embeddings for them in a single batch. That's some more complexity though.
513-
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
514-
{
515-
generatedEmbeddings ??= new IReadOnlyList<MEAI.Embedding>?[vectorPropertyCount];
516-
generatedEmbeddings[i] = await floatTask.ConfigureAwait(false);
517-
}
518-
else if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<byte>>(records, cancellationToken, out var byteTask))
519-
{
520-
generatedEmbeddings ??= new IReadOnlyList<MEAI.Embedding>?[vectorPropertyCount];
521-
generatedEmbeddings[i] = await byteTask.ConfigureAwait(false);
522-
}
523-
else if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<sbyte>>(records, cancellationToken, out var sbyteTask))
524-
{
525-
generatedEmbeddings ??= new IReadOnlyList<MEAI.Embedding>?[vectorPropertyCount];
526-
generatedEmbeddings[i] = await sbyteTask.ConfigureAwait(false);
527-
}
528-
else
529-
{
530-
throw new InvalidOperationException(
531-
$"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding<float>).Name}' for the given input type.");
532-
}
513+
generatedEmbeddings ??= new IReadOnlyList<MEAI.Embedding>?[vectorPropertyCount];
514+
generatedEmbeddings[i] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false);
533515
}
534516

535517
return (records, generatedEmbeddings);
@@ -592,22 +574,19 @@ private static async ValueTask<object> GetSearchVectorAsync<TInput>(TInput searc
592574
ReadOnlyMemory<float> m => m,
593575
float[] a => new ReadOnlyMemory<float>(a),
594576
Embedding<float> e => e.Vector,
595-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
596-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
597577

598578
// int8
599579
ReadOnlyMemory<sbyte> m => m,
600580
sbyte[] a => new ReadOnlyMemory<sbyte>(a),
601581
Embedding<sbyte> e => e.Vector,
602-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<sbyte>> generator
603-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
604582

605583
// uint8
606584
ReadOnlyMemory<byte> m => m,
607585
byte[] a => new ReadOnlyMemory<byte>(a),
608586
Embedding<byte> e => e.Vector,
609-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<byte>> generator
610-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
587+
588+
_ when vectorProperty.EmbeddingGenerationDispatcher is not null
589+
=> await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false),
611590

612591
_ => vectorProperty.EmbeddingGenerator is null
613592
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), CosmosNoSqlModelBuilder.SupportedVectorTypes))

dotnet/src/VectorData/CosmosNoSql/CosmosNoSqlModelBuilder.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,13 @@ static bool IsValid(Type type)
7474
protected override bool IsVectorPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes)
7575
=> IsVectorPropertyTypeValidCore(type, out supportedTypes);
7676

77-
protected override Type? ResolveEmbeddingType(
78-
VectorPropertyModel vectorProperty,
79-
IEmbeddingGenerator embeddingGenerator,
80-
Type? userRequestedEmbeddingType)
81-
// Resolve embedding type for float, byte, and sbyte embedding generators.
82-
=> vectorProperty.ResolveEmbeddingType<Embedding<float>>(embeddingGenerator, userRequestedEmbeddingType)
83-
?? vectorProperty.ResolveEmbeddingType<Embedding<byte>>(embeddingGenerator, userRequestedEmbeddingType)
84-
?? vectorProperty.ResolveEmbeddingType<Embedding<sbyte>>(embeddingGenerator, userRequestedEmbeddingType);
77+
/// <inheritdoc />
78+
protected override IReadOnlyList<EmbeddingGenerationDispatcher> EmbeddingGenerationDispatchers { get; } =
79+
[
80+
EmbeddingGenerationDispatcher.Create<Embedding<float>>(),
81+
EmbeddingGenerationDispatcher.Create<Embedding<byte>>(),
82+
EmbeddingGenerationDispatcher.Create<Embedding<sbyte>>()
83+
];
8584

8685
internal static bool IsVectorPropertyTypeValidCore(Type type, [NotNullWhen(false)] out string? supportedTypes)
8786
{

dotnet/src/VectorData/InMemory/InMemoryCollection.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,8 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
206206

207207
// TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties),
208208
// and generate embeddings for them in a single batch. That's some more complexity though.
209-
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
210-
{
211-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
212-
generatedEmbeddings[i] = (IReadOnlyList<Embedding<float>>)await floatTask.ConfigureAwait(false);
213-
}
214-
else
215-
{
216-
throw new InvalidOperationException(
217-
$"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding<float>).Name}' for the given input type.");
218-
}
209+
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
210+
generatedEmbeddings[i] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false);
219211
}
220212

221213
var collectionDictionary = this.GetCollectionDictionary();
@@ -282,8 +274,8 @@ public override async IAsyncEnumerable<VectorSearchResult<TRecord>> SearchAsync<
282274
ReadOnlyMemory<float> r => r,
283275
float[] f => new ReadOnlyMemory<float>(f),
284276
Embedding<float> e => e.Vector,
285-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
286-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
277+
_ when vectorProperty.EmbeddingGenerationDispatcher is not null
278+
=> ((Embedding<float>)await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false)).Vector,
287279

288280
_ => vectorProperty.EmbeddingGenerator is null
289281
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), InMemoryModelBuilder.SupportedVectorTypes))

dotnet/src/VectorData/MongoDB/MongoCollection.cs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -351,16 +351,8 @@ private static TKey GetStorageKey(BsonDocument document)
351351

352352
// TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties),
353353
// and generate embeddings for them in a single batch. That's some more complexity though.
354-
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
355-
{
356-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
357-
generatedEmbeddings[i] = await floatTask.ConfigureAwait(false);
358-
}
359-
else
360-
{
361-
throw new InvalidOperationException(
362-
$"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding<float>).Name}' for the given input type.");
363-
}
354+
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
355+
generatedEmbeddings[i] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false);
364356
}
365357

366358
return (records, generatedEmbeddings);
@@ -452,8 +444,8 @@ private static async ValueTask<float[]> GetSearchVectorArrayAsync<TInput>(TInput
452444
{
453445
ReadOnlyMemory<float> r => r,
454446
Embedding<float> e => e.Vector,
455-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
456-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
447+
_ when vectorProperty.EmbeddingGenerationDispatcher is not null
448+
=> ((Embedding<float>)await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false)).Vector,
457449

458450
_ => vectorProperty.EmbeddingGenerator is null
459451
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), MongoModelBuilder.SupportedVectorTypes))

dotnet/src/VectorData/PgVector/PostgresCollection.cs

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -209,28 +209,8 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
209209

210210
// TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties),
211211
// and generate embeddings for them in a single batch. That's some more complexity though.
212-
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
213-
{
214-
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
215-
generatedEmbeddings[vectorProperty] = await floatTask.ConfigureAwait(false);
216-
}
217-
#if NET
218-
else if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<Half>>(records, cancellationToken, out var halfTask))
219-
{
220-
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
221-
generatedEmbeddings[vectorProperty] = await halfTask.ConfigureAwait(false);
222-
}
223-
#endif
224-
else if (vectorProperty.TryGenerateEmbeddings<TRecord, BinaryEmbedding>(records, cancellationToken, out var binaryTask))
225-
{
226-
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
227-
generatedEmbeddings[vectorProperty] = await binaryTask.ConfigureAwait(false);
228-
}
229-
else
230-
{
231-
throw new InvalidOperationException(
232-
$"The embedding generator configured on property '{vectorProperty.ModelName}' cannot produce an embedding of type '{typeof(Embedding<float>).Name}' for the given input type.");
233-
}
212+
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
213+
generatedEmbeddings[vectorProperty] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false);
234214
}
235215

236216
using var connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);
@@ -589,28 +569,25 @@ private async Task<object> ConvertSearchInputToVectorAsync<TInput>(TInput search
589569
ReadOnlyMemory<float> r => r,
590570
float[] f => new ReadOnlyMemory<float>(f),
591571
Embedding<float> e => e.Vector,
592-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<float>> generator
593-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
594572

595573
#if NET
596574
// Dense float16
597575
ReadOnlyMemory<Half> r => r,
598576
Half[] f => new ReadOnlyMemory<Half>(f),
599577
Embedding<Half> e => e.Vector,
600-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embedding<Half>> generator
601-
=> await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
602578
#endif
603579

604580
// Dense Binary
605581
BitArray b => b,
606582
BinaryEmbedding e => e.Vector,
607-
_ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, BinaryEmbedding> generator
608-
=> await generator.GenerateAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false),
609583

610584
// Sparse
611585
SparseVector sv => sv,
612586
// TODO: Add a PG-specific SparseVectorEmbedding type
613587

588+
_ when vectorProperty.EmbeddingGenerationDispatcher is not null
589+
=> await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false),
590+
614591
_ => vectorProperty.EmbeddingGenerator is null
615592
? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), PostgresModelBuilder.SupportedVectorTypes))
616593
: throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType()))

dotnet/src/VectorData/PgVector/PostgresModelBuilder.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,12 @@ protected override void ValidateProperty(PropertyModel propertyModel, VectorStor
115115
}
116116

117117
/// <inheritdoc />
118-
protected override Type? ResolveEmbeddingType(
119-
VectorPropertyModel vectorProperty,
120-
IEmbeddingGenerator embeddingGenerator,
121-
Type? userRequestedEmbeddingType)
122-
=> vectorProperty.ResolveEmbeddingType<Embedding<float>>(embeddingGenerator, userRequestedEmbeddingType)
118+
protected override IReadOnlyList<EmbeddingGenerationDispatcher> EmbeddingGenerationDispatchers { get; } =
119+
[
120+
EmbeddingGenerationDispatcher.Create<Embedding<float>>(),
123121
#if NET
124-
?? vectorProperty.ResolveEmbeddingType<Embedding<Half>>(embeddingGenerator, userRequestedEmbeddingType)
122+
EmbeddingGenerationDispatcher.Create<Embedding<Half>>(),
125123
#endif
126-
?? vectorProperty.ResolveEmbeddingType<BinaryEmbedding>(embeddingGenerator, userRequestedEmbeddingType);
124+
EmbeddingGenerationDispatcher.Create<BinaryEmbedding>()
125+
];
127126
}

0 commit comments

Comments
 (0)