简体   繁体   中英

Implementing recursive function with __iter__ method in a Python class

So I'm working on a problem in which I'm to create a Python Class to generate all Permutations of a list and I'm running across the following questions:

  1. I can complete this easily with a simple recursive function, but as a class it seems like I would want to use the iter method. My method calls a recursive function (list_all) that's almost identical to my iter , which is very unsettling. How do I modify my recursive function to be in compliance with best practices for iter ?
  2. I wrote this code, saw that it worked, and I feel like I don't understand it, I try to trace the code line by line in a test case, but to me it looks like the first element in the list is frozen each time. and the rest of the list is shuffled. Instead the output comes out in an unexpected order. I'm not understanding something!

Thanks!

class permutations():
  def __init__(self, ls):
    self.list = ls

  def __iter__(self):
    ls = self.list
    length = len(ls)
    if length <= 1:
      yield ls
    else:
      for p in self.list_all(ls[1:]):
        for x in range(length):
          yield p[:x] + ls[0:1] + p[x:]  

  def list_all(self, ls):
    length = len(ls)
    if length <= 1:
      yield ls
    else:
      for p in self.list_all(ls[1:]):
        for x in range(length):
          yield p[:x] + ls[0:1] + p[x:]

Just call self.list_all from __iter__ :

class permutations():
  def __init__(self, ls):
    self.list = ls

  def __iter__(self):
    for item in self.list_all(self.list):
      yield item

  def list_all(self, ls):
    length = len(ls)
    if length <= 1:
      yield ls
    else:
      for p in self.list_all(ls[1:]):
        for x in range(length):
          yield p[:x] + ls[0:1] + p[x:]

Your list_all method is already a generator, so you can return that directly in __iter__ :

class permutations():
    def __init__(self, ls):
        self.list = ls

    def __iter__(self):
        return self.list_all(self.list)

    def list_all(self, ls):
        length = len(ls)
        if length <= 1:
            yield ls
        else:
            for p in self.list_all(ls[1:]):
                for x in range(length):
                    yield p[:x] + ls[0:1] + p[x:]

This is both cleaner to read and executes faster.

You also have option is to define list_all inside __iter__ .

class permutations2():
    def __init__(self, ls):
        self.list = ls

    def __iter__(self):
        def list_all(ls):
            length = len(ls)
            if length <= 1:
                yield ls
            else:
                for p in list_all(ls[1:]):
                    for x in range(length):
                        yield p[:x] + ls[0:1] + p[x:]
                    
        return list_all(self.list)

Timing permutations vs my permutations2 gives almost identical results.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM