Adds routing keys in message headers

If the incoming message has a stream_routekey header, we use that
to route the message to a named function. Also adding the header
to messages coming out of suppliers.

The biggest change here is sort of orthogonal: it fixes a bug where
Spring Integration would subscribe twice to the same input channel
if the FunctionCatalog contains both functions and consumers. Then
when a message comes in it is dispatched to one or the other, but not
both. So the routing key couldn't work without fixing that
problem.
This commit is contained in:
Dave Syer
2017-07-11 17:03:39 +01:00
parent 4b30721d02
commit 7e966c73ca
8 changed files with 176 additions and 30 deletions

View File

@@ -262,6 +262,9 @@ public class ContextFunctionCatalogAutoConfiguration {
else if (target instanceof Function) {
registration.target(target((Function<?, ?>) target, key));
}
for (String name : registration.getNames()) {
beans.put(name, key);
}
this.registrations.put(registration.getTarget(), key);
}

View File

@@ -132,6 +132,12 @@ public class StreamConfiguration {
@Override
public ConditionOutcome getMatchOutcome(ConditionContext context,
AnnotatedTypeMetadata metadata) {
return getMatchOutcomeForType(this.type, context, metadata);
}
protected ConditionOutcome getMatchOutcomeForType(Class<?> type,
ConditionContext context, AnnotatedTypeMetadata metadata) {
if (context.getBeanFactory().getBeanNamesForType(type, false,
false).length > 0) {
String endpoint = new RelaxedPropertyResolver(context.getEnvironment(),
@@ -175,5 +181,15 @@ public class StreamConfiguration {
public ConsumerCondition() {
super(Consumer.class);
}
@Override
public ConditionOutcome getMatchOutcome(ConditionContext context,
AnnotatedTypeMetadata metadata) {
if (getMatchOutcomeForType(Function.class, context, metadata).isMatch()) {
return ConditionOutcome
.noMatch(String.format("bean of type Function detected"));
}
return super.getMatchOutcome(context, metadata);
}
}
}

View File

@@ -27,11 +27,13 @@ public class StreamConfigurationProperties {
private String endpoint;
/**
* Interval to be used for the Duration (in milliseconds) of a non-Flux producing Supplier.
* Default is 0, which means the Supplier will only be invoked once.
* Interval to be used for the Duration (in milliseconds) of a non-Flux producing
* Supplier. Default is 0, which means the Supplier will only be invoked once.
*/
private long interval = 0L;
public static final String ROUTE_KEY = "stream_routekey";
public String getEndpoint() {
return endpoint;
}

View File

@@ -83,12 +83,22 @@ public class StreamListeningConsumerInvoker implements SmartInitializingSingleto
name = names.iterator().next();
}
else {
for (String candidate : names) {
Class<?> inputType = functionInspector.getInputType(candidate);
Object value = this.converter.fromMessage(input, inputType);
if (value != null && inputType.isInstance(value)) {
name = candidate;
break;
if (input.getHeaders()
.containsKey(StreamConfigurationProperties.ROUTE_KEY)) {
String key = (String) input.getHeaders()
.get(StreamConfigurationProperties.ROUTE_KEY);
if (functionCatalog.lookupFunction(key) != null) {
return key;
}
}
else {
for (String candidate : names) {
Class<?> inputType = functionInspector.getInputType(candidate);
Object value = this.converter.fromMessage(input, inputType);
if (value != null && inputType.isInstance(value)) {
name = candidate;
break;
}
}
}
}

View File

@@ -16,7 +16,14 @@
package org.springframework.cloud.function.stream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.springframework.beans.factory.SmartInitializingSingleton;
@@ -27,11 +34,13 @@ import org.springframework.cloud.stream.annotation.Output;
import org.springframework.cloud.stream.annotation.StreamListener;
import org.springframework.cloud.stream.converter.CompositeMessageConverterFactory;
import org.springframework.cloud.stream.messaging.Processor;
import org.springframework.cloud.stream.reactive.FluxSender;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.support.MessageBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* @author Mark Fisher
@@ -49,7 +58,11 @@ public class StreamListeningFunctionInvoker implements SmartInitializingSingleto
private final String defaultEndpoint;
private static final String NOENDPOINT = "__NOENDPOINT__";
private final Map<String, FluxMessageProcessor> processors = new HashMap<>();
private int count = -1;
private static final FluxMessageProcessor NOENDPOINT = flux -> Flux.empty();
public StreamListeningFunctionInvoker(FunctionCatalog functionCatalog,
FunctionInspector functionInspector,
@@ -66,40 +79,110 @@ public class StreamListeningFunctionInvoker implements SmartInitializingSingleto
}
@StreamListener
@Output(Processor.OUTPUT)
public Flux<?> handle(@Input(Processor.INPUT) Flux<Message<?>> input) {
return input.groupBy(this::select)
.filter(group -> functionCatalog.lookupFunction(group.key()) != null)
.flatMap(group -> process(group.key(), group));
public Mono<Void> handle(@Input(Processor.INPUT) Flux<Message<?>> input,
@Output(Processor.OUTPUT) FluxSender output) {
return output.send(
input.groupBy(this::select).flatMap(group -> group.key().process(group)));
}
private Flux<?> process(String name, Flux<Message<?>> flux) {
return (Flux<?>) functionCatalog.lookupFunction(name)
.apply(flux.map(message -> convertInput(name).apply(message)));
private Flux<Message<?>> function(String name, Flux<Message<?>> flux) {
// TODO: the routing key could be added here, but really it should be added in
// Spring Cloud Stream
// (https://github.com/spring-cloud/spring-cloud-stream/issues/1010)
AtomicReference<Map<String, Object>> headers = new AtomicReference<Map<String, Object>>(
new LinkedHashMap<>());
return ((Flux<?>) functionCatalog.lookupFunction(name).apply(flux.map(message -> {
Object applied = convertInput(name).apply(message);
headers.set(message.getHeaders());
return applied;
}))).map(result -> message(result, headers.get()));
}
private String select(Message<?> input) {
private Message<?> message(Object result, Map<String, Object> headers) {
return result instanceof Message ? (Message<?>) result
: MessageBuilder.withPayload(result).copyHeadersIfAbsent(headers).build();
}
private Flux<Message<?>> consumer(String name, Flux<Message<?>> flux) {
functionCatalog.lookupConsumer(name)
.accept(flux.map(message -> convertInput(name).apply(message)));
return Flux.empty();
}
private Flux<Message<?>> balance(List<String> names, Flux<Message<?>> flux) {
if (names.isEmpty()) {
return Flux.empty();
}
String name = choose(names);
if (functionCatalog.lookupConsumer(name) != null) {
return consumer(name, flux);
}
return function(name, flux);
}
private synchronized String choose(List<String> names) {
if (++count >= names.size() || count < 0) {
count = 0;
}
return names.get(count);
}
private FluxMessageProcessor select(Message<?> input) {
String name = defaultEndpoint;
if (name != null) {
name = stash(name);
}
if (name == null) {
Set<String> names = functionCatalog.getFunctionNames();
if (input.getHeaders().containsKey(StreamConfigurationProperties.ROUTE_KEY)) {
String key = (String) input.getHeaders()
.get(StreamConfigurationProperties.ROUTE_KEY);
name = stash(key);
}
}
if (name == null) {
Set<String> names = new LinkedHashSet<>(functionCatalog.getFunctionNames());
names.addAll(functionCatalog.getConsumerNames());
List<String> matches = new ArrayList<>();
if (names.size() == 1) {
name = names.iterator().next();
String key = names.iterator().next();
name = stash(key);
}
else {
for (String candidate : names) {
Class<?> inputType = functionInspector.getInputType(candidate);
Object value = this.converter.fromMessage(input, inputType);
if (value != null && inputType.isInstance(value)) {
name = candidate;
break;
matches.add(candidate);
}
}
if (matches.size() == 1) {
name = stash(matches.iterator().next());
}
else {
return flux -> balance(matches, flux);
}
}
}
if (name == null) {
return NOENDPOINT;
}
return name;
return processors.get(name);
}
private String stash(String key) {
if (functionCatalog.lookupFunction(key) != null) {
if (!processors.containsKey(key)) {
processors.put(key, flux -> function(key, flux));
}
return key;
}
else if (functionCatalog.lookupConsumer(key) != null) {
if (!processors.containsKey(key)) {
processors.put(key, flux -> consumer(key, flux));
}
return key;
}
return null;
}
private Function<Message<?>, Object> convertInput(String name) {
@@ -123,4 +206,9 @@ public class StreamListeningFunctionInvoker implements SmartInitializingSingleto
return this.converter.fromMessage(m, inputType);
}
}
interface FluxMessageProcessor {
Flux<Message<?>> process(Flux<Message<?>> flux);
}
}

View File

@@ -21,6 +21,7 @@ import java.util.function.Supplier;
import org.springframework.cloud.function.registry.FunctionCatalog;
import org.springframework.cloud.stream.messaging.Source;
import org.springframework.integration.endpoint.MessageProducerSupport;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
@@ -40,17 +41,19 @@ public class SupplierInvokingMessageProducer<T> extends MessageProducerSupport {
@Override
protected void doStart() {
supplier()
.subscribe(m -> this.sendMessage(MessageBuilder.withPayload(m).build()));
supplier().subscribe(m -> this.sendMessage(m));
}
private Flux<?> supplier() {
private Flux<Message<?>> supplier() {
Supplier<Flux<?>> supplier = null;
Flux<?> result = Flux.empty();
Flux<Message<?>> result = Flux.empty();
for (String name : functionCatalog.getSupplierNames()) {
supplier = functionCatalog.lookupSupplier(name);
Assert.notNull(supplier, "Supplier must not be null");
result = Flux.merge(result, supplier.get());
result = Flux.merge(result,
supplier.get().map(payload -> MessageBuilder.withPayload(payload)
.setHeader(StreamConfigurationProperties.ROUTE_KEY, name)
.build()));
}
return result;
}

View File

@@ -22,12 +22,14 @@ import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.function.stream.StreamConfigurationProperties;
import org.springframework.cloud.stream.messaging.Processor;
import org.springframework.cloud.stream.test.binder.MessageCollector;
import org.springframework.context.annotation.Bean;
@@ -53,8 +55,13 @@ public class PojoStreamingMixedTests {
@Autowired
List<Bar> collector;
@Before
public void init() {
collector.clear();
}
@Test
public void test() throws Exception {
public void balance() throws Exception {
processor.input()
.send(MessageBuilder.withPayload("{\"name\":\"hello\"}").build());
processor.input()
@@ -66,6 +73,19 @@ public class PojoStreamingMixedTests {
assertThat(collector).hasSize(1);
}
@Test
public void routing() throws Exception {
processor.input().send(MessageBuilder.withPayload("{\"name\":\"hello\"}")
.setHeader(StreamConfigurationProperties.ROUTE_KEY, "uppercase").build());
processor.input().send(MessageBuilder.withPayload("{\"name\":\"world\"}")
.setHeader(StreamConfigurationProperties.ROUTE_KEY, "uppercase").build());
Message<?> result = messageCollector.forChannel(processor.output()).poll(1000,
TimeUnit.MILLISECONDS);
assertThat(result.getPayload()).isInstanceOf(Foo.class);
// routing key sends messages to the function, not the consumer
assertThat(collector).hasSize(0);
}
@SpringBootApplication
public static class StreamingFunctionApplication {

View File

@@ -25,6 +25,7 @@ import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.function.stream.StreamConfigurationProperties;
import org.springframework.cloud.stream.messaging.Source;
import org.springframework.cloud.stream.test.binder.MessageCollector;
import org.springframework.context.annotation.Bean;
@@ -48,8 +49,11 @@ public class StreamSupplierTests {
@Test
public void test() throws Exception {
Message<?> result = messageCollector.forChannel(source.output()).poll(1000, TimeUnit.MILLISECONDS);
Message<?> result = messageCollector.forChannel(source.output()).poll(1000,
TimeUnit.MILLISECONDS);
assertThat(result.getPayload()).isEqualTo("foo");
assertThat(result.getHeaders().get(StreamConfigurationProperties.ROUTE_KEY))
.isEqualTo("simpleSupplier");
}
@SpringBootApplication