itertools
--- 為高效迴圈建立迭代器的函式¶
此模組實現了一些 迭代器 構建塊,其靈感來自於 APL、Haskell 和 SML 中的構想。每個構建塊都經過重構,以使其適用於 Python。
該模組將一組快速、高效利用記憶體的工具標準化,這些工具本身或組合起來都很有用。它們共同構成了一個“迭代器代數”,使得在純 Python 中能夠簡潔高效地構建專用工具。
例如,SML 提供了一個製表工具:tabulate(f)
,它可以生成序列 f(0), f(1), ...
。在 Python 中,可以透過組合 map()
和 count()
形成 map(f, count())
來實現相同的效果。
無限迭代器
迭代器 |
引數 |
結果 |
示例 |
---|---|---|---|
[start[, step]] |
start, start+step, start+2*step, … |
|
|
p |
p0, p1, … plast, p0, p1, … |
|
|
elem [,n] |
elem, elem, elem, … 無限重複或重複 n 次 |
|
在最短輸入序列上終止的迭代器
迭代器 |
引數 |
結果 |
示例 |
---|---|---|---|
p [,func] |
p0, p0+p1, p0+p1+p2, … |
|
|
p, n |
(p0, p1, …, p_n-1), … |
|
|
p, q, … |
p0, p1, … plast, q0, q1, … |
|
|
可迭代物件 |
p0, p1, … plast, q0, q1, … |
|
|
data, selectors |
(d[0] if s[0]), (d[1] if s[1]), … |
|
|
predicate, seq |
seq[n], seq[n+1], 從 predicate 為假時開始 |
|
|
predicate, seq |
seq 中 predicate(elem) 為假的元素 |
|
|
iterable[, key] |
按 key(v) 值分組的子迭代器 |
|
|
seq, [start,] stop [, step] |
來自 seq[start:stop:step] 的元素 |
|
|
可迭代物件 |
(p[0], p[1]), (p[1], p[2]) |
|
|
func, seq |
func(*seq[0]), func(*seq[1]), … |
|
|
predicate, seq |
seq[0], seq[1], 直到 predicate 為假 |
|
|
it, n |
it1, it2, … itn 將一個迭代器拆分為 n 個 |
|
|
p, q, … |
(p[0], q[0]), (p[1], q[1]), … |
|
組合迭代器
迭代器 |
引數 |
結果 |
---|---|---|
p, q, … [repeat=1] |
笛卡爾積,相當於巢狀的 for 迴圈 |
|
p[, r] |
長度為 r 的元組,所有可能的排序,無重複元素 |
|
p, r |
長度為 r 的元組,按排序順序,無重複元素 |
|
p, r |
長度為 r 的元組,按排序順序,有重複元素 |
示例: |
結果 |
---|---|
|
|
|
|
|
|
|
|
Itertool 函式¶
以下所有函式都用於構建並返回迭代器。有些提供無限長度的資料流,因此它們只應被那些會截斷資料流的函式或迴圈訪問。
- itertools.accumulate(iterable[, function, *, initial=None])¶
建立一個迭代器,返回累加的和或來自其他二元函式的累加結果。
function 預設為加法。function 應接受兩個引數:一個累加的總值和一個來自 iterable 的值。
如果提供了 initial 值,累加將從該值開始,並且輸出將比輸入的可迭代物件多一個元素。
大致相當於:
def accumulate(iterable, function=operator.add, *, initial=None): 'Return running totals' # accumulate([1,2,3,4,5]) → 1 3 6 10 15 # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115 # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120 iterator = iter(iterable) total = initial if initial is None: try: total = next(iterator) except StopIteration: return yield total for element in iterator: total = function(total, element) yield total
要計算一個執行中的最小值,請將 function 設定為
min()
。對於執行中的最大值,請將 function 設定為max()
。或者對於執行中的乘積,請將 function 設定為operator.mul()
。要構建攤銷表,可累加利息並應用付款。>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8] >>> list(accumulate(data, max)) # running maximum [3, 4, 6, 6, 6, 9, 9, 9, 9, 9] >>> list(accumulate(data, operator.mul)) # running product [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0] # Amortize a 5% loan of 1000 with 10 annual payments of 90 >>> update = lambda balance, payment: round(balance * 1.05) - payment >>> list(accumulate(repeat(90, 10), update, initial=1_000)) [1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]
另請參閱
functools.reduce()
,它是一個類似函式,只返回最終的累加值。在 3.2 版本加入。
在 3.3 版本發生變更: 添加了可選的 function 形參。
在 3.8 版本發生變更: 添加了可選的 initial 形參。
- itertools.batched(iterable, n, *, strict=False)¶
將來自 iterable 的資料分批成長度為 n 的元組。最後一批的長度可能小於 n。
如果 strict 為真,且最後一批的長度小於 n,將引發
ValueError
。遍歷輸入的可迭代物件,並將資料累積成大小至多為 n 的元組。輸入是被惰性消耗的,僅消耗足以填充一批的數量。一旦批次滿員或輸入的可迭代物件耗盡,就會產生結果。
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet'] >>> unflattened = list(batched(flattened_data, 2)) >>> unflattened [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
大致相當於:
def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 2) → AB CD EF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch
3.12 新版功能.
在 3.13 版本發生變更: 添加了 strict 選項。
- itertools.chain(*iterables)¶
建立一個迭代器,它從第一個可迭代物件中返回元素,直到耗盡,然後繼續到下一個可迭代物件,直到所有的可迭代物件都被耗盡。這將多個數據源組合成一個單一的迭代器。大致等同於:
def chain(*iterables): # chain('ABC', 'DEF') → A B C D E F for iterable in iterables: yield from iterable
- classmethod chain.from_iterable(iterable)¶
chain()
的備用建構函式。從一個惰性求值的單一可迭代物件引數中獲取鏈式輸入。大致等同於:def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) → A B C D E F for iterable in iterables: yield from iterable
- itertools.combinations(iterable, r)¶
返回輸入 iterable 中元素的長度為 r 的子序列。
輸出是
product()
的一個子序列,只保留那些是 iterable 的子序列的條目。輸出的長度由math.comb()
給出,當0 ≤ r ≤ n
時計算n! / r! / (n - r)!
,當r > n
時為零。組合元組按輸入 iterable 的順序以字典序發出。如果輸入 iterable 是排序的,則輸出的元組將按排序順序生成。
元素根據其位置而不是其值被視為唯一的。如果輸入元素是唯一的,則每個組合中將不會有重複的值。
大致相當於:
def combinations(iterable, r): # combinations('ABCD', 2) → AB AC AD BC BD CD # combinations(range(4), 3) → 012 013 023 123 pool = tuple(iterable) n = len(pool) if r > n: return indices = list(range(r)) yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != i + n - r: break else: return indices[i] += 1 for j in range(i+1, r): indices[j] = indices[j-1] + 1 yield tuple(pool[i] for i in indices)
- itertools.combinations_with_replacement(iterable, r)¶
返回輸入 iterable 中元素的長度為 r 的子序列,允許單個元素重複多次。
輸出是
product()
的一個子序列,只保留那些是 iterable 的子序列(可能帶有重複元素)的條目。當n > 0
時,返回的子序列數量為(n + r - 1)! / r! / (n - 1)!
。組合元組按輸入 iterable 的順序以字典序發出。如果輸入 iterable 是排序的,則輸出的元組將按排序順序生成。
元素根據其位置而不是其值被視為唯一的。如果輸入元素是唯一的,則生成的組合也將是唯一的。
大致相當於:
def combinations_with_replacement(iterable, r): # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC pool = tuple(iterable) n = len(pool) if not n and r: return indices = [0] * r yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != n - 1: break else: return indices[i:] = [indices[i] + 1] * (r - i) yield tuple(pool[i] for i in indices)
在 3.1 版本加入。
- itertools.compress(data, selectors)¶
建立一個迭代器,它返回 data 中那些在 selectors 中相應元素為真的元素。當 data 或 selectors 的任一可迭代物件被耗盡時停止。大致等同於:
def compress(data, selectors): # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F return (datum for datum, selector in zip(data, selectors) if selector)
在 3.1 版本加入。
- itertools.count(start=0, step=1)¶
建立一個迭代器,返回從 start 開始的等間距值。可與
map()
一起生成連續的資料點,或與zip()
一起新增序列號。大致等同於:def count(start=0, step=1): # count(10) → 10 11 12 13 14 ... # count(2.5, 0.5) → 2.5 3.0 3.5 ... n = start while True: yield n n += step
當使用浮點數計數時,透過使用乘法程式碼,有時可以獲得更好的精度,例如:
(start + step * i for i in count())
。在 3.1 版本發生變更: 增加了 step 引數並允許非整數引數。
- itertools.cycle(iterable)¶
建立一個迭代器,返回來自 iterable 的元素,並儲存每個元素的副本。當可迭代物件被耗盡時,從儲存的副本中返回元素。無限重複。大致等同於:
def cycle(iterable): # cycle('ABCD') → A B C D A B C D A B C D ... saved = [] for element in iterable: yield element saved.append(element) while saved: for element in saved: yield element
此迭代工具可能需要大量的輔助儲存空間(取決於可迭代物件的長度)。
- itertools.dropwhile(predicate, iterable)¶
建立一個迭代器,只要 predicate 為真,就從 iterable 中丟棄元素,之後返回每個元素。大致等同於:
def dropwhile(predicate, iterable): # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8 iterator = iter(iterable) for x in iterator: if not predicate(x): yield x break for x in iterator: yield x
請注意,在斷言首次變為假之前,它不會產生*任何*輸出,因此這個迭代工具可能會有很長的啟動時間。
- itertools.filterfalse(predicate, iterable)¶
建立一個迭代器,它從 iterable 中過濾元素,只返回那些 predicate 返回假值的元素。如果 predicate 為
None
,則返回值為假的專案。大致等同於:def filterfalse(predicate, iterable): # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8 if predicate is None: predicate = bool for x in iterable: if not predicate(x): yield x
- itertools.groupby(iterable, key=None)¶
建立一個迭代器,返回來自 iterable 的連續鍵和組。key 是一個計算每個元素的鍵值的函式。如果未指定或為
None
,key 預設為一個恆等函式並返回元素本身。通常,可迭代物件需要已經按相同的鍵函式排序。groupby()
的操作類似於 Unix 中的uniq
過濾器。每當鍵函式的值發生變化時,它就會生成一箇中斷或新組(這就是為什麼通常需要使用相同的鍵函式對資料進行排序的原因)。這種行為不同於 SQL 的 GROUP BY,後者聚合公共元素,而不考慮它們的輸入順序。返回的組本身是一個迭代器,它與
groupby()
共享底層的可迭代物件。因為源是共享的,當groupby()
物件前進時,前一個組就不再可見。因此,如果以後需要這些資料,應該將其儲存為一個列表。groups = [] uniquekeys = [] data = sorted(data, key=keyfunc) for k, g in groupby(data, keyfunc): groups.append(list(g)) # Store group iterator as a list uniquekeys.append(k)
groupby()
大致等同於:def groupby(iterable, key=None): # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D keyfunc = (lambda x: x) if key is None else key iterator = iter(iterable) exhausted = False def _grouper(target_key): nonlocal curr_value, curr_key, exhausted yield curr_value for curr_value in iterator: curr_key = keyfunc(curr_value) if curr_key != target_key: return yield curr_value exhausted = True try: curr_value = next(iterator) except StopIteration: return curr_key = keyfunc(curr_value) while not exhausted: target_key = curr_key curr_group = _grouper(target_key) yield curr_key, curr_group if curr_key == target_key: for _ in curr_group: pass
- itertools.islice(iterable, stop)¶
- itertools.islice(iterable, start, stop[, step])
建立一個迭代器,返回可迭代物件中的選定元素。工作方式類似於序列切片,但不支援 start、stop 或 step 的負值。
如果 start 為零或
None
,迭代從零開始。否則,跳過可迭代物件中的元素,直到達到 start。如果 stop 是
None
,迭代將持續進行直到輸入耗盡(如果會耗盡的話)。 否則,它會在指定位置停止。如果 step 是
None
,則步長預設為 1。元素會連續地被返回,除非 step 設定為大於 1 的值,這會導致一些條目被跳過。大致相當於:
def islice(iterable, *args): # islice('ABCDEFG', 2) → A B # islice('ABCDEFG', 2, 4) → C D # islice('ABCDEFG', 2, None) → C D E F G # islice('ABCDEFG', 0, None, 2) → A C E G s = slice(*args) start = 0 if s.start is None else s.start stop = s.stop step = 1 if s.step is None else s.step if start < 0 or (stop is not None and stop < 0) or step <= 0: raise ValueError indices = count() if stop is None else range(max(start, stop)) next_i = start for i, element in zip(indices, iterable): if i == next_i: yield element next_i += step
如果輸入是一個迭代器,那麼完全消耗 islice 會使輸入迭代器前進
max(start, stop)
步,而與 step 的值無關。
- itertools.pairwise(iterable)¶
返回從輸入 iterable 中獲取的連續重疊對。
輸出迭代器中的二元組數量將比輸入數量少一。如果輸入的可迭代物件的值少於兩個,它將為空。
大致相當於:
def pairwise(iterable): # pairwise('ABCDEFG') → AB BC CD DE EF FG iterator = iter(iterable) a = next(iterator, None) for b in iterator: yield a, b a = b
在 3.10 版本加入。
- itertools.permutations(iterable, r=None)¶
返回來自 iterable 的元素的連續 r 長度排列。
如果 r 未指定或為
None
,則 r 預設為 iterable 的長度,並生成所有可能的全長排列。輸出是
product()
的一個子序列,其中已過濾掉有重複元素的條目。輸出的長度由math.perm()
給出,當0 ≤ r ≤ n
時計算n! / (n - r)!
,當r > n
時為零。排列元組按輸入 iterable 的順序以字典序發出。如果輸入 iterable 是排序的,則輸出的元組將按排序順序生成。
元素根據其位置而不是其值被視為唯一的。如果輸入元素是唯一的,則排列中不會有重複的值。
大致相當於:
def permutations(iterable, r=None): # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC # permutations(range(3)) → 012 021 102 120 201 210 pool = tuple(iterable) n = len(pool) r = n if r is None else r if r > n: return indices = list(range(n)) cycles = list(range(n, n-r, -1)) yield tuple(pool[i] for i in indices[:r]) while n: for i in reversed(range(r)): cycles[i] -= 1 if cycles[i] == 0: indices[i:] = indices[i+1:] + indices[i:i+1] cycles[i] = n - i else: j = cycles[i] indices[i], indices[-j] = indices[-j], indices[i] yield tuple(pool[i] for i in indices[:r]) break else: return
- itertools.product(*iterables, repeat=1)¶
輸入可迭代物件的笛卡爾積。
大致相當於生成器表示式中的巢狀 for 迴圈。例如,
product(A, B)
返回的結果與((x,y) for x in A for y in B)
相同。巢狀的迴圈像里程錶一樣迴圈,最右邊的元素在每次迭代時都會前進。這種模式建立了一個字典序的排序,因此如果輸入的迭代器是排序的,那麼乘積元組將以排序的順序發出。
要計算一個可迭代物件與自身的乘積,請使用可選的 repeat 關鍵字引數指定重複次數。例如,
product(A, repeat=4)
與product(A, A, A, A)
的意思相同。此函式大致等同於以下程式碼,但實際實現不會在記憶體中構建中間結果:
def product(*iterables, repeat=1): # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111 if repeat < 0: raise ValueError('repeat argument cannot be negative') pools = [tuple(pool) for pool in iterables] * repeat result = [[]] for pool in pools: result = [x+[y] for x in result for y in pool] for prod in result: yield tuple(prod)
在
product()
執行之前,它會完全消耗輸入的迭代器,將值池儲存在記憶體中以生成乘積。因此,它只適用於有限的輸入。
- itertools.repeat(object[, times])¶
建立一個迭代器,它一次又一次地返回 object。除非指定了 times 引數,否則它會無限期執行。
大致相當於:
def repeat(object, times=None): # repeat(10, 3) → 10 10 10 if times is None: while True: yield object else: for i in range(times): yield object
repeat 的一個常見用途是向 map 或 zip 提供一個常量值流:
>>> list(map(pow, range(10), repeat(2))) [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
- itertools.starmap(function, iterable)¶
建立一個迭代器,使用從 iterable 獲得的引數來計算 function。當引數已經“預先打包”成元組時,使用它來代替
map()
。map()
和starmap()
之間的區別與function(a,b)
和function(*c)
之間的區別相類似。大致等同於:def starmap(function, iterable): # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000 for args in iterable: yield function(*args)
- itertools.takewhile(predicate, iterable)¶
建立一個迭代器,只要 predicate 為真,就從 iterable 中返回元素。大致等同於:
def takewhile(predicate, iterable): # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4 for x in iterable: if not predicate(x): break yield x
注意,第一個不滿足謂詞條件的元素會從輸入迭代器中消耗掉,並且無法訪問它。如果應用程式希望在 takewhile 執行完畢後繼續消耗輸入迭代器,這可能會成為一個問題。為了解決這個問題,可以考慮使用 more-itertools before_and_after()。
- itertools.tee(iterable, n=2)¶
從單個可迭代物件返回 n 個獨立的迭代器。
大致相當於:
def tee(iterable, n=2): if n < 0: raise ValueError if n == 0: return () iterator = _tee(iterable) result = [iterator] for _ in range(n - 1): result.append(_tee(iterator)) return tuple(result) class _tee: def __init__(self, iterable): it = iter(iterable) if isinstance(it, _tee): self.iterator = it.iterator self.link = it.link else: self.iterator = it self.link = [None, None] def __iter__(self): return self def __next__(self): link = self.link if link[1] is None: link[0] = next(self.iterator) link[1] = [None, None] value, self.link = link return value
當輸入 iterable 已經是一個 tee 迭代器物件時,返回的元組的所有成員的構造方式就好像它們是由上游的
tee()
呼叫產生的一樣。這種“扁平化步驟”允許巢狀的tee()
呼叫共享相同的底層資料鏈,並且只有一個更新步驟而不是一連串的呼叫。扁平化特性使得 tee 迭代器可以高效地進行窺視:
def lookahead(tee_iterator): "Return the next value without moving the input forward" [forked_iterator] = tee(tee_iterator, 1) return next(forked_iterator)
>>> iterator = iter('abcdef') >>> [iterator] = tee(iterator, 1) # Make the input peekable >>> next(iterator) # Move the iterator forward 'a' >>> lookahead(iterator) # Check next value 'b' >>> next(iterator) # Continue moving forward 'b'
tee
迭代器不是執行緒安全的。當同時使用由同一個tee()
呼叫返回的迭代器時,可能會引發RuntimeError
,即使原始的 iterable 是執行緒安全的。這個迭代工具可能需要大量的輔助儲存空間(取決於需要儲存多少臨時資料)。一般來說,如果一個迭代器在另一個迭代器開始之前使用了大部分或全部資料,使用
list()
會比使用tee()
更快。
- itertools.zip_longest(*iterables, fillvalue=None)¶
建立一個迭代器,聚合來自每個 iterables 的元素。
如果可迭代物件的長度不均勻,缺失的值將用 fillvalue 填充。如果未指定,fillvalue 預設為
None
。迭代將持續到最長的可迭代物件被耗盡。
大致相當於:
def zip_longest(*iterables, fillvalue=None): # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D- iterators = list(map(iter, iterables)) num_active = len(iterators) if not num_active: return while True: values = [] for i, iterator in enumerate(iterators): try: value = next(iterator) except StopIteration: num_active -= 1 if not num_active: return iterators[i] = repeat(fillvalue) value = fillvalue values.append(value) yield tuple(values)
如果其中一個可迭代物件可能是無限的,那麼
zip_longest()
函式應該被包裝在限制呼叫次數的東西中(例如islice()
或takewhile()
)。
Itertools 配方¶
本節展示了使用現有的 itertools 作為構建塊來建立擴充套件工具集的配方。
itertools 配方的主要目的是教育。這些配方展示了思考單個工具的各種方式——例如,chain.from_iterable
與扁平化的概念有關。這些配方也提供了關於工具如何組合的想法——例如,starmap()
和 repeat()
如何協同工作。這些配方還展示了將 itertools 與 operator
和 collections
模組以及內建的 itertools(如 map()
、filter()
、reversed()
和 enumerate()
)一起使用的模式。
配方的次要目的是作為孵化器。 accumulate()
, compress()
, and pairwise()
等 itertools 工具最初都是以配方的形式出現的。 當前,sliding_window()
, iter_index()
和 sieve()
等配方正在測試中,以檢驗它們是否能證明其價值。
基本上所有這些配方以及許多許多其他的配方都可以從 Python 包索引上的 more-itertools 專案安裝。
python -m pip install more-itertools
許多配方提供了與底層工具集相同的高效能。透過一次處理一個元素而不是將整個可迭代物件一次性載入到記憶體中,保持了卓越的記憶體效能。透過以函式式風格連結工具,程式碼量得以保持較小。透過優先選擇“向量化”構建塊,而不是使用會產生直譯器開銷的 for 迴圈和生成器,從而保持了高速。
from collections import Counter, deque
from contextlib import suppress
from functools import reduce
from math import comb, prod, sumprod, isqrt
from operator import itemgetter, getitem, mul, neg
def take(n, iterable):
"Return first n items of the iterable as a list."
return list(islice(iterable, n))
def prepend(value, iterable):
"Prepend a single value in front of an iterable."
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)
def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def repeatfunc(function, times=None, *args):
"Repeat calls to a function with specified arguments."
if times is None:
return starmap(function, repeat(args))
return starmap(function, repeat(args, times))
def flatten(list_of_lists):
"Flatten one level of nesting."
return chain.from_iterable(list_of_lists)
def ncycles(iterable, n):
"Returns the sequence elements n times."
return chain.from_iterable(repeat(tuple(iterable), n))
def loops(n):
"Loop n times. Like range(n) but without creating integers."
# for _ in loops(100): ...
return repeat(None, n)
def tail(n, iterable):
"Return an iterator over the last n items."
# tail(3, 'ABCDEFG') → E F G
return iter(deque(iterable, maxlen=n))
def consume(iterator, n=None):
"Advance the iterator n-steps ahead. If n is None, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
deque(iterator, maxlen=0)
else:
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"Returns the nth item or a default value."
return next(islice(iterable, n, None), default)
def quantify(iterable, predicate=bool):
"Given a predicate that returns True or False, count the True results."
return sum(map(predicate, iterable))
def first_true(iterable, default=False, predicate=None):
"Returns the first true value or the *default* if there is no true value."
# first_true([a,b,c], x) → a or b or c or x
# first_true([a,b], x, f) → a if f(a) else b if f(b) else x
return next(filter(predicate, iterable), default)
def all_equal(iterable, key=None):
"Returns True if all the elements are equal to each other."
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1
def unique_justseen(iterable, key=None):
"Yield unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') → A B C D A B
# unique_justseen('ABBcCAD', str.casefold) → A B c A D
if key is None:
return map(itemgetter(0), groupby(iterable))
return map(next, map(itemgetter(1), groupby(iterable, key)))
def unique_everseen(iterable, key=None):
"Yield unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') → A B C D
# unique_everseen('ABBcCAD', str.casefold) → A B c D
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
def unique(iterable, key=None, reverse=False):
"Yield unique elements in sorted order. Supports unhashable inputs."
# unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
sequenced = sorted(iterable, key=key, reverse=reverse)
return unique_justseen(sequenced, key=key)
def sliding_window(iterable, n):
"Collect data into overlapping fixed-length chunks or blocks."
# sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
iterator = iter(iterable)
window = deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"Visit input iterables in a cycle until each is exhausted."
# roundrobin('ABC', 'D', 'EF') → A D E B F C
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)
def subslices(seq):
"Return all contiguous non-empty subslices of a sequence."
# subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(getitem, repeat(seq), slices)
def iter_index(iterable, value, start=0, stop=None):
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') → 0 1 4 7
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
stop = len(iterable) if stop is None else stop
i = start
with suppress(ValueError):
while True:
yield (i := seq_index(value, i, stop))
i += 1
def iter_except(function, exception, first=None):
"Convert a call-until-exception interface to an iterator interface."
# iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
with suppress(exception):
if first is not None:
yield first()
while True:
yield function()
以下配方更具數學風格:
def multinomial(*counts):
"Number of distinct arrangements of a multiset."
# Counter('abracadabra').values() → 5 2 2 1 1
# multinomial(5, 2, 2, 1, 1) → 83160
return prod(map(comb, accumulate(counts), counts))
def powerset(iterable):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def sum_of_squares(iterable):
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) → 1400
return sumprod(*tee(iterable))
def reshape(matrix, columns):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) → (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), columns, strict=True)
def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
# transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
return zip(*matrix, strict=True)
def matmul(m1, m2):
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(sumprod, product(m1, transpose(m2))), n)
def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.
Convolutions are mathematically commutative; however, the inputs are
evaluated differently. The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.
Article: https://betterexplained.com/articles/intuitive-convolution/
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
# convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
# convolve(data, [1, -2, 1]) → 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
windowed_signal = sliding_window(padded_signal, n)
return map(sumprod, repeat(kernel), windowed_signal)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
(x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
factors = zip(repeat(1), map(neg, roots))
return list(reduce(convolve, factors, [1]))
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 5
# polynomial_eval([1, -4, -17, 60], x=5) → 0
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return sumprod(coefficients, powers)
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
f(x) = x³ -4x² -17x + 60
f'(x) = 3x² -8x -17
"""
# polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(mul, coefficients, powers))
def sieve(n):
"Primes less than n."
# sieve(30) → 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start=3, stop=isqrt(n) + 1):
data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
yield from iter_index(data, 1, start=3)
def factor(n):
"Prime factors of n."
# factor(99) → 3 3 11
# factor(1_000_000_000_000_007) → 47 59 360620266859
# factor(1_000_000_000_000_403) → 1000000000000403
for prime in sieve(isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n
def is_prime(n):
"Return True if n is prime."
# is_prime(1_000_000_000_000_403) → True
return n > 1 and next(factor(n)) == n
def totient(n):
"Count of natural numbers up to n that are coprime to n."
# https://mathworld.wolfram.com/TotientFunction.html
# totient(12) → 4 because len([1, 5, 7, 11]) == 4
for prime in set(factor(n)):
n -= n // prime
return n