Lua中的table.sort算法原理
table.sort的介绍
table.sort是Lua自带的一个排序函数,函数原型为:
1 | table.sort(list[, comp]) |
其中list是目标table,comp是一个可选参数,可以自定义比较函数;当不提供comp函数时则默认按照升序进行排序;这里需要注意table.sort是一个不稳定的排序算法;同时排序的table必须是一个数组,并且数组的索引必须是连续的;
table.sort的算法原理
首先来看Lua源码中的sort方法的实现原理:
1 2 3 4 5 6 7 8 9 | 1: static int sort (lua_State *L) { 2: int n = aux_getn(L, 1); 3: luaL_checkstack(L, 40, ""); /* assume array is smaller than 2^40 */ 4: if (!lua_isnoneornil(L, 2)) /* is there a 2nd argument? */ 5: luaL_checktype(L, 2, LUA_TFUNCTION); 6: lua_settop(L, 2); /* make sure there is two arguments */ 7: auxsort(L, 1, n); 8: return 0; 9: } |
这个方法主要是获取数组的大小,并对参数进行相应的校验;之后调用auxsort方法来实现排序算法;
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | 1: static void auxsort (lua_State *L, int l, int u) { 2: while (l < u) { /* for tail recursion */ 3: int i, j; 4: /* sort elements a[l], a[(l+u)/2] and a[u] */ 5: lua_rawgeti(L, 1, l); 6: lua_rawgeti(L, 1, u); 7: if (sort_comp(L, -1, -2)) /* a[u] < a[l]? */ 8: set2(L, l, u); /* swap a[l] - a[u] */ 9: else 10: lua_pop(L, 2); 11: if (u-l == 1) break; /* only 2 elements */ 12: i = (l+u)/2; 13: lua_rawgeti(L, 1, i); 14: lua_rawgeti(L, 1, l); 15: if (sort_comp(L, -2, -1)) /* a[i]<a[l]? */ 16: set2(L, i, l); 17: else { 18: lua_pop(L, 1); /* remove a[l] */ 19: lua_rawgeti(L, 1, u); 20: if (sort_comp(L, -1, -2)) /* a[u]<a[i]? */ 21: set2(L, i, u); 22: else 23: lua_pop(L, 2); 24: } 25: if (u-l == 2) break; /* only 3 elements */ 26: lua_rawgeti(L, 1, i); /* Pivot */ 27: lua_pushvalue(L, -1); 28: lua_rawgeti(L, 1, u-1); 29: set2(L, i, u-1); 30: /* a[l] <= P == a[u-1] <= a[u], only need to sort from l+1 to u-2 */ 31: i = l; j = u-1; 32: for (;;) { /* invariant: a[l..i] <= P <= a[j..u] */ 33: /* repeat ++i until a[i] >= P */ 34: while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) { 35: if (i>u) luaL_error(L, "invalid order function for sorting"); 36: lua_pop(L, 1); /* remove a[i] */ 37: } 38: /* repeat --j until a[j] <= P */ 39: while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) { 40: if (j<l) luaL_error(L, "invalid order function for sorting"); 41: lua_pop(L, 1); /* remove a[j] */ 42: } 43: if (j<i) { 44: lua_pop(L, 3); /* pop pivot, a[i], a[j] */ 45: break; 46: } 47: set2(L, i, j); 48: } 49: lua_rawgeti(L, 1, u-1); 50: lua_rawgeti(L, 1, i); 51: set2(L, u-1, i); /* swap pivot (a[u-1]) with a[i] */ 52: /* a[l..i-1] <= a[i] == P <= a[i+1..u] */ 53: /* adjust so that smaller half is in [j..i] and larger one in [l..u] */ 54: if (i-l < u-i) { 55: j=l; i=i-1; l=i+2; 56: } 57: else { 58: j=i+1; i=u; u=j-2; 59: } 60: auxsort(L, j, i); /* call recursively the smaller one */ 61: } /* repeat the routine for the larger one */ 62: } |
上面是Lua源码中的sort方法的实现,Lua源码中包含了很多对堆栈的操作,为了更直观的看到算法的实现,对上面的方法进行了伪代码的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | 1: sort(array list, int l, int u) 2: { 3: while(l < u) 4: { 5: if(list[u] < list[l]) 6: swap(list[u], list[l]) 7: 8: if(u - l == 1) 9: break 10: int i = (l + u)/2 11: if(list[i] < list[l]) 12: swap(list[i], list[l]) 13: else if(list[u] < list[i]) 14: swap(list[u], list[l]) 15: 16: if(u - l == 2) 17: break 18: 19: int p = i 20: swap(list[i], list[u-1]) 21: i = l 22: j = u-1 23: for(;;) 24: { 25: while(++i, list[i] < list[p]) 26: { 27: if(i > u) error("invalid order function for sorting") 28: } 29: while(--j, list[j] > list[p]) 30: { 27: if(j < l) error("invalid order function for sorting") 32: } 33: if(j < i) 34: break 35: 36: swap(list[i], list[j]) 37: } 38: 39: swap(list[u-1], list[i]) 40: if (i-l < u-i) 41: j=l; i=i-1; l=i+2; 42: else 43: j=i+1; i=u; u=j-2; 44: 45: sort(list, j, i) 46: } 47: } |
通过伪代码可以看出sort的核心算法的本质是快速排序,所以说table.sort排序是不稳定的排序;第5-17行是使用三数取中的方法对排序算法进行了优化;第39-43行主要是通过对中间值所在位置的判断,来设置下次排序的起始和结束索引;
下面看一下sort_comp方法的实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | 1: static int sort_comp (lua_State *L, int a, int b) { 2: if (!lua_isnil(L, 2)) { /* function? */ 3: int res; 4: lua_pushvalue(L, 2); 5: lua_pushvalue(L, a-1); /* -1 to compensate function */ 6: lua_pushvalue(L, b-2); /* -2 to compensate function and `a' */ 7: lua_call(L, 2, 1); 8: res = lua_toboolean(L, -1); 9: lua_pop(L, 1); 10: return res; 11: } 12: else /* a < b? */ 13: return lua_lessthan(L, a, b); 14: } |
可以看到在第二行对table.sort的第二个参数comp进行了检查,当comp参数为空时,则使用系统提供的升序排序方法;
(注:以上都是自己的理解,欢迎各位大佬指正!)