learn and grow up

ThreadLocal内的小秘密-巧妙降低内存泄露风险

字数统计: 2.1k阅读时长: 10 min
2019/08/23 Share

ThreadLocal是我们最熟悉的一个多线程工具类,它是用来实现多线程场景中每个线程可以拥有自己的变量对象,而且不受其他线程的影响,做到数据隔离。

如果按照我们平时开发的习惯和思路,想要实现上述功能,大概会写出下面的代码,写一个工具类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public class MyThreadLocal {
private MyThreadLocal(){

}
//存放每个thread的私有变量,key为thread自己
private static final Map<Thread,Object> myloaclMap=new HashMap<Thread,Object>();
private volatile static MyThreadLocal instance;

public <T extends Thread> Object get(T t){
return myloaclMap.get(t);
}

public <T extends Thread> void set(T s,Object value){
myloaclMap.put(s, value);
}

//双检锁方式实现单例模式
public static MyThreadLocal getInstance(){
if(instance==null){
synchronized(MyThreadLocal.class){
if(instance==null){
instance=new MyThreadLocal();
}
}
}
return instance;
}

}

但是这样写是最初级的实现方法,而且有很大的内存泄露风险:Thread结束后如果我们不手动remove掉map中的对象,那么map内的这个thread对应的value会一直存在于内存中无法释放被

那我们来看看ThreadLocal是如何实现的呢?

当我们调用ThreadLocal.get()时,内部是这样的:

java.lang.ThreadLocal.get()—>java.lang.ThreadLocal.getMap(Thread);

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

可以看到,这里是先获取了当前线程,再从线程内获取threadLocals,这个变量的类型为:ThreadLocal.ThreadLocalMap。

当线程的threadLocals为空的时候,他是这么处理:

java.lang.ThreadLocal.setInitialValue()—>java.lang.ThreadLocal.initialValue()—>java.lang.ThreadLocal.createMap(Thread, T);

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//设置默认值,这里就看到我们为什么需要重写initialValue函数,不然的话就会存一个空值进去
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocal.set()时大同小异:

1
2
3
4
5
6
7
8
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

看到这里我们就知道了,其实线程中所有的自己独有变量都存放到了:threadLocals里。而thredlocal只是一个从来操作和实现逻辑的工具类~

这样的话,原来是thread是自己随身携带了口袋去保存了这些变量,如果thread关闭了,那么threadLocals内的变量也就不存在强引用,占用的内存也就得以释放。

但是,还有个问题!!!

如果thread是存放在线程池的,整个线程不结束,那threadLocals内无用的变量所占用的内存还是需要我们手工去释放???不然还是一样会内存泄露!?放心,大佬们已经想到了这个问题,他们是如何解决这个问题的呢?答案就是:WeakReference-弱引用

我们来继续看代码:创建threadLocals的时候是怎么样写的:

java.lang.ThreadLocal.createMap(Thread, T)—>java.lang.ThreadLocal.ThreadLocalMap.ThreadLocalMap(ThreadLocal<?>, Object)

1
2
3
4
5
6
7
8
9
10
11
12
13
 void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

.............
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
//存入firstValue对象
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}

可以看到threadLocal把我们的值存放到了table内,table是private Entry[] table;

PS:看到这里,我们好像看到了HashMap影子:)

数组下标为特殊算法算出来的key,值为java.lang.ThreadLocal.ThreadLocalMap.Entry类型,重点就是这个

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
//super(k)指向:
//WeakReference WeakReference<T> extends Reference<T>
//java.lang.ref.Reference.Reference(T)的:java.lang.ref.Reference.referent
Reference(T referent) {
this(referent, null);
}

我们再回过头来仔细看threadlocal.get()如何获取的值的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
//重点
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

...................
//java.lang.ThreadLocal.ThreadLocalMap.getEntry(ThreadLocal<?>)
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

看到了吧:先根据同样算法获取下标值,再获取table指定位置的值。再e.get()与threadlocal判断相等,成功就返回所存的值。

那么问题来了,为什么要这么写呢?因为:这里的entry.referent是弱引用于threadlocal,也就是它也引用了threadlocal用来保存threadlocal对应的value,但是如果其他强引用threadlocal的变量已经被销毁(比如线程内某个方法内new的变量内包含threadlocal类型的变量,方法退出后,这个new的变量随之被回收,强引用消失),那么gc会在下次把这个内存回收掉,那么entry.referent就会指向null。

这样的话,thread.threadLocals内的table就会存在一个这样的数据:table[n]内的referent为null,value还是原来的value,那么需要我们怎么处理才会彻底回收table[n]所占用的内存呢?对了,就是table[n]=null,删除引用,那我们来看看threadLocal是怎么帮我们实现的:

我们每调用一次threadLocal.remove(),都会调用下面段代码():

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//java.lang.ThreadLocal.ThreadLocalMap.remove(ThreadLocal<?>)
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

............................
//java.lang.ThreadLocal.ThreadLocalMap.expungeStaleEntry(int)
//关键函数,清空不需要的threadlcal=value所占用的table[]的引用,让gc可以回收内存。
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// expunge entry at staleSlot
//删除该节点强引用,
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
//继续检查后续节点。key为null的也删除强引用
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

。。。。。。。。。。。。。。。。。。

//比如java.lang.ThreadLocal.ThreadLocalMap.set(ThreadLocal<?>, Object)
//再看set和get,里面都是利用循环查找,发现table[n].refernt为null的,那么就删除table[n]与其的强引用,让gc可以顺利回收无用的内存
private void set(ThreadLocal<?> key, Object value) {

// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.

Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

//获取table内存放的treadlocal下标的对象,以防threadLocalHashCode出现碰撞,循环判断key是否相等,顺便...
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

if (k == key) {
e.value = value;
return;
}

//如果为空,重点,整理链条,清空value的强引用,
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

...........................
//java.lang.ThreadLocal.ThreadLocalMap.replaceStaleEntry(ThreadLocal<?>, Object, int)
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// Back up to check for prior stale entry in current run.
// We clean out whole runs at a time to avoid continual
// incremental rehashing due to garbage collector freeing
// up refs in bunches (i.e., whenever the collector runs).
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// Find either the key or trailing null slot of run, whichever
// occurs first
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

// If we find key, then we need to swap it
// with the stale entry to maintain hash table order.
// The newly stale slot, or any other stale slot
// encountered above it, can then be sent to expungeStaleEntry
// to remove or rehash all of the other entries in run.
if (k == key) {
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;

// Start expunge at preceding stale entry if it exists
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// If we didn't find stale entry on backward scan, the
// first stale entry seen while scanning for key is the
// first still present in the run.
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

由此可见,我们为了防止内存泄露,在线程池内,threadLocal不用了的话,还是要显示调用threadlocal对应的get或者set方法,最好是remove()方法,以便回收内存!!!!!,

CATALOG