From c2d83677161cf51fd312f42de84a340967bceedd Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Tue, 28 Apr 2026 07:41:07 -0700 Subject: [PATCH 01/11] feat: add protoc-gen-pony plugin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generates Pony source code (class val records + sister Codec primitives) from .proto files. Calls into the Pony `protobuf` runtime library at straw-hat-team/trogonai.com/mrmeeseeks/protobuf for WireReader/ WireWriter/Tag/Scalar/WireType/WireError. Built on `google.golang.org/protobuf/compiler/protogen` for descriptor traversal. Pre-injects M-stub mappings so users targeting Pony don't have to set go_package or M-params just to satisfy protogen's Go-import requirement. v1 scope: singular implicit-presence proto3 scalars (bool, int32/64, uint32/64, sint32/64, fixed32/64, sfixed32/64, float, double, string, bytes). Repeated, optional explicit presence, oneofs, maps, embedded messages, and enums emit a `// TODO protoc-gen-pony` placeholder until the corresponding codegen lands. Services (gRPC) are out of scope. End-to-end smoke verified: protoc --pony_out=... produces .pony files that compile cleanly against the runtime and round-trip a User message through encode→bytes→decode with field values preserved. 9 tests pass, all running in parallel via t.Parallel(). Signed-off-by: Erik Nilsen --- Taskfile.yml | 7 + cmd/protoc-gen-pony/CHANGELOG.md | 10 + cmd/protoc-gen-pony/README.md | 65 ++++++ cmd/protoc-gen-pony/generate.go | 371 +++++++++++++++++++++++++++++++ cmd/protoc-gen-pony/main.go | 127 +++++++++++ cmd/protoc-gen-pony/main_test.go | 230 +++++++++++++++++++ 6 files changed, 810 insertions(+) create mode 100644 cmd/protoc-gen-pony/CHANGELOG.md create mode 100644 cmd/protoc-gen-pony/README.md create mode 100644 cmd/protoc-gen-pony/generate.go create mode 100644 cmd/protoc-gen-pony/main.go create mode 100644 cmd/protoc-gen-pony/main_test.go diff --git a/Taskfile.yml b/Taskfile.yml index ab01dc7..b706f6a 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -64,6 +64,7 @@ tasks: - echo "Building protoc plugin binaries..." - go build -o protoc-gen-connect-go-servicestruct ./cmd/protoc-gen-connect-go-servicestruct - go build -o protoc-gen-elixir-grpc ./cmd/protoc-gen-elixir-grpc + - go build -o protoc-gen-pony ./cmd/protoc-gen-pony build-plugin-go: desc: Build the Go Connect protoc plugin binary @@ -77,6 +78,12 @@ tasks: - echo "Building Elixir gRPC protoc plugin binary..." - go build -o protoc-gen-elixir-grpc ./cmd/protoc-gen-elixir-grpc + build-plugin-pony: + desc: Build the Pony protoc plugin binary + cmds: + - echo "Building Pony protoc plugin binary..." + - go build -o protoc-gen-pony ./cmd/protoc-gen-pony + clean: desc: Clean build artifacts and coverage files cmds: diff --git a/cmd/protoc-gen-pony/CHANGELOG.md b/cmd/protoc-gen-pony/CHANGELOG.md new file mode 100644 index 0000000..87a09d8 --- /dev/null +++ b/cmd/protoc-gen-pony/CHANGELOG.md @@ -0,0 +1,10 @@ +# Changelog + +## Unreleased + +### Features + +* Initial release. Generates Pony `class val` records + sister `Codec` + primitives for proto3 messages with singular implicit-presence scalar + fields. Repeated, `optional`, oneof, map, embedded-message, and enum + fields surface as `// TODO protoc-gen-pony` placeholders. diff --git a/cmd/protoc-gen-pony/README.md b/cmd/protoc-gen-pony/README.md new file mode 100644 index 0000000..ed5e66a --- /dev/null +++ b/cmd/protoc-gen-pony/README.md @@ -0,0 +1,65 @@ +# protoc-gen-pony + +A Protobuf compiler plugin that generates Pony source code — `class val` +records plus sister `Codec` primitives that decode and encode against the +[`protobuf` Pony runtime library][runtime]. + +## Install + +```bash +go install github.com/TrogonStack/protoc-gen/cmd/protoc-gen-pony@latest +``` + +## Usage + +With `protoc`: + +```bash +protoc --pony_out=gen path/to/file.proto +``` + +With [buf]: + +```yaml +# buf.gen.yaml +version: v2 +plugins: + - local: protoc-gen-pony + out: gen +``` + +## Output + +For a `User` message in `acme/v1/user.proto`: + +```protobuf +syntax = "proto3"; +package acme.v1; + +message User { + int32 id = 1; + string name = 2; + bool active = 3; +} +``` + +The plugin writes `gen/acme/v1/user.pony` with a `class val User` record and a +`primitive UserCodec` exposing `decode(reader: WireReader ref): (User val | +WireError)` and `encode(writer: WireWriter ref, msg: User val)`. + +## Runtime requirement + +Generated code calls into the Pony `protobuf` package — `WireReader`, +`WireWriter`, `Tag`, `Scalar`, `WireType`, `WireError`. See [the runtime +sources][runtime]. + +## Coverage + +v1 supports singular implicit-presence proto3 scalars (bool, int32/64, +uint32/64, sint32/64, fixed32/64, sfixed32/64, float, double, string, +bytes). Repeated fields, `optional` explicit presence, oneofs, maps, +embedded messages, and enums emit a `// TODO protoc-gen-pony` comment +until the corresponding codegen lands. Services (gRPC) are out of scope. + +[buf]: https://buf.build +[runtime]: https://github.com/straw-hat-team/trogonai.com/tree/main/mrmeeseeks/protobuf diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go new file mode 100644 index 0000000..f100078 --- /dev/null +++ b/cmd/protoc-gen-pony/generate.go @@ -0,0 +1,371 @@ +package main + +import ( + "fmt" + "strings" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/reflect/protoreflect" +) + +// generateFile emits one `.pony` source per `.proto` input. Pony has no +// notion of a `package` keyword inside the file (packages are directories), +// so we just stream each top-level message in the file. Nested messages +// flatten with `_` (Outer_Inner) — see collectAndEmitMessages. +// +// Output path mirrors the proto file's path layout (e.g. +// `acme/users/v1/user.proto` → `acme/users/v1/user.pony`). We compute it +// from file.Desc.Path() directly because protogen's GeneratedFilenamePrefix +// is prefixed with the Go import path, which is irrelevant for Pony output. +func generateFile(plugin *protogen.Plugin, file *protogen.File) { + if len(file.Messages) == 0 && len(file.Enums) == 0 { + return + } + outPath := strings.TrimSuffix(file.Desc.Path(), ".proto") + ".pony" + g := plugin.NewGeneratedFile(outPath, protogen.GoImportPath("")) + + // Use `//` not `"""` — Pony allows only one package docstring per + // directory, and the runtime's protobuf.pony already owns it. + g.P(`// Generated by protoc-gen-pony. DO NOT EDIT.`) + g.P(`// Source: `, file.Desc.Path()) + g.P() + + collectAndEmitMessages(g, file.Messages, "") + for _, enum := range file.Enums { + emitEnumTodo(g, enum, "") + } +} + +// collectAndEmitMessages walks a slice of messages depth-first, flattening +// nested types into their parent's namespace via `Outer_Inner` mangling. +// Messages are emitted in source order — protogen already gives us a +// deterministic walk, so we don't need to sort. +func collectAndEmitMessages(g *protogen.GeneratedFile, messages []*protogen.Message, namePrefix string) { + for _, msg := range messages { + flatName := namePrefix + string(msg.Desc.Name()) + emitMessage(g, msg, flatName) + if len(msg.Messages) > 0 { + collectAndEmitMessages(g, msg.Messages, flatName+"_") + } + for _, enum := range msg.Enums { + emitEnumTodo(g, enum, flatName+"_") + } + } +} + +func emitMessage(g *protogen.GeneratedFile, msg *protogen.Message, className string) { + supported, unsupported := classifyFields(msg.Fields) + + emitClass(g, className, msg.Fields, supported, unsupported) + g.P() + emitCodec(g, className, supported) + g.P() +} + +func emitClass(g *protogen.GeneratedFile, className string, all []*protogen.Field, supported, unsupported []*protogen.Field) { + g.P(`class val `, className) + for _, field := range all { + if isSupported(field) { + g.P(` let `, fieldName(field), `: `, ponyType(field)) + } else { + g.P(` // TODO protoc-gen-pony: field `, fieldName(field), ` (`, fieldShape(field), `)`) + } + } + g.P() + emitConstructor(g, supported) + _ = unsupported +} + +func emitConstructor(g *protogen.GeneratedFile, supported []*protogen.Field) { + if len(supported) == 0 { + g.P(` new val create() => None`) + return + } + g.P(` new val create(`) + for i, field := range supported { + suffix := "," + if i == len(supported)-1 { + suffix = ")" + } + g.P(` `, fieldName(field), `': `, ponyType(field), ` = `, ponyDefault(field), suffix) + } + g.P(` =>`) + for _, field := range supported { + g.P(` `, fieldName(field), ` = `, fieldName(field), `'`) + } +} + +func emitCodec(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { + g.P(`primitive `, className, `Codec`) + emitDecode(g, className, supported) + g.P() + emitEncode(g, className, supported) +} + +func emitDecode(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { + g.P(` fun decode(reader: WireReader ref): (`, className, ` val | WireError) =>`) + for _, field := range supported { + g.P(` var `, fieldName(field), `: `, ponyType(field), ` = `, ponyDefault(field)) + } + g.P(` while not reader.at_end() do`) + g.P(` match reader.read_tag()`) + g.P(` | let t: Tag =>`) + g.P(` match (t.field_number, t.wire_type)`) + for _, field := range supported { + g.P(` | (`, field.Desc.Number(), `, `, ponyWireType(field), `) =>`) + g.P(` match `, ponyReadExpr(field)) + g.P(` | let v: `, ponyType(field), ` => `, fieldName(field), ` = v`) + g.P(` | let e: WireError => return e`) + g.P(` end`) + } + g.P(` else`) + g.P(` match reader.skip(t.wire_type)`) + g.P(` | None => None`) + g.P(` | let e: WireError => return e`) + g.P(` end`) + g.P(` end`) + g.P(` | let e: WireError => return e`) + g.P(` end`) + g.P(` end`) + emitConstructorCall(g, className, supported) +} + +func emitConstructorCall(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { + if len(supported) == 0 { + g.P(` `, className) + return + } + parts := make([]string, len(supported)) + for i, field := range supported { + parts[i] = fieldName(field) + } + g.P(` `, className, `(`, strings.Join(parts, ", "), `)`) +} + +func emitEncode(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { + g.P(` fun encode(writer: WireWriter ref, msg: `, className, ` val) =>`) + if len(supported) == 0 { + g.P(` None`) + return + } + for _, field := range supported { + emitEncodeField(g, field) + } +} + +func emitEncodeField(g *protogen.GeneratedFile, field *protogen.Field) { + num := field.Desc.Number() + if field.Desc.Kind() == protoreflect.StringKind { + // write_string_field handles the empty-string skip internally. + g.P(` writer.write_string_field(`, num, `, msg.`, fieldName(field), `)`) + return + } + g.P(` if `, presenceCheck(field), ` then`) + g.P(` writer.write_tag(Tag(`, num, `, `, ponyWireType(field), `))`) + g.P(` `, ponyWriteCall(field, "msg."+fieldName(field))) + g.P(` end`) +} + +// emitEnumTodo lays down a placeholder for enums until enum codegen lands. +func emitEnumTodo(g *protogen.GeneratedFile, enum *protogen.Enum, namePrefix string) { + g.P(`// TODO protoc-gen-pony: enum `, namePrefix, enum.Desc.Name()) + g.P() +} + +// ── classifiers ───────────────────────────────────────────────────────── + +// isSupported is the v1 cut: singular implicit-presence proto3 scalars only. +// Repeated, optional explicit presence, oneofs, maps, messages, enums, and +// groups all surface as TODO comments until the corresponding codegen lands. +func isSupported(field *protogen.Field) bool { + if field.Desc.IsList() || field.Desc.IsMap() { + return false + } + if field.Desc.HasPresence() && !field.Desc.HasOptionalKeyword() { + // Editions field with explicit presence; defer until presence codegen. + // (HasOptionalKeyword catches the proto3 `optional` case below.) + } + if field.Desc.HasOptionalKeyword() { + return false + } + if field.Oneof != nil && !field.Desc.HasOptionalKeyword() { + return false + } + switch field.Desc.Kind() { + case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.EnumKind: + return false + } + return true +} + +func classifyFields(fields []*protogen.Field) (supported, unsupported []*protogen.Field) { + for _, f := range fields { + if isSupported(f) { + supported = append(supported, f) + } else { + unsupported = append(unsupported, f) + } + } + return +} + +// fieldShape returns a short string describing why a field is unsupported, +// for the TODO comment. +func fieldShape(field *protogen.Field) string { + parts := []string{} + if field.Desc.IsMap() { + parts = append(parts, "map") + } else if field.Desc.IsList() { + parts = append(parts, "repeated") + } + if field.Desc.HasOptionalKeyword() { + parts = append(parts, "optional") + } + if field.Oneof != nil { + parts = append(parts, "oneof") + } + parts = append(parts, field.Desc.Kind().String()) + return strings.Join(parts, " ") +} + +// fieldName returns the Pony-side identifier for a proto field. Proto field +// names are already snake_case so this is the identity for now; the helper +// exists so future name-collision handling lives in one place. +func fieldName(field *protogen.Field) string { + return string(field.Desc.Name()) +} + +// ── scalar type lookup ────────────────────────────────────────────────── + +func ponyType(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.BoolKind: + return "Bool" + case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: + return "I32" + case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: + return "I64" + case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: + return "U32" + case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: + return "U64" + case protoreflect.FloatKind: + return "F32" + case protoreflect.DoubleKind: + return "F64" + case protoreflect.StringKind: + return "String val" + case protoreflect.BytesKind: + return "Array[U8] val" + } + return "" +} + +func ponyDefault(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.BoolKind: + return "false" + case protoreflect.FloatKind, protoreflect.DoubleKind: + return "0.0" + case protoreflect.StringKind: + return `""` + case protoreflect.BytesKind: + return "recover val Array[U8] end" + } + return "0" +} + +func ponyWireType(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: + return "WireFixed32" + case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: + return "WireFixed64" + case protoreflect.StringKind, protoreflect.BytesKind: + return "WireLenDelim" + } + return "WireVarint" +} + +func ponyReadExpr(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.BoolKind: + return "Scalar.read_bool(reader)" + case protoreflect.Int32Kind: + return "Scalar.read_int32(reader)" + case protoreflect.Int64Kind: + return "Scalar.read_int64(reader)" + case protoreflect.Uint32Kind: + return "Scalar.read_uint32(reader)" + case protoreflect.Uint64Kind: + return "Scalar.read_uint64(reader)" + case protoreflect.Sint32Kind: + return "Scalar.read_sint32(reader)" + case protoreflect.Sint64Kind: + return "Scalar.read_sint64(reader)" + case protoreflect.Fixed32Kind: + return "Scalar.read_fixed32(reader)" + case protoreflect.Fixed64Kind: + return "Scalar.read_fixed64(reader)" + case protoreflect.Sfixed32Kind: + return "Scalar.read_sfixed32(reader)" + case protoreflect.Sfixed64Kind: + return "Scalar.read_sfixed64(reader)" + case protoreflect.FloatKind: + return "Scalar.read_float(reader)" + case protoreflect.DoubleKind: + return "Scalar.read_double(reader)" + case protoreflect.StringKind: + return "reader.read_string()" + case protoreflect.BytesKind: + return "reader.read_len_delim()" + } + return "" +} + +func ponyWriteCall(field *protogen.Field, valueRef string) string { + switch field.Desc.Kind() { + case protoreflect.BoolKind: + return fmt.Sprintf("Scalar.write_bool(writer, %s)", valueRef) + case protoreflect.Int32Kind: + return fmt.Sprintf("Scalar.write_int32(writer, %s)", valueRef) + case protoreflect.Int64Kind: + return fmt.Sprintf("Scalar.write_int64(writer, %s)", valueRef) + case protoreflect.Uint32Kind: + return fmt.Sprintf("Scalar.write_uint32(writer, %s)", valueRef) + case protoreflect.Uint64Kind: + return fmt.Sprintf("Scalar.write_uint64(writer, %s)", valueRef) + case protoreflect.Sint32Kind: + return fmt.Sprintf("Scalar.write_sint32(writer, %s)", valueRef) + case protoreflect.Sint64Kind: + return fmt.Sprintf("Scalar.write_sint64(writer, %s)", valueRef) + case protoreflect.Fixed32Kind: + return fmt.Sprintf("Scalar.write_fixed32(writer, %s)", valueRef) + case protoreflect.Fixed64Kind: + return fmt.Sprintf("Scalar.write_fixed64(writer, %s)", valueRef) + case protoreflect.Sfixed32Kind: + return fmt.Sprintf("Scalar.write_sfixed32(writer, %s)", valueRef) + case protoreflect.Sfixed64Kind: + return fmt.Sprintf("Scalar.write_sfixed64(writer, %s)", valueRef) + case protoreflect.FloatKind: + return fmt.Sprintf("Scalar.write_float(writer, %s)", valueRef) + case protoreflect.DoubleKind: + return fmt.Sprintf("Scalar.write_double(writer, %s)", valueRef) + case protoreflect.BytesKind: + return fmt.Sprintf("writer.write_len_delim(%s)", valueRef) + } + return "" +} + +func presenceCheck(field *protogen.Field) string { + ref := "msg." + fieldName(field) + switch field.Desc.Kind() { + case protoreflect.BoolKind: + return ref + case protoreflect.FloatKind, protoreflect.DoubleKind: + return ref + " != 0.0" + case protoreflect.BytesKind: + return ref + ".size() > 0" + } + return ref + " != 0" +} diff --git a/cmd/protoc-gen-pony/main.go b/cmd/protoc-gen-pony/main.go new file mode 100644 index 0000000..14800cb --- /dev/null +++ b/cmd/protoc-gen-pony/main.go @@ -0,0 +1,127 @@ +// protoc-gen-pony is a plugin for the Protobuf compiler that generates +// Pony code (class val records + sister Codec primitives that decode/encode +// against the Pony `protobuf` runtime library). To use it, build this +// program and make it available on your PATH as protoc-gen-pony. +// +// With protoc: +// +// protoc --pony_out=gen path/to/file.proto +// +// With [buf], your buf.gen.yaml will look like this: +// +// version: v2 +// plugins: +// - local: protoc-gen-pony +// out: gen +// +// Generated files import the `protobuf` runtime library — see +// https://github.com/straw-hat-team/trogonai.com/tree/main/mrmeeseeks/protobuf +// for the Pony source. The runtime exposes WireReader/WireWriter, Tag, +// Scalar, the WireType union, and the WireError typed-error union. +// +// [buf]: https://buf.build +package main + +import ( + "flag" + "fmt" + "io" + "os" + "strings" + + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/pluginpb" +) + +var ( + // Set by ldflags during build time. + version = "dev" + commit = "unknown" + date = "unknown" +) + +const usage = "\n\nFlags:\n -h, --help\tPrint this help and exit.\n --version\tPrint the version and exit." + +// goImportStub keeps protogen happy on non-Go targets. protogen.Options.New +// errors out if any input file lacks a Go import path; users targeting Pony +// shouldn't have to set go_package or M-mappings just for that. We prepend +// a stub M-entry for every file before calling protogen — user-provided +// params still win because they appear later in the Parameter string. +const goImportStub = "protoc-gen-pony/stub" + +func main() { + if len(os.Args) == 2 && os.Args[1] == "--version" { + fmt.Printf("protoc-gen-pony %s (commit: %s, built: %s)\n", version, commit, date) + os.Exit(0) + } + if len(os.Args) == 2 && (os.Args[1] == "-h" || os.Args[1] == "--help") { + if _, err := fmt.Fprintln(os.Stdout, usage); err != nil { + os.Exit(1) + } + os.Exit(0) + } + if len(os.Args) != 1 { + if _, err := fmt.Fprintln(os.Stderr, usage); err != nil { + os.Exit(1) + } + os.Exit(1) + } + + in, err := io.ReadAll(os.Stdin) + if err != nil { + fmt.Fprintf(os.Stderr, "protoc-gen-pony: read stdin: %v\n", err) + os.Exit(1) + } + var req pluginpb.CodeGeneratorRequest + if err := proto.Unmarshal(in, &req); err != nil { + fmt.Fprintf(os.Stderr, "protoc-gen-pony: unmarshal request: %v\n", err) + os.Exit(1) + } + injectGoImportStubs(&req) + + var flagSet flag.FlagSet + plugin, err := protogen.Options{ParamFunc: flagSet.Set}.New(&req) + if err != nil { + fmt.Fprintf(os.Stderr, "protoc-gen-pony: %v\n", err) + os.Exit(1) + } + plugin.SupportedFeatures = + uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) | + uint64(pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS) + plugin.SupportedEditionsMinimum = descriptorpb.Edition_EDITION_PROTO2 + plugin.SupportedEditionsMaximum = descriptorpb.Edition_EDITION_2024 + + for _, file := range plugin.Files { + if file.Generate { + generateFile(plugin, file) + } + } + + resp := plugin.Response() + out, err := proto.Marshal(resp) + if err != nil { + fmt.Fprintf(os.Stderr, "protoc-gen-pony: marshal response: %v\n", err) + os.Exit(1) + } + if _, err := os.Stdout.Write(out); err != nil { + fmt.Fprintf(os.Stderr, "protoc-gen-pony: write stdout: %v\n", err) + os.Exit(1) + } +} + +// injectGoImportStubs prepends an M=stub for every file in the request +// so protogen.New doesn't reject the input. User-provided M-params still +// take precedence because they appear later in the comma-separated string. +func injectGoImportStubs(req *pluginpb.CodeGeneratorRequest) { + parts := make([]string, 0, len(req.GetProtoFile())+1) + for _, file := range req.GetProtoFile() { + parts = append(parts, "M"+file.GetName()+"="+goImportStub) + } + if existing := req.GetParameter(); existing != "" { + parts = append(parts, existing) + } + combined := strings.Join(parts, ",") + req.Parameter = &combined +} diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go new file mode 100644 index 0000000..7a573ae --- /dev/null +++ b/cmd/protoc-gen-pony/main_test.go @@ -0,0 +1,230 @@ +package main + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/pluginpb" +) + +// runPlugin builds a synthetic plugin invocation around the supplied +// FileDescriptorProtos and returns the generated content for the file +// matching `wantFilename`. Test helper for the rest of the file. +func runPlugin(t *testing.T, files []*descriptorpb.FileDescriptorProto, wantFilename string) string { + t.Helper() + toGenerate := make([]string, 0, len(files)) + mappings := make([]string, 0, len(files)) + for _, f := range files { + toGenerate = append(toGenerate, f.GetName()) + // protogen insists on a Go import path even for non-Go targets; + // supplying M= stops it from erroring on missing go_package. + mappings = append(mappings, "M"+f.GetName()+"=example.com/test") + } + param := strings.Join(mappings, ",") + req := &pluginpb.CodeGeneratorRequest{ + FileToGenerate: toGenerate, + ProtoFile: files, + Parameter: ¶m, + } + plugin, err := protogen.Options{}.New(req) + require.NoError(t, err) + for _, f := range plugin.Files { + if f.Generate { + generateFile(plugin, f) + } + } + resp := plugin.Response() + require.Empty(t, resp.GetError(), "plugin reported error") + for _, f := range resp.GetFile() { + if f.GetName() == wantFilename { + return f.GetContent() + } + } + t.Fatalf("expected output file %q not in plugin response (got %v)", wantFilename, fileNames(resp.GetFile())) + return "" +} + +func fileNames(files []*pluginpb.CodeGeneratorResponse_File) []string { + names := make([]string, len(files)) + for i, f := range files { + names[i] = f.GetName() + } + return names +} + +// scalarMessageProto returns a FileDescriptor with one User message +// containing one int32, one string, and one bool field. +func scalarMessageProto() *descriptorpb.FileDescriptorProto { + return &descriptorpb.FileDescriptorProto{ + Name: proto.String("user.proto"), + Package: proto.String("acme.v1"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("User"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("id"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), + JsonName: proto.String("id"), + }, + { + Name: proto.String("name"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("name"), + }, + { + Name: proto.String("active"), + Number: proto.Int32(3), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(), + JsonName: proto.String("active"), + }, + }, + }, + }, + } +} + +func TestScalarMessage_ClassDecl(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{scalarMessageProto()}, "user.pony") + assert.Contains(t, out, "class val User") + assert.Contains(t, out, "let id: I32") + assert.Contains(t, out, "let name: String val") + assert.Contains(t, out, "let active: Bool") +} + +func TestScalarMessage_Constructor(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{scalarMessageProto()}, "user.pony") + assert.Contains(t, out, "new val create(") + assert.Contains(t, out, "id': I32 = 0") + assert.Contains(t, out, `name': String val = ""`) + assert.Contains(t, out, "active': Bool = false") +} + +func TestScalarMessage_Codec(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{scalarMessageProto()}, "user.pony") + assert.Contains(t, out, "primitive UserCodec") + assert.Contains(t, out, "fun decode(reader: WireReader ref): (User val | WireError)") + assert.Contains(t, out, "fun encode(writer: WireWriter ref, msg: User val)") +} + +func TestScalarMessage_DecodeDispatch(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{scalarMessageProto()}, "user.pony") + assert.Contains(t, out, "(1, WireVarint)") + assert.Contains(t, out, "(2, WireLenDelim)") + assert.Contains(t, out, "(3, WireVarint)") + assert.Contains(t, out, "Scalar.read_int32(reader)") + assert.Contains(t, out, "reader.read_string()") + assert.Contains(t, out, "Scalar.read_bool(reader)") +} + +func TestScalarMessage_EncodePresence(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{scalarMessageProto()}, "user.pony") + // String fields go through write_string_field (handles empty check). + assert.Contains(t, out, "writer.write_string_field(2, msg.name)") + // Numeric fields gate emission on != 0. + assert.Contains(t, out, "if msg.id != 0 then") + // Bool fields gate on the value itself. + assert.Contains(t, out, "if msg.active then") +} + +func TestEmptyMessage(t *testing.T) { + t.Parallel() + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("empty.proto"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Empty")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "empty.pony") + assert.Contains(t, out, "class val Empty") + assert.Contains(t, out, "new val create() => None") + assert.Contains(t, out, "primitive EmptyCodec") +} + +func TestUnsupportedShapesEmitTodo(t *testing.T) { + t.Parallel() + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("zoo.proto"), + Package: proto.String("zoo"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Zoo"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("id"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), + }, + { + Name: proto.String("tags"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + }, + }, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "zoo.pony") + assert.Contains(t, out, "let id: I32") + assert.Contains(t, out, "TODO protoc-gen-pony: field tags") + // Repeated fields don't appear in the constructor's supported list. + supportedConstructorLine := strings.Contains(out, "tags':") + assert.False(t, supportedConstructorLine, "tags should be skipped from the constructor signature") +} + +func TestNestedMessageFlatNaming(t *testing.T) { + t.Parallel() + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("nested.proto"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Outer"), + NestedType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Inner"), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("value"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT64.Enum(), + }, + }, + }, + }, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "nested.pony") + assert.Contains(t, out, "class val Outer") + assert.Contains(t, out, "class val Outer_Inner") + assert.Contains(t, out, "primitive Outer_InnerCodec") +} + +func TestFileHeaderHasSourceComment(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{scalarMessageProto()}, "user.pony") + assert.Contains(t, out, "// Generated by protoc-gen-pony. DO NOT EDIT.") + assert.Contains(t, out, "// Source: user.proto") +} From 2a253ff3ca3425206b83933fbad97384cd688552 Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Tue, 28 Apr 2026 08:05:15 -0700 Subject: [PATCH 02/11] refactor: collapse scalar dispatch tables and trim duplication Five parallel switches over protoreflect.Kind (ponyType / ponyDefault / ponyWireType / ponyReadExpr / ponyWriteCall / presenceCheck) collapse into a single scalarSpecs map. Helpers become one-line lookups; adding a Kind now requires editing one entry instead of risking a missed update across five tables. Also: - Fix isSupported's empty-if branch for editions explicit-presence fields; they now correctly fall into the unsupported (TODO) bucket. - Drop classifyFields's unused `unsupported` return and the dead `_ = unsupported` at the call site. emitClass now walks msg.Fields directly, branching on isSupported per field. - Inline fieldName (one-liner placeholder) into its 8 call sites. - Wire test runPlugin through injectGoImportStubs so the production workaround gets test coverage and the test stops re-implementing it. - Add a `field()` test fixture builder; collapses 6-line FieldDescriptor literals to one-liners. - Trim a duplicate godoc and a WHAT-comment. Net: -147 lines across the three files. End-to-end protoc smoke still generates Pony that compiles against the runtime and round-trips a User message correctly. 9 tests pass with -race. Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/generate.go | 273 ++++++++++++++----------------- cmd/protoc-gen-pony/main.go | 3 - cmd/protoc-gen-pony/main_test.go | 70 +++----- 3 files changed, 147 insertions(+), 199 deletions(-) diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go index f100078..1983246 100644 --- a/cmd/protoc-gen-pony/generate.go +++ b/cmd/protoc-gen-pony/generate.go @@ -54,26 +54,24 @@ func collectAndEmitMessages(g *protogen.GeneratedFile, messages []*protogen.Mess } func emitMessage(g *protogen.GeneratedFile, msg *protogen.Message, className string) { - supported, unsupported := classifyFields(msg.Fields) - - emitClass(g, className, msg.Fields, supported, unsupported) + supported := supportedFields(msg.Fields) + emitClass(g, className, msg.Fields, supported) g.P() emitCodec(g, className, supported) g.P() } -func emitClass(g *protogen.GeneratedFile, className string, all []*protogen.Field, supported, unsupported []*protogen.Field) { +func emitClass(g *protogen.GeneratedFile, className string, all, supported []*protogen.Field) { g.P(`class val `, className) for _, field := range all { if isSupported(field) { - g.P(` let `, fieldName(field), `: `, ponyType(field)) + g.P(` let `, field.Desc.Name(), `: `, ponyType(field)) } else { - g.P(` // TODO protoc-gen-pony: field `, fieldName(field), ` (`, fieldShape(field), `)`) + g.P(` // TODO protoc-gen-pony: field `, field.Desc.Name(), ` (`, fieldShape(field), `)`) } } g.P() emitConstructor(g, supported) - _ = unsupported } func emitConstructor(g *protogen.GeneratedFile, supported []*protogen.Field) { @@ -87,11 +85,11 @@ func emitConstructor(g *protogen.GeneratedFile, supported []*protogen.Field) { if i == len(supported)-1 { suffix = ")" } - g.P(` `, fieldName(field), `': `, ponyType(field), ` = `, ponyDefault(field), suffix) + g.P(` `, field.Desc.Name(), `': `, ponyType(field), ` = `, ponyDefault(field), suffix) } g.P(` =>`) for _, field := range supported { - g.P(` `, fieldName(field), ` = `, fieldName(field), `'`) + g.P(` `, field.Desc.Name(), ` = `, field.Desc.Name(), `'`) } } @@ -105,7 +103,7 @@ func emitCodec(g *protogen.GeneratedFile, className string, supported []*protoge func emitDecode(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { g.P(` fun decode(reader: WireReader ref): (`, className, ` val | WireError) =>`) for _, field := range supported { - g.P(` var `, fieldName(field), `: `, ponyType(field), ` = `, ponyDefault(field)) + g.P(` var `, field.Desc.Name(), `: `, ponyType(field), ` = `, ponyDefault(field)) } g.P(` while not reader.at_end() do`) g.P(` match reader.read_tag()`) @@ -114,7 +112,7 @@ func emitDecode(g *protogen.GeneratedFile, className string, supported []*protog for _, field := range supported { g.P(` | (`, field.Desc.Number(), `, `, ponyWireType(field), `) =>`) g.P(` match `, ponyReadExpr(field)) - g.P(` | let v: `, ponyType(field), ` => `, fieldName(field), ` = v`) + g.P(` | let v: `, ponyType(field), ` => `, field.Desc.Name(), ` = v`) g.P(` | let e: WireError => return e`) g.P(` end`) } @@ -137,7 +135,7 @@ func emitConstructorCall(g *protogen.GeneratedFile, className string, supported } parts := make([]string, len(supported)) for i, field := range supported { - parts[i] = fieldName(field) + parts[i] = string(field.Desc.Name()) } g.P(` `, className, `(`, strings.Join(parts, ", "), `)`) } @@ -157,38 +155,33 @@ func emitEncodeField(g *protogen.GeneratedFile, field *protogen.Field) { num := field.Desc.Number() if field.Desc.Kind() == protoreflect.StringKind { // write_string_field handles the empty-string skip internally. - g.P(` writer.write_string_field(`, num, `, msg.`, fieldName(field), `)`) + g.P(` writer.write_string_field(`, num, `, msg.`, field.Desc.Name(), `)`) return } g.P(` if `, presenceCheck(field), ` then`) g.P(` writer.write_tag(Tag(`, num, `, `, ponyWireType(field), `))`) - g.P(` `, ponyWriteCall(field, "msg."+fieldName(field))) + g.P(` `, ponyWriteCall(field, "msg."+string(field.Desc.Name()))) g.P(` end`) } -// emitEnumTodo lays down a placeholder for enums until enum codegen lands. func emitEnumTodo(g *protogen.GeneratedFile, enum *protogen.Enum, namePrefix string) { g.P(`// TODO protoc-gen-pony: enum `, namePrefix, enum.Desc.Name()) g.P() } -// ── classifiers ───────────────────────────────────────────────────────── - // isSupported is the v1 cut: singular implicit-presence proto3 scalars only. -// Repeated, optional explicit presence, oneofs, maps, messages, enums, and -// groups all surface as TODO comments until the corresponding codegen lands. +// Repeated, optional explicit presence (proto3 `optional` and editions +// EXPLICIT presence), oneofs, maps, messages, enums, and groups all surface +// as TODO comments until the corresponding codegen lands. func isSupported(field *protogen.Field) bool { if field.Desc.IsList() || field.Desc.IsMap() { return false } - if field.Desc.HasPresence() && !field.Desc.HasOptionalKeyword() { - // Editions field with explicit presence; defer until presence codegen. - // (HasOptionalKeyword catches the proto3 `optional` case below.) - } - if field.Desc.HasOptionalKeyword() { + // Editions fields with explicit presence (and proto3 `optional`). + if field.Desc.HasPresence() { return false } - if field.Oneof != nil && !field.Desc.HasOptionalKeyword() { + if field.Oneof != nil { return false } switch field.Desc.Kind() { @@ -198,15 +191,14 @@ func isSupported(field *protogen.Field) bool { return true } -func classifyFields(fields []*protogen.Field) (supported, unsupported []*protogen.Field) { +func supportedFields(fields []*protogen.Field) []*protogen.Field { + var out []*protogen.Field for _, f := range fields { if isSupported(f) { - supported = append(supported, f) - } else { - unsupported = append(unsupported, f) + out = append(out, f) } } - return + return out } // fieldShape returns a short string describing why a field is unsupported, @@ -228,144 +220,125 @@ func fieldShape(field *protogen.Field) string { return strings.Join(parts, " ") } -// fieldName returns the Pony-side identifier for a proto field. Proto field -// names are already snake_case so this is the identity for now; the helper -// exists so future name-collision handling lives in one place. -func fieldName(field *protogen.Field) string { - return string(field.Desc.Name()) +// scalarSpec lookups for Pony codegen. Single source of truth for the +// FieldKind → (Pony type, default, wire type, read expr, write fmt, +// presence-check fmt) mapping. writeFmt and presence are fmt.Sprintf +// patterns with one `%s` for the value reference; an empty writeFmt means +// the kind is special-cased in emitEncodeField (StringKind goes through +// write_string_field). Adding a new Kind requires editing one entry. +type scalarSpec struct { + ponyType string + ponyDefault string + wireType string + readExpr string + writeFmt string + presence string } -// ── scalar type lookup ────────────────────────────────────────────────── +var scalarSpecs = map[protoreflect.Kind]scalarSpec{ + protoreflect.BoolKind: { + ponyType: "Bool", ponyDefault: "false", wireType: "WireVarint", + readExpr: "Scalar.read_bool(reader)", writeFmt: "Scalar.write_bool(writer, %s)", + presence: "%s", + }, + protoreflect.Int32Kind: { + ponyType: "I32", ponyDefault: "0", wireType: "WireVarint", + readExpr: "Scalar.read_int32(reader)", writeFmt: "Scalar.write_int32(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Int64Kind: { + ponyType: "I64", ponyDefault: "0", wireType: "WireVarint", + readExpr: "Scalar.read_int64(reader)", writeFmt: "Scalar.write_int64(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Uint32Kind: { + ponyType: "U32", ponyDefault: "0", wireType: "WireVarint", + readExpr: "Scalar.read_uint32(reader)", writeFmt: "Scalar.write_uint32(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Uint64Kind: { + ponyType: "U64", ponyDefault: "0", wireType: "WireVarint", + readExpr: "Scalar.read_uint64(reader)", writeFmt: "Scalar.write_uint64(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Sint32Kind: { + ponyType: "I32", ponyDefault: "0", wireType: "WireVarint", + readExpr: "Scalar.read_sint32(reader)", writeFmt: "Scalar.write_sint32(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Sint64Kind: { + ponyType: "I64", ponyDefault: "0", wireType: "WireVarint", + readExpr: "Scalar.read_sint64(reader)", writeFmt: "Scalar.write_sint64(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Fixed32Kind: { + ponyType: "U32", ponyDefault: "0", wireType: "WireFixed32", + readExpr: "Scalar.read_fixed32(reader)", writeFmt: "Scalar.write_fixed32(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Fixed64Kind: { + ponyType: "U64", ponyDefault: "0", wireType: "WireFixed64", + readExpr: "Scalar.read_fixed64(reader)", writeFmt: "Scalar.write_fixed64(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Sfixed32Kind: { + ponyType: "I32", ponyDefault: "0", wireType: "WireFixed32", + readExpr: "Scalar.read_sfixed32(reader)", writeFmt: "Scalar.write_sfixed32(writer, %s)", + presence: "%s != 0", + }, + protoreflect.Sfixed64Kind: { + ponyType: "I64", ponyDefault: "0", wireType: "WireFixed64", + readExpr: "Scalar.read_sfixed64(reader)", writeFmt: "Scalar.write_sfixed64(writer, %s)", + presence: "%s != 0", + }, + protoreflect.FloatKind: { + ponyType: "F32", ponyDefault: "0.0", wireType: "WireFixed32", + readExpr: "Scalar.read_float(reader)", writeFmt: "Scalar.write_float(writer, %s)", + presence: "%s != 0.0", + }, + protoreflect.DoubleKind: { + ponyType: "F64", ponyDefault: "0.0", wireType: "WireFixed64", + readExpr: "Scalar.read_double(reader)", writeFmt: "Scalar.write_double(writer, %s)", + presence: "%s != 0.0", + }, + protoreflect.StringKind: { + ponyType: "String val", ponyDefault: `""`, wireType: "WireLenDelim", + readExpr: "reader.read_string()", + // writeFmt empty: emitEncodeField special-cases StringKind through + // write_string_field. presence empty for the same reason. + }, + protoreflect.BytesKind: { + ponyType: "Array[U8] val", ponyDefault: "recover val Array[U8] end", wireType: "WireLenDelim", + readExpr: "reader.read_len_delim()", writeFmt: "writer.write_len_delim(%s)", + presence: "%s.size() > 0", + }, +} func ponyType(field *protogen.Field) string { - switch field.Desc.Kind() { - case protoreflect.BoolKind: - return "Bool" - case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: - return "I32" - case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: - return "I64" - case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: - return "U32" - case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: - return "U64" - case protoreflect.FloatKind: - return "F32" - case protoreflect.DoubleKind: - return "F64" - case protoreflect.StringKind: - return "String val" - case protoreflect.BytesKind: - return "Array[U8] val" - } - return "" + return scalarSpecs[field.Desc.Kind()].ponyType } func ponyDefault(field *protogen.Field) string { - switch field.Desc.Kind() { - case protoreflect.BoolKind: - return "false" - case protoreflect.FloatKind, protoreflect.DoubleKind: - return "0.0" - case protoreflect.StringKind: - return `""` - case protoreflect.BytesKind: - return "recover val Array[U8] end" - } - return "0" + return scalarSpecs[field.Desc.Kind()].ponyDefault } func ponyWireType(field *protogen.Field) string { - switch field.Desc.Kind() { - case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: - return "WireFixed32" - case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: - return "WireFixed64" - case protoreflect.StringKind, protoreflect.BytesKind: - return "WireLenDelim" - } - return "WireVarint" + return scalarSpecs[field.Desc.Kind()].wireType } func ponyReadExpr(field *protogen.Field) string { - switch field.Desc.Kind() { - case protoreflect.BoolKind: - return "Scalar.read_bool(reader)" - case protoreflect.Int32Kind: - return "Scalar.read_int32(reader)" - case protoreflect.Int64Kind: - return "Scalar.read_int64(reader)" - case protoreflect.Uint32Kind: - return "Scalar.read_uint32(reader)" - case protoreflect.Uint64Kind: - return "Scalar.read_uint64(reader)" - case protoreflect.Sint32Kind: - return "Scalar.read_sint32(reader)" - case protoreflect.Sint64Kind: - return "Scalar.read_sint64(reader)" - case protoreflect.Fixed32Kind: - return "Scalar.read_fixed32(reader)" - case protoreflect.Fixed64Kind: - return "Scalar.read_fixed64(reader)" - case protoreflect.Sfixed32Kind: - return "Scalar.read_sfixed32(reader)" - case protoreflect.Sfixed64Kind: - return "Scalar.read_sfixed64(reader)" - case protoreflect.FloatKind: - return "Scalar.read_float(reader)" - case protoreflect.DoubleKind: - return "Scalar.read_double(reader)" - case protoreflect.StringKind: - return "reader.read_string()" - case protoreflect.BytesKind: - return "reader.read_len_delim()" - } - return "" + return scalarSpecs[field.Desc.Kind()].readExpr } func ponyWriteCall(field *protogen.Field, valueRef string) string { - switch field.Desc.Kind() { - case protoreflect.BoolKind: - return fmt.Sprintf("Scalar.write_bool(writer, %s)", valueRef) - case protoreflect.Int32Kind: - return fmt.Sprintf("Scalar.write_int32(writer, %s)", valueRef) - case protoreflect.Int64Kind: - return fmt.Sprintf("Scalar.write_int64(writer, %s)", valueRef) - case protoreflect.Uint32Kind: - return fmt.Sprintf("Scalar.write_uint32(writer, %s)", valueRef) - case protoreflect.Uint64Kind: - return fmt.Sprintf("Scalar.write_uint64(writer, %s)", valueRef) - case protoreflect.Sint32Kind: - return fmt.Sprintf("Scalar.write_sint32(writer, %s)", valueRef) - case protoreflect.Sint64Kind: - return fmt.Sprintf("Scalar.write_sint64(writer, %s)", valueRef) - case protoreflect.Fixed32Kind: - return fmt.Sprintf("Scalar.write_fixed32(writer, %s)", valueRef) - case protoreflect.Fixed64Kind: - return fmt.Sprintf("Scalar.write_fixed64(writer, %s)", valueRef) - case protoreflect.Sfixed32Kind: - return fmt.Sprintf("Scalar.write_sfixed32(writer, %s)", valueRef) - case protoreflect.Sfixed64Kind: - return fmt.Sprintf("Scalar.write_sfixed64(writer, %s)", valueRef) - case protoreflect.FloatKind: - return fmt.Sprintf("Scalar.write_float(writer, %s)", valueRef) - case protoreflect.DoubleKind: - return fmt.Sprintf("Scalar.write_double(writer, %s)", valueRef) - case protoreflect.BytesKind: - return fmt.Sprintf("writer.write_len_delim(%s)", valueRef) + spec := scalarSpecs[field.Desc.Kind()] + if spec.writeFmt == "" { + return "" } - return "" + return fmt.Sprintf(spec.writeFmt, valueRef) } func presenceCheck(field *protogen.Field) string { - ref := "msg." + fieldName(field) - switch field.Desc.Kind() { - case protoreflect.BoolKind: - return ref - case protoreflect.FloatKind, protoreflect.DoubleKind: - return ref + " != 0.0" - case protoreflect.BytesKind: - return ref + ".size() > 0" - } - return ref + " != 0" + spec := scalarSpecs[field.Desc.Kind()] + return fmt.Sprintf(spec.presence, "msg."+string(field.Desc.Name())) } diff --git a/cmd/protoc-gen-pony/main.go b/cmd/protoc-gen-pony/main.go index 14800cb..d3e80b9 100644 --- a/cmd/protoc-gen-pony/main.go +++ b/cmd/protoc-gen-pony/main.go @@ -111,9 +111,6 @@ func main() { } } -// injectGoImportStubs prepends an M=stub for every file in the request -// so protogen.New doesn't reject the input. User-provided M-params still -// take precedence because they appear later in the comma-separated string. func injectGoImportStubs(req *pluginpb.CodeGeneratorRequest) { parts := make([]string, 0, len(req.GetProtoFile())+1) for _, file := range req.GetProtoFile() { diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index 7a573ae..e46458b 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -14,23 +14,19 @@ import ( // runPlugin builds a synthetic plugin invocation around the supplied // FileDescriptorProtos and returns the generated content for the file -// matching `wantFilename`. Test helper for the rest of the file. +// matching `wantFilename`. Goes through injectGoImportStubs so that helper +// gets exercised by the test suite. func runPlugin(t *testing.T, files []*descriptorpb.FileDescriptorProto, wantFilename string) string { t.Helper() toGenerate := make([]string, 0, len(files)) - mappings := make([]string, 0, len(files)) for _, f := range files { toGenerate = append(toGenerate, f.GetName()) - // protogen insists on a Go import path even for non-Go targets; - // supplying M= stops it from erroring on missing go_package. - mappings = append(mappings, "M"+f.GetName()+"=example.com/test") } - param := strings.Join(mappings, ",") req := &pluginpb.CodeGeneratorRequest{ FileToGenerate: toGenerate, ProtoFile: files, - Parameter: ¶m, } + injectGoImportStubs(req) plugin, err := protogen.Options{}.New(req) require.NoError(t, err) for _, f := range plugin.Files { @@ -57,6 +53,19 @@ func fileNames(files []*pluginpb.CodeGeneratorResponse_File) []string { return names } +// field is a fixture builder for FieldDescriptorProto that defaults Label +// to OPTIONAL — the singular-presence shape the v1 plugin generates code +// for. Tests that need a different label build the descriptor inline. +func field(name string, num int32, kind descriptorpb.FieldDescriptorProto_Type) *descriptorpb.FieldDescriptorProto { + return &descriptorpb.FieldDescriptorProto{ + Name: proto.String(name), + Number: proto.Int32(num), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: kind.Enum(), + JsonName: proto.String(name), + } +} + // scalarMessageProto returns a FileDescriptor with one User message // containing one int32, one string, and one bool field. func scalarMessageProto() *descriptorpb.FileDescriptorProto { @@ -68,27 +77,9 @@ func scalarMessageProto() *descriptorpb.FileDescriptorProto { { Name: proto.String("User"), Field: []*descriptorpb.FieldDescriptorProto{ - { - Name: proto.String("id"), - Number: proto.Int32(1), - Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), - Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), - JsonName: proto.String("id"), - }, - { - Name: proto.String("name"), - Number: proto.Int32(2), - Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), - Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), - JsonName: proto.String("name"), - }, - { - Name: proto.String("active"), - Number: proto.Int32(3), - Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), - Type: descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(), - JsonName: proto.String("active"), - }, + field("id", 1, descriptorpb.FieldDescriptorProto_TYPE_INT32), + field("name", 2, descriptorpb.FieldDescriptorProto_TYPE_STRING), + field("active", 3, descriptorpb.FieldDescriptorProto_TYPE_BOOL), }, }, }, @@ -160,6 +151,8 @@ func TestEmptyMessage(t *testing.T) { func TestUnsupportedShapesEmitTodo(t *testing.T) { t.Parallel() + tags := field("tags", 2, descriptorpb.FieldDescriptorProto_TYPE_STRING) + tags.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() file := &descriptorpb.FileDescriptorProto{ Name: proto.String("zoo.proto"), Package: proto.String("zoo"), @@ -168,18 +161,8 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { { Name: proto.String("Zoo"), Field: []*descriptorpb.FieldDescriptorProto{ - { - Name: proto.String("id"), - Number: proto.Int32(1), - Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), - Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), - }, - { - Name: proto.String("tags"), - Number: proto.Int32(2), - Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), - Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), - }, + field("id", 1, descriptorpb.FieldDescriptorProto_TYPE_INT32), + tags, }, }, }, @@ -204,12 +187,7 @@ func TestNestedMessageFlatNaming(t *testing.T) { { Name: proto.String("Inner"), Field: []*descriptorpb.FieldDescriptorProto{ - { - Name: proto.String("value"), - Number: proto.Int32(1), - Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), - Type: descriptorpb.FieldDescriptorProto_TYPE_INT64.Enum(), - }, + field("value", 1, descriptorpb.FieldDescriptorProto_TYPE_INT64), }, }, }, From f38b8fef7912826b25607000e3519fda107a2425 Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Tue, 28 Apr 2026 08:38:43 -0700 Subject: [PATCH 03/11] docs: point at TrogonStack/protobuf-pony for runtime The runtime extracted out of straw-hat-team/trogonai.com/mrmeeseeks into its own repo. README and main.go godoc now reference the new location. No code change. Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/README.md | 2 +- cmd/protoc-gen-pony/main.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/protoc-gen-pony/README.md b/cmd/protoc-gen-pony/README.md index ed5e66a..f180210 100644 --- a/cmd/protoc-gen-pony/README.md +++ b/cmd/protoc-gen-pony/README.md @@ -62,4 +62,4 @@ embedded messages, and enums emit a `// TODO protoc-gen-pony` comment until the corresponding codegen lands. Services (gRPC) are out of scope. [buf]: https://buf.build -[runtime]: https://github.com/straw-hat-team/trogonai.com/tree/main/mrmeeseeks/protobuf +[runtime]: https://github.com/TrogonStack/protobuf-pony diff --git a/cmd/protoc-gen-pony/main.go b/cmd/protoc-gen-pony/main.go index d3e80b9..2bfd78d 100644 --- a/cmd/protoc-gen-pony/main.go +++ b/cmd/protoc-gen-pony/main.go @@ -15,9 +15,9 @@ // out: gen // // Generated files import the `protobuf` runtime library — see -// https://github.com/straw-hat-team/trogonai.com/tree/main/mrmeeseeks/protobuf -// for the Pony source. The runtime exposes WireReader/WireWriter, Tag, -// Scalar, the WireType union, and the WireError typed-error union. +// https://github.com/TrogonStack/protobuf-pony for the Pony source. The +// runtime exposes WireReader/WireWriter, Tag, Scalar, the WireType union, +// and the WireError typed-error union. // // [buf]: https://buf.build package main From 5c2f4a5005fa5dcc7a04020ffab80b6f8e429e97 Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Tue, 28 Apr 2026 13:01:06 -0700 Subject: [PATCH 04/11] fix(protoc-gen-pony): address PR review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - main.go: drop FEATURE_SUPPORTS_EDITIONS + FEATURE_PROTO3_OPTIONAL advertisement and the EDITION_PROTO2..EDITION_2024 range. The v1 generator emits TODO comments for explicit-presence/oneof/map/ embedded/enum fields rather than handling them, so advertising the feature flags risks protoc passing inputs we'd silently miscompile. When a future PR ships real codegen for those shapes, flip the flags back on. - cmd/protoc-gen-pony/CHANGELOG.md: deleted. release-please owns the changelog generation; the hand-written file conflicts with that flow. - .github/.release-please-config.json + .release-please-manifest.json: wire protoc-gen-pony in alongside the connect-go and elixir-grpc plugins. Manifest entry starts at 0.0.0 so the first feat: commit bumps to 0.1.0. - .github/goreleaser.yml: add a builds: stanza for protoc-gen-pony so tagged releases produce a binary across linux/darwin/windows × amd64/arm64. - Taskfile.yml: build-plugin now uses deps: [build-plugin-go, build-plugin-elixir, build-plugin-pony] instead of inlining the same go build commands. - main_test.go: TestUnsupportedShapesEmitTodo now exercises every unsupported shape — repeated, proto3 optional, oneof, embedded message, enum, and map — asserting both the TODO comment emission and that none of the unsupported field names appear in the constructor signature. Signed-off-by: Erik Nilsen --- .github/.release-please-config.json | 3 ++ .github/.release-please-manifest.json | 3 +- .github/goreleaser.yml | 19 +++++++ Taskfile.yml | 6 +-- cmd/protoc-gen-pony/CHANGELOG.md | 10 ---- cmd/protoc-gen-pony/main.go | 12 ++--- cmd/protoc-gen-pony/main_test.go | 73 +++++++++++++++++++++++++-- 7 files changed, 99 insertions(+), 27 deletions(-) delete mode 100644 cmd/protoc-gen-pony/CHANGELOG.md diff --git a/.github/.release-please-config.json b/.github/.release-please-config.json index b253e91..d9573a8 100644 --- a/.github/.release-please-config.json +++ b/.github/.release-please-config.json @@ -27,6 +27,9 @@ }, "cmd/protoc-gen-elixir-grpc": { "component": "protoc-gen-elixir-grpc" + }, + "cmd/protoc-gen-pony": { + "component": "protoc-gen-pony" } }, "plugins": [ diff --git a/.github/.release-please-manifest.json b/.github/.release-please-manifest.json index 1adb373..542199d 100644 --- a/.github/.release-please-manifest.json +++ b/.github/.release-please-manifest.json @@ -1,4 +1,5 @@ { "cmd/protoc-gen-connect-go-servicestruct": "0.2.0", - "cmd/protoc-gen-elixir-grpc": "0.4.2" + "cmd/protoc-gen-elixir-grpc": "0.4.2", + "cmd/protoc-gen-pony": "0.0.0" } diff --git a/.github/goreleaser.yml b/.github/goreleaser.yml index bc17e7b..c6f366a 100644 --- a/.github/goreleaser.yml +++ b/.github/goreleaser.yml @@ -44,6 +44,25 @@ builds: - -X main.commit={{.Commit}} - -X main.date={{.Date}} + - id: protoc-gen-pony + main: ./cmd/protoc-gen-pony + binary: protoc-gen-pony + skip: '{{ ne .Env.BUILD_COMPONENT "protoc-gen-pony" }}' + env: + - CGO_ENABLED=0 + goos: + - linux + - darwin + - windows + goarch: + - amd64 + - arm64 + ldflags: + - -s -w + - -X main.version={{.Version}} + - -X main.commit={{.Commit}} + - -X main.date={{.Date}} + archives: - id: default name_template: >- diff --git a/Taskfile.yml b/Taskfile.yml index b706f6a..19c6a13 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -60,11 +60,7 @@ tasks: build-plugin: desc: Build all protoc plugin binaries - cmds: - - echo "Building protoc plugin binaries..." - - go build -o protoc-gen-connect-go-servicestruct ./cmd/protoc-gen-connect-go-servicestruct - - go build -o protoc-gen-elixir-grpc ./cmd/protoc-gen-elixir-grpc - - go build -o protoc-gen-pony ./cmd/protoc-gen-pony + deps: [build-plugin-go, build-plugin-elixir, build-plugin-pony] build-plugin-go: desc: Build the Go Connect protoc plugin binary diff --git a/cmd/protoc-gen-pony/CHANGELOG.md b/cmd/protoc-gen-pony/CHANGELOG.md deleted file mode 100644 index 87a09d8..0000000 --- a/cmd/protoc-gen-pony/CHANGELOG.md +++ /dev/null @@ -1,10 +0,0 @@ -# Changelog - -## Unreleased - -### Features - -* Initial release. Generates Pony `class val` records + sister `Codec` - primitives for proto3 messages with singular implicit-presence scalar - fields. Repeated, `optional`, oneof, map, embedded-message, and enum - fields surface as `// TODO protoc-gen-pony` placeholders. diff --git a/cmd/protoc-gen-pony/main.go b/cmd/protoc-gen-pony/main.go index 2bfd78d..7a1732f 100644 --- a/cmd/protoc-gen-pony/main.go +++ b/cmd/protoc-gen-pony/main.go @@ -31,7 +31,6 @@ import ( "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" ) @@ -87,11 +86,12 @@ func main() { fmt.Fprintf(os.Stderr, "protoc-gen-pony: %v\n", err) os.Exit(1) } - plugin.SupportedFeatures = - uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL) | - uint64(pluginpb.CodeGeneratorResponse_FEATURE_SUPPORTS_EDITIONS) - plugin.SupportedEditionsMinimum = descriptorpb.Edition_EDITION_PROTO2 - plugin.SupportedEditionsMaximum = descriptorpb.Edition_EDITION_2024 + // Don't advertise FEATURE_PROTO3_OPTIONAL or FEATURE_SUPPORTS_EDITIONS: + // the v1 generator emits TODO comments for explicit-presence/oneof/map/ + // embedded/enum fields rather than handling them. Advertising support + // for features we don't implement risks protoc passing input we'll + // silently miscompile. When a future PR lands real codegen for those + // shapes, flip the flags back on. for _, file := range plugin.Files { if file.Generate { diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index e46458b..74489f3 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -151,8 +151,48 @@ func TestEmptyMessage(t *testing.T) { func TestUnsupportedShapesEmitTodo(t *testing.T) { t.Parallel() + + // Repeated string. tags := field("tags", 2, descriptorpb.FieldDescriptorProto_TYPE_STRING) tags.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + + // Real oneof "kind" at OneofIndex 0 — synthetic oneofs for proto3 + // `optional` must come AFTER real oneofs in OneofDecl, so the real one + // is declared first. + typeA := field("type_a", 4, descriptorpb.FieldDescriptorProto_TYPE_STRING) + typeA.OneofIndex = proto.Int32(0) + typeB := field("type_b", 5, descriptorpb.FieldDescriptorProto_TYPE_INT32) + typeB.OneofIndex = proto.Int32(0) + + // proto3 explicit `optional` — synthesized into a single-field oneof + // at OneofIndex 1. + optCount := field("count", 3, descriptorpb.FieldDescriptorProto_TYPE_INT32) + optCount.Proto3Optional = proto.Bool(true) + optCount.OneofIndex = proto.Int32(1) + + // Embedded message field. + parent := field("parent", 6, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + parent.TypeName = proto.String(".zoo.Parent") + + // Enum field. + status := field("status", 7, descriptorpb.FieldDescriptorProto_TYPE_ENUM) + status.TypeName = proto.String(".zoo.Status") + + // map — modeled in descriptors as a repeated MESSAGE + // field pointing at a synthetic nested MapEntry type with + // MessageOptions.map_entry=true. + metadata := field("metadata", 8, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + metadata.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + metadata.TypeName = proto.String(".zoo.Zoo.MetadataEntry") + mapEntry := &descriptorpb.DescriptorProto{ + Name: proto.String("MetadataEntry"), + Field: []*descriptorpb.FieldDescriptorProto{ + field("key", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + field("value", 2, descriptorpb.FieldDescriptorProto_TYPE_INT32), + }, + Options: &descriptorpb.MessageOptions{MapEntry: proto.Bool(true)}, + } + file := &descriptorpb.FileDescriptorProto{ Name: proto.String("zoo.proto"), Package: proto.String("zoo"), @@ -162,17 +202,40 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { Name: proto.String("Zoo"), Field: []*descriptorpb.FieldDescriptorProto{ field("id", 1, descriptorpb.FieldDescriptorProto_TYPE_INT32), - tags, + tags, optCount, typeA, typeB, parent, status, metadata, + }, + NestedType: []*descriptorpb.DescriptorProto{mapEntry}, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + {Name: proto.String("kind")}, + {Name: proto.String("_count")}, + }, + }, + {Name: proto.String("Parent")}, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Status"), + Value: []*descriptorpb.EnumValueDescriptorProto{ + {Name: proto.String("UNKNOWN"), Number: proto.Int32(0)}, + {Name: proto.String("ACTIVE"), Number: proto.Int32(1)}, }, }, }, } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "zoo.pony") + + // The one supported field stays. assert.Contains(t, out, "let id: I32") - assert.Contains(t, out, "TODO protoc-gen-pony: field tags") - // Repeated fields don't appear in the constructor's supported list. - supportedConstructorLine := strings.Contains(out, "tags':") - assert.False(t, supportedConstructorLine, "tags should be skipped from the constructor signature") + + // Every unsupported shape lays down a TODO and stays out of the constructor. + unsupported := []string{"tags", "count", "type_a", "type_b", "parent", "status", "metadata"} + for _, name := range unsupported { + assert.Contains(t, out, "TODO protoc-gen-pony: field "+name, + "missing TODO for %q", name) + assert.False(t, strings.Contains(out, name+"': "), + "%q should be skipped from the constructor signature", name) + } } func TestNestedMessageFlatNaming(t *testing.T) { From bc4840f13f06b927bd97b0a546213074256e8638 Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Tue, 28 Apr 2026 13:08:26 -0700 Subject: [PATCH 05/11] fix(cd): wire protoc-gen-pony into release component validation + clean up review nits MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - .github/workflows/cd.yml: add protoc-gen-pony to the validation case statement. Without this, a protoc-gen-pony@v0.1.0 tag would fail the "Valid component" gate and abort the release. Caught by simplify pass. - cmd/protoc-gen-pony/main.go: trim the no-features-advertised comment from 6 lines to 4. Speculation about future PR re-enabling the flags is recoverable from git history; the load-bearing WHY (silent miscompile risk) stays. - cmd/protoc-gen-pony/main_test.go: replace assert.False(strings.Contains(...)) with assert.NotContains, and drop the now-unused strings import. Drop a few WHAT-comments that restated the line below them; keep the protobuf-internals comments (synthetic-oneof ordering, optional encoding, MapEntry shape) — those document non-obvious descriptor semantics future readers need. Signed-off-by: Erik Nilsen --- .github/workflows/cd.yml | 4 ++-- cmd/protoc-gen-pony/main.go | 10 ++++------ cmd/protoc-gen-pony/main_test.go | 13 +++---------- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 461605e..c901dbc 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -76,12 +76,12 @@ jobs: # Validate component is known case "$COMPONENT" in - protoc-gen-elixir-grpc|protoc-gen-connect-go-servicestruct) + protoc-gen-elixir-grpc|protoc-gen-connect-go-servicestruct|protoc-gen-pony) echo "Valid component: ${COMPONENT}" ;; *) echo "ERROR: Unknown component: ${COMPONENT}" - echo "Valid components: protoc-gen-elixir-grpc, protoc-gen-connect-go-servicestruct" + echo "Valid components: protoc-gen-elixir-grpc, protoc-gen-connect-go-servicestruct, protoc-gen-pony" exit 1 ;; esac diff --git a/cmd/protoc-gen-pony/main.go b/cmd/protoc-gen-pony/main.go index 7a1732f..ca30c99 100644 --- a/cmd/protoc-gen-pony/main.go +++ b/cmd/protoc-gen-pony/main.go @@ -86,12 +86,10 @@ func main() { fmt.Fprintf(os.Stderr, "protoc-gen-pony: %v\n", err) os.Exit(1) } - // Don't advertise FEATURE_PROTO3_OPTIONAL or FEATURE_SUPPORTS_EDITIONS: - // the v1 generator emits TODO comments for explicit-presence/oneof/map/ - // embedded/enum fields rather than handling them. Advertising support - // for features we don't implement risks protoc passing input we'll - // silently miscompile. When a future PR lands real codegen for those - // shapes, flip the flags back on. + // Advertise no features — the v1 generator emits TODO comments for + // explicit-presence/oneof/map/embedded/enum fields rather than + // handling them. Advertising features we don't implement risks + // protoc passing input we'd silently miscompile. for _, file := range plugin.Files { if file.Generate { diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index 74489f3..7a10abd 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -1,7 +1,6 @@ package main import ( - "strings" "testing" "github.com/stretchr/testify/assert" @@ -152,13 +151,11 @@ func TestEmptyMessage(t *testing.T) { func TestUnsupportedShapesEmitTodo(t *testing.T) { t.Parallel() - // Repeated string. tags := field("tags", 2, descriptorpb.FieldDescriptorProto_TYPE_STRING) tags.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() - // Real oneof "kind" at OneofIndex 0 — synthetic oneofs for proto3 - // `optional` must come AFTER real oneofs in OneofDecl, so the real one - // is declared first. + // Synthetic oneofs (proto3 `optional`) must come AFTER real oneofs in + // OneofDecl, so the real "kind" oneof is declared first at index 0. typeA := field("type_a", 4, descriptorpb.FieldDescriptorProto_TYPE_STRING) typeA.OneofIndex = proto.Int32(0) typeB := field("type_b", 5, descriptorpb.FieldDescriptorProto_TYPE_INT32) @@ -170,11 +167,9 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { optCount.Proto3Optional = proto.Bool(true) optCount.OneofIndex = proto.Int32(1) - // Embedded message field. parent := field("parent", 6, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) parent.TypeName = proto.String(".zoo.Parent") - // Enum field. status := field("status", 7, descriptorpb.FieldDescriptorProto_TYPE_ENUM) status.TypeName = proto.String(".zoo.Status") @@ -225,15 +220,13 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "zoo.pony") - // The one supported field stays. assert.Contains(t, out, "let id: I32") - // Every unsupported shape lays down a TODO and stays out of the constructor. unsupported := []string{"tags", "count", "type_a", "type_b", "parent", "status", "metadata"} for _, name := range unsupported { assert.Contains(t, out, "TODO protoc-gen-pony: field "+name, "missing TODO for %q", name) - assert.False(t, strings.Contains(out, name+"': "), + assert.NotContains(t, out, name+"': ", "%q should be skipped from the constructor signature", name) } } From ef333da6b674edd4d0411d6d16873f4e8a5d2dcf Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Tue, 28 Apr 2026 13:13:43 -0700 Subject: [PATCH 06/11] test(protoc-gen-pony): lock in injectGoImportStubs parameter preservation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cover the "user-provided params still win" contract — existing parameter entries (including empty values and duplicates) must appear verbatim after the injected M-stubs, since protogen's later-wins semantics depend on it. Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/main_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index 7a10abd..040144c 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -262,3 +262,23 @@ func TestFileHeaderHasSourceComment(t *testing.T) { assert.Contains(t, out, "// Generated by protoc-gen-pony. DO NOT EDIT.") assert.Contains(t, out, "// Source: user.proto") } + +// "User-provided params still win" (main.go:50) relies on existing entries +// appearing verbatim after the injected stubs — protogen's later-wins +// semantics depend on it. Lock in: nothing dropped or reordered, including +// empty values and duplicates. +func TestInjectGoImportStubs_PreservesExistingParameters(t *testing.T) { + t.Parallel() + const existing = "foo=1,bar=,foo=2" + req := &pluginpb.CodeGeneratorRequest{ + ProtoFile: []*descriptorpb.FileDescriptorProto{ + {Name: proto.String("user.proto")}, + {Name: proto.String("admin.proto")}, + }, + Parameter: proto.String(existing), + } + injectGoImportStubs(req) + assert.Equal(t, + "Muser.proto=protoc-gen-pony/stub,Madmin.proto=protoc-gen-pony/stub,"+existing, + req.GetParameter()) +} From fc84e7cf0733e333bf12b63fa86da51e08593e9d Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Wed, 29 Apr 2026 09:40:37 -0700 Subject: [PATCH 07/11] feat(protoc-gen-pony): enums, embedded messages, repeated fields, optional, cross-dir refs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Enums: primitive per value + type alias + FromValue dispatcher - Singular message fields: (Child val | None) type, sub-codec decode/encode - Repeated scalar/enum fields: packed wire format, Array[T] trn → val - Repeated message fields: per-entry len-delim, same trn accumulator pattern - proto3 optional scalars/enums: (T | None) type, match-on-None encode (explicit presence) - Cross-file same-directory refs: path.Dir equality instead of exact path match - emitPackedEncode helper: eliminates 4 near-identical packed encode blocks - ponyEnumFromValueName helper: centralises the FromValue name derivation - ponyMessageClassName/enumNamePrefix: O(n) append+reverse instead of O(n²) prepend - Manifest: bump protoc-gen-pony to 0.0.1 per Yordis review comment Signed-off-by: Erik Nilsen --- .github/.release-please-manifest.json | 2 +- cmd/protoc-gen-pony/generate.go | 590 ++++++++++++++++++++------ cmd/protoc-gen-pony/main_test.go | 251 ++++++++++- 3 files changed, 717 insertions(+), 126 deletions(-) diff --git a/.github/.release-please-manifest.json b/.github/.release-please-manifest.json index 542199d..ec07d14 100644 --- a/.github/.release-please-manifest.json +++ b/.github/.release-please-manifest.json @@ -1,5 +1,5 @@ { "cmd/protoc-gen-connect-go-servicestruct": "0.2.0", "cmd/protoc-gen-elixir-grpc": "0.4.2", - "cmd/protoc-gen-pony": "0.0.0" + "cmd/protoc-gen-pony": "0.0.1" } diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go index 1983246..ecdfb0b 100644 --- a/cmd/protoc-gen-pony/generate.go +++ b/cmd/protoc-gen-pony/generate.go @@ -2,12 +2,21 @@ package main import ( "fmt" + "path" "strings" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" ) +// genCtx carries the plugin and file context through all emit functions so +// that same-file message/enum name resolution is available everywhere. +type genCtx struct { + plugin *protogen.Plugin + file *protogen.File + g *protogen.GeneratedFile +} + // generateFile emits one `.pony` source per `.proto` input. Pony has no // notion of a `package` keyword inside the file (packages are directories), // so we just stream each top-level message in the file. Nested messages @@ -30,9 +39,10 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File) { g.P(`// Source: `, file.Desc.Path()) g.P() - collectAndEmitMessages(g, file.Messages, "") + ctx := &genCtx{plugin: plugin, file: file, g: g} + ctx.collectAndEmitMessages(file.Messages, "") for _, enum := range file.Enums { - emitEnumTodo(g, enum, "") + ctx.emitEnum(enum, "") } } @@ -40,171 +50,542 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File) { // nested types into their parent's namespace via `Outer_Inner` mangling. // Messages are emitted in source order — protogen already gives us a // deterministic walk, so we don't need to sort. -func collectAndEmitMessages(g *protogen.GeneratedFile, messages []*protogen.Message, namePrefix string) { +func (ctx *genCtx) collectAndEmitMessages(messages []*protogen.Message, namePrefix string) { for _, msg := range messages { - flatName := namePrefix + string(msg.Desc.Name()) - emitMessage(g, msg, flatName) - if len(msg.Messages) > 0 { - collectAndEmitMessages(g, msg.Messages, flatName+"_") + if msg.Desc.IsMapEntry() { + continue // synthetic map-entry type — map fields emit TODO } + flatName := namePrefix + string(msg.Desc.Name()) + ctx.emitMessage(msg, flatName) + ctx.collectAndEmitMessages(msg.Messages, flatName+"_") for _, enum := range msg.Enums { - emitEnumTodo(g, enum, flatName+"_") + ctx.emitEnum(enum, flatName+"_") } } } -func emitMessage(g *protogen.GeneratedFile, msg *protogen.Message, className string) { - supported := supportedFields(msg.Fields) - emitClass(g, className, msg.Fields, supported) - g.P() - emitCodec(g, className, supported) - g.P() +func (ctx *genCtx) emitMessage(msg *protogen.Message, className string) { + supported := ctx.supportedFields(msg.Fields) + ctx.emitClass(className, msg.Fields, supported) + ctx.g.P() + ctx.emitCodec(className, supported) + ctx.g.P() } -func emitClass(g *protogen.GeneratedFile, className string, all, supported []*protogen.Field) { - g.P(`class val `, className) +func (ctx *genCtx) emitClass(className string, all, supported []*protogen.Field) { + ctx.g.P(`class val `, className) for _, field := range all { - if isSupported(field) { - g.P(` let `, field.Desc.Name(), `: `, ponyType(field)) + if ctx.isSupported(field) { + ctx.g.P(` let `, field.Desc.Name(), `: `, ctx.fieldPonyType(field)) } else { - g.P(` // TODO protoc-gen-pony: field `, field.Desc.Name(), ` (`, fieldShape(field), `)`) + ctx.g.P(` // TODO protoc-gen-pony: field `, field.Desc.Name(), ` (`, fieldShape(field), `)`) } } - g.P() - emitConstructor(g, supported) + ctx.g.P() + ctx.emitConstructor(supported) } -func emitConstructor(g *protogen.GeneratedFile, supported []*protogen.Field) { +func (ctx *genCtx) emitConstructor(supported []*protogen.Field) { if len(supported) == 0 { - g.P(` new val create() => None`) + ctx.g.P(` new val create() => None`) return } - g.P(` new val create(`) + ctx.g.P(` new val create(`) for i, field := range supported { suffix := "," if i == len(supported)-1 { suffix = ")" } - g.P(` `, field.Desc.Name(), `': `, ponyType(field), ` = `, ponyDefault(field), suffix) + ctx.g.P(` `, field.Desc.Name(), `': `, ctx.fieldPonyType(field), ` = `, ctx.fieldPonyDefault(field), suffix) } - g.P(` =>`) + ctx.g.P(` =>`) for _, field := range supported { - g.P(` `, field.Desc.Name(), ` = `, field.Desc.Name(), `'`) + ctx.g.P(` `, field.Desc.Name(), ` = `, field.Desc.Name(), `'`) } } -func emitCodec(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { - g.P(`primitive `, className, `Codec`) - emitDecode(g, className, supported) - g.P() - emitEncode(g, className, supported) +func (ctx *genCtx) emitCodec(className string, supported []*protogen.Field) { + ctx.g.P(`primitive `, className, `Codec`) + ctx.emitDecode(className, supported) + ctx.g.P() + ctx.emitEncode(className, supported) } -func emitDecode(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { - g.P(` fun decode(reader: WireReader ref): (`, className, ` val | WireError) =>`) +func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field) { + ctx.g.P(` fun decode(reader: WireReader ref): (`, className, ` val | WireError) =>`) for _, field := range supported { - g.P(` var `, field.Desc.Name(), `: `, ponyType(field), ` = `, ponyDefault(field)) + name := string(field.Desc.Name()) + if field.Desc.IsList() { + elem := ctx.elemPonyType(field) + ctx.g.P(` var `, name, `: Array[`, elem, `] trn = recover trn Array[`, elem, `] end`) + } else { + ctx.g.P(` var `, name, `: `, ctx.fieldPonyType(field), ` = `, ctx.fieldPonyDefault(field)) + } } - g.P(` while not reader.at_end() do`) - g.P(` match reader.read_tag()`) - g.P(` | let t: Tag =>`) - g.P(` match (t.field_number, t.wire_type)`) + ctx.g.P(` while not reader.at_end() do`) + ctx.g.P(` match reader.read_tag()`) + ctx.g.P(` | let t: Tag =>`) + ctx.g.P(` match (t.field_number, t.wire_type)`) for _, field := range supported { - g.P(` | (`, field.Desc.Number(), `, `, ponyWireType(field), `) =>`) - g.P(` match `, ponyReadExpr(field)) - g.P(` | let v: `, ponyType(field), ` => `, field.Desc.Name(), ` = v`) - g.P(` | let e: WireError => return e`) - g.P(` end`) - } - g.P(` else`) - g.P(` match reader.skip(t.wire_type)`) - g.P(` | None => None`) - g.P(` | let e: WireError => return e`) - g.P(` end`) - g.P(` end`) - g.P(` | let e: WireError => return e`) - g.P(` end`) - g.P(` end`) - emitConstructorCall(g, className, supported) + ctx.emitDecodeArm(field) + } + ctx.g.P(` else`) + ctx.g.P(` match reader.skip(t.wire_type)`) + ctx.g.P(` | None => None`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` end`) + ctx.emitConstructorCall(className, supported) +} + +func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { + name := string(field.Desc.Name()) + num := field.Desc.Number() + wt := ctx.fieldPonyWireType(field) + ctx.g.P(` | (`, num, `, `, wt, `) =>`) + + switch { + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.MessageKind: + // non-packed: one tag + len-delim per element (proto3 repeated message) + codec := ponyMessageClassName(field.Message) + "Codec" + elemType := ponyMessageClassName(field.Message) + " val" + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let b: Array[U8] val =>`) + ctx.g.P(` match `, codec, `.decode(WireReader(b))`) + ctx.g.P(` | let v: `, elemType, ` => `, name, `.push(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.EnumKind: + fromValue := ponyEnumFromValueName(field.Enum) + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let b: Array[U8] val =>`) + ctx.g.P(` let sub = WireReader(b)`) + ctx.g.P(` while not sub.at_end() do`) + ctx.g.P(` match Scalar.read_int32(sub)`) + ctx.g.P(` | let v: I32 => `, name, `.push(`, fromValue, `(v))`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + + case field.Desc.IsList(): + spec := scalarSpecs[field.Desc.Kind()] + readExpr := strings.Replace(spec.readExpr, "reader", "sub", 1) + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let b: Array[U8] val =>`) + ctx.g.P(` let sub = WireReader(b)`) + ctx.g.P(` while not sub.at_end() do`) + ctx.g.P(` match `, readExpr) + ctx.g.P(` | let v: `, spec.ponyType, ` => `, name, `.push(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + + case field.Desc.Kind() == protoreflect.MessageKind: + codec := ponyMessageClassName(field.Message) + "Codec" + msgType := ponyMessageClassName(field.Message) + " val" + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let b: Array[U8] val =>`) + ctx.g.P(` match `, codec, `.decode(WireReader(b))`) + ctx.g.P(` | let v: `, msgType, ` => `, name, ` = v`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + + case field.Desc.Kind() == protoreflect.EnumKind: + fromValue := ponyEnumFromValueName(field.Enum) + ctx.g.P(` match Scalar.read_int32(reader)`) + ctx.g.P(` | let v: I32 => `, name, ` = `, fromValue, `(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + + default: + spec := scalarSpecs[field.Desc.Kind()] + ctx.g.P(` match `, spec.readExpr) + ctx.g.P(` | let v: `, spec.ponyType, ` => `, name, ` = v`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + } } -func emitConstructorCall(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { +func (ctx *genCtx) emitConstructorCall(className string, supported []*protogen.Field) { if len(supported) == 0 { - g.P(` `, className) + ctx.g.P(` `, className) return } parts := make([]string, len(supported)) for i, field := range supported { - parts[i] = string(field.Desc.Name()) + if field.Desc.IsList() { + parts[i] = "consume " + string(field.Desc.Name()) + } else { + parts[i] = string(field.Desc.Name()) + } } - g.P(` `, className, `(`, strings.Join(parts, ", "), `)`) + ctx.g.P(` `, className, `(`, strings.Join(parts, ", "), `)`) } -func emitEncode(g *protogen.GeneratedFile, className string, supported []*protogen.Field) { - g.P(` fun encode(writer: WireWriter ref, msg: `, className, ` val) =>`) +func (ctx *genCtx) emitEncode(className string, supported []*protogen.Field) { + ctx.g.P(` fun encode(writer: WireWriter ref, msg: `, className, ` val) =>`) if len(supported) == 0 { - g.P(` None`) + ctx.g.P(` None`) return } for _, field := range supported { - emitEncodeField(g, field) + ctx.emitEncodeField(field) } } -func emitEncodeField(g *protogen.GeneratedFile, field *protogen.Field) { +func (ctx *genCtx) emitEncodeField(field *protogen.Field) { + ref := "msg." + string(field.Desc.Name()) num := field.Desc.Number() - if field.Desc.Kind() == protoreflect.StringKind { - // write_string_field handles the empty-string skip internally. - g.P(` writer.write_string_field(`, num, `, msg.`, field.Desc.Name(), `)`) - return + + switch { + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.MessageKind: + // non-packed: one tag + len-delim per element (proto3 repeated message) + codec := ponyMessageClassName(field.Message) + "Codec" + ctx.g.P(` for v in `, ref, `.values() do`) + ctx.g.P(` let sub = WireWriter`) + ctx.g.P(` `, codec, `.encode(sub, v)`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(sub.done())`) + ctx.g.P(` end`) + + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.EnumKind: + ctx.emitPackedEncode(ref, num, "Scalar.write_int32(sub, v.value())") + + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.StringKind: + ctx.emitPackedEncode(ref, num, "sub.write_string(v)") + + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.BytesKind: + ctx.emitPackedEncode(ref, num, "sub.write_len_delim(v)") + + case field.Desc.IsList(): + spec := scalarSpecs[field.Desc.Kind()] + ctx.emitPackedEncode(ref, num, strings.Replace(fmt.Sprintf(spec.writeFmt, "v"), "writer", "sub", 1)) + + case field.Desc.Kind() == protoreflect.MessageKind: + codec := ponyMessageClassName(field.Message) + "Codec" + msgType := ponyMessageClassName(field.Message) + " val" + ctx.g.P(` match `, ref) + ctx.g.P(` | let v: `, msgType, ` =>`) + ctx.g.P(` let sub = WireWriter`) + ctx.g.P(` `, codec, `.encode(sub, v)`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(sub.done())`) + ctx.g.P(` end`) + + case !field.Desc.IsList() && field.Desc.HasOptionalKeyword(): + // optional field: explicit presence — emit match on None (never skip zero values) + ctx.emitOptionalEncodeField(field) + + case field.Desc.Kind() == protoreflect.EnumKind: + // singular enum: skip zero value (implicit absence) + ctx.g.P(` if `, ref, `.value() != 0 then`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireVarint))`) + ctx.g.P(` Scalar.write_int32(writer, `, ref, `.value())`) + ctx.g.P(` end`) + + case field.Desc.Kind() == protoreflect.StringKind: + // write_string_field handles empty-string skip internally + ctx.g.P(` writer.write_string_field(`, num, `, `, ref, `)`) + + default: + spec := scalarSpecs[field.Desc.Kind()] + ctx.g.P(` if `, fmt.Sprintf(spec.presence, ref), ` then`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, `, spec.wireType, `))`) + ctx.g.P(` `, fmt.Sprintf(spec.writeFmt, ref)) + ctx.g.P(` end`) } - g.P(` if `, presenceCheck(field), ` then`) - g.P(` writer.write_tag(Tag(`, num, `, `, ponyWireType(field), `))`) - g.P(` `, ponyWriteCall(field, "msg."+string(field.Desc.Name()))) - g.P(` end`) } -func emitEnumTodo(g *protogen.GeneratedFile, enum *protogen.Enum, namePrefix string) { - g.P(`// TODO protoc-gen-pony: enum `, namePrefix, enum.Desc.Name()) - g.P() +func (ctx *genCtx) emitPackedEncode(ref string, num protoreflect.FieldNumber, writeOp string) { + ctx.g.P(` if `, ref, `.size() > 0 then`) + ctx.g.P(` let sub = WireWriter`) + ctx.g.P(` for v in `, ref, `.values() do`) + ctx.g.P(` `, writeOp) + ctx.g.P(` end`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(sub.done())`) + ctx.g.P(` end`) } -// isSupported is the v1 cut: singular implicit-presence proto3 scalars only. -// Repeated, optional explicit presence (proto3 `optional` and editions -// EXPLICIT presence), oneofs, maps, messages, enums, and groups all surface -// as TODO comments until the corresponding codegen lands. -func isSupported(field *protogen.Field) bool { - if field.Desc.IsList() || field.Desc.IsMap() { - return false +func (ctx *genCtx) emitOptionalEncodeField(field *protogen.Field) { + ref := "msg." + string(field.Desc.Name()) + num := field.Desc.Number() + switch field.Desc.Kind() { + case protoreflect.EnumKind: + enumType := ponyEnumTypeName(field.Enum) + ctx.g.P(` match `, ref) + ctx.g.P(` | let v: `, enumType, ` =>`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireVarint))`) + ctx.g.P(` Scalar.write_int32(writer, v.value())`) + ctx.g.P(` end`) + case protoreflect.StringKind: + ctx.g.P(` match `, ref) + ctx.g.P(` | let v: String val =>`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_string(v)`) + ctx.g.P(` end`) + case protoreflect.BytesKind: + ctx.g.P(` match `, ref) + ctx.g.P(` | let v: Array[U8] val =>`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(v)`) + ctx.g.P(` end`) + default: + spec := scalarSpecs[field.Desc.Kind()] + ctx.g.P(` match `, ref) + ctx.g.P(` | let v: `, spec.ponyType, ` =>`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, `, spec.wireType, `))`) + ctx.g.P(` `, fmt.Sprintf(spec.writeFmt, "v")) + ctx.g.P(` end`) + } +} + +func (ctx *genCtx) emitEnum(enum *protogen.Enum, namePrefix string) { + enumTypeName := namePrefix + string(enum.Desc.Name()) + fromValueName := enumTypeName + "FromValue" + + var zeroName string + primNames := make([]string, 0, len(enum.Values)) + for _, v := range enum.Values { + prim := namePrefix + screamingToPascal(string(v.Desc.Name())) + ctx.g.P(`primitive `, prim, ` fun value(): I32 => `, v.Desc.Number()) + primNames = append(primNames, prim) + if v.Desc.Number() == 0 && zeroName == "" { + zeroName = prim + } + } + if zeroName == "" && len(primNames) > 0 { + zeroName = primNames[0] } - // Editions fields with explicit presence (and proto3 `optional`). - if field.Desc.HasPresence() { + ctx.g.P() + ctx.g.P(`type `, enumTypeName, ` is (`, strings.Join(primNames, " | "), `)`) + ctx.g.P() + ctx.g.P(`primitive `, fromValueName) + ctx.g.P(` fun apply(v: I32): `, enumTypeName, ` =>`) + ctx.g.P(` match v`) + for _, v := range enum.Values { + if v.Desc.Number() == 0 { + continue // zero value is the else branch + } + prim := namePrefix + screamingToPascal(string(v.Desc.Name())) + ctx.g.P(` | `, v.Desc.Number(), ` => `, prim) + } + ctx.g.P(` else `, zeroName) + ctx.g.P(` end`) + ctx.g.P() +} + +// isSupported returns true for field shapes we generate code for. +// Out: maps, oneofs, proto3 optional (synthetic oneof), groups, cross-file +// message/enum refs. In: scalars (singular + repeated), same-file messages +// (singular + repeated), same-file enums (singular + repeated). +func (ctx *genCtx) isSupported(field *protogen.Field) bool { + if field.Desc.IsMap() { return false } - if field.Oneof != nil { + if field.Oneof != nil && !field.Desc.HasOptionalKeyword() { return false } switch field.Desc.Kind() { - case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.EnumKind: + case protoreflect.GroupKind: return false + case protoreflect.MessageKind: + if field.Message == nil || field.Message.Desc.IsMapEntry() { + return false + } + return ctx.isSamePonyPackage(field.Message.Desc.ParentFile()) + case protoreflect.EnumKind: + if field.Enum == nil { + return false + } + return ctx.isSamePonyPackage(field.Enum.Desc.ParentFile()) } return true } -func supportedFields(fields []*protogen.Field) []*protogen.Field { +func (ctx *genCtx) supportedFields(fields []*protogen.Field) []*protogen.Field { var out []*protogen.Field for _, f := range fields { - if isSupported(f) { + if ctx.isSupported(f) { out = append(out, f) } } return out } +func (ctx *genCtx) isSamePonyPackage(parentFile protoreflect.FileDescriptor) bool { + return path.Dir(parentFile.Path()) == path.Dir(ctx.file.Desc.Path()) +} + +// fieldPonyType returns the Pony type declaration for a field. +// Repeated fields become Array[elem] val. +// Singular message fields become (ChildMsg val | None). +// Singular enum fields become the enum type alias. +// Scalars use scalarSpecs. +func (ctx *genCtx) fieldPonyType(field *protogen.Field) string { + if field.Desc.IsList() { + return "Array[" + ctx.elemPonyType(field) + "] val" + } + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return "(" + ponyMessageClassName(field.Message) + " val | None)" + case protoreflect.EnumKind: + if field.Desc.HasOptionalKeyword() { + return "(" + ponyEnumTypeName(field.Enum) + " | None)" + } + return ponyEnumTypeName(field.Enum) + } + if field.Desc.HasOptionalKeyword() { + return "(" + scalarSpecs[field.Desc.Kind()].ponyType + " | None)" + } + return scalarSpecs[field.Desc.Kind()].ponyType +} + +// elemPonyType returns the element type for repeated fields (no None wrapper +// for messages — each element is always present). +func (ctx *genCtx) elemPonyType(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return ponyMessageClassName(field.Message) + " val" + case protoreflect.EnumKind: + return ponyEnumTypeName(field.Enum) + } + return scalarSpecs[field.Desc.Kind()].ponyType +} + +// fieldPonyDefault returns the default value expression for a field. +func (ctx *genCtx) fieldPonyDefault(field *protogen.Field) string { + if field.Desc.IsList() { + return "recover val Array[" + ctx.elemPonyType(field) + "] end" + } + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return "None" + case protoreflect.EnumKind: + if field.Desc.HasOptionalKeyword() { + return "None" + } + return ponyEnumZeroValuePrimitive(field.Enum) + } + if field.Desc.HasOptionalKeyword() { + return "None" + } + return scalarSpecs[field.Desc.Kind()].ponyDefault +} + +// fieldPonyWireType returns the wire type string for the decode match arm. +func (ctx *genCtx) fieldPonyWireType(field *protogen.Field) string { + if field.Desc.IsList() { + return "WireLenDelim" // packed scalars + per-entry messages + } + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return "WireLenDelim" + case protoreflect.EnumKind: + return "WireVarint" + } + return scalarSpecs[field.Desc.Kind()].wireType +} + +// ── Name helpers ───────────────────────────────────────────────────────────── + +// screamingToPascal converts SCREAMING_SNAKE_CASE to PascalCase. +// STATUS_UNKNOWN → StatusUnknown, ACTIVE → Active. +func screamingToPascal(s string) string { + parts := strings.Split(strings.ToLower(s), "_") + var b strings.Builder + for _, p := range parts { + if len(p) > 0 { + b.WriteString(strings.ToUpper(p[:1]) + p[1:]) + } + } + return b.String() +} + +// ponyMessageClassName builds the flat Pony class name for a message by +// walking the parent chain (Outer.Inner → Outer_Inner). +func ponyMessageClassName(msg *protogen.Message) string { + var ancestors []string + parent := msg.Desc.Parent() + for { + parentMsg, ok := parent.(protoreflect.MessageDescriptor) + if !ok { + break + } + ancestors = append(ancestors, string(parentMsg.Name())) + parent = parentMsg.Parent() + } + // ancestors are innermost-first; reverse to outermost-first + for i, j := 0, len(ancestors)-1; i < j; i, j = i+1, j-1 { + ancestors[i], ancestors[j] = ancestors[j], ancestors[i] + } + return strings.Join(append(ancestors, string(msg.Desc.Name())), "_") +} + +// enumNamePrefix returns the message-hierarchy prefix (e.g. "Zoo_") for +// an enum nested inside messages. Top-level enums return "". +func enumNamePrefix(enum *protogen.Enum) string { + var ancestors []string + parent := enum.Desc.Parent() + for { + parentMsg, ok := parent.(protoreflect.MessageDescriptor) + if !ok { + break + } + ancestors = append(ancestors, string(parentMsg.Name())) + parent = parentMsg.Parent() + } + if len(ancestors) == 0 { + return "" + } + for i, j := 0, len(ancestors)-1; i < j; i, j = i+1, j-1 { + ancestors[i], ancestors[j] = ancestors[j], ancestors[i] + } + return strings.Join(ancestors, "_") + "_" +} + +// ponyEnumTypeName returns the Pony type alias name for an enum. +func ponyEnumTypeName(enum *protogen.Enum) string { + return enumNamePrefix(enum) + string(enum.Desc.Name()) +} + +func ponyEnumFromValueName(enum *protogen.Enum) string { + return ponyEnumTypeName(enum) + "FromValue" +} + +// ponyEnumZeroValuePrimitive returns the Pony primitive name for the proto3 +// zero value (number == 0) of an enum. +func ponyEnumZeroValuePrimitive(enum *protogen.Enum) string { + prefix := enumNamePrefix(enum) + for _, v := range enum.Values { + if v.Desc.Number() == 0 { + return prefix + screamingToPascal(string(v.Desc.Name())) + } + } + if len(enum.Values) > 0 { + return prefix + screamingToPascal(string(enum.Values[0].Desc.Name())) + } + return ponyEnumTypeName(enum) + "Zero" +} + +// ── fieldShape ──────────────────────────────────────────────────────────────── + // fieldShape returns a short string describing why a field is unsupported, // for the TODO comment. func fieldShape(field *protogen.Field) string { - parts := []string{} + var parts []string if field.Desc.IsMap() { parts = append(parts, "map") } else if field.Desc.IsList() { @@ -220,6 +601,8 @@ func fieldShape(field *protogen.Field) string { return strings.Join(parts, " ") } +// ── scalarSpecs ─────────────────────────────────────────────────────────────── + // scalarSpec lookups for Pony codegen. Single source of truth for the // FieldKind → (Pony type, default, wire type, read expr, write fmt, // presence-check fmt) mapping. writeFmt and presence are fmt.Sprintf @@ -313,32 +696,3 @@ var scalarSpecs = map[protoreflect.Kind]scalarSpec{ presence: "%s.size() > 0", }, } - -func ponyType(field *protogen.Field) string { - return scalarSpecs[field.Desc.Kind()].ponyType -} - -func ponyDefault(field *protogen.Field) string { - return scalarSpecs[field.Desc.Kind()].ponyDefault -} - -func ponyWireType(field *protogen.Field) string { - return scalarSpecs[field.Desc.Kind()].wireType -} - -func ponyReadExpr(field *protogen.Field) string { - return scalarSpecs[field.Desc.Kind()].readExpr -} - -func ponyWriteCall(field *protogen.Field, valueRef string) string { - spec := scalarSpecs[field.Desc.Kind()] - if spec.writeFmt == "" { - return "" - } - return fmt.Sprintf(spec.writeFmt, valueRef) -} - -func presenceCheck(field *protogen.Field) string { - spec := scalarSpecs[field.Desc.Kind()] - return fmt.Sprintf(spec.presence, "msg."+string(field.Desc.Name())) -} diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index 040144c..a8b5858 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -148,9 +148,9 @@ func TestEmptyMessage(t *testing.T) { assert.Contains(t, out, "primitive EmptyCodec") } -func TestUnsupportedShapesEmitTodo(t *testing.T) { - t.Parallel() - +// zooFileProto is a shared fixture for tests that need same-file message, +// enum, repeated, and unsupported shapes together. +func zooFileProto() *descriptorpb.FileDescriptorProto { tags := field("tags", 2, descriptorpb.FieldDescriptorProto_TYPE_STRING) tags.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() @@ -188,7 +188,7 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { Options: &descriptorpb.MessageOptions{MapEntry: proto.Bool(true)}, } - file := &descriptorpb.FileDescriptorProto{ + return &descriptorpb.FileDescriptorProto{ Name: proto.String("zoo.proto"), Package: proto.String("zoo"), Syntax: proto.String("proto3"), @@ -217,13 +217,25 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { }, }, } +} - out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "zoo.pony") +func TestUnsupportedShapesEmitTodo(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") assert.Contains(t, out, "let id: I32") - unsupported := []string{"tags", "count", "type_a", "type_b", "parent", "status", "metadata"} - for _, name := range unsupported { + // These shapes are now generated — confirm they are NOT TODO comments. + assert.Contains(t, out, "let tags: Array[String val] val") + assert.Contains(t, out, "let parent: (Parent val | None)") + assert.Contains(t, out, "let status: Status") + + // proto3 optional int32 is now supported. + assert.Contains(t, out, "let count: (I32 | None)") + + // These shapes remain unsupported — confirm TODO comments, not constructor params. + stillUnsupported := []string{"type_a", "type_b", "metadata"} + for _, name := range stillUnsupported { assert.Contains(t, out, "TODO protoc-gen-pony: field "+name, "missing TODO for %q", name) assert.NotContains(t, out, name+"': ", @@ -231,6 +243,116 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { } } +func TestEnumGeneration(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Primitives for each enum value. + assert.Contains(t, out, "primitive Unknown fun value(): I32 => 0") + assert.Contains(t, out, "primitive Active fun value(): I32 => 1") + + // Type alias union. + assert.Contains(t, out, "type Status is (Unknown | Active)") + + // FromValue dispatcher with zero-value fallback. + assert.Contains(t, out, "primitive StatusFromValue") + assert.Contains(t, out, "| 1 => Active") + assert.Contains(t, out, "else Unknown") +} + +func TestEnumField_ClassAndCodec(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Class declaration and constructor default. + assert.Contains(t, out, "let status: Status") + assert.Contains(t, out, "status': Status = Unknown") + + // Decode: reads an I32 and applies FromValue. + assert.Contains(t, out, "status = StatusFromValue(v)") + + // Encode: skips zero value. + assert.Contains(t, out, "if msg.status.value() != 0 then") + assert.Contains(t, out, "Scalar.write_int32(writer, msg.status.value())") +} + +func TestMessageField_ClassAndCodec(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Class declaration and constructor default. + assert.Contains(t, out, "let parent: (Parent val | None)") + assert.Contains(t, out, "parent': (Parent val | None) = None") + + // Decode: reads len-delim bytes, hands to sub-codec. + assert.Contains(t, out, "match ParentCodec.decode(WireReader(b))") + assert.Contains(t, out, "| let v: Parent val => parent = v") + + // Encode: match on None, emit sub-writer only when present. + assert.Contains(t, out, "match msg.parent") + assert.Contains(t, out, "| let v: Parent val =>") + assert.Contains(t, out, "ParentCodec.encode(sub, v)") +} + +func TestRepeatedScalarField(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Class field and constructor default. + assert.Contains(t, out, "let tags: Array[String val] val") + assert.Contains(t, out, "tags': Array[String val] val = recover val Array[String val] end") + + // Decode: trn accumulator + packed sub-reader loop. + assert.Contains(t, out, "var tags: Array[String val] trn = recover trn Array[String val] end") + assert.Contains(t, out, "let sub = WireReader(b)") + assert.Contains(t, out, "tags.push(v)") + + // Constructor call consumes the trn. + assert.Contains(t, out, "consume tags") + + // Encode: gate on non-empty, packed sub-writer. + assert.Contains(t, out, "if msg.tags.size() > 0 then") + assert.Contains(t, out, "sub.write_string(v)") + assert.Contains(t, out, "writer.write_tag(Tag(2, WireLenDelim))") +} + +func TestRepeatedMessageField(t *testing.T) { + t.Parallel() + + item := field("item", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + item.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + item.TypeName = proto.String(".pkg.Item") + + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("pkg.proto"), + Package: proto.String("pkg"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Container"), + Field: []*descriptorpb.FieldDescriptorProto{item}, + }, + {Name: proto.String("Item")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "pkg.pony") + + // Class field and default. + assert.Contains(t, out, "let item: Array[Item val] val") + assert.Contains(t, out, "item': Array[Item val] val = recover val Array[Item val] end") + + // Decode: trn accumulator, per-entry sub-codec. + assert.Contains(t, out, "var item: Array[Item val] trn = recover trn Array[Item val] end") + assert.Contains(t, out, "match ItemCodec.decode(WireReader(b))") + assert.Contains(t, out, "| let v: Item val => item.push(v)") + assert.Contains(t, out, "consume item") + + // Encode: one tag+len-delim per element. + assert.Contains(t, out, "for v in msg.item.values() do") + assert.Contains(t, out, "ItemCodec.encode(sub, v)") + assert.Contains(t, out, "writer.write_tag(Tag(2, WireLenDelim))") +} + func TestNestedMessageFlatNaming(t *testing.T) { t.Parallel() file := &descriptorpb.FileDescriptorProto{ @@ -267,6 +389,121 @@ func TestFileHeaderHasSourceComment(t *testing.T) { // appearing verbatim after the injected stubs — protogen's later-wins // semantics depend on it. Lock in: nothing dropped or reordered, including // empty values and duplicates. +func TestOptionalScalarField(t *testing.T) { + t.Parallel() + + score := field("score", 2, descriptorpb.FieldDescriptorProto_TYPE_INT32) + score.Proto3Optional = proto.Bool(true) + score.OneofIndex = proto.Int32(0) + + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("player.proto"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Player"), + Field: []*descriptorpb.FieldDescriptorProto{score}, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + {Name: proto.String("_score")}, + }, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "player.pony") + + // Class field and constructor default. + assert.Contains(t, out, "let score: (I32 | None)") + assert.Contains(t, out, "score': (I32 | None) = None") + + // Decode var: (I32 | None) + assert.Contains(t, out, "var score: (I32 | None) = None") + + // Encode: match on None (explicit presence — zero is emitted when set). + assert.Contains(t, out, "match msg.score") + assert.Contains(t, out, "| let v: I32 =>") + assert.NotContains(t, out, "if msg.score != 0") +} + +func TestOptionalEnumField(t *testing.T) { + t.Parallel() + + status := field("status", 1, descriptorpb.FieldDescriptorProto_TYPE_ENUM) + status.TypeName = proto.String(".opt_test.Color") + status.Proto3Optional = proto.Bool(true) + status.OneofIndex = proto.Int32(0) + + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("opt_test.proto"), + Package: proto.String("opt_test"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Palette"), + Field: []*descriptorpb.FieldDescriptorProto{status}, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + {Name: proto.String("_status")}, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Color"), + Value: []*descriptorpb.EnumValueDescriptorProto{ + {Name: proto.String("RED"), Number: proto.Int32(0)}, + {Name: proto.String("BLUE"), Number: proto.Int32(1)}, + }, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "opt_test.pony") + + // Class field: (Color | None), default None. + assert.Contains(t, out, "let status: (Color | None)") + assert.Contains(t, out, "status': (Color | None) = None") + + // Encode: match on None (not zero-check). + assert.Contains(t, out, "match msg.status") + assert.Contains(t, out, "| let v: Color =>") + assert.NotContains(t, out, "if msg.status.value() != 0") +} + +func TestCrossFileSameDirectoryRef(t *testing.T) { + t.Parallel() + + addrField := field("address", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + addrField.TypeName = proto.String(".geo.Address") + + personFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("geo/person.proto"), + Package: proto.String("geo"), + Syntax: proto.String("proto3"), + Dependency: []string{"geo/address.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Person"), + Field: []*descriptorpb.FieldDescriptorProto{addrField}, + }, + }, + } + addressFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("geo/address.proto"), + Package: proto.String("geo"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Address")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{addressFile, personFile}, "geo/person.pony") + + // Cross-file same-directory ref should be generated, not TODO. + assert.Contains(t, out, "let address: (Address val | None)") + assert.NotContains(t, out, "TODO protoc-gen-pony: field address") + + // Sub-codec calls present. + assert.Contains(t, out, "AddressCodec.decode(WireReader(b))") + assert.Contains(t, out, "AddressCodec.encode(sub, v)") +} + func TestInjectGoImportStubs_PreservesExistingParameters(t *testing.T) { t.Parallel() const existing = "foo=1,bar=,foo=2" From 761401e7798360fa6b0389679a72becb95f1ec2b Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Wed, 29 Apr 2026 10:00:07 -0700 Subject: [PATCH 08/11] =?UTF-8?q?fix(protoc-gen-pony):=20address=20CodeRab?= =?UTF-8?q?bit=20review=20=E2=80=94=20wire=20correctness=20+=20enum=20roun?= =?UTF-8?q?d-trip?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Services now included in empty-file early return; emit TODO comment per service - Repeated string/bytes get per-element WireLenDelim arms (never packable) - Packable repeated (numeric/enum) emit both packed (WireLenDelim) and unpacked (element wire type) decode arms per proto3 decoder spec - Add StatusRaw class val for unknown enum values; FromValue preserves raw I32 instead of collapsing to zero, fixing encode round-trip for unknown values - Remove unused fieldPonyWireType helper Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/generate.go | 95 ++++++++++++++++++++------------ cmd/protoc-gen-pony/main_test.go | 66 ++++++++++++++++++---- 2 files changed, 117 insertions(+), 44 deletions(-) diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go index ecdfb0b..a14d00d 100644 --- a/cmd/protoc-gen-pony/generate.go +++ b/cmd/protoc-gen-pony/generate.go @@ -27,7 +27,7 @@ type genCtx struct { // from file.Desc.Path() directly because protogen's GeneratedFilenamePrefix // is prefixed with the Go import path, which is irrelevant for Pony output. func generateFile(plugin *protogen.Plugin, file *protogen.File) { - if len(file.Messages) == 0 && len(file.Enums) == 0 { + if len(file.Messages) == 0 && len(file.Enums) == 0 && len(file.Services) == 0 { return } outPath := strings.TrimSuffix(file.Desc.Path(), ".proto") + ".pony" @@ -44,6 +44,9 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File) { for _, enum := range file.Enums { ctx.emitEnum(enum, "") } + for _, svc := range file.Services { + g.P(`// TODO protoc-gen-pony: service `, svc.Desc.Name(), ` (service)`) + } } // collectAndEmitMessages walks a slice of messages depth-first, flattening @@ -144,14 +147,13 @@ func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field) { func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { name := string(field.Desc.Name()) num := field.Desc.Number() - wt := ctx.fieldPonyWireType(field) - ctx.g.P(` | (`, num, `, `, wt, `) =>`) switch { case field.Desc.IsList() && field.Desc.Kind() == protoreflect.MessageKind: - // non-packed: one tag + len-delim per element (proto3 repeated message) + // non-packed: one tag + len-delim per element codec := ponyMessageClassName(field.Message) + "Codec" elemType := ponyMessageClassName(field.Message) + " val" + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) ctx.g.P(` match reader.read_len_delim()`) ctx.g.P(` | let b: Array[U8] val =>`) ctx.g.P(` match `, codec, `.decode(WireReader(b))`) @@ -161,8 +163,25 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.StringKind: + // non-packed: string/bytes are never packable; one len-delim per element + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) + ctx.g.P(` match reader.read_string()`) + ctx.g.P(` | let v: String val => `, name, `.push(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.BytesKind: + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let v: Array[U8] val => `, name, `.push(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + case field.Desc.IsList() && field.Desc.Kind() == protoreflect.EnumKind: + // packed primary arm + unpacked arm (proto3 decoders must accept both) fromValue := ponyEnumFromValueName(field.Enum) + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) ctx.g.P(` match reader.read_len_delim()`) ctx.g.P(` | let b: Array[U8] val =>`) ctx.g.P(` let sub = WireReader(b)`) @@ -174,25 +193,38 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { ctx.g.P(` end`) ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) + ctx.g.P(` | (`, num, `, WireVarint) =>`) + ctx.g.P(` match Scalar.read_int32(reader)`) + ctx.g.P(` | let v: I32 => `, name, `.push(`, fromValue, `(v))`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) case field.Desc.IsList(): + // packable numeric scalar: packed arm + unpacked arm spec := scalarSpecs[field.Desc.Kind()] - readExpr := strings.Replace(spec.readExpr, "reader", "sub", 1) + readExprSub := strings.Replace(spec.readExpr, "reader", "sub", 1) + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) ctx.g.P(` match reader.read_len_delim()`) ctx.g.P(` | let b: Array[U8] val =>`) ctx.g.P(` let sub = WireReader(b)`) ctx.g.P(` while not sub.at_end() do`) - ctx.g.P(` match `, readExpr) + ctx.g.P(` match `, readExprSub) ctx.g.P(` | let v: `, spec.ponyType, ` => `, name, `.push(v)`) ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) ctx.g.P(` end`) ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) + ctx.g.P(` | (`, num, `, `, spec.wireType, `) =>`) + ctx.g.P(` match `, spec.readExpr) + ctx.g.P(` | let v: `, spec.ponyType, ` => `, name, `.push(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) case field.Desc.Kind() == protoreflect.MessageKind: codec := ponyMessageClassName(field.Message) + "Codec" msgType := ponyMessageClassName(field.Message) + " val" + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) ctx.g.P(` match reader.read_len_delim()`) ctx.g.P(` | let b: Array[U8] val =>`) ctx.g.P(` match `, codec, `.decode(WireReader(b))`) @@ -204,6 +236,7 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { case field.Desc.Kind() == protoreflect.EnumKind: fromValue := ponyEnumFromValueName(field.Enum) + ctx.g.P(` | (`, num, `, WireVarint) =>`) ctx.g.P(` match Scalar.read_int32(reader)`) ctx.g.P(` | let v: I32 => `, name, ` = `, fromValue, `(v)`) ctx.g.P(` | let e: WireError => return e`) @@ -211,6 +244,7 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { default: spec := scalarSpecs[field.Desc.Kind()] + ctx.g.P(` | (`, num, `, `, spec.wireType, `) =>`) ctx.g.P(` match `, spec.readExpr) ctx.g.P(` | let v: `, spec.ponyType, ` => `, name, ` = v`) ctx.g.P(` | let e: WireError => return e`) @@ -264,10 +298,17 @@ func (ctx *genCtx) emitEncodeField(field *protogen.Field) { ctx.emitPackedEncode(ref, num, "Scalar.write_int32(sub, v.value())") case field.Desc.IsList() && field.Desc.Kind() == protoreflect.StringKind: - ctx.emitPackedEncode(ref, num, "sub.write_string(v)") + // non-packed: string is never packable; one tag + value per element + ctx.g.P(` for v in `, ref, `.values() do`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_string(v)`) + ctx.g.P(` end`) case field.Desc.IsList() && field.Desc.Kind() == protoreflect.BytesKind: - ctx.emitPackedEncode(ref, num, "sub.write_len_delim(v)") + ctx.g.P(` for v in `, ref, `.values() do`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(v)`) + ctx.g.P(` end`) case field.Desc.IsList(): spec := scalarSpecs[field.Desc.Kind()] @@ -355,34 +396,34 @@ func (ctx *genCtx) emitOptionalEncodeField(field *protogen.Field) { func (ctx *genCtx) emitEnum(enum *protogen.Enum, namePrefix string) { enumTypeName := namePrefix + string(enum.Desc.Name()) fromValueName := enumTypeName + "FromValue" + rawName := enumTypeName + "Raw" - var zeroName string primNames := make([]string, 0, len(enum.Values)) for _, v := range enum.Values { prim := namePrefix + screamingToPascal(string(v.Desc.Name())) ctx.g.P(`primitive `, prim, ` fun value(): I32 => `, v.Desc.Number()) primNames = append(primNames, prim) - if v.Desc.Number() == 0 && zeroName == "" { - zeroName = prim - } - } - if zeroName == "" && len(primNames) > 0 { - zeroName = primNames[0] } + + // Raw class preserves unknown numeric values across decode/re-encode (proto3 + // forward-compat: a peer may send values not yet in this schema). + ctx.g.P() + ctx.g.P(`class val `, rawName) + ctx.g.P(` let _v: I32`) + ctx.g.P(` new val create(v: I32) => _v = v`) + ctx.g.P(` fun value(): I32 => _v`) + ctx.g.P() - ctx.g.P(`type `, enumTypeName, ` is (`, strings.Join(primNames, " | "), `)`) + ctx.g.P(`type `, enumTypeName, ` is (`, strings.Join(append(primNames, rawName), " | "), `)`) ctx.g.P() ctx.g.P(`primitive `, fromValueName) ctx.g.P(` fun apply(v: I32): `, enumTypeName, ` =>`) ctx.g.P(` match v`) for _, v := range enum.Values { - if v.Desc.Number() == 0 { - continue // zero value is the else branch - } prim := namePrefix + screamingToPascal(string(v.Desc.Name())) ctx.g.P(` | `, v.Desc.Number(), ` => `, prim) } - ctx.g.P(` else `, zeroName) + ctx.g.P(` else `, rawName, `(v)`) ctx.g.P(` end`) ctx.g.P() } @@ -485,20 +526,6 @@ func (ctx *genCtx) fieldPonyDefault(field *protogen.Field) string { return scalarSpecs[field.Desc.Kind()].ponyDefault } -// fieldPonyWireType returns the wire type string for the decode match arm. -func (ctx *genCtx) fieldPonyWireType(field *protogen.Field) string { - if field.Desc.IsList() { - return "WireLenDelim" // packed scalars + per-entry messages - } - switch field.Desc.Kind() { - case protoreflect.MessageKind: - return "WireLenDelim" - case protoreflect.EnumKind: - return "WireVarint" - } - return scalarSpecs[field.Desc.Kind()].wireType -} - // ── Name helpers ───────────────────────────────────────────────────────────── // screamingToPascal converts SCREAMING_SNAKE_CASE to PascalCase. diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index a8b5858..b33edec 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -251,13 +251,18 @@ func TestEnumGeneration(t *testing.T) { assert.Contains(t, out, "primitive Unknown fun value(): I32 => 0") assert.Contains(t, out, "primitive Active fun value(): I32 => 1") - // Type alias union. - assert.Contains(t, out, "type Status is (Unknown | Active)") + // Raw class val preserves unknown numeric values (proto3 forward-compat). + assert.Contains(t, out, "class val StatusRaw") + assert.Contains(t, out, "fun value(): I32 => _v") - // FromValue dispatcher with zero-value fallback. + // Type alias union includes the Raw fallback. + assert.Contains(t, out, "type Status is (Unknown | Active | StatusRaw)") + + // FromValue: all known values explicit, unknowns go to Raw(v). assert.Contains(t, out, "primitive StatusFromValue") + assert.Contains(t, out, "| 0 => Unknown") assert.Contains(t, out, "| 1 => Active") - assert.Contains(t, out, "else Unknown") + assert.Contains(t, out, "else StatusRaw(v)") } func TestEnumField_ClassAndCodec(t *testing.T) { @@ -302,18 +307,18 @@ func TestRepeatedScalarField(t *testing.T) { assert.Contains(t, out, "let tags: Array[String val] val") assert.Contains(t, out, "tags': Array[String val] val = recover val Array[String val] end") - // Decode: trn accumulator + packed sub-reader loop. + // Decode: trn accumulator + per-element read (string is never packed). assert.Contains(t, out, "var tags: Array[String val] trn = recover trn Array[String val] end") - assert.Contains(t, out, "let sub = WireReader(b)") - assert.Contains(t, out, "tags.push(v)") + assert.Contains(t, out, "reader.read_string()") + assert.Contains(t, out, "| let v: String val => tags.push(v)") // Constructor call consumes the trn. assert.Contains(t, out, "consume tags") - // Encode: gate on non-empty, packed sub-writer. - assert.Contains(t, out, "if msg.tags.size() > 0 then") - assert.Contains(t, out, "sub.write_string(v)") + // Encode: per-element (no packed blob for strings). + assert.Contains(t, out, "for v in msg.tags.values() do") assert.Contains(t, out, "writer.write_tag(Tag(2, WireLenDelim))") + assert.Contains(t, out, "writer.write_string(v)") } func TestRepeatedMessageField(t *testing.T) { @@ -385,6 +390,47 @@ func TestFileHeaderHasSourceComment(t *testing.T) { assert.Contains(t, out, "// Source: user.proto") } +func TestRepeatedNumericField(t *testing.T) { + t.Parallel() + + scores := field("scores", 2, descriptorpb.FieldDescriptorProto_TYPE_INT32) + scores.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("game.proto"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Game"), Field: []*descriptorpb.FieldDescriptorProto{scores}}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "game.pony") + + // Both packed (WireLenDelim) and unpacked (WireVarint) arms must be emitted. + assert.Contains(t, out, "(2, WireLenDelim)") + assert.Contains(t, out, "(2, WireVarint)") + // Packed arm uses sub-reader; unpacked arm reads directly. + assert.Contains(t, out, "Scalar.read_int32(sub)") + assert.Contains(t, out, "Scalar.read_int32(reader)") + + // Encode uses packed format. + assert.Contains(t, out, "if msg.scores.size() > 0 then") + assert.Contains(t, out, "Scalar.write_int32(sub, v)") +} + +func TestServiceEmitsTodo(t *testing.T) { + t.Parallel() + + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("svc.proto"), + Syntax: proto.String("proto3"), + Service: []*descriptorpb.ServiceDescriptorProto{ + {Name: proto.String("GreeterService")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "svc.pony") + assert.Contains(t, out, "// TODO protoc-gen-pony: service GreeterService (service)") +} + // "User-provided params still win" (main.go:50) relies on existing entries // appearing verbatim after the injected stubs — protogen's later-wins // semantics depend on it. Lock in: nothing dropped or reordered, including From 95ca568861bde7d78144bff2a331c33013d6edcf Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Wed, 29 Apr 2026 13:37:46 -0700 Subject: [PATCH 09/11] feat(protoc-gen-pony): real oneof support + cross-dir use directives + property tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Real oneofs: emits wrapper class per member (ZooKindTypeA) + union type alias (ZooKind); whole oneof stays TODO if any member is WKT/group - Cross-dir refs: collects dep dirs across all nested messages/oneofs, computes relative paths via protoRelDir, emits sorted use directives - Property tests: protoRelDir inverse round-trip + snakeToPascal never contains underscore (testing/quick, no new dependency) - Simplify: screamingToPascal delegates to snakeToPascal; isSupported delegates kind-check to isSupportedOneofMember; crossDirUseDirectives caches depDir→relDir to avoid recomputing protoRelDir per field Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/generate.go | 432 +++++++++++++++++++++++++++---- cmd/protoc-gen-pony/main_test.go | 331 ++++++++++++++++++++++- 2 files changed, 699 insertions(+), 64 deletions(-) diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go index a14d00d..9a45ebb 100644 --- a/cmd/protoc-gen-pony/generate.go +++ b/cmd/protoc-gen-pony/generate.go @@ -3,6 +3,7 @@ package main import ( "fmt" "path" + "sort" "strings" "google.golang.org/protobuf/compiler/protogen" @@ -40,6 +41,14 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File) { g.P() ctx := &genCtx{plugin: plugin, file: file, g: g} + + if useDirs := ctx.crossDirUseDirectives(); len(useDirs) > 0 { + for _, u := range useDirs { + g.P(`use "`, u, `"`) + } + g.P() + } + ctx.collectAndEmitMessages(file.Messages, "") for _, enum := range file.Enums { ctx.emitEnum(enum, "") @@ -69,52 +78,90 @@ func (ctx *genCtx) collectAndEmitMessages(messages []*protogen.Message, namePref func (ctx *genCtx) emitMessage(msg *protogen.Message, className string) { supported := ctx.supportedFields(msg.Fields) - ctx.emitClass(className, msg.Fields, supported) + oneofs := ctx.supportedRealOneofs(msg) + for _, oo := range oneofs { + ctx.emitOneofWrapperTypes(className, oo) + } + ctx.emitClass(className, msg.Fields, oneofs) + ctx.emitConstructor(supported, oneofs, className) ctx.g.P() - ctx.emitCodec(className, supported) + ctx.emitCodec(className, supported, oneofs) ctx.g.P() } -func (ctx *genCtx) emitClass(className string, all, supported []*protogen.Field) { +func (ctx *genCtx) emitClass(className string, all []*protogen.Field, oneofs []*protogen.Oneof) { + inSupportedOneof := func(oo *protogen.Oneof) bool { + for _, o := range oneofs { + if o == oo { + return true + } + } + return false + } + todo := func(f *protogen.Field) { + ctx.g.P(` // TODO protoc-gen-pony: field `, f.Desc.Name(), ` (`, fieldShape(f), `)`) + } ctx.g.P(`class val `, className) + seenOneof := make(map[*protogen.Oneof]bool) for _, field := range all { - if ctx.isSupported(field) { + if field.Oneof != nil && !field.Desc.HasOptionalKeyword() { + oo := field.Oneof + if inSupportedOneof(oo) { + if !seenOneof[oo] { + seenOneof[oo] = true + ctx.g.P(` let `, oo.Desc.Name(), `: `, oneofTypeName(className, oo)) + } + } else { + todo(field) + } + } else if ctx.isSupported(field) { ctx.g.P(` let `, field.Desc.Name(), `: `, ctx.fieldPonyType(field)) } else { - ctx.g.P(` // TODO protoc-gen-pony: field `, field.Desc.Name(), ` (`, fieldShape(field), `)`) + todo(field) } } ctx.g.P() - ctx.emitConstructor(supported) } -func (ctx *genCtx) emitConstructor(supported []*protogen.Field) { - if len(supported) == 0 { +func (ctx *genCtx) emitConstructor(supported []*protogen.Field, oneofs []*protogen.Oneof, className string) { + if len(supported) == 0 && len(oneofs) == 0 { ctx.g.P(` new val create() => None`) return } - ctx.g.P(` new val create(`) - for i, field := range supported { - suffix := "," - if i == len(supported)-1 { - suffix = ")" + total := len(supported) + len(oneofs) + suffix := func(i int) string { + if i == total { + return ")" } - ctx.g.P(` `, field.Desc.Name(), `': `, ctx.fieldPonyType(field), ` = `, ctx.fieldPonyDefault(field), suffix) + return "," + } + ctx.g.P(` new val create(`) + i := 0 + for _, field := range supported { + i++ + ctx.g.P(` `, field.Desc.Name(), `': `, ctx.fieldPonyType(field), ` = `, ctx.fieldPonyDefault(field), suffix(i)) + } + for _, oo := range oneofs { + i++ + ctx.g.P(` `, oo.Desc.Name(), `': `, oneofTypeName(className, oo), ` = None`, suffix(i)) } ctx.g.P(` =>`) for _, field := range supported { ctx.g.P(` `, field.Desc.Name(), ` = `, field.Desc.Name(), `'`) } + for _, oo := range oneofs { + ctx.g.P(` `, oo.Desc.Name(), ` = `, oo.Desc.Name(), `'`) + } } -func (ctx *genCtx) emitCodec(className string, supported []*protogen.Field) { +func (ctx *genCtx) emitCodec(className string, supported []*protogen.Field, oneofs []*protogen.Oneof) { ctx.g.P(`primitive `, className, `Codec`) - ctx.emitDecode(className, supported) + ctx.emitDecode(className, supported, oneofs) ctx.g.P() - ctx.emitEncode(className, supported) + ctx.emitEncode(className, supported, oneofs) } -func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field) { +func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field, oneofs []*protogen.Oneof) { ctx.g.P(` fun decode(reader: WireReader ref): (`, className, ` val | WireError) =>`) for _, field := range supported { name := string(field.Desc.Name()) @@ -125,6 +172,9 @@ func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field) { ctx.g.P(` var `, name, `: `, ctx.fieldPonyType(field), ` = `, ctx.fieldPonyDefault(field)) } } + for _, oo := range oneofs { + ctx.g.P(` var `, oo.Desc.Name(), `: `, oneofTypeName(className, oo), ` = None`) + } ctx.g.P(` while not reader.at_end() do`) ctx.g.P(` match reader.read_tag()`) ctx.g.P(` | let t: Tag =>`) @@ -132,6 +182,11 @@ func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field) { for _, field := range supported { ctx.emitDecodeArm(field) } + for _, oo := range oneofs { + for _, f := range oo.Fields { + ctx.emitOneofDecodeArm(f, string(oo.Desc.Name()), oneofCaseName(className, oo, f)) + } + } ctx.g.P(` else`) ctx.g.P(` match reader.skip(t.wire_type)`) ctx.g.P(` | None => None`) @@ -141,7 +196,7 @@ func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field) { ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) ctx.g.P(` end`) - ctx.emitConstructorCall(className, supported) + ctx.emitConstructorCall(className, supported, oneofs) } func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { @@ -252,31 +307,37 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { } } -func (ctx *genCtx) emitConstructorCall(className string, supported []*protogen.Field) { - if len(supported) == 0 { +func (ctx *genCtx) emitConstructorCall(className string, supported []*protogen.Field, oneofs []*protogen.Oneof) { + if len(supported) == 0 && len(oneofs) == 0 { ctx.g.P(` `, className) return } - parts := make([]string, len(supported)) - for i, field := range supported { + parts := make([]string, 0, len(supported)+len(oneofs)) + for _, field := range supported { if field.Desc.IsList() { - parts[i] = "consume " + string(field.Desc.Name()) + parts = append(parts, "consume "+string(field.Desc.Name())) } else { - parts[i] = string(field.Desc.Name()) + parts = append(parts, string(field.Desc.Name())) } } + for _, oo := range oneofs { + parts = append(parts, string(oo.Desc.Name())) + } ctx.g.P(` `, className, `(`, strings.Join(parts, ", "), `)`) } -func (ctx *genCtx) emitEncode(className string, supported []*protogen.Field) { +func (ctx *genCtx) emitEncode(className string, supported []*protogen.Field, oneofs []*protogen.Oneof) { ctx.g.P(` fun encode(writer: WireWriter ref, msg: `, className, ` val) =>`) - if len(supported) == 0 { + if len(supported) == 0 && len(oneofs) == 0 { ctx.g.P(` None`) return } for _, field := range supported { ctx.emitEncodeField(field) } + for _, oo := range oneofs { + ctx.emitOneofEncodeBlock(className, oo) + } } func (ctx *genCtx) emitEncodeField(field *protogen.Field) { @@ -393,6 +454,200 @@ func (ctx *genCtx) emitOptionalEncodeField(field *protogen.Field) { } } +func (ctx *genCtx) emitOneofWrapperTypes(className string, oo *protogen.Oneof) { + caseNames := make([]string, 0, len(oo.Fields)) + for _, f := range oo.Fields { + caseName := oneofCaseName(className, oo, f) + valType := ctx.oneofMemberPonyType(f) + valDefault := ctx.oneofMemberDefault(f) + ctx.g.P(`class val `, caseName) + ctx.g.P(` let value: `, valType) + if valDefault != "" { + ctx.g.P(` new val create(value': `, valType, ` = `, valDefault, `) => value = value'`) + } else { + ctx.g.P(` new val create(value': `, valType, `) => value = value'`) + } + ctx.g.P() + caseNames = append(caseNames, caseName) + } + ctx.g.P(`type `, oneofTypeName(className, oo), ` is (`, strings.Join(append(caseNames, "None"), " | "), `)`) + ctx.g.P() +} + +// emitOneofDecodeArm emits a match arm that reads a single oneof member and +// assigns `varName = CaseClass(decoded_value)`. +func (ctx *genCtx) emitOneofDecodeArm(field *protogen.Field, varName, caseName string) { + num := field.Desc.Number() + switch field.Desc.Kind() { + case protoreflect.MessageKind: + codec := ponyMessageClassName(field.Message) + "Codec" + msgType := ponyMessageClassName(field.Message) + " val" + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let b: Array[U8] val =>`) + ctx.g.P(` match `, codec, `.decode(WireReader(b))`) + ctx.g.P(` | let v: `, msgType, ` => `, varName, ` = `, caseName, `(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + case protoreflect.EnumKind: + fromValue := ponyEnumFromValueName(field.Enum) + ctx.g.P(` | (`, num, `, WireVarint) =>`) + ctx.g.P(` match Scalar.read_int32(reader)`) + ctx.g.P(` | let v: I32 => `, varName, ` = `, caseName, `(`, fromValue, `(v))`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + case protoreflect.StringKind: + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) + ctx.g.P(` match reader.read_string()`) + ctx.g.P(` | let v: String val => `, varName, ` = `, caseName, `(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + case protoreflect.BytesKind: + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let v: Array[U8] val => `, varName, ` = `, caseName, `(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + default: + spec := scalarSpecs[field.Desc.Kind()] + ctx.g.P(` | (`, num, `, `, spec.wireType, `) =>`) + ctx.g.P(` match `, spec.readExpr) + ctx.g.P(` | let v: `, spec.ponyType, ` => `, varName, ` = `, caseName, `(v)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + } +} + +// emitOneofEncodeBlock emits `match msg.kind | let v: CaseA => ... | None => None end`. +func (ctx *genCtx) emitOneofEncodeBlock(className string, oo *protogen.Oneof) { + ctx.g.P(` match msg.`, oo.Desc.Name()) + for _, f := range oo.Fields { + caseName := oneofCaseName(className, oo, f) + ctx.g.P(` | let v: `, caseName, ` =>`) + ctx.emitOneofCaseEncode(f) + } + ctx.g.P(` | None => None`) + ctx.g.P(` end`) +} + +func (ctx *genCtx) emitOneofCaseEncode(field *protogen.Field) { + num := field.Desc.Number() + switch field.Desc.Kind() { + case protoreflect.MessageKind: + codec := ponyMessageClassName(field.Message) + "Codec" + ctx.g.P(` let sub = WireWriter`) + ctx.g.P(` `, codec, `.encode(sub, v.value)`) + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(sub.done())`) + case protoreflect.EnumKind: + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireVarint))`) + ctx.g.P(` Scalar.write_int32(writer, v.value.value())`) + case protoreflect.StringKind: + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_string(v.value)`) + case protoreflect.BytesKind: + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(v.value)`) + default: + spec := scalarSpecs[field.Desc.Kind()] + ctx.g.P(` writer.write_tag(Tag(`, num, `, `, spec.wireType, `))`) + ctx.g.P(` `, fmt.Sprintf(spec.writeFmt, "v.value")) + } +} + +// supportedRealOneofs returns the non-synthetic oneofs in msg whose every +// member passes isSupportedOneofMember. If any member is unsupported (WKT, +// group, etc.), the whole oneof emits TODO comments instead. +func (ctx *genCtx) supportedRealOneofs(msg *protogen.Message) []*protogen.Oneof { + var result []*protogen.Oneof + for _, oo := range msg.Oneofs { + if oo.Desc.IsSynthetic() { + continue + } + ok := true + for _, f := range oo.Fields { + if !ctx.isSupportedOneofMember(f) { + ok = false + break + } + } + if ok { + result = append(result, oo) + } + } + return result +} + +// isSupportedOneofMember applies the field-kind checks to a oneof member +// without the real-oneof gate that isSupported enforces. +func (ctx *genCtx) isSupportedOneofMember(field *protogen.Field) bool { + switch field.Desc.Kind() { + case protoreflect.GroupKind: + return false + case protoreflect.MessageKind: + if field.Message == nil || field.Message.Desc.IsMapEntry() { + return false + } + return !isWKT(field.Message.Desc.ParentFile()) + case protoreflect.EnumKind: + if field.Enum == nil { + return false + } + return !isWKT(field.Enum.Desc.ParentFile()) + } + return true +} + +// oneofMemberPonyType returns the raw Pony type for a oneof member value — +// no None wrapper, since None comes from the outer union type alias. +func (ctx *genCtx) oneofMemberPonyType(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return ponyMessageClassName(field.Message) + " val" + case protoreflect.EnumKind: + return ponyEnumTypeName(field.Enum) + } + return scalarSpecs[field.Desc.Kind()].ponyType +} + +// oneofMemberDefault returns the constructor default for a oneof member's +// value field. Returns "" for message kinds (no sensible zero value exists). +func (ctx *genCtx) oneofMemberDefault(field *protogen.Field) string { + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return "" + case protoreflect.EnumKind: + return ponyEnumZeroValuePrimitive(field.Enum) + } + return scalarSpecs[field.Desc.Kind()].ponyDefault +} + +// oneofTypeName builds the Pony type alias name for a real oneof. +// e.g. className="Zoo", oneof="kind" → "ZooKind" +func oneofTypeName(className string, oo *protogen.Oneof) string { + return className + snakeToPascal(string(oo.Desc.Name())) +} + +// oneofCaseName builds the Pony wrapper class name for one oneof member. +// e.g. className="Zoo", oneof="kind", field="type_a" → "ZooKindTypeA" +func oneofCaseName(className string, oo *protogen.Oneof, field *protogen.Field) string { + return className + snakeToPascal(string(oo.Desc.Name())) + snakeToPascal(string(field.Desc.Name())) +} + +// snakeToPascal converts snake_case to PascalCase: "type_a" → "TypeA". +func snakeToPascal(s string) string { + parts := strings.Split(s, "_") + var b strings.Builder + for _, p := range parts { + if len(p) > 0 { + b.WriteString(strings.ToUpper(p[:1]) + p[1:]) + } + } + return b.String() +} + func (ctx *genCtx) emitEnum(enum *protogen.Enum, namePrefix string) { enumTypeName := namePrefix + string(enum.Desc.Name()) fromValueName := enumTypeName + "FromValue" @@ -429,9 +684,9 @@ func (ctx *genCtx) emitEnum(enum *protogen.Enum, namePrefix string) { } // isSupported returns true for field shapes we generate code for. -// Out: maps, oneofs, proto3 optional (synthetic oneof), groups, cross-file -// message/enum refs. In: scalars (singular + repeated), same-file messages -// (singular + repeated), same-file enums (singular + repeated). +// Out: maps, real oneofs, groups, WKT message/enum refs. +// In: scalars (singular + repeated), messages (singular + repeated, any dir), +// enums (singular + repeated, any dir), proto3 optional. func (ctx *genCtx) isSupported(field *protogen.Field) bool { if field.Desc.IsMap() { return false @@ -439,21 +694,101 @@ func (ctx *genCtx) isSupported(field *protogen.Field) bool { if field.Oneof != nil && !field.Desc.HasOptionalKeyword() { return false } - switch field.Desc.Kind() { - case protoreflect.GroupKind: - return false - case protoreflect.MessageKind: - if field.Message == nil || field.Message.Desc.IsMapEntry() { - return false + return ctx.isSupportedOneofMember(field) +} + +// isWKT reports whether f is a well-known-type file. WKT support requires +// hand-written runtime shims that don't exist yet, so these refs stay TODO. +func isWKT(f protoreflect.FileDescriptor) bool { + return strings.HasPrefix(f.Path(), "google/protobuf/") +} + +// crossDirUseDirectives returns sorted relative-path strings for Pony `use` +// directives needed to reference types from other proto packages. Same- +// directory refs are already in the same Pony package and need no import. +func (ctx *genCtx) crossDirUseDirectives() []string { + thisDir := path.Dir(ctx.file.Desc.Path()) + dirToRel := make(map[string]string) + + addDepPath := func(depPath string) { + if depPath == "" { + return } - return ctx.isSamePonyPackage(field.Message.Desc.ParentFile()) - case protoreflect.EnumKind: - if field.Enum == nil { - return false + depDir := path.Dir(depPath) + if depDir == thisDir { + return + } + if _, ok := dirToRel[depDir]; !ok { + dirToRel[depDir] = protoRelDir(thisDir, depDir) } - return ctx.isSamePonyPackage(field.Enum.Desc.ParentFile()) } - return true + + checkField := func(f *protogen.Field) { + switch f.Desc.Kind() { + case protoreflect.MessageKind: + if f.Message != nil { + addDepPath(f.Message.Desc.ParentFile().Path()) + } + case protoreflect.EnumKind: + if f.Enum != nil { + addDepPath(f.Enum.Desc.ParentFile().Path()) + } + } + } + + var walkMessages func([]*protogen.Message) + walkMessages = func(msgs []*protogen.Message) { + for _, msg := range msgs { + if msg.Desc.IsMapEntry() { + continue + } + // Collect cross-dir deps from regular supported fields. + for _, f := range msg.Fields { + if ctx.isSupported(f) { + checkField(f) + } + } + // Also collect from oneof members in supported real oneofs. + for _, oo := range ctx.supportedRealOneofs(msg) { + for _, f := range oo.Fields { + checkField(f) + } + } + walkMessages(msg.Messages) + } + } + walkMessages(ctx.file.Messages) + + result := make([]string, 0, len(dirToRel)) + for _, rel := range dirToRel { + result = append(result, rel) + } + sort.Strings(result) + return result +} + +// protoRelDir returns the relative path from fromDir to toDir using +// slash-separated proto path components (e.g. "geo" → "../common"). +// If the result doesn't start with "..", it is prefixed with "./" to +// distinguish it from a package-name lookup. +func protoRelDir(fromDir, toDir string) string { + from := strings.Split(path.Clean(fromDir), "/") + to := strings.Split(path.Clean(toDir), "/") + // trim common prefix + i := 0 + for i < len(from) && i < len(to) && from[i] == to[i] { + i++ + } + var parts []string + for range from[i:] { + parts = append(parts, "..") + } + parts = append(parts, to[i:]...) + rel := strings.Join(parts, "/") + if !strings.HasPrefix(rel, "..") { + rel = "./" + rel + } + return rel } func (ctx *genCtx) supportedFields(fields []*protogen.Field) []*protogen.Field { @@ -466,10 +801,6 @@ func (ctx *genCtx) supportedFields(fields []*protogen.Field) []*protogen.Field { return out } -func (ctx *genCtx) isSamePonyPackage(parentFile protoreflect.FileDescriptor) bool { - return path.Dir(parentFile.Path()) == path.Dir(ctx.file.Desc.Path()) -} - // fieldPonyType returns the Pony type declaration for a field. // Repeated fields become Array[elem] val. // Singular message fields become (ChildMsg val | None). @@ -531,14 +862,7 @@ func (ctx *genCtx) fieldPonyDefault(field *protogen.Field) string { // screamingToPascal converts SCREAMING_SNAKE_CASE to PascalCase. // STATUS_UNKNOWN → StatusUnknown, ACTIVE → Active. func screamingToPascal(s string) string { - parts := strings.Split(strings.ToLower(s), "_") - var b strings.Builder - for _, p := range parts { - if len(p) > 0 { - b.WriteString(strings.ToUpper(p[:1]) + p[1:]) - } - } - return b.String() + return snakeToPascal(strings.ToLower(s)) } // ponyMessageClassName builds the flat Pony class name for a message by diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index b33edec..45c9176 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -1,7 +1,10 @@ package main import ( + "path" + "strings" "testing" + "testing/quick" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -229,18 +232,16 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { assert.Contains(t, out, "let tags: Array[String val] val") assert.Contains(t, out, "let parent: (Parent val | None)") assert.Contains(t, out, "let status: Status") - - // proto3 optional int32 is now supported. assert.Contains(t, out, "let count: (I32 | None)") - // These shapes remain unsupported — confirm TODO comments, not constructor params. - stillUnsupported := []string{"type_a", "type_b", "metadata"} - for _, name := range stillUnsupported { - assert.Contains(t, out, "TODO protoc-gen-pony: field "+name, - "missing TODO for %q", name) - assert.NotContains(t, out, name+"': ", - "%q should be skipped from the constructor signature", name) - } + // The `kind` oneof (type_a + type_b) is now supported via oneof codegen. + assert.Contains(t, out, "let kind: ZooKind") + assert.NotContains(t, out, "TODO protoc-gen-pony: field type_a") + assert.NotContains(t, out, "TODO protoc-gen-pony: field type_b") + + // map field remains unsupported. + assert.Contains(t, out, "TODO protoc-gen-pony: field metadata") + assert.NotContains(t, out, "metadata': ", "metadata should be skipped from constructor") } func TestEnumGeneration(t *testing.T) { @@ -550,6 +551,316 @@ func TestCrossFileSameDirectoryRef(t *testing.T) { assert.Contains(t, out, "AddressCodec.encode(sub, v)") } +func TestOneofWrapperTypes(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Wrapper classes for each oneof member. + assert.Contains(t, out, "class val ZooKindTypeA") + assert.Contains(t, out, "let value: String val") + assert.Contains(t, out, "new val create(value': String val = \"\") => value = value'") + + assert.Contains(t, out, "class val ZooKindTypeB") + assert.Contains(t, out, "let value: I32") + assert.Contains(t, out, "new val create(value': I32 = 0) => value = value'") + + // Type alias union includes None. + assert.Contains(t, out, "type ZooKind is (ZooKindTypeA | ZooKindTypeB | None)") +} + +func TestOneofConstructorParam(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + assert.Contains(t, out, "kind': ZooKind = None") + assert.Contains(t, out, "kind = kind'") +} + +func TestOneofDecode(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Declare the oneof var. + assert.Contains(t, out, "var kind: ZooKind = None") + + // String arm (type_a = field 4) wraps in ZooKindTypeA. + assert.Contains(t, out, "(4, WireLenDelim)") + assert.Contains(t, out, "| let v: String val => kind = ZooKindTypeA(v)") + + // Int32 arm (type_b = field 5) wraps in ZooKindTypeB. + assert.Contains(t, out, "(5, WireVarint)") + assert.Contains(t, out, "| let v: I32 => kind = ZooKindTypeB(v)") + + // Constructor call includes kind. + assert.Contains(t, out, "kind)") +} + +func TestOneofEncode(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + assert.Contains(t, out, "match msg.kind") + assert.Contains(t, out, "| let v: ZooKindTypeA =>") + assert.Contains(t, out, "writer.write_tag(Tag(4, WireLenDelim))") + assert.Contains(t, out, "writer.write_string(v.value)") + assert.Contains(t, out, "| let v: ZooKindTypeB =>") + assert.Contains(t, out, "writer.write_tag(Tag(5, WireVarint))") + assert.Contains(t, out, "Scalar.write_int32(writer, v.value)") + assert.Contains(t, out, "| None => None") +} + +func TestOneofWithMessageMember(t *testing.T) { + t.Parallel() + + childField := field("child", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + childField.TypeName = proto.String(".pkg.Child") + childField.OneofIndex = proto.Int32(0) + + numField := field("num", 3, descriptorpb.FieldDescriptorProto_TYPE_INT64) + numField.OneofIndex = proto.Int32(0) + + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("pkg.proto"), + Package: proto.String("pkg"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Parent"), + Field: []*descriptorpb.FieldDescriptorProto{childField, numField}, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + {Name: proto.String("payload")}, + }, + }, + {Name: proto.String("Child")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "pkg.pony") + + // Message member: no default in wrapper constructor (no = value). + assert.Contains(t, out, "class val ParentPayloadChild") + assert.Contains(t, out, "let value: Child val") + assert.Contains(t, out, "new val create(value': Child val) => value = value'") + assert.NotContains(t, out, "new val create(value': Child val =") + + // Int64 member has default. + assert.Contains(t, out, "class val ParentPayloadNum") + assert.Contains(t, out, "new val create(value': I64 = 0)") + + // Type alias. + assert.Contains(t, out, "type ParentPayload is (ParentPayloadChild | ParentPayloadNum | None)") + + // Decode: message arm reads sub-codec. + assert.Contains(t, out, "match ChildCodec.decode(WireReader(b))") + assert.Contains(t, out, "| let v: Child val => payload = ParentPayloadChild(v)") + + // Encode: message arm uses sub-writer. + assert.Contains(t, out, "| let v: ParentPayloadChild =>") + assert.Contains(t, out, "ChildCodec.encode(sub, v.value)") +} + +func TestOneofUnsupportedWhenMemberIsWKT(t *testing.T) { + t.Parallel() + + tsField := field("created_at", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + tsField.TypeName = proto.String(".google.protobuf.Timestamp") + tsField.OneofIndex = proto.Int32(0) + strField := field("label", 3, descriptorpb.FieldDescriptorProto_TYPE_STRING) + strField.OneofIndex = proto.Int32(0) + + tsFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("google/protobuf/timestamp.proto"), Package: proto.String("google.protobuf"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{{Name: proto.String("Timestamp")}}, + } + eventFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("event.proto"), + Package: proto.String("event"), + Syntax: proto.String("proto3"), + Dependency: []string{"google/protobuf/timestamp.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Event"), + Field: []*descriptorpb.FieldDescriptorProto{tsField, strField}, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + {Name: proto.String("when")}, + }, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{tsFile, eventFile}, "event.pony") + + // Whole oneof stays TODO because one member (Timestamp) is WKT. + assert.Contains(t, out, "TODO protoc-gen-pony: field created_at") + assert.Contains(t, out, "TODO protoc-gen-pony: field label") + assert.NotContains(t, out, "type EventWhen") +} + +func TestCrossDirectoryRef(t *testing.T) { + t.Parallel() + + addrField := field("address", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + addrField.TypeName = proto.String(".common.Address") + + personFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("geo/person.proto"), + Package: proto.String("geo"), + Syntax: proto.String("proto3"), + Dependency: []string{"common/address.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Person"), + Field: []*descriptorpb.FieldDescriptorProto{addrField}, + }, + }, + } + addressFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("common/address.proto"), + Package: proto.String("common"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Address")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{addressFile, personFile}, "geo/person.pony") + + // use directive for the cross-directory dep. + assert.Contains(t, out, `use "../common"`) + + // Field generated (not TODO). + assert.Contains(t, out, "let address: (Address val | None)") + assert.NotContains(t, out, "TODO protoc-gen-pony: field address") + + // Codec calls generated. + assert.Contains(t, out, "AddressCodec.decode(WireReader(b))") + assert.Contains(t, out, "AddressCodec.encode(sub, v)") +} + +func TestCrossDirectoryDedupedUse(t *testing.T) { + t.Parallel() + + cityField := field("city", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + cityField.TypeName = proto.String(".common.City") + countryField := field("country", 3, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + countryField.TypeName = proto.String(".common.Country") + + personFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("geo/person.proto"), + Package: proto.String("geo"), + Syntax: proto.String("proto3"), + Dependency: []string{"common/city.proto", "common/country.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Person"), + Field: []*descriptorpb.FieldDescriptorProto{cityField, countryField}, + }, + }, + } + cityFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("common/city.proto"), + Package: proto.String("common"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("City")}, + }, + } + countryFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("common/country.proto"), + Package: proto.String("common"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Country")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{cityFile, countryFile, personFile}, "geo/person.pony") + + // Only one use directive for common/ even though two deps come from there. + assert.Equal(t, 1, strings.Count(out, `use "../common"`)) +} + +func TestWKTRefEmitsTodo(t *testing.T) { + t.Parallel() + + tsField := field("created_at", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + tsField.TypeName = proto.String(".google.protobuf.Timestamp") + + tsFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("google/protobuf/timestamp.proto"), + Package: proto.String("google.protobuf"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Timestamp")}, + }, + } + eventFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("events/event.proto"), + Package: proto.String("events"), + Syntax: proto.String("proto3"), + Dependency: []string{"google/protobuf/timestamp.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Event"), + Field: []*descriptorpb.FieldDescriptorProto{tsField}, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{tsFile, eventFile}, "events/event.pony") + + // WKT field stays as TODO. + assert.Contains(t, out, "TODO protoc-gen-pony: field created_at") + + // No use directive for WKT. + assert.NotContains(t, out, `use "google/protobuf"`) + assert.NotContains(t, out, `use "../google/protobuf"`) +} + +// cleanDirSegs filters an arbitrary []string down to a valid proto directory +// path (slash-separated ASCII alphanumeric components). Returns ("", false) +// when no valid segments remain. +func cleanDirSegs(segs []string) (string, bool) { + var out []string + for _, s := range segs { + clean := strings.Map(func(r rune) rune { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + return r + } + return -1 + }, s) + if clean != "" { + out = append(out, clean) + } + } + if len(out) == 0 { + return "", false + } + return strings.Join(out, "/"), true +} + +// TestProtoRelDir_Inverse verifies that joining `from` with the result of +// protoRelDir always resolves back to `to`. +func TestProtoRelDir_Inverse(t *testing.T) { + f := func(fromSegs, toSegs []string) bool { + from, ok1 := cleanDirSegs(fromSegs) + to, ok2 := cleanDirSegs(toSegs) + if !ok1 || !ok2 || from == to { + return true // skip: preconditions not met + } + return path.Join(from, protoRelDir(from, to)) == to + } + if err := quick.Check(f, &quick.Config{MaxCount: 2000}); err != nil { + t.Error(err) + } +} + +// TestSnakeToPascal_NeverContainsUnderscore verifies that snakeToPascal removes +// all underscores regardless of the input string. +func TestSnakeToPascal_NeverContainsUnderscore(t *testing.T) { + f := func(s string) bool { + return !strings.Contains(snakeToPascal(s), "_") + } + if err := quick.Check(f, &quick.Config{MaxCount: 2000}); err != nil { + t.Error(err) + } +} + func TestInjectGoImportStubs_PreservesExistingParameters(t *testing.T) { t.Parallel() const existing = "foo=1,bar=,foo=2" From 92c7ca2d49f4fc8fc575e9b76ecc9544357abefd Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Wed, 29 Apr 2026 13:59:29 -0700 Subject: [PATCH 10/11] feat(protoc-gen-pony): map field support - isSupportedMapField: scalar and non-WKT enum values supported; message values stay TODO (no proto3 zero-value default available) - Decode: Map[K, V] trn accumulator, per-entry sub-reader (entry_sub), key field 1 + value field 2 arms, skip-unknown, assign via (k) = v - Encode: pairs() loop, always writes both key and value (no zero elision inside map entries per proto3 spec) - use "collections" emitted when any supported map field is present; sorted together with cross-dir relative use directives - crossDirUseDirectives: handles cross-dir enum value deps in map fields - 4 new tests (37 total): ClassAndCodec, UseCollections, MessageValueTodo, EnumValue Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/generate.go | 194 +++++++++++++++++++++++++++++-- cmd/protoc-gen-pony/main_test.go | 130 ++++++++++++++++++++- 2 files changed, 311 insertions(+), 13 deletions(-) diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go index 9a45ebb..5e4ab94 100644 --- a/cmd/protoc-gen-pony/generate.go +++ b/cmd/protoc-gen-pony/generate.go @@ -42,8 +42,13 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File) { ctx := &genCtx{plugin: plugin, file: file, g: g} - if useDirs := ctx.crossDirUseDirectives(); len(useDirs) > 0 { - for _, u := range useDirs { + useDirectives := ctx.crossDirUseDirectives() + if ctx.fileHasMaps() { + useDirectives = append(useDirectives, "collections") + sort.Strings(useDirectives) + } + if len(useDirectives) > 0 { + for _, u := range useDirectives { g.P(`use "`, u, `"`) } g.P() @@ -165,7 +170,11 @@ func (ctx *genCtx) emitDecode(className string, supported []*protogen.Field, one ctx.g.P(` fun decode(reader: WireReader ref): (`, className, ` val | WireError) =>`) for _, field := range supported { name := string(field.Desc.Name()) - if field.Desc.IsList() { + if field.Desc.IsMap() { + kType := ctx.mapKeyPonyType(field.Message.Fields[0]) + vType := ctx.mapValuePonyType(field.Message.Fields[1]) + ctx.g.P(` var `, name, `: Map[`, kType, `, `, vType, `] trn = recover trn Map[`, kType, `, `, vType, `] end`) + } else if field.Desc.IsList() { elem := ctx.elemPonyType(field) ctx.g.P(` var `, name, `: Array[`, elem, `] trn = recover trn Array[`, elem, `] end`) } else { @@ -204,6 +213,8 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { num := field.Desc.Number() switch { + case field.Desc.IsMap(): + ctx.emitMapDecodeArm(field, name, num) case field.Desc.IsList() && field.Desc.Kind() == protoreflect.MessageKind: // non-packed: one tag + len-delim per element codec := ponyMessageClassName(field.Message) + "Codec" @@ -307,6 +318,59 @@ func (ctx *genCtx) emitDecodeArm(field *protogen.Field) { } } +func (ctx *genCtx) emitMapDecodeArm(field *protogen.Field, name string, num protoreflect.FieldNumber) { + keyField := field.Message.Fields[0] + valField := field.Message.Fields[1] + keySpec := scalarSpecs[keyField.Desc.Kind()] + kType := ctx.mapKeyPonyType(keyField) + vType := ctx.mapValuePonyType(valField) + keyReadExpr := strings.Replace(keySpec.readExpr, "reader", "entry_sub", 1) + + ctx.g.P(` | (`, num, `, WireLenDelim) =>`) + ctx.g.P(` match reader.read_len_delim()`) + ctx.g.P(` | let b: Array[U8] val =>`) + ctx.g.P(` let entry_sub = WireReader(b)`) + ctx.g.P(` var entry_k: `, kType, ` = `, keySpec.ponyDefault) + ctx.g.P(` var entry_v: `, vType, ` = `, ctx.mapValueDefault(valField)) + ctx.g.P(` while not entry_sub.at_end() do`) + ctx.g.P(` match entry_sub.read_tag()`) + ctx.g.P(` | let t: Tag =>`) + ctx.g.P(` match (t.field_number, t.wire_type)`) + ctx.g.P(` | (1, `, keySpec.wireType, `) =>`) + ctx.g.P(` match `, keyReadExpr) + ctx.g.P(` | let kk: `, kType, ` => entry_k = kk`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + if valField.Desc.Kind() == protoreflect.EnumKind { + fromValue := ponyEnumFromValueName(valField.Enum) + ctx.g.P(` | (2, WireVarint) =>`) + ctx.g.P(` match Scalar.read_int32(entry_sub)`) + ctx.g.P(` | let vv: I32 => entry_v = `, fromValue, `(vv)`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + } else { + valSpec := scalarSpecs[valField.Desc.Kind()] + valReadExpr := strings.Replace(valSpec.readExpr, "reader", "entry_sub", 1) + ctx.g.P(` | (2, `, valSpec.wireType, `) =>`) + ctx.g.P(` match `, valReadExpr) + ctx.g.P(` | let vv: `, valSpec.ponyType, ` => entry_v = vv`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + } + ctx.g.P(` else`) + ctx.g.P(` match entry_sub.skip(t.wire_type)`) + ctx.g.P(` | None => None`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` end`) + ctx.g.P(` `, name, `(entry_k) = entry_v`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) +} + func (ctx *genCtx) emitConstructorCall(className string, supported []*protogen.Field, oneofs []*protogen.Oneof) { if len(supported) == 0 && len(oneofs) == 0 { ctx.g.P(` `, className) @@ -314,7 +378,7 @@ func (ctx *genCtx) emitConstructorCall(className string, supported []*protogen.F } parts := make([]string, 0, len(supported)+len(oneofs)) for _, field := range supported { - if field.Desc.IsList() { + if field.Desc.IsList() || field.Desc.IsMap() { parts = append(parts, "consume "+string(field.Desc.Name())) } else { parts = append(parts, string(field.Desc.Name())) @@ -345,6 +409,8 @@ func (ctx *genCtx) emitEncodeField(field *protogen.Field) { num := field.Desc.Number() switch { + case field.Desc.IsMap(): + ctx.emitMapEncodeField(field, ref, num) case field.Desc.IsList() && field.Desc.Kind() == protoreflect.MessageKind: // non-packed: one tag + len-delim per element (proto3 repeated message) codec := ponyMessageClassName(field.Message) + "Codec" @@ -410,6 +476,36 @@ func (ctx *genCtx) emitEncodeField(field *protogen.Field) { } } +func (ctx *genCtx) emitMapEncodeField(field *protogen.Field, ref string, num protoreflect.FieldNumber) { + keyField := field.Message.Fields[0] + valField := field.Message.Fields[1] + keySpec := scalarSpecs[keyField.Desc.Kind()] + + ctx.g.P(` for (k, v) in `, ref, `.pairs() do`) + ctx.g.P(` let sub = WireWriter`) + if keyField.Desc.Kind() == protoreflect.StringKind { + ctx.g.P(` sub.write_tag(Tag(1, WireLenDelim))`) + ctx.g.P(` sub.write_string(k)`) + } else { + ctx.g.P(` sub.write_tag(Tag(1, `, keySpec.wireType, `))`) + ctx.g.P(` `, fmt.Sprintf(strings.Replace(keySpec.writeFmt, "writer", "sub", 1), "k")) + } + if valField.Desc.Kind() == protoreflect.EnumKind { + ctx.g.P(` sub.write_tag(Tag(2, WireVarint))`) + ctx.g.P(` Scalar.write_int32(sub, v.value())`) + } else if valField.Desc.Kind() == protoreflect.StringKind { + ctx.g.P(` sub.write_tag(Tag(2, WireLenDelim))`) + ctx.g.P(` sub.write_string(v)`) + } else { + valSpec := scalarSpecs[valField.Desc.Kind()] + ctx.g.P(` sub.write_tag(Tag(2, `, valSpec.wireType, `))`) + ctx.g.P(` `, fmt.Sprintf(strings.Replace(valSpec.writeFmt, "writer", "sub", 1), "v")) + } + ctx.g.P(` writer.write_tag(Tag(`, num, `, WireLenDelim))`) + ctx.g.P(` writer.write_len_delim(sub.done())`) + ctx.g.P(` end`) +} + func (ctx *genCtx) emitPackedEncode(ref string, num protoreflect.FieldNumber, writeOp string) { ctx.g.P(` if `, ref, `.size() > 0 then`) ctx.g.P(` let sub = WireWriter`) @@ -580,6 +676,26 @@ func (ctx *genCtx) supportedRealOneofs(msg *protogen.Message) []*protogen.Oneof return result } +// isSupportedMapField returns true for map fields whose value kind we can +// generate — scalar and non-WKT enum. Message values have no proto3 zero and +// would require a default() on the codec; they stay TODO. +func (ctx *genCtx) isSupportedMapField(field *protogen.Field) bool { + if field.Message == nil || len(field.Message.Fields) != 2 { + return false + } + valField := field.Message.Fields[1] + switch valField.Desc.Kind() { + case protoreflect.GroupKind, protoreflect.MessageKind: + return false + case protoreflect.EnumKind: + if valField.Enum == nil { + return false + } + return !isWKT(valField.Enum.Desc.ParentFile()) + } + return true +} + // isSupportedOneofMember applies the field-kind checks to a oneof member // without the real-oneof gate that isSupported enforces. func (ctx *genCtx) isSupportedOneofMember(field *protogen.Field) bool { @@ -624,6 +740,24 @@ func (ctx *genCtx) oneofMemberDefault(field *protogen.Field) string { return scalarSpecs[field.Desc.Kind()].ponyDefault } +func (ctx *genCtx) mapKeyPonyType(field *protogen.Field) string { + return scalarSpecs[field.Desc.Kind()].ponyType +} + +func (ctx *genCtx) mapValuePonyType(field *protogen.Field) string { + if field.Desc.Kind() == protoreflect.EnumKind { + return ponyEnumTypeName(field.Enum) + } + return scalarSpecs[field.Desc.Kind()].ponyType +} + +func (ctx *genCtx) mapValueDefault(field *protogen.Field) string { + if field.Desc.Kind() == protoreflect.EnumKind { + return ponyEnumZeroValuePrimitive(field.Enum) + } + return scalarSpecs[field.Desc.Kind()].ponyDefault +} + // oneofTypeName builds the Pony type alias name for a real oneof. // e.g. className="Zoo", oneof="kind" → "ZooKind" func oneofTypeName(className string, oo *protogen.Oneof) string { @@ -684,12 +818,12 @@ func (ctx *genCtx) emitEnum(enum *protogen.Enum, namePrefix string) { } // isSupported returns true for field shapes we generate code for. -// Out: maps, real oneofs, groups, WKT message/enum refs. +// Out: real oneofs, groups, WKT message/enum refs, map (no zero value). // In: scalars (singular + repeated), messages (singular + repeated, any dir), -// enums (singular + repeated, any dir), proto3 optional. +// enums (singular + repeated, any dir), proto3 optional, map. func (ctx *genCtx) isSupported(field *protogen.Field) bool { if field.Desc.IsMap() { - return false + return ctx.isSupportedMapField(field) } if field.Oneof != nil && !field.Desc.HasOptionalKeyword() { return false @@ -703,6 +837,27 @@ func isWKT(f protoreflect.FileDescriptor) bool { return strings.HasPrefix(f.Path(), "google/protobuf/") } +func (ctx *genCtx) fileHasMaps() bool { + var check func([]*protogen.Message) bool + check = func(msgs []*protogen.Message) bool { + for _, msg := range msgs { + if msg.Desc.IsMapEntry() { + continue + } + for _, f := range msg.Fields { + if f.Desc.IsMap() && ctx.isSupported(f) { + return true + } + } + if check(msg.Messages) { + return true + } + } + return false + } + return check(ctx.file.Messages) +} + // crossDirUseDirectives returns sorted relative-path strings for Pony `use` // directives needed to reference types from other proto packages. Same- // directory refs are already in the same Pony package and need no import. @@ -724,6 +879,16 @@ func (ctx *genCtx) crossDirUseDirectives() []string { } checkField := func(f *protogen.Field) { + if f.Desc.IsMap() { + // Map entry message is always co-located; only value enums can cross dirs. + if f.Message != nil && len(f.Message.Fields) == 2 { + valField := f.Message.Fields[1] + if valField.Desc.Kind() == protoreflect.EnumKind && valField.Enum != nil { + addDepPath(valField.Enum.Desc.ParentFile().Path()) + } + } + return + } switch f.Desc.Kind() { case protoreflect.MessageKind: if f.Message != nil { @@ -802,11 +967,15 @@ func (ctx *genCtx) supportedFields(fields []*protogen.Field) []*protogen.Field { } // fieldPonyType returns the Pony type declaration for a field. -// Repeated fields become Array[elem] val. +// Map fields become Map[K, V] val. Repeated fields become Array[elem] val. // Singular message fields become (ChildMsg val | None). -// Singular enum fields become the enum type alias. -// Scalars use scalarSpecs. +// Singular enum fields become the enum type alias. Scalars use scalarSpecs. func (ctx *genCtx) fieldPonyType(field *protogen.Field) string { + if field.Desc.IsMap() { + kType := ctx.mapKeyPonyType(field.Message.Fields[0]) + vType := ctx.mapValuePonyType(field.Message.Fields[1]) + return "Map[" + kType + ", " + vType + "] val" + } if field.Desc.IsList() { return "Array[" + ctx.elemPonyType(field) + "] val" } @@ -839,6 +1008,11 @@ func (ctx *genCtx) elemPonyType(field *protogen.Field) string { // fieldPonyDefault returns the default value expression for a field. func (ctx *genCtx) fieldPonyDefault(field *protogen.Field) string { + if field.Desc.IsMap() { + kType := ctx.mapKeyPonyType(field.Message.Fields[0]) + vType := ctx.mapValuePonyType(field.Message.Fields[1]) + return "recover val Map[" + kType + ", " + vType + "] end" + } if field.Desc.IsList() { return "recover val Array[" + ctx.elemPonyType(field) + "] end" } diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index 45c9176..3061af0 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -239,9 +239,9 @@ func TestUnsupportedShapesEmitTodo(t *testing.T) { assert.NotContains(t, out, "TODO protoc-gen-pony: field type_a") assert.NotContains(t, out, "TODO protoc-gen-pony: field type_b") - // map field remains unsupported. - assert.Contains(t, out, "TODO protoc-gen-pony: field metadata") - assert.NotContains(t, out, "metadata': ", "metadata should be skipped from constructor") + // map is now generated. + assert.Contains(t, out, "let metadata: Map[String val, I32] val") + assert.NotContains(t, out, "TODO protoc-gen-pony: field metadata") } func TestEnumGeneration(t *testing.T) { @@ -861,6 +861,130 @@ func TestSnakeToPascal_NeverContainsUnderscore(t *testing.T) { } } +func TestMapField_ClassAndCodec(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + + // Class field and constructor default. + assert.Contains(t, out, "let metadata: Map[String val, I32] val") + assert.Contains(t, out, "metadata': Map[String val, I32] val = recover val Map[String val, I32] end") + + // Decode: trn accumulator, entry sub-reader, key/value arms, final assign. + assert.Contains(t, out, "var metadata: Map[String val, I32] trn = recover trn Map[String val, I32] end") + assert.Contains(t, out, "let entry_sub = WireReader(b)") + assert.Contains(t, out, "var entry_k: String val") + assert.Contains(t, out, "var entry_v: I32") + assert.Contains(t, out, "metadata(entry_k) = entry_v") + assert.Contains(t, out, "consume metadata") + + // Encode: pairs() loop, always-write key + value. + assert.Contains(t, out, "for (k, v) in msg.metadata.pairs() do") + assert.Contains(t, out, "sub.write_string(k)") + assert.Contains(t, out, "Scalar.write_int32(sub, v)") +} + +func TestMapField_UseCollections(t *testing.T) { + t.Parallel() + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{zooFileProto()}, "zoo.pony") + assert.Contains(t, out, `use "collections"`) +} + +func TestMapField_MessageValueTodo(t *testing.T) { + t.Parallel() + + entryField := field("items", 1, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + entryField.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + entryField.TypeName = proto.String(".pkg.Container.ItemsEntry") + mapEntry := &descriptorpb.DescriptorProto{ + Name: proto.String("ItemsEntry"), + Field: []*descriptorpb.FieldDescriptorProto{ + field("key", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + { + Name: proto.String("value"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(), + TypeName: proto.String(".pkg.Item"), + JsonName: proto.String("value"), + }, + }, + Options: &descriptorpb.MessageOptions{MapEntry: proto.Bool(true)}, + } + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("pkg.proto"), + Package: proto.String("pkg"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Container"), + Field: []*descriptorpb.FieldDescriptorProto{entryField}, + NestedType: []*descriptorpb.DescriptorProto{mapEntry}, + }, + {Name: proto.String("Item")}, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "pkg.pony") + assert.Contains(t, out, "TODO protoc-gen-pony: field items") + assert.NotContains(t, out, `use "collections"`) +} + +func TestMapField_EnumValue(t *testing.T) { + t.Parallel() + + entryField := field("by_status", 1, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + entryField.Label = descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum() + entryField.TypeName = proto.String(".pkg.Lookup.ByStatusEntry") + mapEntry := &descriptorpb.DescriptorProto{ + Name: proto.String("ByStatusEntry"), + Field: []*descriptorpb.FieldDescriptorProto{ + field("key", 1, descriptorpb.FieldDescriptorProto_TYPE_STRING), + { + Name: proto.String("value"), + Number: proto.Int32(2), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_ENUM.Enum(), + TypeName: proto.String(".pkg.Color"), + JsonName: proto.String("value"), + }, + }, + Options: &descriptorpb.MessageOptions{MapEntry: proto.Bool(true)}, + } + file := &descriptorpb.FileDescriptorProto{ + Name: proto.String("pkg.proto"), + Package: proto.String("pkg"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Lookup"), + Field: []*descriptorpb.FieldDescriptorProto{entryField}, + NestedType: []*descriptorpb.DescriptorProto{mapEntry}, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Color"), + Value: []*descriptorpb.EnumValueDescriptorProto{ + {Name: proto.String("RED"), Number: proto.Int32(0)}, + {Name: proto.String("BLUE"), Number: proto.Int32(1)}, + }, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "pkg.pony") + + // Class field uses enum type as map value. + assert.Contains(t, out, "let by_status: Map[String val, Color] val") + + // Decode: FromValue applied to I32. + assert.Contains(t, out, "| let vv: I32 => entry_v = ColorFromValue(vv)") + + // Encode: .value() call on enum. + assert.Contains(t, out, "Scalar.write_int32(sub, v.value())") + + // use "collections" emitted. + assert.Contains(t, out, `use "collections"`) +} + func TestInjectGoImportStubs_PreservesExistingParameters(t *testing.T) { t.Parallel() const existing = "foo=1,bar=,foo=2" From 0b7466548c38442aaaff8e270296ab83eb5e22f3 Mon Sep 17 00:00:00 2001 From: Erik Nilsen Date: Wed, 29 Apr 2026 14:11:01 -0700 Subject: [PATCH 11/11] feat(protoc-gen-pony): WKTs, map, codec.default(), v1.0 docs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Well-known types: - Narrow isWKT blocklist to only the 4 circular/JSON-only files (struct, type, api, descriptor). Timestamp, Duration, Any, wrappers, FieldMask, Empty, etc. now generate as regular proto3 messages. - Update TestWKTRefEmitsTodo to use struct.proto (still blocked); add TestWKT_TimestampGenerates to confirm Timestamp generates. - Update TestOneofUnsupportedWhenMemberIsWKT to use struct.proto Value. map values: - Add fun default(): ClassName val => ClassName to every XxxCodec — map message values initialise to codec.default() when the value field is absent on the wire (proto3 zero semantics). - isSupportedMapField: accept non-WKT message values (was always TODO). - mapValuePonyType/mapValueDefault: handle MessageKind (Child val, ChildCodec.default()). - emitMapDecodeArm: sub-codec decode arm for message values. - emitMapEncodeField: vsub WireWriter + codec encode for message values. - crossDirUseDirectives: also collect message-valued map cross-dir deps. - Rename TestMapField_MessageValueTodo → TestMapField_MessageValue (now a positive test). Release / docs: - Add // Requires protobuf-pony runtime >= 0.1.0 to generated header. - CHANGELOG.md created with full feature coverage list. - README.md: replace outdated coverage stub with accurate supported / known-limitations sections. Signed-off-by: Erik Nilsen --- cmd/protoc-gen-pony/CHANGELOG.md | 22 +++++++ cmd/protoc-gen-pony/README.md | 24 ++++++-- cmd/protoc-gen-pony/generate.go | 84 +++++++++++++++++++++----- cmd/protoc-gen-pony/main_test.go | 100 ++++++++++++++++++++++++------- 4 files changed, 189 insertions(+), 41 deletions(-) create mode 100644 cmd/protoc-gen-pony/CHANGELOG.md diff --git a/cmd/protoc-gen-pony/CHANGELOG.md b/cmd/protoc-gen-pony/CHANGELOG.md new file mode 100644 index 0000000..cd93888 --- /dev/null +++ b/cmd/protoc-gen-pony/CHANGELOG.md @@ -0,0 +1,22 @@ +# Changelog + +## Unreleased + +### Features + +- All proto3 scalar types (bool, int32/64, uint32/64, sint32/64, fixed32/64, sfixed32/64, float, double, string, bytes) +- Enums: primitive per value, type alias union, `FromValue` dispatcher, `Raw` class for unknown values +- Singular and repeated embedded messages; sub-codec decode/encode +- Repeated scalar and enum fields (packed wire format) +- proto3 `optional` explicit presence (`(T | None)` type, match-on-None encode) +- Real `oneof` fields: wrapper class per member, union type alias, full decode/encode +- `map` fields: scalar, enum, and message values; `use "collections"` auto-emitted +- Cross-directory `use` directives (relative paths, deduplicated per directory) +- Well-known types (`google/protobuf/timestamp.proto`, `duration.proto`, `any.proto`, `wrappers.proto`, `field_mask.proto`, `empty.proto`, etc.) generate as regular proto3 messages +- Generated file header includes minimum required `protobuf-pony` runtime version + +### Known limitations + +- `google/protobuf/struct.proto`, `type.proto`, `api.proto`, `descriptor.proto` emit `TODO` comments — circular or JSON-only semantics +- JSON-specific WKT encoding (Timestamp as RFC 3339, etc.) is out of scope +- Services (gRPC stubs) are not generated diff --git a/cmd/protoc-gen-pony/README.md b/cmd/protoc-gen-pony/README.md index f180210..239ec6a 100644 --- a/cmd/protoc-gen-pony/README.md +++ b/cmd/protoc-gen-pony/README.md @@ -55,11 +55,25 @@ sources][runtime]. ## Coverage -v1 supports singular implicit-presence proto3 scalars (bool, int32/64, -uint32/64, sint32/64, fixed32/64, sfixed32/64, float, double, string, -bytes). Repeated fields, `optional` explicit presence, oneofs, maps, -embedded messages, and enums emit a `// TODO protoc-gen-pony` comment -until the corresponding codegen lands. Services (gRPC) are out of scope. +Supported (no `TODO` comments emitted): + +- All proto3 scalar types — bool, int32/64, uint32/64, sint32/64, + fixed32/64, sfixed32/64, float, double, string, bytes +- Enums (primitives + type alias + `FromValue` dispatcher + `Raw` fallback) +- Singular and repeated embedded messages +- proto3 `optional` explicit presence (`(T | None)` type) +- Real `oneof` fields (wrapper class per member, union type alias) +- `map` where V is a scalar, enum, or non-blocked message +- Cross-directory `use` directives (relative path, auto-deduped) +- Well-known types: Timestamp, Duration, Any, FieldMask, wrappers, Empty, etc. + generate as regular proto3 messages with no special treatment + +**Known limitations:** + +- `google/protobuf/struct.proto`, `type.proto`, `api.proto`, `descriptor.proto` + stay as `TODO` — circular or JSON-only types not representable in plain proto3 +- JSON-specific WKT encoding (Timestamp as RFC 3339, etc.) is out of scope +- Services (gRPC stubs) are not generated [buf]: https://buf.build [runtime]: https://github.com/TrogonStack/protobuf-pony diff --git a/cmd/protoc-gen-pony/generate.go b/cmd/protoc-gen-pony/generate.go index 5e4ab94..614fab7 100644 --- a/cmd/protoc-gen-pony/generate.go +++ b/cmd/protoc-gen-pony/generate.go @@ -38,6 +38,7 @@ func generateFile(plugin *protogen.Plugin, file *protogen.File) { // directory, and the runtime's protobuf.pony already owns it. g.P(`// Generated by protoc-gen-pony. DO NOT EDIT.`) g.P(`// Source: `, file.Desc.Path()) + g.P(`// Requires protobuf-pony runtime >= 0.1.0`) g.P() ctx := &genCtx{plugin: plugin, file: file, g: g} @@ -161,6 +162,7 @@ func (ctx *genCtx) emitConstructor(supported []*protogen.Field, oneofs []*protog func (ctx *genCtx) emitCodec(className string, supported []*protogen.Field, oneofs []*protogen.Oneof) { ctx.g.P(`primitive `, className, `Codec`) + ctx.g.P(` fun default(): `, className, ` val => `, className) ctx.emitDecode(className, supported, oneofs) ctx.g.P() ctx.emitEncode(className, supported, oneofs) @@ -341,14 +343,27 @@ func (ctx *genCtx) emitMapDecodeArm(field *protogen.Field, name string, num prot ctx.g.P(` | let kk: `, kType, ` => entry_k = kk`) ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) - if valField.Desc.Kind() == protoreflect.EnumKind { + switch valField.Desc.Kind() { + case protoreflect.MessageKind: + codec := ponyMessageClassName(valField.Message) + "Codec" + msgType := ponyMessageClassName(valField.Message) + " val" + ctx.g.P(` | (2, WireLenDelim) =>`) + ctx.g.P(` match entry_sub.read_len_delim()`) + ctx.g.P(` | let vb: Array[U8] val =>`) + ctx.g.P(` match `, codec, `.decode(WireReader(vb))`) + ctx.g.P(` | let vv: `, msgType, ` => entry_v = vv`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + ctx.g.P(` | let e: WireError => return e`) + ctx.g.P(` end`) + case protoreflect.EnumKind: fromValue := ponyEnumFromValueName(valField.Enum) ctx.g.P(` | (2, WireVarint) =>`) ctx.g.P(` match Scalar.read_int32(entry_sub)`) ctx.g.P(` | let vv: I32 => entry_v = `, fromValue, `(vv)`) ctx.g.P(` | let e: WireError => return e`) ctx.g.P(` end`) - } else { + default: valSpec := scalarSpecs[valField.Desc.Kind()] valReadExpr := strings.Replace(valSpec.readExpr, "reader", "entry_sub", 1) ctx.g.P(` | (2, `, valSpec.wireType, `) =>`) @@ -490,13 +505,20 @@ func (ctx *genCtx) emitMapEncodeField(field *protogen.Field, ref string, num pro ctx.g.P(` sub.write_tag(Tag(1, `, keySpec.wireType, `))`) ctx.g.P(` `, fmt.Sprintf(strings.Replace(keySpec.writeFmt, "writer", "sub", 1), "k")) } - if valField.Desc.Kind() == protoreflect.EnumKind { + switch valField.Desc.Kind() { + case protoreflect.MessageKind: + codec := ponyMessageClassName(valField.Message) + "Codec" + ctx.g.P(` let vsub = WireWriter`) + ctx.g.P(` `, codec, `.encode(vsub, v)`) + ctx.g.P(` sub.write_tag(Tag(2, WireLenDelim))`) + ctx.g.P(` sub.write_len_delim(vsub.done())`) + case protoreflect.EnumKind: ctx.g.P(` sub.write_tag(Tag(2, WireVarint))`) ctx.g.P(` Scalar.write_int32(sub, v.value())`) - } else if valField.Desc.Kind() == protoreflect.StringKind { + case protoreflect.StringKind: ctx.g.P(` sub.write_tag(Tag(2, WireLenDelim))`) ctx.g.P(` sub.write_string(v)`) - } else { + default: valSpec := scalarSpecs[valField.Desc.Kind()] ctx.g.P(` sub.write_tag(Tag(2, `, valSpec.wireType, `))`) ctx.g.P(` `, fmt.Sprintf(strings.Replace(valSpec.writeFmt, "writer", "sub", 1), "v")) @@ -677,16 +699,21 @@ func (ctx *genCtx) supportedRealOneofs(msg *protogen.Message) []*protogen.Oneof } // isSupportedMapField returns true for map fields whose value kind we can -// generate — scalar and non-WKT enum. Message values have no proto3 zero and -// would require a default() on the codec; they stay TODO. +// generate — scalar, non-WKT enum, non-WKT message (codec.default() supplies +// the zero value for missing entries on the wire). func (ctx *genCtx) isSupportedMapField(field *protogen.Field) bool { if field.Message == nil || len(field.Message.Fields) != 2 { return false } valField := field.Message.Fields[1] switch valField.Desc.Kind() { - case protoreflect.GroupKind, protoreflect.MessageKind: + case protoreflect.GroupKind: return false + case protoreflect.MessageKind: + if valField.Message == nil || valField.Message.Desc.IsMapEntry() { + return false + } + return !isWKT(valField.Message.Desc.ParentFile()) case protoreflect.EnumKind: if valField.Enum == nil { return false @@ -745,14 +772,20 @@ func (ctx *genCtx) mapKeyPonyType(field *protogen.Field) string { } func (ctx *genCtx) mapValuePonyType(field *protogen.Field) string { - if field.Desc.Kind() == protoreflect.EnumKind { + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return ponyMessageClassName(field.Message) + " val" + case protoreflect.EnumKind: return ponyEnumTypeName(field.Enum) } return scalarSpecs[field.Desc.Kind()].ponyType } func (ctx *genCtx) mapValueDefault(field *protogen.Field) string { - if field.Desc.Kind() == protoreflect.EnumKind { + switch field.Desc.Kind() { + case protoreflect.MessageKind: + return ponyMessageClassName(field.Message) + "Codec.default()" + case protoreflect.EnumKind: return ponyEnumZeroValuePrimitive(field.Enum) } return scalarSpecs[field.Desc.Kind()].ponyDefault @@ -831,10 +864,22 @@ func (ctx *genCtx) isSupported(field *protogen.Field) bool { return ctx.isSupportedOneofMember(field) } -// isWKT reports whether f is a well-known-type file. WKT support requires -// hand-written runtime shims that don't exist yet, so these refs stay TODO. +// wktBlocklist holds the google/protobuf/ files that can't be generated as +// plain proto3 messages: struct.proto has a circular Value↔ListValue↔Struct +// oneof, type.proto and api.proto depend on it, and descriptor.proto is +// descriptor-only. Everything else (Timestamp, Duration, Any, wrappers, etc.) +// generates naturally as regular proto3 messages. +var wktBlocklist = map[string]bool{ + "google/protobuf/struct.proto": true, + "google/protobuf/type.proto": true, + "google/protobuf/api.proto": true, + "google/protobuf/descriptor.proto": true, +} + +// isWKT reports whether f is a blocked well-known-type file. Most WKTs +// generate as regular proto3 messages; only the circular/JSON-only ones stay TODO. func isWKT(f protoreflect.FileDescriptor) bool { - return strings.HasPrefix(f.Path(), "google/protobuf/") + return wktBlocklist[f.Path()] } func (ctx *genCtx) fileHasMaps() bool { @@ -880,11 +925,18 @@ func (ctx *genCtx) crossDirUseDirectives() []string { checkField := func(f *protogen.Field) { if f.Desc.IsMap() { - // Map entry message is always co-located; only value enums can cross dirs. + // Map entry message is always co-located; check value for cross-dir deps. if f.Message != nil && len(f.Message.Fields) == 2 { valField := f.Message.Fields[1] - if valField.Desc.Kind() == protoreflect.EnumKind && valField.Enum != nil { - addDepPath(valField.Enum.Desc.ParentFile().Path()) + switch valField.Desc.Kind() { + case protoreflect.MessageKind: + if valField.Message != nil { + addDepPath(valField.Message.Desc.ParentFile().Path()) + } + case protoreflect.EnumKind: + if valField.Enum != nil { + addDepPath(valField.Enum.Desc.ParentFile().Path()) + } } } return diff --git a/cmd/protoc-gen-pony/main_test.go b/cmd/protoc-gen-pony/main_test.go index 3061af0..3c9dce4 100644 --- a/cmd/protoc-gen-pony/main_test.go +++ b/cmd/protoc-gen-pony/main_test.go @@ -660,36 +660,39 @@ func TestOneofWithMessageMember(t *testing.T) { func TestOneofUnsupportedWhenMemberIsWKT(t *testing.T) { t.Parallel() - tsField := field("created_at", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) - tsField.TypeName = proto.String(".google.protobuf.Timestamp") - tsField.OneofIndex = proto.Int32(0) + // google/protobuf/struct.proto is in the blocklist — a Value oneof member + // keeps the whole oneof as TODO. + valueField := field("payload", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + valueField.TypeName = proto.String(".google.protobuf.Value") + valueField.OneofIndex = proto.Int32(0) strField := field("label", 3, descriptorpb.FieldDescriptorProto_TYPE_STRING) strField.OneofIndex = proto.Int32(0) - tsFile := &descriptorpb.FileDescriptorProto{ - Name: proto.String("google/protobuf/timestamp.proto"), Package: proto.String("google.protobuf"), + structFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("google/protobuf/struct.proto"), + Package: proto.String("google.protobuf"), Syntax: proto.String("proto3"), - MessageType: []*descriptorpb.DescriptorProto{{Name: proto.String("Timestamp")}}, + MessageType: []*descriptorpb.DescriptorProto{{Name: proto.String("Value")}}, } eventFile := &descriptorpb.FileDescriptorProto{ Name: proto.String("event.proto"), Package: proto.String("event"), Syntax: proto.String("proto3"), - Dependency: []string{"google/protobuf/timestamp.proto"}, + Dependency: []string{"google/protobuf/struct.proto"}, MessageType: []*descriptorpb.DescriptorProto{ { Name: proto.String("Event"), - Field: []*descriptorpb.FieldDescriptorProto{tsField, strField}, + Field: []*descriptorpb.FieldDescriptorProto{valueField, strField}, OneofDecl: []*descriptorpb.OneofDescriptorProto{ {Name: proto.String("when")}, }, }, }, } - out := runPlugin(t, []*descriptorpb.FileDescriptorProto{tsFile, eventFile}, "event.pony") + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{structFile, eventFile}, "event.pony") - // Whole oneof stays TODO because one member (Timestamp) is WKT. - assert.Contains(t, out, "TODO protoc-gen-pony: field created_at") + // Whole oneof stays TODO because one member (Value from struct.proto) is blocked. + assert.Contains(t, out, "TODO protoc-gen-pony: field payload") assert.Contains(t, out, "TODO protoc-gen-pony: field label") assert.NotContains(t, out, "type EventWhen") } @@ -779,6 +782,42 @@ func TestCrossDirectoryDedupedUse(t *testing.T) { func TestWKTRefEmitsTodo(t *testing.T) { t.Parallel() + // google/protobuf/struct.proto is in the blocklist (circular Value type). + valueField := field("payload", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) + valueField.TypeName = proto.String(".google.protobuf.Value") + + structFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("google/protobuf/struct.proto"), + Package: proto.String("google.protobuf"), + Syntax: proto.String("proto3"), + MessageType: []*descriptorpb.DescriptorProto{ + {Name: proto.String("Value")}, + }, + } + eventFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("events/event.proto"), + Package: proto.String("events"), + Syntax: proto.String("proto3"), + Dependency: []string{"google/protobuf/struct.proto"}, + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Event"), + Field: []*descriptorpb.FieldDescriptorProto{valueField}, + }, + }, + } + out := runPlugin(t, []*descriptorpb.FileDescriptorProto{structFile, eventFile}, "events/event.pony") + + // Blocked WKT field stays as TODO. + assert.Contains(t, out, "TODO protoc-gen-pony: field payload") + + // No use directive for blocked WKT. + assert.NotContains(t, out, `use "../google/protobuf"`) +} + +func TestWKT_TimestampGenerates(t *testing.T) { + t.Parallel() + tsField := field("created_at", 2, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) tsField.TypeName = proto.String(".google.protobuf.Timestamp") @@ -787,7 +826,13 @@ func TestWKTRefEmitsTodo(t *testing.T) { Package: proto.String("google.protobuf"), Syntax: proto.String("proto3"), MessageType: []*descriptorpb.DescriptorProto{ - {Name: proto.String("Timestamp")}, + { + Name: proto.String("Timestamp"), + Field: []*descriptorpb.FieldDescriptorProto{ + field("seconds", 1, descriptorpb.FieldDescriptorProto_TYPE_INT64), + field("nanos", 2, descriptorpb.FieldDescriptorProto_TYPE_INT32), + }, + }, }, } eventFile := &descriptorpb.FileDescriptorProto{ @@ -804,12 +849,12 @@ func TestWKTRefEmitsTodo(t *testing.T) { } out := runPlugin(t, []*descriptorpb.FileDescriptorProto{tsFile, eventFile}, "events/event.pony") - // WKT field stays as TODO. - assert.Contains(t, out, "TODO protoc-gen-pony: field created_at") + // Timestamp generates as a real class (not TODO). + assert.Contains(t, out, "let created_at: (Timestamp val | None)") + assert.NotContains(t, out, "TODO protoc-gen-pony: field created_at") - // No use directive for WKT. - assert.NotContains(t, out, `use "google/protobuf"`) - assert.NotContains(t, out, `use "../google/protobuf"`) + // Cross-dir use directive emitted. + assert.Contains(t, out, `use "../google/protobuf"`) } // cleanDirSegs filters an arbitrary []string down to a valid proto directory @@ -889,7 +934,7 @@ func TestMapField_UseCollections(t *testing.T) { assert.Contains(t, out, `use "collections"`) } -func TestMapField_MessageValueTodo(t *testing.T) { +func TestMapField_MessageValue(t *testing.T) { t.Parallel() entryField := field("items", 1, descriptorpb.FieldDescriptorProto_TYPE_MESSAGE) @@ -924,8 +969,23 @@ func TestMapField_MessageValueTodo(t *testing.T) { }, } out := runPlugin(t, []*descriptorpb.FileDescriptorProto{file}, "pkg.pony") - assert.Contains(t, out, "TODO protoc-gen-pony: field items") - assert.NotContains(t, out, `use "collections"`) + + // map now generates. + assert.Contains(t, out, "let items: Map[String val, Item val] val") + assert.NotContains(t, out, "TODO protoc-gen-pony: field items") + + // codec.default() emitted for Item and Container. + assert.Contains(t, out, "fun default(): Item val => Item") + assert.Contains(t, out, "fun default(): Container val => Container") + + // Decode: sub-codec with ItemCodec.default() as initial value. + assert.Contains(t, out, "var entry_v: Item val = ItemCodec.default()") + assert.Contains(t, out, "match ItemCodec.decode(WireReader(vb))") + + // Encode: sub-writer for value. + assert.Contains(t, out, "ItemCodec.encode(vsub, v)") + + assert.Contains(t, out, `use "collections"`) } func TestMapField_EnumValue(t *testing.T) {