从源码开始重新认识ThreadLocal

从源码开始重新认识ThreadLocal

最近在巩固Java基础,发现很多平时在使用的东西,其实自己并不了解它的原理,在看了JDK1.8中ThreadLocal这个工具类的源码的同时,也翻看了很多大牛写的博客,总结下来,加深记忆。

简介

从JDK1.2开始,Java就提供了ThreadLocal类。
image

所谓ThreadLocal,是Thread Local Variable(线程局部变量)的意思,ThreadLocal是java.lang包下提供的一个工具类,主要的作用是隔离线程资源,保证线程安全,通过ThreadLocal类,我们可以为每个线程创建一个独立的变量副本,从而避免并发访问时的线程安全问题。

基本方法

ThreadLocal类似于HashMap,保存的是k:v型数据结构,但是他只能保存一个,各个线程的数据互不影响。

ThreadLocal只提供了一个空的构造函数。

1
2
3
4
5
6
/**
* Creates a thread local variable.
* @see #withInitial(java.util.function.Supplier)
*/
public ThreadLocal() {
}

ThreadLocal中的get()方法,不用传入任何参数

1
public T get();

ThreadLocal的set()方法,放入的是一个泛型参数

1
public void set(T value);

ThreadLocal的remove()方法

1
public void remove();

针对ThreadLocal的主要使用就是这三个方法,所以说ThreadLocal的使用其实并没有任何难度,不需要写任何同步代码就可以实现线程安全。

示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public class ThreadLocalExample {
//定义一个String类型的ThreadLocal
private static ThreadLocal<String> localVariable = new ThreadLocal<>();
//定义一个Integer类型的ThreadLocal
private static ThreadLocal<Integer> localVariable1 = new ThreadLocal<>();

/**
* 打印函数
*
* @param
*/
private static void print() {
//打印当前线程本地内存中localVariable变量的值
System.out.println(Thread.currentThread().getName() + " String类型的ThreadLocal: " + localVariable.get());
System.out.println(Thread.currentThread().getName() + " Integer类型的ThreadLocal: " + localVariable1.get());
localVariable.remove();
}

public static void main(String[] args) {
new Thread(new ThreadLocalThread("线程1 data", 9090900), "线程1").start();
new Thread(new ThreadLocalThread("线程2 data", 9999999), "线程2").start();
}

static class ThreadLocalThread implements Runnable {
private String stringThreadLocal;
private Integer integerThreadLocal;

public ThreadLocalThread(String stringThreadLocal, Integer integerThreadLocal) {
this.stringThreadLocal = stringThreadLocal;
this.integerThreadLocal = integerThreadLocal;
}

@Override
public void run() {
out.println("当前线程:" + Thread.currentThread().getName());
localVariable.set(stringThreadLocal);
localVariable1.set(integerThreadLocal);
//调用打印函数
print();
//打印本地变量
System.out.println(Thread.currentThread().getName() + " remove after: " + localVariable.get());
}
}
}

输出结果

1
2
3
4
5
6
7
8
当前线程:线程1
线程1 String类型的ThreadLocal: 线程1 data
线程1 Integer类型的ThreadLocal: 9090900
线程1 remove after: null
当前线程:线程2
线程2 String类型的ThreadLocal: 线程2 data
线程2 Integer类型的ThreadLocal: 9999999
线程2 remove after: null

可以看出线程1和线程2的变量完全隔离开了。

从源码看原理

那ThreadLocal是如何做到这些的呢,先来看看set方法的源码。

