Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions core/src/main/java/io/grpc/internal/ServerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public final class ServerImpl extends io.grpc.Server implements InternalInstrume
private final ObjectPool<? extends Executor> executorPool;
/** Executor for application processing. Safe to read after {@link #start()}. */
private Executor executor;
private final HandlerRegistry registry;
private final InternalHandlerRegistry registry;
private final HandlerRegistry fallbackRegistry;
private final List<ServerTransportFilter> transportFilters;
// This is iterated on a per-call basis. Use an array instead of a Collection to avoid iterator
Expand Down Expand Up @@ -498,8 +498,12 @@ private void streamCreatedInternal(

final StatsTraceContext statsTraceCtx = Preconditions.checkNotNull(
stream.statsTraceContext(), "statsTraceCtx not present from stream");
final ServerMethodDefinition<?, ?> primaryMethod = registry.lookupMethod(methodName, null);

final Context.CancellableContext context = createContext(headers, statsTraceCtx);
if (primaryMethod != null) {
statsTraceCtx.serverCallMethodResolved(primaryMethod.getMethodDescriptor());
}

final Link link = PerfMark.linkOut();

Expand Down Expand Up @@ -536,7 +540,7 @@ private void runInternal() {
ServerMethodDefinition<?, ?> wrapMethod;
ServerCallParameters<?, ?> callParams;
try {
ServerMethodDefinition<?, ?> method = registry.lookupMethod(methodName);
ServerMethodDefinition<?, ?> method = primaryMethod;
if (method == null) {
method = fallbackRegistry.lookupMethod(methodName, stream.getAuthority());
}
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/java/io/grpc/internal/StatsTraceContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.grpc.ClientStreamTracer;
import io.grpc.Context;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerStreamTracer;
import io.grpc.ServerStreamTracer.ServerCallInfo;
import io.grpc.Status;
Expand All @@ -38,6 +39,14 @@
*/
@ThreadSafe
public final class StatsTraceContext {
/**
* Internal hook for server tracers that can use the resolved method descriptor before
* {@link ServerStreamTracer#serverCallStarted(ServerCallInfo)} runs.
*/
public interface ServerCallMethodListener {
void serverCallMethodResolved(MethodDescriptor<?, ?> method);
}

public static final StatsTraceContext NOOP = new StatsTraceContext(new StreamTracer[0]);

private final StreamTracer[] tracers;
Expand Down Expand Up @@ -144,6 +153,20 @@ public void serverCallStarted(ServerCallInfo<?, ?> callInfo) {
}
}

/**
* Notifies server tracers that a primary-registry method descriptor was resolved before
* {@link ServerStreamTracer#serverCallStarted(ServerCallInfo)}.
*
* <p>Called from {@link io.grpc.internal.ServerImpl}.
*/
public void serverCallMethodResolved(MethodDescriptor<?, ?> method) {
for (StreamTracer tracer : tracers) {
if (tracer instanceof ServerCallMethodListener) {
((ServerCallMethodListener) tracer).serverCallMethodResolved(method);
}
}
}

/**
* See {@link StreamTracer#streamClosed}. This may be called multiple times, and only the first
* value will be taken.
Expand Down
224 changes: 224 additions & 0 deletions core/src/test/java/io/grpc/internal/ServerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ public class ServerImplTest {
.setRequestMarshaller(STRING_MARSHALLER)
.setResponseMarshaller(INTEGER_MARSHALLER)
.build();
private static final MethodDescriptor<String, Integer> GENERATED_METHOD =
METHOD.toBuilder()
.setSampledToLocalTracing(true)
.build();
private static final Context.Key<String> SERVER_ONLY = Context.key("serverOnly");
private static final Context.Key<String> SERVER_TRACER_ADDED_KEY = Context.key("tracer-added");
private static final Context.CancellableContext SERVER_CONTEXT =
Expand All @@ -142,6 +146,60 @@ public boolean shouldAccept(Runnable runnable) {
};
private static final String AUTHORITY = "some_authority";

private static final class MethodNameCapturingTracer extends ServerStreamTracer
implements StatsTraceContext.ServerCallMethodListener {
@Nullable private ServerCallInfo<?, ?> serverCallInfo;
@Nullable private String recordedMethodName;
@Nullable private String resolvedMethodName;
private boolean streamClosed;

@Override
public synchronized void serverCallMethodResolved(MethodDescriptor<?, ?> method) {
resolvedMethodName =
recordMethodName(method.isSampledToLocalTracing(), method.getFullMethodName());
}

@Override
public synchronized void streamClosed(Status status) {
streamClosed = true;
if (serverCallInfo != null) {
recordedMethodName =
recordMethodName(
serverCallInfo.getMethodDescriptor().isSampledToLocalTracing(),
serverCallInfo.getMethodDescriptor().getFullMethodName());
} else if (resolvedMethodName != null) {
recordedMethodName = resolvedMethodName;
} else {
recordedMethodName = "other";
}
}

@Override
public synchronized void serverCallStarted(ServerCallInfo<?, ?> callInfo) {
serverCallInfo = callInfo;
if (streamClosed) {
recordedMethodName =
recordMethodName(
callInfo.getMethodDescriptor().isSampledToLocalTracing(),
callInfo.getMethodDescriptor().getFullMethodName());
}
}

@Nullable
synchronized ServerCallInfo<?, ?> getServerCallInfo() {
return serverCallInfo;
}

@Nullable
synchronized String getRecordedMethodName() {
return recordedMethodName;
}

private static String recordMethodName(boolean generatedMethod, String fullMethodName) {
return generatedMethod ? fullMethodName : "other";
}
}

@Rule public final MockitoRule mocks = MockitoJUnit.rule();

@BeforeClass
Expand Down Expand Up @@ -462,6 +520,172 @@ public void methodNotFound() throws Exception {
assertEquals(Status.Code.UNIMPLEMENTED, statusCaptor.getValue().getCode());
}

@Test
public void primaryRegistryGeneratedMethod_streamClosedBeforeStart_preservesMethodName()
throws Exception {
MethodNameCapturingTracer methodNameTracer = new MethodNameCapturingTracer();
streamTracerFactories =
Collections.singletonList(
new ServerStreamTracer.Factory() {
@Override
public ServerStreamTracer newServerStreamTracer(
String fullMethodName, Metadata headers) {
return methodNameTracer;
}
});
builder.addService(
ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", GENERATED_METHOD))
.addMethod(
GENERATED_METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
ServerCall<String, Integer> call, Metadata headers) {
return callListener;
}
})
.build());

createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(
streamTracerFactories, GENERATED_METHOD.getFullMethodName(), requestHeaders);
when(stream.getAttributes()).thenReturn(Attributes.EMPTY);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);

transportListener.streamCreated(stream, GENERATED_METHOD.getFullMethodName(), requestHeaders);
verify(stream).setListener(isA(ServerStreamListener.class));
verify(stream, atLeast(1)).statsTraceContext();

statsTraceCtx.streamClosed(Status.CANCELLED);
assertNull(methodNameTracer.getServerCallInfo());
assertEquals(
GENERATED_METHOD.getFullMethodName(),
methodNameTracer.getRecordedMethodName());

assertEquals(1, executor.runDueTasks());

assertNotNull(methodNameTracer.getServerCallInfo());
assertSame(GENERATED_METHOD, methodNameTracer.getServerCallInfo().getMethodDescriptor());
assertEquals(
GENERATED_METHOD.getFullMethodName(),
methodNameTracer.getRecordedMethodName());
verify(fallbackRegistry, never()).lookupMethod(anyString(), any());
}

@Test
public void primaryRegistryNonGeneratedMethod_streamClosedBeforeStart_recordsOther()
throws Exception {
MethodNameCapturingTracer methodNameTracer = new MethodNameCapturingTracer();
streamTracerFactories =
Collections.singletonList(
new ServerStreamTracer.Factory() {
@Override
public ServerStreamTracer newServerStreamTracer(
String fullMethodName, Metadata headers) {
return methodNameTracer;
}
});
builder.addService(
ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", METHOD))
.addMethod(
METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
ServerCall<String, Integer> call, Metadata headers) {
return callListener;
}
})
.build());

createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(
streamTracerFactories, METHOD.getFullMethodName(), requestHeaders);
when(stream.getAttributes()).thenReturn(Attributes.EMPTY);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);

transportListener.streamCreated(stream, METHOD.getFullMethodName(), requestHeaders);
verify(stream).setListener(isA(ServerStreamListener.class));
verify(stream, atLeast(1)).statsTraceContext();

statsTraceCtx.streamClosed(Status.CANCELLED);
assertNull(methodNameTracer.getServerCallInfo());
assertEquals("other", methodNameTracer.getRecordedMethodName());

assertEquals(1, executor.runDueTasks());

assertNotNull(methodNameTracer.getServerCallInfo());
assertSame(METHOD, methodNameTracer.getServerCallInfo().getMethodDescriptor());
assertEquals("other", methodNameTracer.getRecordedMethodName());
verify(fallbackRegistry, never()).lookupMethod(anyString(), any());
}

@Test
public void fallbackRegistryGeneratedMethod_streamClosedBeforeStart_resolvesOnAsyncLookup()
throws Exception {
MethodNameCapturingTracer methodNameTracer = new MethodNameCapturingTracer();
streamTracerFactories =
Collections.singletonList(
new ServerStreamTracer.Factory() {
@Override
public ServerStreamTracer newServerStreamTracer(
String fullMethodName, Metadata headers) {
return methodNameTracer;
}
});
mutableFallbackRegistry.addService(
ServerServiceDefinition.builder(new ServiceDescriptor("Waiter", GENERATED_METHOD))
.addMethod(
GENERATED_METHOD,
new ServerCallHandler<String, Integer>() {
@Override
public ServerCall.Listener<String> startCall(
ServerCall<String, Integer> call, Metadata headers) {
return callListener;
}
})
.build());

createAndStartServer();
ServerTransportListener transportListener
= transportServer.registerNewServerTransport(new SimpleServerTransport());
transportListener.transportReady(Attributes.EMPTY);
Metadata requestHeaders = new Metadata();
StatsTraceContext statsTraceCtx =
StatsTraceContext.newServerContext(
streamTracerFactories, GENERATED_METHOD.getFullMethodName(), requestHeaders);
when(stream.getAttributes()).thenReturn(Attributes.EMPTY);
when(stream.statsTraceContext()).thenReturn(statsTraceCtx);

transportListener.streamCreated(stream, GENERATED_METHOD.getFullMethodName(), requestHeaders);
verify(stream).setListener(isA(ServerStreamListener.class));
verify(stream, atLeast(1)).statsTraceContext();

statsTraceCtx.streamClosed(Status.CANCELLED);
assertNull(methodNameTracer.getServerCallInfo());
assertEquals("other", methodNameTracer.getRecordedMethodName());
verify(fallbackRegistry, never()).lookupMethod(anyString(), any());

assertEquals(1, executor.runDueTasks());

assertNotNull(methodNameTracer.getServerCallInfo());
assertSame(GENERATED_METHOD, methodNameTracer.getServerCallInfo().getMethodDescriptor());
assertEquals(
GENERATED_METHOD.getFullMethodName(),
methodNameTracer.getRecordedMethodName());
verify(fallbackRegistry).lookupMethod(GENERATED_METHOD.getFullMethodName(), AUTHORITY);
}


@Test
public void executorSupplierSameExecutorBasic() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StreamTracer;
import io.grpc.internal.StatsTraceContext.ServerCallMethodListener;
import io.grpc.opentelemetry.GrpcOpenTelemetry.TargetFilter;
import io.opentelemetry.api.baggage.Baggage;
import io.opentelemetry.api.common.AttributesBuilder;
Expand Down Expand Up @@ -526,7 +527,8 @@ void recordFinishedCall(CallOptions callOptions) {
}
}

private static final class ServerTracer extends ServerStreamTracer {
private static final class ServerTracer extends ServerStreamTracer
implements ServerCallMethodListener {
@Nullable private static final AtomicIntegerFieldUpdater<ServerTracer> streamClosedUpdater;
@Nullable private static final AtomicLongFieldUpdater<ServerTracer> outboundWireSizeUpdater;
@Nullable private static final AtomicLongFieldUpdater<ServerTracer> inboundWireSizeUpdater;
Expand Down Expand Up @@ -587,6 +589,11 @@ public io.grpc.Context filterContext(io.grpc.Context context) {
return context;
}

@Override
public void serverCallMethodResolved(MethodDescriptor<?, ?> method) {
isGeneratedMethod = method.isSampledToLocalTracing();
}

@Override
public void serverCallStarted(ServerCallInfo<?, ?> callInfo) {
// Only record method name as an attribute if isSampledToLocalTracing is set to true,
Expand Down Expand Up @@ -644,9 +651,24 @@ public void streamClosed(Status status) {
}
stopwatch.stop();
long elapsedTimeNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS);
AttributesBuilder builder = io.opentelemetry.api.common.Attributes.builder()
.put(METHOD_KEY, recordMethodName(fullMethodName, isGeneratedMethod))
.put(STATUS_KEY, status.getCode().toString());
recordClosedStream(
status,
elapsedTimeNanos,
outboundWireSize,
inboundWireSize,
isGeneratedMethod);
}

private void recordClosedStream(
Status status,
long elapsedTimeNanos,
long closedOutboundWireSize,
long closedInboundWireSize,
boolean generatedMethod) {
AttributesBuilder builder =
io.opentelemetry.api.common.Attributes.builder()
.put(METHOD_KEY, recordMethodName(fullMethodName, generatedMethod))
.put(STATUS_KEY, status.getCode().toString());
for (OpenTelemetryPlugin.ServerStreamPlugin plugin : streamPlugins) {
plugin.addLabels(builder);
}
Expand All @@ -658,11 +680,11 @@ public void streamClosed(Status status) {
}
if (module.resource.serverTotalSentCompressedMessageSizeCounter() != null) {
module.resource.serverTotalSentCompressedMessageSizeCounter()
.record(outboundWireSize, attributes, otelContext);
.record(closedOutboundWireSize, attributes, otelContext);
}
if (module.resource.serverTotalReceivedCompressedMessageSizeCounter() != null) {
module.resource.serverTotalReceivedCompressedMessageSizeCounter()
.record(inboundWireSize, attributes, otelContext);
.record(closedInboundWireSize, attributes, otelContext);
}
}
}
Expand Down Expand Up @@ -744,4 +766,3 @@ public void onClose(Status status, Metadata trailers) {
}
}
}

Loading
Loading