Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public bool ShouldWriteBarrelsIfClassExists
}
private static readonly HashSet<GenerationLanguage> BarreledLanguages = [
GenerationLanguage.Ruby,
GenerationLanguage.Rust,
];
private static readonly HashSet<GenerationLanguage> BarreledLanguagesWithConstantFileName = [];
public bool CleanOutput
Expand Down
3 changes: 2 additions & 1 deletion src/Kiota.Builder/GenerationLanguage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ public enum GenerationLanguage
Go,
Ruby,
Dart,
HTTP
HTTP,
Rust
}
43 changes: 43 additions & 0 deletions src/Kiota.Builder/PathSegmenters/RustPathSegmenter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using System;
using System.Collections.Generic;
using System.Linq;

using Kiota.Builder.CodeDOM;
using Kiota.Builder.Extensions;

namespace Kiota.Builder.PathSegmenters;

public class RustPathSegmenter(string rootPath, string clientNamespaceName) : CommonPathSegmenter(rootPath, clientNamespaceName)
{
public override string FileSuffix => ".rs";
public override IEnumerable<string> GetAdditionalSegment(CodeElement currentElement, string fileName)
{
if (currentElement is CodeNamespace ns && IsRootNamespace(ns))
return Enumerable.Empty<string>(); // lib.rs at output root, no subdirectory

return currentElement switch
{
CodeNamespace => new[] { GetLastFileNameSegment(currentElement) },
_ => Enumerable.Empty<string>(),
};
}

public override string NormalizeFileName(CodeElement currentElement)
{
if (currentElement is CodeNamespace ns && IsRootNamespace(ns))
return "lib"; // root namespace becomes lib.rs

return currentElement switch
{
CodeNamespace => "mod",
_ => GetLastFileNameSegment(currentElement).ToSnakeCase(),
};
}

public override string NormalizeNamespaceSegment(string segmentName) => segmentName?.ToSnakeCase() ?? string.Empty;

private bool IsRootNamespace(CodeNamespace ns)
{
return ns.Name.Equals(ClientNamespaceName, StringComparison.OrdinalIgnoreCase);
}
}
3 changes: 3 additions & 0 deletions src/Kiota.Builder/Refiners/ILanguageRefiner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ public static async Task RefineAsync(GenerationConfiguration config, CodeNamespa
case GenerationLanguage.Dart:
await new DartRefiner(config).RefineAsync(generatedCode, cancellationToken).ConfigureAwait(false);
break;
case GenerationLanguage.Rust:
await new RustRefiner(config).RefineAsync(generatedCode, cancellationToken).ConfigureAwait(false);
break;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;

namespace Kiota.Builder.Refiners;

public class RustExceptionsReservedNamesProvider : IReservedNamesProvider
{
private readonly Lazy<HashSet<string>> _reservedNames = new(static () => new(StringComparer.OrdinalIgnoreCase)
{
"error",
"source",
"description",
});
public HashSet<string> ReservedNames => _reservedNames.Value;
}
271 changes: 271 additions & 0 deletions src/Kiota.Builder/Refiners/RustRefiner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using Kiota.Builder.CodeDOM;
using Kiota.Builder.Configuration;
using Kiota.Builder.Extensions;

namespace Kiota.Builder.Refiners;

public class RustRefiner : CommonLanguageRefiner, ILanguageRefiner
{
private const string AbstractionsNamespaceName = "kiota_abstractions";
private const string SerializationNamespaceName = "kiota_abstractions::serialization";

public RustRefiner(GenerationConfiguration configuration) : base(configuration) { }

public override Task RefineAsync(CodeNamespace generatedCode, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
_configuration.NamespaceNameSeparator = "::";
return Task.Run(() =>
{
cancellationToken.ThrowIfCancellationRequested();

DeduplicateErrorMappings(generatedCode);
MoveRequestBuilderPropertiesToBaseType(generatedCode,
new CodeUsing
{
Name = "BaseRequestBuilder",
Declaration = new CodeType
{
Name = AbstractionsNamespaceName,
IsExternal = true
}
},
accessModifier: AccessModifier.Public);
ReplaceIndexersByMethodsWithParameter(
generatedCode,
false,
static x => $"by_{x.ToSnakeCase()}",
static x => x.ToSnakeCase(),
GenerationLanguage.Rust);
cancellationToken.ThrowIfCancellationRequested();

AddInnerClasses(generatedCode, true, string.Empty, false);
cancellationToken.ThrowIfCancellationRequested();

RemoveRequestConfigurationClasses(generatedCode,
new CodeUsing
{
Name = "RequestConfiguration",
Declaration = new CodeType { Name = AbstractionsNamespaceName, IsExternal = true }
},
new CodeType { Name = "DefaultQueryParameters", IsExternal = true });
RemoveCancellationParameter(generatedCode);
cancellationToken.ThrowIfCancellationRequested();

ConvertUnionTypesToWrapper(
generatedCode,
_configuration.UsesBackingStore,
static s => s.ToSnakeCase(),
true,
string.Empty,
string.Empty,
"is_composed_type"
);
PromoteComposedTypesToNamespace(generatedCode);
cancellationToken.ThrowIfCancellationRequested();

ReplaceReservedNames(
generatedCode,
new RustReservedNamesProvider(),
x => $"r#{x}",
shouldReplaceCallback: static x => x is not CodeEnumOption && x is not CodeEnum);
ReplaceReservedExceptionPropertyNames(
generatedCode,
new RustExceptionsReservedNamesProvider(),
static x => $"{x}_prop");

AddPropertiesAndMethodTypesImports(generatedCode, true, false, true);
AddDefaultImports(generatedCode, defaultUsingEvaluators);
cancellationToken.ThrowIfCancellationRequested();

CorrectCoreType(generatedCode, CorrectMethodType, CorrectPropertyType, CorrectImplements);
DisableActionOf(generatedCode, CodeParameterKind.RequestConfiguration);

AddGetterAndSetterMethods(generatedCode,
new()
{
CodePropertyKind.Custom,
CodePropertyKind.AdditionalData,
CodePropertyKind.BackingStore,
},
static (_, s) => s.ToSnakeCase(),
_configuration.UsesBackingStore,
false, "get_", "set_");
AddConstructorsForDefaultValues(generatedCode, true, true);
MakeModelPropertiesNullable(generatedCode);
cancellationToken.ThrowIfCancellationRequested();

var defaultConfiguration = new GenerationConfiguration();
ReplaceDefaultSerializationModules(generatedCode, defaultConfiguration.Serializers,
new(StringComparer.OrdinalIgnoreCase)
{
"kiota_serialization_json::JsonSerializationWriterFactory",
"kiota_serialization_text::TextSerializationWriterFactory",
"kiota_serialization_form::FormSerializationWriterFactory",
});
ReplaceDefaultDeserializationModules(generatedCode, defaultConfiguration.Deserializers,
new(StringComparer.OrdinalIgnoreCase)
{
"kiota_serialization_json::JsonParseNodeFactory",
"kiota_serialization_text::TextParseNodeFactory",
"kiota_serialization_form::FormParseNodeFactory",
});
AddParentClassToErrorClasses(generatedCode, "ApiError", AbstractionsNamespaceName);
AddDiscriminatorMappingsUsingsToParentClasses(generatedCode, "ParseNode", true);
AddParsableImplementsForModelClasses(generatedCode, "Parsable");
cancellationToken.ThrowIfCancellationRequested();

ReplacePropertyNames(generatedCode,
new()
{
CodePropertyKind.Custom,
CodePropertyKind.QueryParameter,
},
static s => s.ToSnakeCase());
AddPrimaryErrorMessage(generatedCode, "error_message",
() => new CodeType { Name = "String", IsNullable = false, IsExternal = true });
NormalizeEnumNames(generatedCode);
}, cancellationToken);
}

/// Promotes model-kind classes (composed type wrappers) out of request builder classes
/// and into the request builder's parent namespace so they become separate .rs files.
private static void PromoteComposedTypesToNamespace(CodeElement currentElement)
{
if (currentElement is CodeClass parentClass && parentClass.IsOfKind(CodeClassKind.RequestBuilder))
{
var parentNamespace = parentClass.GetImmediateParentOfType<CodeNamespace>();
if (parentNamespace != null)
{
var toPromote = parentClass.InnerClasses
.Where(static c => !c.IsOfKind(CodeClassKind.QueryParameters, CodeClassKind.RequestConfiguration, CodeClassKind.ParameterSet))
.ToList();
foreach (var inner in toPromote)
{
parentClass.RemoveChildElementByName(inner.Name);
inner.Parent = parentNamespace;
parentNamespace.AddClass(inner);
}
}
}
CrawlTree(currentElement, PromoteComposedTypesToNamespace);
}

/// Normalize enum names: "Order_status" -> "OrderStatus"
private static void NormalizeEnumNames(CodeElement currentElement)
{
if (currentElement is CodeEnum codeEnum)
{
var newName = string.Join("", codeEnum.Name.Split('_').Select(static s => s.ToFirstCharacterUpperCase()));
if (!newName.Equals(codeEnum.Name, StringComparison.Ordinal))
codeEnum.Name = newName;
}
CrawlTree(currentElement, NormalizeEnumNames);
}

private static readonly AdditionalUsingEvaluator[] defaultUsingEvaluators =
[
new(static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.RequestAdapter),
AbstractionsNamespaceName, "RequestAdapter"),
new(static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.RequestGenerator),
AbstractionsNamespaceName, "RequestInformation", "HttpMethod", "RequestOption"),
new(static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.Serializer),
SerializationNamespaceName, "SerializationWriter"),
new(static x => x is CodeMethod method && method.IsOfKind(CodeMethodKind.Deserializer, CodeMethodKind.Factory),
SerializationNamespaceName, "ParseNode", "Parsable"),
new(static x => x is CodeClass cls && cls.IsOfKind(CodeClassKind.Model),
SerializationNamespaceName, "Parsable"),
new(static x => x is CodeClass cls && cls.IsOfKind(CodeClassKind.Model) &&
cls.Properties.Any(static p => p.IsOfKind(CodePropertyKind.AdditionalData)),
SerializationNamespaceName, "AdditionalDataHolder"),
new(static x => x is CodeProperty prop && prop.IsOfKind(CodePropertyKind.Headers),
AbstractionsNamespaceName, "RequestHeaders"),
];