1
2
3
4
5
6
7
8
9
10
public void set(T value) {
//获取当前线程对象
Thread t = Thread.currentThread();
//获取ThreadLocalMap对象
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

重点就在于这个ThreadLocalMap,ThreadLocal就是通过这玩意来实现线程隔离的。下面是getMap方法:

1
2
3
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

这里返回的是t对象也就是当前线程对象里面的threadLocals这个变量。我们再看看Thread源码:

1
2
3
4
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class.
*/
ThreadLocal.ThreadLocalMap threadLocals = null;

看注释的意思是:threadLocals是用于修饰当前线程的ThreadLocal值,这个ThreadLocalMap变量由ThreadLocal来维护。

看到这里明白了,ThreadLocal之所以能够隔离线程资源,是因为每个线程的ThreadLocalMap都在当前线程对象里,其他线程根本无法访问到。

继续看set方法的源码,获取到ThreadLocalMap对象后,开始设置值。其中有两个操作:map.set(this, value)和createMap(t, value),第一个是调用ThreadLocalMap的set方法,此处注意:传入的key是当前ThreadLocal对象,createMap方法是调用了ThreadLocalMap的构造方法,同样传入的key也是当前ThreadLocal对象,此处不贴代码了。

get()方法的源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public T get() {
//获取当前线程
Thread t = Thread.currentThread();
//获取ThreadLocalMap对象
ThreadLocalMap map = getMap(t);
if (map != null) {
//拿到ThreadLocalMap中的Entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

从代码可以看出get方法要返回的值是ThreadLocalMap中的Entry对象的value值。

ThreadLocalMap

从上面的分析中,已经认识到了ThreadLocalMap这个类的重要性,ThreadLocalMap是ThreadLocal的一个静态内部类,从命名来看,这也是一个map结构,没错,其实ThreadLocal中很多东西都和HashMap中的很像,接下来继续看ThreadLocalMap的源码。

调用ThreadLocalMap的构造方法,会初始化一个长度为16的Entry数组,每一个
Entry对象保存的都是k-v键值对,key是ThreadLocal,调用ThreadLocal的set方法,相当于是把他自己当成key放进ThreadLocalMap中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/**
* The table, resized as necessary.
* table.length MUST always be a power of two.
*/
private Entry[] table;
/**
* The initial capacity -- MUST be a power of two.
*/
private static final int INITIAL_CAPACITY = 16;

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
//初始值16
table = new Entry[INITIAL_CAPACITY];
//计算下标,类似于HashMap计算bucket的位置,使用的是key的hashcode和length-1取模
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
//阈值默认为length的三分之二,从setThreshold()方法中可以得到
setThreshold(INITIAL_CAPACITY);
}

再看看Entry:

1
2
3
4
5
6
7
8
9
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

Entry继承了WeakReference这个类,并把key保存在了WeakReference中,这代表了Entry的key是一个弱引用,这会导致k也就是ThreadLocal对象在没有外部强引用指向它的时候,他会被gc强制回收。

ThreadLocalMap的set方法,ThreadLocal的set方法也是调用的这个方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/**
* Set the value associated with key.
*
* @param key the thread local object
* @param value the value to be set
*/
private void set(ThreadLocal<?> key, Object value) {

// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.

Entry[] tab = table;
int len = tab.length;
//同HashMap,计算元素位置
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
//key相等,设置值
if (k == key) {
e.value = value;
return;
}
//遇到空槽,设置并替换过期的Entry
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

/**
* Increment i modulo len.
*/
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

set的基本过程是:

  1. 根据key(ThreadLocal)的hashcode计算出Entry的位置,每个ThreadLocal对象都有一个hash值threadLocalHashCode,每初始化一个ThreadLocal对象,hash值就增加一个固定的大小0x61c88647。
  2. 然后和计算出的Entry的key进行比较,如果相等,那么就放入新值
  3. 如果计算出的Entry的k为空,说明已经被gc,就替换过期的Entry值
  4. 如果都没有满足,说明计算出的Entry的key和当前要设置的值没有任何关系,初始化一个新的Entry放入当前的位置

ThreadLocal的内存泄漏

前面说过,Entry的key是个弱引用,如果被jvm的gc回收,那么就会出现一个问题,Entry的value在当前线程一直运行的情况下,Thread中持有ThreadLocalMap对象,相当于持有对Entry对象的强引用,如果线程不停止,Entry的value可能一直得不到回收,时间长了,就会发生内存泄漏。解决的办法是在使用了ThreadLocal的set方法后,显式的调用ThreadLocal的remove方法。

总结

这是一张手画的ThreadLocal的基本原理图

总结下来就是:每个Thread维护一个ThreadLocalMap映射表,这个map的key是ThreadLocal实例本身,value是真正需要存储的Object。ThreadLocal本身并不存储值,它只是作为一个key来让线程从map中获取value,虚线标识弱引用,表示ThreadLocalMap是使用ThreadLocal的弱引用作为key,弱引用在GC时会被回收。

源码看起来虽然很痛苦,但是却能学到很多东西,以前的自己很少去注意这些,只会使用,这样对于一个Java程序员修炼内功是极为不利的,如果有不对的地方,欢迎指出。

持续学习,夯实基础,共勉。

感谢

以下的博客给了我很多帮助

占小狼,狼哥的博客给了我很多帮助

https://www.jianshu.com/p/377bb840802f

@kiraSally

https://www.zybuluo.com/kiraSally/note/854555