Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/code-generator-static-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ export class StaticCodeGenerator implements CodeGenerator, SourceBuilder {
case "raw":
return segment.content;
case "code":
return `(*ss_ptr) << ${this.#renderString(segment.content)};\n`;
return `ss << ${this.#renderString(segment.content)};\n`;
case "expression":
return `(*ss_ptr) << ${segment.content};\n`;
return `ss << ${segment.content};\n`;
}
})
.join("");
Expand Down Expand Up @@ -205,10 +205,10 @@ export class StaticCodeGenerator implements CodeGenerator, SourceBuilder {

#pragma push_macro("MainFunctionStart")
#undef MainFunctionStart
#define MainFunctionStart ss_ptr = &shader_helper.MainFunctionBody
#define MainFunctionStart() { [[maybe_unused]] auto& ss = shader_helper.MainFunctionBody();
#pragma push_macro("MainFunctionEnd")
#undef MainFunctionEnd
#define MainFunctionEnd ss_ptr = &shader_helper.AdditionalImplementation
#define MainFunctionEnd() }

// Helper templates

Expand Down Expand Up @@ -271,7 +271,7 @@ std::string pass_as_string(T&& v) {
filePath
)}>::type ${paramsIsNotUsed ? "" : "params"}) {`
);
implContent.push(" OStringStream* ss_ptr = &shader_helper.AdditionalImplementation();");
implContent.push(" [[maybe_unused]] auto& ss = shader_helper.AdditionalImplementation();");
implContent.push("");

// Add parameter assignments for easier access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

template <>
Status ApplyTemplate<"shader/triangle.wgsl.template">(ShaderHelper& shader_helper, TemplateParameter<"shader/triangle.wgsl.template">::type ) {
OStringStream* ss_ptr = &shader_helper.AdditionalImplementation();
[[maybe_unused]] auto& ss = shader_helper.AdditionalImplementation();

(*ss_ptr) << __str_0;
ss << __str_0;


return Status::OK();
Expand Down
6 changes: 3 additions & 3 deletions test/testcases/build-basic/expected/static-cpp/index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

#pragma push_macro("MainFunctionStart")
#undef MainFunctionStart
#define MainFunctionStart ss_ptr = &shader_helper.MainFunctionBody
#define MainFunctionStart() { [[maybe_unused]] auto& ss = shader_helper.MainFunctionBody();
#pragma push_macro("MainFunctionEnd")
#undef MainFunctionEnd
#define MainFunctionEnd ss_ptr = &shader_helper.AdditionalImplementation
#define MainFunctionEnd() }

// Helper templates

Expand All @@ -35,7 +35,7 @@ std::string pass_as_string(T&& v) {

// Include template implementations

#include "generated/shader/triangle.h" // f5cec46558b917d8a4ec739f0ab02e71a31046d9c0d14028daa4aa0557a72da6
#include "generated/shader/triangle.h" // d61822d677fc73d87adbe0b7b8e81d556ca515fef270bbe01dcdcb10596c350b

#pragma pop_macro("MainFunctionStart")
#pragma pop_macro("MainFunctionEnd")
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

template <>
Status ApplyTemplate<"tensor/pad.wgsl.template">(ShaderHelper& shader_helper, TemplateParameter<"tensor/pad.wgsl.template">::type params) {
OStringStream* ss_ptr = &shader_helper.AdditionalImplementation();
[[maybe_unused]] auto& ss = shader_helper.AdditionalImplementation();

// Extract parameters
auto& __param_dim_value_zero = params.param_dim_value_zero;
Expand All @@ -18,51 +18,51 @@ Status ApplyTemplate<"tensor/pad.wgsl.template">(ShaderHelper& shader_helper, Te
auto& __var_output = *params.var_output;

MainFunctionStart();
(*ss_ptr) << "\n ";
(*ss_ptr) << shader_helper.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
(*ss_ptr) << ";\n\n let constant_value =\n";
ss << "\n ";
ss << shader_helper.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
ss << ";\n\n let constant_value =\n";
if (__param_is_float16) {
(*ss_ptr) << " bitcast<vec2<f16>>(uniforms.constant_value)[0];\n";
ss << " bitcast<vec2<f16>>(uniforms.constant_value)[0];\n";
} else {
(*ss_ptr) << " bitcast<output_value_t>(uniforms.constant_value);\n";
ss << " bitcast<output_value_t>(uniforms.constant_value);\n";
}
if (__param_dim_value_zero) {
(*ss_ptr) << " output[global_idx] = constant_value;\n";
ss << " output[global_idx] = constant_value;\n";
} else {
(*ss_ptr) << " let output_indices = ";
(*ss_ptr) << __var_output.OffsetToIndices("global_idx");
(*ss_ptr) << ";\n var input_index = u32(0);\n var use_pad_value = false;\n var in_coord = i32(0);\n\n for (var dim = 0; dim < ";
(*ss_ptr) << __var_output.Rank();
(*ss_ptr) << " && !use_pad_value; dim++) {\n let output_index = i32(";
(*ss_ptr) << GetElementAt("output_indices", "dim", __var_output.Rank());
(*ss_ptr) << ");\n let lower_pads = ";
(*ss_ptr) << GetElementAt("uniforms.lower_pads", "dim", __var_output.Rank());
(*ss_ptr) << ";\n let data_shape = i32(";
(*ss_ptr) << GetElementAt("uniforms.data_shape", "dim", __var_output.Rank());
(*ss_ptr) << ");\n";
ss << " let output_indices = ";
ss << __var_output.OffsetToIndices("global_idx");
ss << ";\n var input_index = u32(0);\n var use_pad_value = false;\n var in_coord = i32(0);\n\n for (var dim = 0; dim < ";
ss << __var_output.Rank();
ss << " && !use_pad_value; dim++) {\n let output_index = i32(";
ss << GetElementAt("output_indices", "dim", __var_output.Rank());
ss << ");\n let lower_pads = ";
ss << GetElementAt("uniforms.lower_pads", "dim", __var_output.Rank());
ss << ";\n let data_shape = i32(";
ss << GetElementAt("uniforms.data_shape", "dim", __var_output.Rank());
ss << ");\n";
if (__param_pad_mode == 0) {
(*ss_ptr) << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n use_pad_value = true;\n";
ss << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n use_pad_value = true;\n";
} else if (__param_pad_mode == 2) {
(*ss_ptr) << " if (output_index < lower_pads) {\n in_coord = 0;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = data_shape - 1;\n";
ss << " if (output_index < lower_pads) {\n in_coord = 0;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = data_shape - 1;\n";
} else if (__param_pad_mode == 1) {
(*ss_ptr) << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n in_coord = output_index - lower_pads;\n if (in_coord < 0) {\n in_coord = -in_coord;\n }\n let _2n_1 = 2 * (data_shape - 1);\n in_coord = in_coord % _2n_1;\n if (in_coord >= data_shape) {\n in_coord = _2n_1 - in_coord;\n }\n";
ss << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n in_coord = output_index - lower_pads;\n if (in_coord < 0) {\n in_coord = -in_coord;\n }\n let _2n_1 = 2 * (data_shape - 1);\n in_coord = in_coord % _2n_1;\n if (in_coord >= data_shape) {\n in_coord = _2n_1 - in_coord;\n }\n";
} else {
(*ss_ptr) << " if (output_index < lower_pads) {\n in_coord = data_shape + output_index - lower_pads;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = output_index - data_shape - lower_pads;\n";
ss << " if (output_index < lower_pads) {\n in_coord = data_shape + output_index - lower_pads;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = output_index - data_shape - lower_pads;\n";
}
(*ss_ptr) << " } else {\n in_coord = output_index - lower_pads;\n }\n\n input_index += select(u32(in_coord)\n";
ss << " } else {\n in_coord = output_index - lower_pads;\n }\n\n input_index += select(u32(in_coord)\n";
if (__var_output.Rank() > 1) {
(*ss_ptr) << " * ";
(*ss_ptr) << GetElementAt("uniforms.data_stride", "dim", __var_output.Rank() - 1);
(*ss_ptr) << "\n";
ss << " * ";
ss << GetElementAt("uniforms.data_stride", "dim", __var_output.Rank() - 1);
ss << "\n";
}
(*ss_ptr) << " , u32(in_coord), dim == ";
(*ss_ptr) << __var_output.Rank();
(*ss_ptr) << " - 1);\n }\n\n ";
(*ss_ptr) << __var_output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)");
(*ss_ptr) << ";\n";
ss << " , u32(in_coord), dim == ";
ss << __var_output.Rank();
ss << " - 1);\n }\n\n ";
ss << __var_output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)");
ss << ";\n";
}
MainFunctionEnd();
(*ss_ptr) << "\n";
ss << "\n";


return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

#pragma push_macro("MainFunctionStart")
#undef MainFunctionStart
#define MainFunctionStart ss_ptr = &shader_helper.MainFunctionBody
#define MainFunctionStart() { [[maybe_unused]] auto& ss = shader_helper.MainFunctionBody();
#pragma push_macro("MainFunctionEnd")
#undef MainFunctionEnd
#define MainFunctionEnd ss_ptr = &shader_helper.AdditionalImplementation
#define MainFunctionEnd() }

// Helper templates

Expand All @@ -34,7 +34,7 @@ std::string pass_as_string(T&& v) {

// Include template implementations

#include "generated/tensor/pad.h" // 0a0ab18c4abbbd85c08852d67af2741d77f44c28c604fe599c54a0d050ea354f
#include "generated/tensor/pad.h" // 6903e8c7560b2507fffd2da3b327d9c6a354da219640c3131c09a92781356c9f

#pragma pop_macro("MainFunctionStart")
#pragma pop_macro("MainFunctionEnd")
Original file line number Diff line number Diff line change
@@ -1 +1 @@
(*ss_ptr) << "@compute @workgroup_size(128)\nfn main(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let buffer_count = 4;\n let index = global_id.x;\n}\n";
ss << "@compute @workgroup_size(128)\nfn main(@builtin(global_invocation_id) global_id: vec3<u32>) {\n let buffer_count = 4;\n let index = global_id.x;\n}\n";
Original file line number Diff line number Diff line change
@@ -1,46 +1,46 @@
MainFunctionStart();
(*ss_ptr) << "\n ";
(*ss_ptr) << shader_helper.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
(*ss_ptr) << ";\n\n let constant_value =\n";
ss << "\n ";
ss << shader_helper.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size");
ss << ";\n\n let constant_value =\n";
if (__param_is_float16) {
(*ss_ptr) << " bitcast<vec2<f16>>(uniforms.constant_value)[0];\n";
ss << " bitcast<vec2<f16>>(uniforms.constant_value)[0];\n";
} else {
(*ss_ptr) << " bitcast<output_value_t>(uniforms.constant_value);\n";
ss << " bitcast<output_value_t>(uniforms.constant_value);\n";
}
if (__param_dim_value_zero) {
(*ss_ptr) << " output[global_idx] = constant_value;\n";
ss << " output[global_idx] = constant_value;\n";
} else {
(*ss_ptr) << " let output_indices = ";
(*ss_ptr) << __var_output.OffsetToIndices("global_idx");
(*ss_ptr) << ";\n var input_index = u32(0);\n var use_pad_value = false;\n var in_coord = i32(0);\n\n for (var dim = 0; dim < ";
(*ss_ptr) << __var_output.Rank();
(*ss_ptr) << " && !use_pad_value; dim++) {\n let output_index = i32(";
(*ss_ptr) << GetElementAt("output_indices", "dim", __var_output.Rank());
(*ss_ptr) << ");\n let lower_pads = ";
(*ss_ptr) << GetElementAt("uniforms.lower_pads", "dim", __var_output.Rank());
(*ss_ptr) << ";\n let data_shape = i32(";
(*ss_ptr) << GetElementAt("uniforms.data_shape", "dim", __var_output.Rank());
(*ss_ptr) << ");\n";
ss << " let output_indices = ";
ss << __var_output.OffsetToIndices("global_idx");
ss << ";\n var input_index = u32(0);\n var use_pad_value = false;\n var in_coord = i32(0);\n\n for (var dim = 0; dim < ";
ss << __var_output.Rank();
ss << " && !use_pad_value; dim++) {\n let output_index = i32(";
ss << GetElementAt("output_indices", "dim", __var_output.Rank());
ss << ");\n let lower_pads = ";
ss << GetElementAt("uniforms.lower_pads", "dim", __var_output.Rank());
ss << ";\n let data_shape = i32(";
ss << GetElementAt("uniforms.data_shape", "dim", __var_output.Rank());
ss << ");\n";
if (__param_pad_mode == 0) {
(*ss_ptr) << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n use_pad_value = true;\n";
ss << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n use_pad_value = true;\n";
} else if (__param_pad_mode == 1) {
(*ss_ptr) << " if (output_index < lower_pads) {\n in_coord = 0;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = data_shape - 1;\n";
ss << " if (output_index < lower_pads) {\n in_coord = 0;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = data_shape - 1;\n";
} else if (__param_pad_mode == 2) {
(*ss_ptr) << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n in_coord = output_index - lower_pads;\n if (in_coord < 0) {\n in_coord = -in_coord;\n }\n let _2n_1 = 2 * (data_shape - 1);\n in_coord = in_coord % _2n_1;\n if (in_coord >= data_shape) {\n in_coord = _2n_1 - in_coord;\n }\n";
ss << " if (output_index < lower_pads || output_index >= data_shape + lower_pads) {\n in_coord = output_index - lower_pads;\n if (in_coord < 0) {\n in_coord = -in_coord;\n }\n let _2n_1 = 2 * (data_shape - 1);\n in_coord = in_coord % _2n_1;\n if (in_coord >= data_shape) {\n in_coord = _2n_1 - in_coord;\n }\n";
} else {
(*ss_ptr) << " if (output_index < lower_pads) {\n in_coord = data_shape + output_index - lower_pads;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = output_index - data_shape - lower_pads;\n";
ss << " if (output_index < lower_pads) {\n in_coord = data_shape + output_index - lower_pads;\n } else if (output_index >= data_shape + lower_pads) {\n in_coord = output_index - data_shape - lower_pads;\n";
}
(*ss_ptr) << " } else {\n in_coord = output_index - lower_pads;\n }\n\n input_index += select(u32(in_coord)\n";
ss << " } else {\n in_coord = output_index - lower_pads;\n }\n\n input_index += select(u32(in_coord)\n";
if (__var_output.Rank() > 1) {
(*ss_ptr) << " * ";
(*ss_ptr) << GetElementAt("uniforms.data_stride", "dim", __var_output.Rank() - 1);
(*ss_ptr) << "\n";
ss << " * ";
ss << GetElementAt("uniforms.data_stride", "dim", __var_output.Rank() - 1);
ss << "\n";
}
(*ss_ptr) << " , u32(in_coord), dim == ";
(*ss_ptr) << __var_output.Rank();
(*ss_ptr) << " - 1);\n }\n\n ";
(*ss_ptr) << __var_output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)");
(*ss_ptr) << ";\n";
ss << " , u32(in_coord), dim == ";
ss << __var_output.Rank();
ss << " - 1);\n }\n\n ";
ss << __var_output.SetByOffset("global_idx", "select(data[input_index], constant_value, use_pad_value)");
ss << ";\n";
}
MainFunctionEnd();
(*ss_ptr) << "\n";
ss << "\n";
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
MainFunctionStart();
(*ss_ptr) << "\n let index = global_id.x;\n\n output[index] = input[index] * 2.0;\n";
ss << "\n let index = global_id.x;\n\n output[index] = input[index] * 2.0;\n";
MainFunctionEnd();
(*ss_ptr) << "\n";
ss << "\n";
Original file line number Diff line number Diff line change
@@ -1 +1 @@
(*ss_ptr) << "let size = BUFFER_SIZE;\n";
ss << "let size = BUFFER_SIZE;\n";
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
(*ss_ptr) << "let size = ";
(*ss_ptr) << __param_BUFFER_SIZE;
(*ss_ptr) << ";\n";
ss << "let size = ";
ss << __param_BUFFER_SIZE;
ss << ";\n";
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
// 1 | // This is a one line comment
// 2 |
// 3 | @compute @workgroup_size(128)
(*ss_ptr) << "@compute @workgroup_size(128)\n";
ss << "@compute @workgroup_size(128)\n";
// 4 | fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
(*ss_ptr) << "fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {\n";
ss << "fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {\n";
// 5 | let buffer_count = 4; /* This
(*ss_ptr) << " let buffer_count = 4;\n";
ss << " let buffer_count = 4;\n";
// 6 | is
(*ss_ptr) << "\n";
ss << "\n";
// 7 | a
// 8 | multi-line
// 9 | comment */
// 10 | let index = global_id.x;
(*ss_ptr) << " let index = global_id.x;\n";
ss << " let index = global_id.x;\n";
// 11 | }
(*ss_ptr) << "}\n";
ss << "}\n";
// 12 |