private static void CorrectMethodType(CodeMethod currentMethod)
{
if (currentMethod.IsOfKind(CodeMethodKind.Serializer))
{
currentMethod.Parameters
.Where(static x => x.Type.Name.StartsWith('I'))
.ToList()
.ForEach(static x => x.Type.Name = x.Type.Name[1..]);
}
else if (currentMethod.IsOfKind(CodeMethodKind.Deserializer))
{
currentMethod.ReturnType.Name = "FieldDeserializers";
currentMethod.Name = "get_field_deserializers";
}
else if (currentMethod.IsOfKind(CodeMethodKind.Factory))
{
currentMethod.Parameters
.Where(static x => x.IsOfKind(CodeParameterKind.ParseNode) && x.Type.Name.StartsWith('I'))
.ToList()
.ForEach(static x => x.Type.Name = x.Type.Name[1..]);
}
else if (currentMethod.IsOfKind(CodeMethodKind.ClientConstructor, CodeMethodKind.Constructor, CodeMethodKind.RawUrlConstructor))
{
currentMethod.Parameters
.Where(static x => x.IsOfKind(CodeParameterKind.RequestAdapter) && x.Type.Name.StartsWith('I'))
.ToList()
.ForEach(static x => x.Type.Name = x.Type.Name[1..]);

if (currentMethod.Parameters.OfKind(CodeParameterKind.PathParameters) is CodeParameter pathsParam)
{
pathsParam.Type.Name = "HashMap<String, String>";
pathsParam.Type.IsNullable = true;
}
}

currentMethod.Parameters
.ToList()
.ForEach(static x => x.Name = x.Name.ToFirstCharacterLowerCase());
}

