简体   繁体   中英

Python compare every line in file with all others

I am implementing a statistical program and have created a performance bottleneck and was hoping that I could obtain some help from the community to possibly point me in the direction of optimization.

I am creating a set for each row in a file and finding the intersection of that set by comparing the set data of each row in the same file. I then use the size of that intersection to filter certain sets from the output. The problem is that I have a nested for loop (O(n 2 )) and the standard size of the files incoming into the program are just over 20,000 lines long. I have timed the algorithm and for under 500 lines it runs in about 20 minutes but for the big files it takes about 8 hours to finish.

I have 16GB of RAM at disposal and a significantly quick 4-core Intel i7 processor. I have noticed no significant difference in memory use by copying the list1 and using a second list for comparison instead of opening the file again(maybe this is because I have an SSD?). I thought the 'with open' mechanism reads/writes directly to the HDD which is slower but noticed no difference when using two lists. In fact, the program rarely uses more than 1GB of RAM during operation.

I am hoping that other people have used a certain datatype or maybe better understands multiprocessing in Python and that they might be able to help me speed things up. I appreciate any help and I hope my code isn't too poorly written.

import ast, sys, os, shutil
list1 = []
end = 0
filterValue = 3

# creates output file with filterValue appended to name
with open(arg2 + arg1 + "/filteredSets" + str(filterValue) , "w") as outfile:
    with open(arg2 + arg1 + "/file", "r") as infile:
        # create a list of sets of rows in file
        for row in infile:
            list1.append(set(ast.literal_eval(row)))

            infile.seek(0)
            for row in infile:
                # if file only has one row, no comparisons need to be made
                if not(len(list1) == 1):
                # get the first set from the list and...
                    set1 = set(ast.literal_eval(row))
                    # ...find the intersection of every other set in the file
                    for i in range(0, len(list1)):
                        # don't compare the set with itself
                        if not(pos == i):
                            set2 = list1[i]
                            set3 = set1.intersection(set2)
                            # if the two sets have less than 3 items in common
                            if(len(set3) < filterValue):
                                # and you've reached the end of the file
                                if(i == len(list1)):
                                    # append the row in outfile
                                    outfile.write(row)
                                    # increase position in infile
                                    pos += 1
                            else:
                                break
                        else:
                            outfile.write(row)

Sample input would be a file with this format:

[userID1, userID2, userID3]
[userID5, userID3, userID9]
[userID10, userID2, userID3, userID1]
[userID8, userID20, userID11, userID1]

The output file if this were the input file would be:

[userID5, userID3, userID9]
[userID8, userID20, userID11, userID1]

...because the two sets removed contained three or more of the same user id's.

视觉示例

This answer is not about how to split code in functions, name variables etc. It's about faster algorithm in terms of complexity.


I'd use a dictionary. Will not write exact code, you can do it yourself.

Sets = dict()
for rowID, row in enumerate(Rows):
  for userID in row:
     if Sets.get(userID) is None:
       Sets[userID] = set()
     Sets[userID].add(rowID)

So, now we have a dictionary, which can be used to quickly obtain rownumbers of rows containing given userID.

BadRows = set()
for rowID, row in enumerate(Rows):
  Intersections = dict()
  for userID in row:
    for rowID_cmp in Sets[userID]: 
      if rowID_cmp != rowID:
        Intersections[rowID_cmp] = Intersections.get(rowID_cmp, 0) + 1
  # Now Intersections contains info about how many "times"
  # row numbered rowID_cmp intersectcs current row
  filteredOut = False
  for rowID_cmp in Intersections:
    if Intersections[rowID_cmp] >= filterValue:
      BadRows.add(rowID_cmp)
      filteredOut = True
  if filteredOut:
    BadRows.add(rowID)

Having rownumbers of all filtered out rows saved to BadRows, now we do iteration one last time:

for rowID, row in enumerate(Rows):
  if rowID not in BadRows:
    # output row

This works in 3 scans and in O(nlogn) time. Maybe you'd have to rework iterating Rows array, because it's a file in your case, but doesn't really change much.

Not sure about python syntax and details, but you get the idea behind my code.

First of all, please pack your the code into functions which do one thing well.

def get_data(*args):
    # get the data.

def find_intersections_sets(list1, list2):
    # do the intersections part.

def loop_over_some_result(result):
    # insert assertions so that you don't end up looping in infinity:
    assert result is not None
    ...

def myfunc(*args):
    source1, source2 = args
    L1, L2 = get_data(source1), get_data(source2)
    intersects = find_intersections_sets(L1,L2)
    ...

if __name__ == "__main__":
    myfunc()

then you can easily profile the code using:

if __name__ == "__main__":
    import cProfile
    cProfile.run('myfunc()')

which gives you invaluable insight into your code behaviour and allows you to track down logical bugs. For more on cProfile, see How can you profile a python script?

An option to track down a logical flaw (we're all humans, right?) is to user a timeout function in a decorate like this (python2) or this (python3) :

Hereby myfunc can be changed to:

def get_data(*args):
    # get the data.

def find_intersections_sets(list1, list2):
    # do the intersections part.

def myfunc(*args):
    source1, source2 = args
    L1, L2 = get_data(source1), get_data(source2)

    @timeout(10) # seconds <---- the clever bit!
    intersects = find_intersections_sets(L1,L2)
    ...

...where the timeout operation will raise an error if it takes too long.

Here is my best guess:

import ast 

def get_data(filename):
    with open(filename, 'r') as fi:
        data = fi.readlines()
    return data

def get_ast_set(line):
    return set(ast.literal_eval(line))

def less_than_x_in_common(set1, set2, limit=3):
    if len(set1.intersection(set2)) < limit:
        return True
    else:
        return False

def check_infile(datafile, savefile, filtervalue=3):
    list1 = [get_ast_set(row) for row in get_data(datafile)]
    outlist = []
    for row in list1:
        if any([less_than_x_in_common(set(row), set(i), limit=filtervalue) for i in outlist]):
            outlist.append(row)
    with open(savefile, 'w') as fo:
        fo.writelines(outlist)

if __name__ == "__main__":
    datafile = str(arg2 + arg1 + "/file")
    savefile = str(arg2 + arg1 + "/filteredSets" + str(filterValue))
    check_infile(datafile, savefile)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM