summaryrefslogtreecommitdiff
path: root/structure/correlations/compute_tau.py
blob: c4e47b1a8ba225f87e7a961b6fb72e42fc6ac8e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# This file is part of MAMMULT: Metrics And Models for Multilayer Networks
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at
# your option) any later version.
# 
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
####
##
## Take as input two files, whose n^th line contains the ranking of
## element n, and compute the Kendall's \tau_b rank correlation
## coefficient
##
##

import sys
from numpy import *


def kendalltau(x,y):
    initial_sort_with_lexsort = True # if True, ~30% slower (but faster under profiler!) but with better worst case (O(n log(n)) than (quick)sort (O(n^2))
    n = len(x)
    temp = range(n) # support structure used by mergesort
    # this closure recursively sorts sections of perm[] by comparing 
    # elements of y[perm[]] using temp[] as support
    # returns the number of swaps required by an equivalent bubble sort
    def mergesort(offs, length):
        exchcnt = 0
        if length == 1:
            return 0
        if length == 2:
            if y[perm[offs]] <= y[perm[offs+1]]:
                return 0
            t = perm[offs]
            perm[offs] = perm[offs+1]
            perm[offs+1] = t
            return 1
        length0 = length / 2
        length1 = length - length0
        middle = offs + length0
        exchcnt += mergesort(offs, length0)
        exchcnt += mergesort(middle, length1)
        if y[perm[middle - 1]] < y[perm[middle]]:
            return exchcnt
        # merging
        i = j = k = 0
        while j < length0 or k < length1:
            if k >= length1 or (j < length0 and y[perm[offs + j]] <= y[perm[middle + k]]):
                temp[i] = perm[offs + j]
                d = i - j
                j += 1
            else:
                temp[i] = perm[middle + k]
                d = (offs + i) - (middle + k)
                k += 1
            if d > 0:
                exchcnt += d;
            i += 1
        perm[offs:offs+length] = temp[0:length]
        return exchcnt
    
    # initial sort on values of x and, if tied, on values of y
    if initial_sort_with_lexsort:
        # sort implemented as mergesort, worst case: O(n log(n))
        perm = lexsort((y, x))
    else:
        # sort implemented as quicksort, 30% faster but with worst case: O(n^2)
        perm = range(n)
        perm.sort(lambda a,b: cmp(x[a],x[b]) or cmp(y[a],y[b]))
    
    # compute joint ties
    first = 0
    t = 0
    for i in xrange(1,n):
        if x[perm[first]] != x[perm[i]] or y[perm[first]] != y[perm[i]]:
            t += ((i - first) * (i - first - 1)) / 2
            first = i
    t += ((n - first) * (n - first - 1)) / 2
    
    # compute ties in x
    first = 0
    u = 0
    for i in xrange(1,n):
        if x[perm[first]] != x[perm[i]]:
            u += ((i - first) * (i - first - 1)) / 2
            first = i
    u += ((n - first) * (n - first - 1)) / 2
    
    # count exchanges 
    exchanges = mergesort(0, n)
    # compute ties in y after mergesort with counting
    first = 0
    v = 0
    for i in xrange(1,n):
        if y[perm[first]] != y[perm[i]]:
            v += ((i - first) * (i - first - 1)) / 2
            first = i
    v += ((n - first) * (n - first - 1)) / 2
    
    tot = (n * (n - 1)) / 2
    if tot == u and tot == v:
        return 1    # Special case for all ties in both ranks
    
    tau = ((tot-(v+u-t)) - 2.0 * exchanges) / (sqrt(float(( tot - u )) * float( tot - v )))
    
    # what follows reproduces ending of Gary Strangman's original stats.kendalltau() in SciPy
    svar = (4.0*n+10.0) / (9.0*n*(n-1))
    z = tau / sqrt(svar)
    ##prob = erfc(abs(z)/1.4142136)
    ##return tau, prob
    return tau

def main():

    if len(sys.argv) < 3:
        print "Usage: %s <file1> <file2>" % sys.argv[0]
        sys.exit(1)

    x1 = []
    x2= []

    lines = open(sys.argv[1]).readlines()

    for l in lines:
        elem = [float(x) if "e" in x or "." in x else int(x) for x in l.strip(" \n").split()][0]
        x1.append(elem)

    lines = open(sys.argv[2]).readlines()

    for l in lines:
        elem = [float(x) if "e" in x or "." in x else int(x) for x in l.strip(" \n").split()][0]
        x2.append(elem)
    

    tau = kendalltau(x1,x2)
    print tau


if  __name__ == "__main__":
    main()