diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index 604819b..0203563 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -37,14 +37,20 @@ class ConformanceService(Protocol): async def unary( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def server_stream( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse, + ], ) -> AsyncIterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse ]: @@ -55,7 +61,10 @@ async def client_stream( request: AsyncIterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamRequest ], - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") @@ -64,7 +73,10 @@ def bidi_stream( request: AsyncIterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamRequest ], - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse, + ], ) -> AsyncIterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse ]: @@ -73,14 +85,20 @@ def bidi_stream( async def unimplemented( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") async def idempotent_unary( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") @@ -307,14 +325,20 @@ class ConformanceServiceSync(Protocol): def unary( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def server_stream( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse, + ], ) -> Iterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse ]: @@ -325,7 +349,10 @@ def client_stream( request: Iterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamRequest ], - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") @@ -334,7 +361,10 @@ def bidi_stream( request: Iterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamRequest ], - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse, + ], ) -> Iterator[ connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse ]: @@ -343,14 +373,20 @@ def bidi_stream( def unimplemented( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def idempotent_unary( self, request: connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryRequest, - ctx: RequestContext, + ctx: RequestContext[ + connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryRequest, + connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse, + ], ) -> connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 69de382..6faaf13 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -35,19 +35,31 @@ class ElizaService(Protocol): async def say( - self, request: example_dot_eliza__pb2.SayRequest, ctx: RequestContext + self, + request: example_dot_eliza__pb2.SayRequest, + ctx: RequestContext[ + example_dot_eliza__pb2.SayRequest, example_dot_eliza__pb2.SayResponse + ], ) -> example_dot_eliza__pb2.SayResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def converse( self, request: AsyncIterator[example_dot_eliza__pb2.ConverseRequest], - ctx: RequestContext, + ctx: RequestContext[ + example_dot_eliza__pb2.ConverseRequest, + example_dot_eliza__pb2.ConverseResponse, + ], ) -> AsyncIterator[example_dot_eliza__pb2.ConverseResponse]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def introduce( - self, request: example_dot_eliza__pb2.IntroduceRequest, ctx: RequestContext + self, + request: example_dot_eliza__pb2.IntroduceRequest, + ctx: RequestContext[ + example_dot_eliza__pb2.IntroduceRequest, + example_dot_eliza__pb2.IntroduceResponse, + ], ) -> AsyncIterator[example_dot_eliza__pb2.IntroduceResponse]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") @@ -174,19 +186,31 @@ def introduce( class ElizaServiceSync(Protocol): def say( - self, request: example_dot_eliza__pb2.SayRequest, ctx: RequestContext + self, + request: example_dot_eliza__pb2.SayRequest, + ctx: RequestContext[ + example_dot_eliza__pb2.SayRequest, example_dot_eliza__pb2.SayResponse + ], ) -> example_dot_eliza__pb2.SayResponse: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def converse( self, request: Iterator[example_dot_eliza__pb2.ConverseRequest], - ctx: RequestContext, + ctx: RequestContext[ + example_dot_eliza__pb2.ConverseRequest, + example_dot_eliza__pb2.ConverseResponse, + ], ) -> Iterator[example_dot_eliza__pb2.ConverseResponse]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def introduce( - self, request: example_dot_eliza__pb2.IntroduceRequest, ctx: RequestContext + self, + request: example_dot_eliza__pb2.IntroduceRequest, + ctx: RequestContext[ + example_dot_eliza__pb2.IntroduceRequest, + example_dot_eliza__pb2.IntroduceResponse, + ], ) -> Iterator[example_dot_eliza__pb2.IntroduceResponse]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 2add65c..3e562c6 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -67,7 +67,7 @@ from connectrpc.server import ConnectASGIApplication, ConnectWSGIApplication, En {{if not .SkipAsync }} {{- range .Services}} class {{.Name}}(Protocol):{{- range .Methods }} - {{if not .ResponseStream }}async {{end}}def {{.PythonName}}(self, request: {{if .RequestStream}}AsyncIterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext) -> {{if .ResponseStream}}AsyncIterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}: + {{if not .ResponseStream }}async {{end}}def {{.PythonName}}(self, request: {{if .RequestStream}}AsyncIterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext[{{.InputType}}, {{.OutputType}}]) -> {{if .ResponseStream}}AsyncIterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") {{ end }} @@ -131,7 +131,7 @@ class {{.Name}}Client(ConnectClient):{{range .Methods}} {{if not .SkipSync }} {{range .Services}} class {{.Name}}Sync(Protocol):{{- range .Methods }} - def {{.PythonName}}(self, request: {{if .RequestStream}}Iterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext) -> {{if .ResponseStream}}Iterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}: + def {{.PythonName}}(self, request: {{if .RequestStream}}Iterator[{{end}}{{.InputType}}{{if .RequestStream}}]{{end}}, ctx: RequestContext[{{.InputType}}, {{.OutputType}}]) -> {{if .ResponseStream}}Iterator[{{end}}{{.OutputType}}{{if .ResponseStream}}]{{end}}: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") {{- end }} diff --git a/protoc-gen-connect-python/generator/template_test.go b/protoc-gen-connect-python/generator/template_test.go index ee57afc..9a38aa1 100644 --- a/protoc-gen-connect-python/generator/template_test.go +++ b/protoc-gen-connect-python/generator/template_test.go @@ -90,3 +90,54 @@ func TestConnectTemplate(t *testing.T) { }) } } + +func TestConnectTemplateRequestContextTypeParams(t *testing.T) { + t.Parallel() + + vars := ConnectTemplateVariables{ + FileName: "test.proto", + ModuleName: "test", + Services: []*ConnectService{ + { + Package: "test", + Name: "TestService", + Methods: []*ConnectMethod{ + { + Package: "test", + ServiceName: "TestService", + Name: "Unary", + PythonName: "Unary", + InputType: "_pb2.TestRequest", + OutputType: "_pb2.TestResponse", + }, + { + Package: "test", + ServiceName: "TestService", + Name: "Bidi", + PythonName: "Bidi", + InputType: "_pb2.StreamRequest", + OutputType: "_pb2.StreamResponse", + Stream: true, + RequestStream: true, + ResponseStream: true, + }, + }, + }, + }, + } + + var buf bytes.Buffer + if err := ConnectTemplate.Execute(&buf, vars); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + result := buf.String() + + for _, want := range []string{ + "ctx: RequestContext[_pb2.TestRequest, _pb2.TestResponse]", + "ctx: RequestContext[_pb2.StreamRequest, _pb2.StreamResponse]", + } { + if !strings.Contains(result, want) { + t.Errorf("generated handler missing parameterized context %q\n--- got ---\n%s", want, result) + } + } +} diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index 4524198..e5c813f 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -37,32 +37,49 @@ class Haberdasher(Protocol): async def make_hat( - self, request: haberdasher__pb2.Size, ctx: RequestContext + self, + request: haberdasher__pb2.Size, + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> haberdasher__pb2.Hat: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") async def make_flexible_hat( - self, request: AsyncIterator[haberdasher__pb2.Size], ctx: RequestContext + self, + request: AsyncIterator[haberdasher__pb2.Size], + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> haberdasher__pb2.Hat: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def make_similar_hats( - self, request: haberdasher__pb2.Size, ctx: RequestContext + self, + request: haberdasher__pb2.Size, + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> AsyncIterator[haberdasher__pb2.Hat]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def make_various_hats( - self, request: AsyncIterator[haberdasher__pb2.Size], ctx: RequestContext + self, + request: AsyncIterator[haberdasher__pb2.Size], + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> AsyncIterator[haberdasher__pb2.Hat]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def list_parts( - self, request: google_dot_protobuf_dot_empty__pb2.Empty, ctx: RequestContext + self, + request: google_dot_protobuf_dot_empty__pb2.Empty, + ctx: RequestContext[ + google_dot_protobuf_dot_empty__pb2.Empty, haberdasher__pb2.Hat.Part + ], ) -> AsyncIterator[haberdasher__pb2.Hat.Part]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") async def do_nothing( - self, request: google_dot_protobuf_dot_empty__pb2.Empty, ctx: RequestContext + self, + request: google_dot_protobuf_dot_empty__pb2.Empty, + ctx: RequestContext[ + google_dot_protobuf_dot_empty__pb2.Empty, + google_dot_protobuf_dot_empty__pb2.Empty, + ], ) -> google_dot_protobuf_dot_empty__pb2.Empty: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") @@ -279,32 +296,49 @@ async def do_nothing( class HaberdasherSync(Protocol): def make_hat( - self, request: haberdasher__pb2.Size, ctx: RequestContext + self, + request: haberdasher__pb2.Size, + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> haberdasher__pb2.Hat: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def make_flexible_hat( - self, request: Iterator[haberdasher__pb2.Size], ctx: RequestContext + self, + request: Iterator[haberdasher__pb2.Size], + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> haberdasher__pb2.Hat: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def make_similar_hats( - self, request: haberdasher__pb2.Size, ctx: RequestContext + self, + request: haberdasher__pb2.Size, + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> Iterator[haberdasher__pb2.Hat]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def make_various_hats( - self, request: Iterator[haberdasher__pb2.Size], ctx: RequestContext + self, + request: Iterator[haberdasher__pb2.Size], + ctx: RequestContext[haberdasher__pb2.Size, haberdasher__pb2.Hat], ) -> Iterator[haberdasher__pb2.Hat]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def list_parts( - self, request: google_dot_protobuf_dot_empty__pb2.Empty, ctx: RequestContext + self, + request: google_dot_protobuf_dot_empty__pb2.Empty, + ctx: RequestContext[ + google_dot_protobuf_dot_empty__pb2.Empty, haberdasher__pb2.Hat.Part + ], ) -> Iterator[haberdasher__pb2.Hat.Part]: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented") def do_nothing( - self, request: google_dot_protobuf_dot_empty__pb2.Empty, ctx: RequestContext + self, + request: google_dot_protobuf_dot_empty__pb2.Empty, + ctx: RequestContext[ + google_dot_protobuf_dot_empty__pb2.Empty, + google_dot_protobuf_dot_empty__pb2.Empty, + ], ) -> google_dot_protobuf_dot_empty__pb2.Empty: raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")