Skip to content

Commit 6d2424c

Browse files
Merge pull request #29 from rstoyanchev/shortcuts
Add shortcuts to propagate from Reactor context to ThreadLocal values
2 parents cc44973 + 1880bce commit 6d2424c

File tree

5 files changed

+114
-46
lines changed

5 files changed

+114
-46
lines changed

context-propagation/src/main/java/io/micrometer/context/ContextAccessor.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ public interface ContextAccessor<READ, WRITE> {
4747
*/
4848
void readValues(READ sourceContext, Predicate<Object> keyPredicate, Map<Object, Object> readValues);
4949

50+
/**
51+
* Read a single value from the source context.
52+
* @param sourceContext the context to read from; the context type should be
53+
* checked with {@link #canReadFrom(Class)} before this method is called
54+
* @param key the key to use to look up the context value
55+
* @return the value, if any
56+
*/
57+
@Nullable
58+
<T> T readValue(READ sourceContext, Object key);
59+
5060
/**
5161
* Whether this accessor can restore values to the given type of context.
5262
* @param contextType the type of external context

context-propagation/src/main/java/io/micrometer/context/ContextSnapshot.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static ContextSnapshot captureUsing(Predicate<Object> keyPredicate, Object... co
155155
/**
156156
* Variant of {@link #captureUsing(Predicate, Object...)} with a specific
157157
* {@link ContextRegistry} instead of the global instance.
158-
* @param contextRegistry the {@code ContextRegistry} instance to use
158+
* @param contextRegistry the registry with the accessors to use
159159
* @param keyPredicate predicate for context value keys
160160
* @param contexts one more context objects to extract values from
161161
* @return a snapshot with saved context values
@@ -166,6 +166,35 @@ static ContextSnapshot captureUsing(
166166
return DefaultContextSnapshot.capture(contextRegistry, keyPredicate, contexts);
167167
}
168168

169+
/**
170+
* Read the values specified by from the given source context, and if found,
171+
* use them to set {@link ThreadLocal} values. Essentially, a shortcut that
172+
* bypasses the need to create of {@link ContextSnapshot} first via
173+
* {@link #capture(Object...)}, followed by {@link #setThreadLocalValues()}.
174+
* @param sourceContext the source context to read values from
175+
* @param keys the keys of the values to read
176+
* @return an object that can be used to reset {@link ThreadLocal} values
177+
* at the end of the context scope, either removing them or restoring their
178+
* previous values, if any.
179+
*/
180+
static Scope setThreadLocalsFrom(Object sourceContext, String... keys) {
181+
return setThreadLocalsFrom(sourceContext, ContextRegistry.getInstance(), keys);
182+
}
183+
184+
/**
185+
* Variant of {@link #setThreadLocalsFrom(Object, String...)} with a specific
186+
* {@link ContextRegistry} instead of the global instance.
187+
* @param sourceContext the source context to read values from
188+
* @param contextRegistry the registry with the accessors to use
189+
* @param keys the keys of the values to read
190+
* @return an object that can be used to reset {@link ThreadLocal} values
191+
* at the end of the context scope, either removing them or restoring their
192+
* previous values, if any.
193+
*/
194+
static Scope setThreadLocalsFrom(Object sourceContext, ContextRegistry contextRegistry, String... keys) {
195+
return DefaultContextSnapshot.setThreadLocalsFrom(sourceContext, contextRegistry, keys);
196+
}
197+
169198

170199
/**
171200
* An object to use to reset {@link ThreadLocal} values at the end of a

context-propagation/src/main/java/io/micrometer/context/DefaultContextSnapshot.java

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
package io.micrometer.context;
1717

1818
import java.util.HashMap;
19-
import java.util.HashSet;
2019
import java.util.Map;
21-
import java.util.Set;
2220
import java.util.function.Predicate;
2321

2422
/**
@@ -31,11 +29,11 @@ final class DefaultContextSnapshot extends HashMap<Object, Object> implements Co
3129

3230
private static final ContextSnapshot emptyContextSnapshot = new DefaultContextSnapshot(new ContextRegistry());
3331

34-
private final ContextRegistry accessorRegistry;
32+
private final ContextRegistry contextRegistry;
3533

3634

37-
DefaultContextSnapshot(ContextRegistry accessorRegistry) {
38-
this.accessorRegistry = accessorRegistry;
35+
DefaultContextSnapshot(ContextRegistry contextRegistry) {
36+
this.contextRegistry = contextRegistry;
3937
}
4038

4139

@@ -48,7 +46,7 @@ public <C> C updateContext(C context) {
4846
public <C> C updateContext(C context, Predicate<Object> keyPredicate) {
4947
if (!isEmpty()) {
5048
Map<Object, Object> valuesToWrite = new HashMap<>();
51-
forEach((key, value) -> {
49+
this.forEach((key, value) -> {
5250
if (keyPredicate.test(key)) {
5351
valuesToWrite.put(key, value);
5452
}
@@ -61,7 +59,7 @@ public <C> C updateContext(C context, Predicate<Object> keyPredicate) {
6159
@SuppressWarnings("unchecked")
6260
private <C> C updateContextInternal(C context, Map<Object, Object> valueContainer) {
6361
if (!isEmpty()) {
64-
ContextAccessor<?, ?> accessor = this.accessorRegistry.getContextAccessorForWrite(context);
62+
ContextAccessor<?, ?> accessor = this.contextRegistry.getContextAccessorForWrite(context);
6563
context = ((ContextAccessor<?, C>) accessor).writeValues(valueContainer, context);
6664
}
6765
return context;
@@ -74,27 +72,41 @@ public Scope setThreadLocalValues() {
7472

7573
@Override
7674
public Scope setThreadLocalValues(Predicate<Object> keyPredicate) {
77-
Set<Object> keys = null;
7875
Map<Object, Object> previousValues = null;
79-
for (ThreadLocalAccessor<?> accessor : this.accessorRegistry.getThreadLocalAccessors()) {
76+
for (ThreadLocalAccessor<?> accessor : this.contextRegistry.getThreadLocalAccessors()) {
8077
Object key = accessor.key();
81-
if (keyPredicate.test(key) && containsKey(key)) {
82-
keys = (keys != null ? keys : new HashSet<>());
83-
keys.add(key);
84-
85-
Object previousValue = accessor.getValue();
86-
previousValues = (previousValues != null ? previousValues : new HashMap<>());
87-
previousValues.put(key, previousValue);
88-
89-
setThreadLocalValue(key, accessor);
78+
if (keyPredicate.test(key) && this.containsKey(key)) {
79+
previousValues = setThreadLocal(key, get(key), accessor, previousValues);
9080
}
9181
}
92-
return (keys != null ? new DefaultScope(keys, previousValues) : () -> { });
82+
return DefaultScope.from(previousValues, this.contextRegistry);
9383
}
9484

9585
@SuppressWarnings("unchecked")
96-
private <V> void setThreadLocalValue(Object key, ThreadLocalAccessor<?> accessor) {
97-
((ThreadLocalAccessor<V>) accessor).setValue((V) get(key));
86+
private static <V> Map<Object, Object> setThreadLocal(
87+
Object key, V value, ThreadLocalAccessor<?> accessor, @Nullable Map<Object, Object> previousValues) {
88+
89+
previousValues = (previousValues != null ? previousValues : new HashMap<>());
90+
previousValues.put(key, accessor.getValue());
91+
((ThreadLocalAccessor<V>) accessor).setValue(value);
92+
return previousValues;
93+
}
94+
95+
@SuppressWarnings("unchecked")
96+
static <C> Scope setThreadLocalsFrom(Object context, ContextRegistry registry, String... keys) {
97+
ContextAccessor<?, ?> contextAccessor = registry.getContextAccessorForRead(context);
98+
Map<Object, Object> previousValues = null;
99+
for (String key : keys) {
100+
Object value = ((ContextAccessor<C, ?>) contextAccessor).readValue((C) context, key);
101+
if (value != null) {
102+
for (ThreadLocalAccessor<?> threadLocalAccessor : registry.getThreadLocalAccessors()) {
103+
if (key.equals(threadLocalAccessor.key())) {
104+
previousValues = setThreadLocal(key, value, threadLocalAccessor, previousValues);
105+
}
106+
}
107+
}
108+
}
109+
return DefaultScope.from(previousValues, registry);
98110
}
99111

100112
@SuppressWarnings("unchecked")
@@ -128,21 +140,21 @@ public String toString() {
128140
/**
129141
* Default implementation of {@link Scope}.
130142
*/
131-
private class DefaultScope implements Scope {
132-
133-
private final Set<Object> keys;
143+
private static class DefaultScope implements Scope {
134144

135145
private final Map<Object, Object> previousValues;
136146

137-
private DefaultScope(Set<Object> keys, Map<Object, Object> previousValues) {
138-
this.keys = keys;
147+
private final ContextRegistry contextRegistry;
148+
149+
private DefaultScope(Map<Object, Object> previousValues, ContextRegistry contextRegistry) {
139150
this.previousValues = previousValues;
151+
this.contextRegistry = contextRegistry;
140152
}
141153

142154
@Override
143155
public void close() {
144-
for (ThreadLocalAccessor<?> accessor : accessorRegistry.getThreadLocalAccessors()) {
145-
if (this.keys.contains(accessor.key())) {
156+
for (ThreadLocalAccessor<?> accessor : this.contextRegistry.getThreadLocalAccessors()) {
157+
if (this.previousValues.containsKey(accessor.key())) {
146158
Object previousValue = this.previousValues.get(accessor.key());
147159
resetThreadLocalValue(accessor, previousValue);
148160
}
@@ -158,6 +170,11 @@ private <V> void resetThreadLocalValue(ThreadLocalAccessor<?> accessor, @Nullabl
158170
accessor.reset();
159171
}
160172
}
173+
174+
public static Scope from(@Nullable Map<Object, Object> previousValues, ContextRegistry registry) {
175+
return (previousValues != null ? new DefaultScope(previousValues, registry) : () -> { });
176+
}
177+
161178
}
162179

163180
}

context-propagation/src/test/java/io/micrometer/context/DefaultContextSnapshotTests.java

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
package io.micrometer.context;
1717

1818

19+
import java.util.Collections;
20+
import java.util.Map;
21+
22+
import io.micrometer.context.ContextSnapshot.Scope;
1923
import org.junit.jupiter.api.Test;
2024

2125
import static org.assertj.core.api.Assertions.assertThat;
@@ -35,31 +39,32 @@ public class DefaultContextSnapshotTests {
3539
void should_propagate_thread_local() {
3640
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
3741

38-
then(ObservationThreadLocalHolder.getValue()).isNull();
3942
ObservationThreadLocalHolder.setValue("hello");
40-
4143
ContextSnapshot snapshot = ContextSnapshot.captureUsing(this.registry, key -> true);
4244

43-
ObservationThreadLocalHolder.reset();
44-
then(ObservationThreadLocalHolder.getValue()).isNull();
45-
46-
try (ContextSnapshot.Scope scope = snapshot.setThreadLocalValues()) {
47-
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hello");
45+
ObservationThreadLocalHolder.setValue("hola");
46+
try {
47+
try (Scope scope = snapshot.setThreadLocalValues()) {
48+
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hello");
49+
}
50+
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hola");
51+
}
52+
finally {
53+
ObservationThreadLocalHolder.reset();
4854
}
49-
50-
then(ObservationThreadLocalHolder.getValue()).isNull();
5155
}
5256

5357
@Test
54-
void should_reset_to_thread_local_to_previous_value() {
58+
void should_propagate_single_thread_local_value() {
59+
this.registry.registerContextAccessor(new TestContextAccessor());
5560
this.registry.registerThreadLocalAccessor(new ObservationThreadLocalAccessor());
5661

57-
ObservationThreadLocalHolder.setValue("hello");
58-
ContextSnapshot snapshot = ContextSnapshot.captureUsing(this.registry, key -> true);
62+
String key = ObservationThreadLocalAccessor.KEY;
63+
Map<String, String> sourceContext = Collections.singletonMap(key, "hello");
5964

6065
ObservationThreadLocalHolder.setValue("hola");
6166
try {
62-
try (ContextSnapshot.Scope scope = snapshot.setThreadLocalValues()) {
67+
try (Scope scope = ContextSnapshot.setThreadLocalsFrom(sourceContext, this.registry, key)) {
6368
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hello");
6469
}
6570
then(ObservationThreadLocalHolder.getValue()).isEqualTo("hola");
@@ -80,7 +85,7 @@ void should_not_fail_on_empty_thread_local() {
8085
ObservationThreadLocalHolder.reset();
8186
then(ObservationThreadLocalHolder.getValue()).isNull();
8287

83-
try (ContextSnapshot.Scope scope = snapshot.setThreadLocalValues()) {
88+
try (Scope scope = snapshot.setThreadLocalValues()) {
8489
then(ObservationThreadLocalHolder.getValue()).isNull();
8590
}
8691

@@ -104,7 +109,7 @@ void should_filter_thread_locals_on_capture() {
104109
fooThreadLocal.remove();
105110
barThreadLocal.remove();
106111

107-
try (ContextSnapshot.Scope scope = snapshot.setThreadLocalValues()) {
112+
try (Scope scope = snapshot.setThreadLocalValues()) {
108113
then(fooThreadLocal.get()).isEqualTo("fooValue");
109114
then(barThreadLocal.get()).isNull();
110115
}
@@ -130,12 +135,12 @@ void should_filter_thread_locals_on_restore() {
130135
fooThreadLocal.remove();
131136
barThreadLocal.remove();
132137

133-
try (ContextSnapshot.Scope scope = snapshot.setThreadLocalValues(key -> key.equals("foo"))) {
138+
try (Scope scope = snapshot.setThreadLocalValues(key -> key.equals("foo"))) {
134139
then(fooThreadLocal.get()).isEqualTo("fooValue");
135140
then(barThreadLocal.get()).isNull();
136141
}
137142

138-
try (ContextSnapshot.Scope scope = snapshot.setThreadLocalValues(key -> key.equals("bar"))) {
143+
try (Scope scope = snapshot.setThreadLocalValues(key -> key.equals("bar"))) {
139144
then(fooThreadLocal.get()).isNull();
140145
then(barThreadLocal.get()).isEqualTo("barValue");
141146
}

context-propagation/src/test/java/io/micrometer/context/TestContextAccessor.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ public void readValues(Map<?, ?> sourceContext, Predicate<Object> keyPredicate,
3636
readValues.putAll(sourceContext);
3737
}
3838

39+
@SuppressWarnings("unchecked")
40+
@Nullable
41+
@Override
42+
public <T> T readValue(Map<?, ?> sourceContext, Object key) {
43+
return (T) sourceContext.get(key);
44+
}
45+
3946
@Override
4047
public boolean canWriteTo(Class<?> contextType) {
4148
return Map.class.isAssignableFrom(contextType);

0 commit comments

Comments
 (0)