292 lines
7.8 KiB
Python
292 lines
7.8 KiB
Python
"""
|
|
This module adds the default axis argument to code which did not specify it
|
|
for the functions where the default was changed in NumPy.
|
|
|
|
The functions changed are
|
|
|
|
add -1 ( all second argument)
|
|
======
|
|
nansum
|
|
nanmax
|
|
nanmin
|
|
nanargmax
|
|
nanargmin
|
|
argmax
|
|
argmin
|
|
compress 3
|
|
|
|
|
|
add 0
|
|
======
|
|
take 3
|
|
repeat 3
|
|
sum # might cause problems with builtin.
|
|
product
|
|
sometrue
|
|
alltrue
|
|
cumsum
|
|
cumproduct
|
|
average
|
|
ptp
|
|
cumprod
|
|
prod
|
|
std
|
|
mean
|
|
"""
|
|
__all__ = ['convertfile', 'convertall', 'converttree']
|
|
|
|
import sys
|
|
import os
|
|
import re
|
|
import glob
|
|
|
|
|
|
_args3 = ['compress', 'take', 'repeat']
|
|
_funcm1 = ['nansum', 'nanmax', 'nanmin', 'nanargmax', 'nanargmin',
|
|
'argmax', 'argmin', 'compress']
|
|
_func0 = ['take', 'repeat', 'sum', 'product', 'sometrue', 'alltrue',
|
|
'cumsum', 'cumproduct', 'average', 'ptp', 'cumprod', 'prod',
|
|
'std', 'mean']
|
|
|
|
_all = _func0 + _funcm1
|
|
func_re = {}
|
|
|
|
for name in _all:
|
|
_astr = r"""%s\s*[(]"""%name
|
|
func_re[name] = re.compile(_astr)
|
|
|
|
|
|
import string
|
|
disallowed = '_' + string.uppercase + string.lowercase + string.digits
|
|
|
|
def _add_axis(fstr, name, repl):
|
|
alter = 0
|
|
if name in _args3:
|
|
allowed_comma = 1
|
|
else:
|
|
allowed_comma = 0
|
|
newcode = ""
|
|
last = 0
|
|
for obj in func_re[name].finditer(fstr):
|
|
nochange = 0
|
|
start, end = obj.span()
|
|
if fstr[start-1] in disallowed:
|
|
continue
|
|
if fstr[start-1] == '.' \
|
|
and fstr[start-6:start-1] != 'numpy' \
|
|
and fstr[start-2:start-1] != 'N' \
|
|
and fstr[start-9:start-1] != 'numarray' \
|
|
and fstr[start-8:start-1] != 'numerix' \
|
|
and fstr[start-8:start-1] != 'Numeric':
|
|
continue
|
|
if fstr[start-1] in ['\t',' ']:
|
|
k = start-2
|
|
while fstr[k] in ['\t',' ']:
|
|
k -= 1
|
|
if fstr[k-2:k+1] == 'def' or \
|
|
fstr[k-4:k+1] == 'class':
|
|
continue
|
|
k = end
|
|
stack = 1
|
|
ncommas = 0
|
|
N = len(fstr)
|
|
while stack:
|
|
if k>=N:
|
|
nochange =1
|
|
break
|
|
if fstr[k] == ')':
|
|
stack -= 1
|
|
elif fstr[k] == '(':
|
|
stack += 1
|
|
elif stack == 1 and fstr[k] == ',':
|
|
ncommas += 1
|
|
if ncommas > allowed_comma:
|
|
nochange = 1
|
|
break
|
|
k += 1
|
|
if nochange:
|
|
continue
|
|
alter += 1
|
|
newcode = "%s%s,%s)" % (newcode, fstr[last:k-1], repl)
|
|
last = k
|
|
if not alter:
|
|
newcode = fstr
|
|
else:
|
|
newcode = "%s%s" % (newcode, fstr[last:])
|
|
return newcode, alter
|
|
|
|
def _import_change(fstr, names):
|
|
# Four possibilities
|
|
# 1.) import numpy with subsequent use of numpy.<name>
|
|
# change this to import numpy.oldnumeric as numpy
|
|
# 2.) import numpy as XXXX with subsequent use of
|
|
# XXXX.<name> ==> import numpy.oldnumeric as XXXX
|
|
# 3.) from numpy import *
|
|
# with subsequent use of one of the names
|
|
# 4.) from numpy import ..., <name>, ... (could span multiple
|
|
# lines. ==> remove all names from list and
|
|
# add from numpy.oldnumeric import <name>
|
|
|
|
num = 0
|
|
# case 1
|
|
importstr = "import numpy"
|
|
ind = fstr.find(importstr)
|
|
if (ind > 0):
|
|
found = 0
|
|
for name in names:
|
|
ind2 = fstr.find("numpy.%s" % name, ind)
|
|
if (ind2 > 0):
|
|
found = 1
|
|
break
|
|
if found:
|
|
fstr = "%s%s%s" % (fstr[:ind], "import numpy.oldnumeric as numpy",
|
|
fstr[ind+len(importstr):])
|
|
num += 1
|
|
|
|
# case 2
|
|
importre = re.compile("""import numpy as ([A-Za-z0-9_]+)""")
|
|
modules = importre.findall(fstr)
|
|
if len(modules) > 0:
|
|
for module in modules:
|
|
found = 0
|
|
for name in names:
|
|
ind2 = fstr.find("%s.%s" % (module, name))
|
|
if (ind2 > 0):
|
|
found = 1
|
|
break
|
|
if found:
|
|
importstr = "import numpy as %s" % module
|
|
ind = fstr.find(importstr)
|
|
fstr = "%s%s%s" % (fstr[:ind],
|
|
"import numpy.oldnumeric as %s" % module,
|
|
fstr[ind+len(importstr):])
|
|
num += 1
|
|
|
|
# case 3
|
|
importstr = "from numpy import *"
|
|
ind = fstr.find(importstr)
|
|
if (ind > 0):
|
|
found = 0
|
|
for name in names:
|
|
ind2 = fstr.find(name, ind)
|
|
if (ind2 > 0) and fstr[ind2-1] not in disallowed:
|
|
found = 1
|
|
break
|
|
if found:
|
|
fstr = "%s%s%s" % (fstr[:ind],
|
|
"from numpy.oldnumeric import *",
|
|
fstr[ind+len(importstr):])
|
|
num += 1
|
|
|
|
# case 4
|
|
ind = 0
|
|
importstr = "from numpy import"
|
|
N = len(importstr)
|
|
while 1:
|
|
ind = fstr.find(importstr, ind)
|
|
if (ind < 0):
|
|
break
|
|
ind += N
|
|
ptr = ind+1
|
|
stack = 1
|
|
while stack:
|
|
if fstr[ptr] == '\\':
|
|
stack += 1
|
|
elif fstr[ptr] == '\n':
|
|
stack -= 1
|
|
ptr += 1
|
|
substr = fstr[ind:ptr]
|
|
found = 0
|
|
substr = substr.replace('\n',' ')
|
|
substr = substr.replace('\\','')
|
|
importnames = [x.strip() for x in substr.split(',')]
|
|
# determine if any of names are in importnames
|
|
addnames = []
|
|
for name in names:
|
|
if name in importnames:
|
|
importnames.remove(name)
|
|
addnames.append(name)
|
|
if len(addnames) > 0:
|
|
fstr = "%s%s\n%s\n%s" % \
|
|
(fstr[:ind],
|
|
"from numpy import %s" % \
|
|
", ".join(importnames),
|
|
"from numpy.oldnumeric import %s" % \
|
|
", ".join(addnames),
|
|
fstr[ptr:])
|
|
num += 1
|
|
|
|
return fstr, num
|
|
|
|
def add_axis(fstr, import_change=False):
|
|
total = 0
|
|
if not import_change:
|
|
for name in _funcm1:
|
|
fstr, num = _add_axis(fstr, name, 'axis=-1')
|
|
total += num
|
|
for name in _func0:
|
|
fstr, num = _add_axis(fstr, name, 'axis=0')
|
|
total += num
|
|
return fstr, total
|
|
else:
|
|
fstr, num = _import_change(fstr, _funcm1+_func0)
|
|
return fstr, num
|
|
|
|
|
|
def makenewfile(name, filestr):
|
|
fid = file(name, 'w')
|
|
fid.write(filestr)
|
|
fid.close()
|
|
|
|
def getfile(name):
|
|
fid = file(name)
|
|
filestr = fid.read()
|
|
fid.close()
|
|
return filestr
|
|
|
|
def copyfile(name, fstr):
|
|
base, ext = os.path.splitext(name)
|
|
makenewfile(base+'.orig', fstr)
|
|
return
|
|
|
|
def convertfile(filename, import_change=False):
|
|
"""Convert the filename given from using Numeric to using NumPy
|
|
|
|
Copies the file to filename.orig and then over-writes the file
|
|
with the updated code
|
|
"""
|
|
filestr = getfile(filename)
|
|
newstr, total = add_axis(filestr, import_change)
|
|
if total > 0:
|
|
print "Changing ", filename
|
|
copyfile(filename, filestr)
|
|
makenewfile(filename, newstr)
|
|
sys.stdout.flush()
|
|
|
|
def fromargs(args):
|
|
filename = args[1]
|
|
convertfile(filename)
|
|
|
|
def convertall(direc=os.path.curdir, import_change=False):
|
|
"""Convert all .py files in the directory given
|
|
|
|
For each file, a backup of <usesnumeric>.py is made as
|
|
<usesnumeric>.py.orig. A new file named <usesnumeric>.py
|
|
is then written with the updated code.
|
|
"""
|
|
files = glob.glob(os.path.join(direc,'*.py'))
|
|
for afile in files:
|
|
convertfile(afile, import_change)
|
|
|
|
def _func(arg, dirname, fnames):
|
|
convertall(dirname, import_change=arg)
|
|
|
|
def converttree(direc=os.path.curdir, import_change=False):
|
|
"""Convert all .py files in the tree given
|
|
|
|
"""
|
|
os.path.walk(direc, _func, import_change)
|
|
|
|
if __name__ == '__main__':
|
|
fromargs(sys.argv)
|