private static void CorrectPropertyType(CodeProperty currentProperty)
{
if (currentProperty.IsOfKind(CodePropertyKind.RequestAdapter))
{
if (currentProperty.Type.Name.StartsWith('I'))
currentProperty.Type.Name = currentProperty.Type.Name[1..];
}
else if (currentProperty.IsOfKind(CodePropertyKind.AdditionalData))
{
currentProperty.Type.Name = "HashMap<String, serde_json::Value>";
currentProperty.DefaultValue = "HashMap::new()";
}
else if (currentProperty.IsOfKind(CodePropertyKind.PathParameters))
{
currentProperty.Type.IsNullable = true;
currentProperty.Type.Name = "HashMap<String, String>";
if (!string.IsNullOrEmpty(currentProperty.DefaultValue))
currentProperty.DefaultValue = "HashMap::new()";
}
else if (currentProperty.IsOfKind(CodePropertyKind.Headers))
{
currentProperty.DefaultValue = "RequestHeaders::new()";
}
else if (currentProperty.IsOfKind(CodePropertyKind.Options))
{
currentProperty.Type.IsNullable = false;
currentProperty.Type.Name = "Vec<Box<dyn RequestOption>>";
}
else if (currentProperty.IsOfKind(CodePropertyKind.BackingStore))
{
if (currentProperty.Type.Name.StartsWith('I'))
currentProperty.Type.Name = currentProperty.Type.Name[1..];
}
}

