【Python源码阅读】list迭代器
Python源代码GitHub地址:https://github.com/python/cpython
在Python Shell中输入以下命令:
1
2
3
4
>>> a = [1, 2, 3]
>>> it = iter(a)
>>> it.__class__
<class 'list_iterator'>
可以看到list
的迭代器的类型为list_iterator
。
“list_iterator
类”的定义
list_iterator
实际上是collections.abc
模块定义的一个别名:
1
list_iterator = type(iter([]))
list.__iter__()
方法源代码
内置函数iter(obj)
等价于obj.__iter__()
,因此尝试寻找list
类的__iter__()
方法源代码。list
类的源代码在CPython的Objects/listobject.c文件中:
1
2
3
4
5
6
7
8
9
PyTypeObject PyList_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"list",
sizeof(PyListObject),
...
list_iter, /* tp_iter */
0, /* tp_iternext */
...
};
其中list_iter
为构造其迭代器的函数(对应__iter__()
方法),其定义为:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
static PyObject *
list_iter(PyObject *seq)
{
listiterobject *it;
if (!PyList_Check(seq)) {
PyErr_BadInternalCall();
return NULL;
}
it = PyObject_GC_New(listiterobject, &PyListIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
Py_INCREF(seq);
it->it_seq = (PyListObject *)seq;
_PyObject_GC_TRACK(it);
return (PyObject *)it;
}
该函数返回了一个listiterobject
类型的指针it
。listiterobject
类型的定义:
1
2
3
4
5
typedef struct {
PyObject_HEAD
Py_ssize_t it_index;
PyListObject *it_seq; /* Set to NULL when iterator is exhausted */
} listiterobject;
可以看到listiterobject
保存了一个指向数据域的指针it_seq
和一个索引值it_index
(基地址+偏移量),使用iter(lst)
或lst.__iter__()
构造list
的迭代器时直接将list
对象的数据域指针seq
赋给了迭代器的it_seq
:it->it_seq = (PyListObject *)seq;
list_iterator
的底层类型及其__next__()
方法
list_iterator
对应的底层实际类型为PyListIter_Type
(相当于listiterobject
的包装类型):
1
2
3
4
5
6
7
8
9
10
PyTypeObject PyListIter_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
"list_iterator", /* tp_name */
sizeof(listiterobject), /* tp_basicsize */
...
PyObject_SelfIter, /* tp_iter */
(iternextfunc)listiter_next, /* tp_iternext */
listiter_methods, /* tp_methods */
0, /* tp_members */
};
PyObject_SelfIter
和listiter_next
两个函数即分别为迭代器的__iter__()
和__next__()
方法的底层函数。迭代器的__iter__()
方法总是返回其自身,__next__()
方法的底层函数listiter_next
的定义:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
static PyObject *
listiter_next(listiterobject *it)
{
PyListObject *seq;
PyObject *item;
assert(it != NULL);
seq = it->it_seq;
if (seq == NULL)
return NULL;
assert(PyList_Check(seq));
if (it->it_index < PyList_GET_SIZE(seq)) {
item = PyList_GET_ITEM(seq, it->it_index);
++it->it_index;
Py_INCREF(item);
return item;
}
it->it_seq = NULL;
Py_DECREF(seq);
return NULL;
}
可以很清楚地看到迭代时执行的操作:当迭代未结束时,根据索引值it_index
(相当于偏移量)从数据域it_seq
中取得下一个元素item
,将索引值+1并返回取得的元素;当迭代结束时返回NULL
,上层函数将抛出StopIteration
异常。