private void CorrectImplements(ProprietableBlockDeclaration block)
{
block.ReplaceImplementByName(KiotaBuilder.AdditionalHolderInterface, "AdditionalDataHolder");
block.ReplaceImplementByName(KiotaBuilder.BackedModelInterface, "BackedModel");
}
}
22 changes: 22 additions & 0 deletions src/Kiota.Builder/Refiners/RustReservedNamesProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;

namespace Kiota.Builder.Refiners;

public class RustReservedNamesProvider : IReservedNamesProvider
{
private readonly Lazy<HashSet<string>> _reservedNames = new(static () => new(StringComparer.OrdinalIgnoreCase) {
// Strict keywords
"as", "break", "const", "continue", "crate", "else", "enum", "extern",
"false", "fn", "for", "if", "impl", "in", "let", "loop", "match",
"mod", "move", "mut", "pub", "ref", "return", "self", "Self",
"static", "struct", "super", "trait", "true", "type", "unsafe",
"use", "where", "while",
// Async/await
"async", "await", "dyn",
// Reserved for future use
"abstract", "become", "box", "do", "final", "macro", "override",
"priv", "try", "typeof", "unsized", "virtual", "yield",
});
public HashSet<string> ReservedNames => _reservedNames.Value;
}
2 changes: 2 additions & 0 deletions src/Kiota.Builder/Writers/LanguageWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Kiota.Builder.Writers.Php;
using Kiota.Builder.Writers.Python;
using Kiota.Builder.Writers.Ruby;
using Kiota.Builder.Writers.Rust;
using Kiota.Builder.Writers.TypeScript;

namespace Kiota.Builder.Writers;
Expand Down Expand Up @@ -191,6 +192,7 @@ public static LanguageWriter GetLanguageWriter(GenerationLanguage language, stri
GenerationLanguage.Go => new GoWriter(outputPath, clientNamespaceName, excludeBackwardCompatible),
GenerationLanguage.Dart => new DartWriter(outputPath, clientNamespaceName),
GenerationLanguage.HTTP => new HttpWriter(outputPath, clientNamespaceName),
GenerationLanguage.Rust => new RustWriter(outputPath, clientNamespaceName),
_ => throw new InvalidEnumArgumentException($"{language} language currently not supported."),
};
}
Expand Down
Loading