1
0
mirror of https://github.com/django/django.git synced 2025-07-05 02:09:13 +00:00

[soc2009/multidb] Merged up to trunk r11756.

git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2009/multidb@11758 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
Alex Gaynor 2009-11-21 06:55:11 +00:00
parent 2bc7422b52
commit e9e73c4b68
198 changed files with 5793 additions and 1878 deletions

View File

@ -16,6 +16,8 @@ The PRIMARY AUTHORS are (and/or have been):
* Brian Rosner * Brian Rosner
* Justin Bronn * Justin Bronn
* Karen Tracey * Karen Tracey
* Jannis Leidel
* James Tauber
More information on the main contributors to Django can be found in More information on the main contributors to Django can be found in
docs/internals/committers.txt. docs/internals/committers.txt.
@ -26,6 +28,7 @@ answer newbie questions, and generally made Django that much better:
ajs <adi@sieker.info> ajs <adi@sieker.info>
alang@bright-green.com alang@bright-green.com
Andi Albrecht <albrecht.andi@gmail.com>
Marty Alchin <gulopine@gamemusic.org> Marty Alchin <gulopine@gamemusic.org>
Ahmad Alhashemi <trans@ahmadh.com> Ahmad Alhashemi <trans@ahmadh.com>
Daniel Alves Barbosa de Oliveira Vaz <danielvaz@gmail.com> Daniel Alves Barbosa de Oliveira Vaz <danielvaz@gmail.com>
@ -267,7 +270,6 @@ answer newbie questions, and generally made Django that much better:
lcordier@point45.com lcordier@point45.com
Jeong-Min Lee <falsetru@gmail.com> Jeong-Min Lee <falsetru@gmail.com>
Tai Lee <real.human@mrmachine.net> Tai Lee <real.human@mrmachine.net>
Jannis Leidel <jl@websushi.org>
Christopher Lenz <http://www.cmlenz.net/> Christopher Lenz <http://www.cmlenz.net/>
lerouxb@gmail.com lerouxb@gmail.com
Piotr Lewandowski <piotr.lewandowski@gmail.com> Piotr Lewandowski <piotr.lewandowski@gmail.com>
@ -422,7 +424,7 @@ answer newbie questions, and generally made Django that much better:
Travis Terry <tdterry7@gmail.com> Travis Terry <tdterry7@gmail.com>
thebjorn <bp@datakortet.no> thebjorn <bp@datakortet.no>
Zach Thompson <zthompson47@gmail.com> Zach Thompson <zthompson47@gmail.com>
Michael Thornhill Michael Thornhill <michael.thornhill@gmail.com>
Deepak Thukral <deep.thukral@gmail.com> Deepak Thukral <deep.thukral@gmail.com>
tibimicu@gmx.net tibimicu@gmx.net
tobias@neuyork.de tobias@neuyork.de
@ -470,6 +472,8 @@ answer newbie questions, and generally made Django that much better:
Gasper Zejn <zejn@kiberpipa.org> Gasper Zejn <zejn@kiberpipa.org>
Jarek Zgoda <jarek.zgoda@gmail.com> Jarek Zgoda <jarek.zgoda@gmail.com>
Cheng Zhang Cheng Zhang
Glenn Maynard <glenn@zewt.org>
bthomas
A big THANK YOU goes to: A big THANK YOU goes to:

10
INSTALL
View File

@ -1,22 +1,16 @@
Thanks for downloading Django. Thanks for downloading Django.
To install it, make sure you have Python 2.3 or greater installed. Then run To install it, make sure you have Python 2.4 or greater installed. Then run
this command from the command prompt: this command from the command prompt:
python setup.py install python setup.py install
Note this requires a working Internet connection if you don't already have the
Python utility "setuptools" installed.
AS AN ALTERNATIVE, you can just copy the entire "django" directory to Python's AS AN ALTERNATIVE, you can just copy the entire "django" directory to Python's
site-packages directory, which is located wherever your Python installation site-packages directory, which is located wherever your Python installation
lives. Some places you might check are: lives. Some places you might check are:
/usr/lib/python2.5/site-packages (Unix, Python 2.5)
/usr/lib/python2.4/site-packages (Unix, Python 2.4) /usr/lib/python2.4/site-packages (Unix, Python 2.4)
/usr/lib/python2.3/site-packages (Unix, Python 2.3)
C:\\PYTHON\site-packages (Windows) C:\\PYTHON\site-packages (Windows)
This second solution does not require a working Internet connection; it
bypasses "setuptools" entirely.
For more detailed instructions, see docs/intro/install.txt. For more detailed instructions, see docs/intro/install.txt.

View File

@ -108,9 +108,6 @@ class Settings(object):
os.environ['TZ'] = self.TIME_ZONE os.environ['TZ'] = self.TIME_ZONE
time.tzset() time.tzset()
def get_all_members(self):
return dir(self)
class UserSettingsHolder(object): class UserSettingsHolder(object):
""" """
Holder for user configured settings. Holder for user configured settings.
@ -129,8 +126,11 @@ class UserSettingsHolder(object):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.default_settings, name) return getattr(self.default_settings, name)
def get_all_members(self): def __dir__(self):
return dir(self) + dir(self.default_settings) return dir(self) + dir(self.default_settings)
# For Python < 2.6:
__members__ = property(lambda self: self.__dir__())
settings = LazySettings() settings = LazySettings()

View File

@ -134,6 +134,12 @@ DATABASE_OPTIONS = {} # Set to empty dictionary for default.
DATABASES = { DATABASES = {
} }
# The email backend to use. For possible shortcuts see django.core.mail.
# The default is to use the SMTP backend.
# Third-party backends can be specified by providing a Python path
# to a module that defines an EmailBackend class.
EMAIL_BACKEND = 'django.core.mail.backends.smtp'
# Host for sending e-mail. # Host for sending e-mail.
EMAIL_HOST = 'localhost' EMAIL_HOST = 'localhost'
@ -303,6 +309,7 @@ DEFAULT_INDEX_TABLESPACE = ''
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
# 'django.middleware.http.ConditionalGetMiddleware', # 'django.middleware.http.ConditionalGetMiddleware',
# 'django.middleware.gzip.GZipMiddleware', # 'django.middleware.gzip.GZipMiddleware',
@ -377,6 +384,18 @@ LOGIN_REDIRECT_URL = '/accounts/profile/'
# The number of days a password reset link is valid for # The number of days a password reset link is valid for
PASSWORD_RESET_TIMEOUT_DAYS = 3 PASSWORD_RESET_TIMEOUT_DAYS = 3
########
# CSRF #
########
# Dotted path to callable to be used as view when a request is
# rejected by the CSRF middleware.
CSRF_FAILURE_VIEW = 'django.views.csrf.csrf_failure'
# Name and domain for CSRF cookie.
CSRF_COOKIE_NAME = 'csrftoken'
CSRF_COOKIE_DOMAIN = None
########### ###########
# TESTING # # TESTING #
########### ###########

View File

@ -5,7 +5,7 @@ msgid ""
msgstr "" msgstr ""
"Project-Id-Version: Django\n" "Project-Id-Version: Django\n"
"Report-Msgid-Bugs-To: \n" "Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2009-07-17 21:59+0200\n" "POT-Creation-Date: 2009-10-25 20:56+0100\n"
"PO-Revision-Date: 2008-02-25 15:53+0100\n" "PO-Revision-Date: 2008-02-25 15:53+0100\n"
"Last-Translator: Jarek Zgoda <jarek.zgoda@gmail.com>\n" "Last-Translator: Jarek Zgoda <jarek.zgoda@gmail.com>\n"
"MIME-Version: 1.0\n" "MIME-Version: 1.0\n"
@ -266,15 +266,15 @@ msgstr "Ten miesiąc"
msgid "This year" msgid "This year"
msgstr "Ten rok" msgstr "Ten rok"
#: contrib/admin/filterspecs.py:147 forms/widgets.py:434 #: contrib/admin/filterspecs.py:147 forms/widgets.py:435
msgid "Yes" msgid "Yes"
msgstr "Tak" msgstr "Tak"
#: contrib/admin/filterspecs.py:147 forms/widgets.py:434 #: contrib/admin/filterspecs.py:147 forms/widgets.py:435
msgid "No" msgid "No"
msgstr "Nie" msgstr "Nie"
#: contrib/admin/filterspecs.py:154 forms/widgets.py:434 #: contrib/admin/filterspecs.py:154 forms/widgets.py:435
msgid "Unknown" msgid "Unknown"
msgstr "Nieznany" msgstr "Nieznany"
@ -320,8 +320,8 @@ msgid "Changed %s."
msgstr "Zmieniono %s" msgstr "Zmieniono %s"
#: contrib/admin/options.py:519 contrib/admin/options.py:529 #: contrib/admin/options.py:519 contrib/admin/options.py:529
#: contrib/comments/templates/comments/preview.html:16 forms/models.py:388 #: contrib/comments/templates/comments/preview.html:16 forms/models.py:384
#: forms/models.py:600 #: forms/models.py:596
msgid "and" msgid "and"
msgstr "i" msgstr "i"
@ -417,11 +417,11 @@ msgstr ""
"Proszę wpisać poprawną nazwę użytkownika i hasło. Uwaga: wielkość liter ma " "Proszę wpisać poprawną nazwę użytkownika i hasło. Uwaga: wielkość liter ma "
"znaczenie." "znaczenie."
#: contrib/admin/sites.py:285 contrib/admin/views/decorators.py:40 #: contrib/admin/sites.py:288 contrib/admin/views/decorators.py:40
msgid "Please log in again, because your session has expired." msgid "Please log in again, because your session has expired."
msgstr "Twoja sesja wygasła, zaloguj się ponownie." msgstr "Twoja sesja wygasła, zaloguj się ponownie."
#: contrib/admin/sites.py:292 contrib/admin/views/decorators.py:47 #: contrib/admin/sites.py:295 contrib/admin/views/decorators.py:47
msgid "" msgid ""
"Looks like your browser isn't configured to accept cookies. Please enable " "Looks like your browser isn't configured to accept cookies. Please enable "
"cookies, reload this page, and try again." "cookies, reload this page, and try again."
@ -429,27 +429,27 @@ msgstr ""
"Twoja przeglądarka nie chce akceptować ciasteczek. Zmień jej ustawienia i " "Twoja przeglądarka nie chce akceptować ciasteczek. Zmień jej ustawienia i "
"spróbuj ponownie." "spróbuj ponownie."
#: contrib/admin/sites.py:308 contrib/admin/sites.py:314 #: contrib/admin/sites.py:311 contrib/admin/sites.py:317
#: contrib/admin/views/decorators.py:66 #: contrib/admin/views/decorators.py:66
msgid "Usernames cannot contain the '@' character." msgid "Usernames cannot contain the '@' character."
msgstr "Nazwy użytkowników nie mogą zawierać znaku '@'." msgstr "Nazwy użytkowników nie mogą zawierać znaku '@'."
#: contrib/admin/sites.py:311 contrib/admin/views/decorators.py:62 #: contrib/admin/sites.py:314 contrib/admin/views/decorators.py:62
#, python-format #, python-format
msgid "Your e-mail address is not your username. Try '%s' instead." msgid "Your e-mail address is not your username. Try '%s' instead."
msgstr "Podany adres e-mail nie jest Twoją nazwą użytkownika. Spróbuj '%s'." msgstr "Podany adres e-mail nie jest Twoją nazwą użytkownika. Spróbuj '%s'."
#: contrib/admin/sites.py:367 #: contrib/admin/sites.py:370
msgid "Site administration" msgid "Site administration"
msgstr "Administracja stroną" msgstr "Administracja stroną"
#: contrib/admin/sites.py:381 contrib/admin/templates/admin/login.html:26 #: contrib/admin/sites.py:384 contrib/admin/templates/admin/login.html:26
#: contrib/admin/templates/registration/password_reset_complete.html:14 #: contrib/admin/templates/registration/password_reset_complete.html:14
#: contrib/admin/views/decorators.py:20 #: contrib/admin/views/decorators.py:20
msgid "Log in" msgid "Log in"
msgstr "Zaloguj się" msgstr "Zaloguj się"
#: contrib/admin/sites.py:426 #: contrib/admin/sites.py:429
#, python-format #, python-format
msgid "%s administration" msgid "%s administration"
msgstr "%s - administracja" msgstr "%s - administracja"
@ -464,27 +464,27 @@ msgstr "Jedno lub więcej %(fieldname)s w %(name)s: %(obj)s"
msgid "One or more %(fieldname)s in %(name)s:" msgid "One or more %(fieldname)s in %(name)s:"
msgstr "Jedno lub więcej %(fieldname)s w %(name)s:" msgstr "Jedno lub więcej %(fieldname)s w %(name)s:"
#: contrib/admin/widgets.py:71 #: contrib/admin/widgets.py:72
msgid "Date:" msgid "Date:"
msgstr "Data:" msgstr "Data:"
#: contrib/admin/widgets.py:71 #: contrib/admin/widgets.py:72
msgid "Time:" msgid "Time:"
msgstr "Czas:" msgstr "Czas:"
#: contrib/admin/widgets.py:95 #: contrib/admin/widgets.py:96
msgid "Currently:" msgid "Currently:"
msgstr "Teraz:" msgstr "Teraz:"
#: contrib/admin/widgets.py:95 #: contrib/admin/widgets.py:96
msgid "Change:" msgid "Change:"
msgstr "Zmień:" msgstr "Zmień:"
#: contrib/admin/widgets.py:124 #: contrib/admin/widgets.py:125
msgid "Lookup" msgid "Lookup"
msgstr "Szukaj" msgstr "Szukaj"
#: contrib/admin/widgets.py:235 #: contrib/admin/widgets.py:237
msgid "Add Another" msgid "Add Another"
msgstr "Dodaj kolejny" msgstr "Dodaj kolejny"
@ -598,7 +598,7 @@ msgstr "Historia"
#: contrib/admin/templates/admin/change_form.html:28 #: contrib/admin/templates/admin/change_form.html:28
#: contrib/admin/templates/admin/edit_inline/stacked.html:13 #: contrib/admin/templates/admin/edit_inline/stacked.html:13
#: contrib/admin/templates/admin/edit_inline/tabular.html:27 #: contrib/admin/templates/admin/edit_inline/tabular.html:28
msgid "View on site" msgid "View on site"
msgstr "Pokaż na stronie" msgstr "Pokaż na stronie"
@ -668,10 +668,10 @@ msgstr ""
#, python-format #, python-format
msgid "" msgid ""
"Are you sure you want to delete the selected %(object_name)s objects? All of " "Are you sure you want to delete the selected %(object_name)s objects? All of "
"the following objects and it's related items will be deleted:" "the following objects and their related items will be deleted:"
msgstr "" msgstr ""
"Czy chcesz skasować %(object_name)s? Następujące obiekty i zależne od nich " "Czy chcesz skasować wybrane %(object_name)s? Następujące obiekty i zależne od "
"zostaną skasowane:" "nich zostaną skasowane:"
#: contrib/admin/templates/admin/filter.html:2 #: contrib/admin/templates/admin/filter.html:2
#, python-format #, python-format
@ -734,7 +734,6 @@ msgid "User"
msgstr "Użytkownik" msgstr "Użytkownik"
#: contrib/admin/templates/admin/object_history.html:24 #: contrib/admin/templates/admin/object_history.html:24
#: contrib/comments/templates/comments/moderation_queue.html:33
msgid "Action" msgid "Action"
msgstr "Akcja" msgstr "Akcja"
@ -1125,7 +1124,6 @@ msgid "Time"
msgstr "Czas" msgstr "Czas"
#: contrib/admindocs/views.py:359 contrib/comments/forms.py:95 #: contrib/admindocs/views.py:359 contrib/comments/forms.py:95
#: contrib/comments/templates/comments/moderation_queue.html:37
#: contrib/flatpages/admin.py:8 contrib/flatpages/models.py:7 #: contrib/flatpages/admin.py:8 contrib/flatpages/models.py:7
msgid "URL" msgid "URL"
msgstr "URL" msgstr "URL"
@ -1428,22 +1426,54 @@ msgstr "użytkownicy"
msgid "message" msgid "message"
msgstr "wiadomość" msgstr "wiadomość"
#: contrib/auth/views.py:56 #: contrib/auth/views.py:58
msgid "Logged out" msgid "Logged out"
msgstr "Wylogowany" msgstr "Wylogowany"
#: contrib/auth/management/commands/createsuperuser.py:23 forms/fields.py:429 #: contrib/auth/management/commands/createsuperuser.py:23 forms/fields.py:428
msgid "Enter a valid e-mail address." msgid "Enter a valid e-mail address."
msgstr "Wprowadź poprawny adres e-mail." msgstr "Wprowadź poprawny adres e-mail."
#: contrib/comments/admin.py:11 #: contrib/comments/admin.py:12
msgid "Content" msgid "Content"
msgstr "Zawartość" msgstr "Zawartość"
#: contrib/comments/admin.py:14 #: contrib/comments/admin.py:15
msgid "Metadata" msgid "Metadata"
msgstr "Metadane" msgstr "Metadane"
#: contrib/comments/admin.py:39
msgid "flagged"
msgstr "oflagowany"
#: contrib/comments/admin.py:40
msgid "Flag selected comments"
msgstr "Oflaguj wybrane komentarze"
#: contrib/comments/admin.py:43
msgid "approved"
msgstr "zaakceptowany"
#: contrib/comments/admin.py:44
msgid "Approve selected comments"
msgstr "Zaakceptuj wybrane komentarze"
#: contrib/comments/admin.py:47
msgid "removed"
msgstr "usunięty"
#: contrib/comments/admin.py:48
msgid "Remove selected comments"
msgstr "Usuń wybrane komentarze"
#: contrib/comments/admin.py:60
#, python-format
msgid "1 comment was successfully %(action)s."
msgid_plural "%(count)s comments were successfully %(action)s."
msgstr[0] "1 komentarz został %(action)s"
msgstr[1] "%(count)s komentarze zostały %(action)s"
msgstr[2] "%(count)s komentarzy zostało %(action)s"
#: contrib/comments/feeds.py:13 #: contrib/comments/feeds.py:13
#, python-format #, python-format
msgid "%(site_name)s comments" msgid "%(site_name)s comments"
@ -1455,7 +1485,6 @@ msgid "Latest comments on %(site_name)s"
msgstr "Ostatnie komentarze na %(site_name)s" msgstr "Ostatnie komentarze na %(site_name)s"
#: contrib/comments/forms.py:93 #: contrib/comments/forms.py:93
#: contrib/comments/templates/comments/moderation_queue.html:34
msgid "Name" msgid "Name"
msgstr "Nazwa" msgstr "Nazwa"
@ -1464,7 +1493,6 @@ msgid "Email address"
msgstr "Adres e-mail" msgstr "Adres e-mail"
#: contrib/comments/forms.py:96 #: contrib/comments/forms.py:96
#: contrib/comments/templates/comments/moderation_queue.html:35
msgid "Comment" msgid "Comment"
msgstr "Komentarz" msgstr "Komentarz"
@ -1592,7 +1620,6 @@ msgid "Really make this comment public?"
msgstr "Czy ten komentarz na pewno ma być publiczny?" msgstr "Czy ten komentarz na pewno ma być publiczny?"
#: contrib/comments/templates/comments/approve.html:12 #: contrib/comments/templates/comments/approve.html:12
#: contrib/comments/templates/comments/moderation_queue.html:49
msgid "Approve" msgid "Approve"
msgstr "Zaakceptuj" msgstr "Zaakceptuj"
@ -1618,7 +1645,6 @@ msgid "Really remove this comment?"
msgstr "Czy na pewno usunąć ten komentarz?" msgstr "Czy na pewno usunąć ten komentarz?"
#: contrib/comments/templates/comments/delete.html:12 #: contrib/comments/templates/comments/delete.html:12
#: contrib/comments/templates/comments/moderation_queue.html:53
msgid "Remove" msgid "Remove"
msgstr "Usuń" msgstr "Usuń"
@ -1652,39 +1678,6 @@ msgstr "Zapisz"
msgid "Preview" msgid "Preview"
msgstr "Podgląd" msgstr "Podgląd"
#: contrib/comments/templates/comments/moderation_queue.html:4
#: contrib/comments/templates/comments/moderation_queue.html:19
msgid "Comment moderation queue"
msgstr "Kolejka moderacji komentarzy"
#: contrib/comments/templates/comments/moderation_queue.html:26
msgid "No comments to moderate"
msgstr "Żaden komentarz nie oczekuje na akceptację"
#: contrib/comments/templates/comments/moderation_queue.html:36
msgid "Email"
msgstr "E-mail"
#: contrib/comments/templates/comments/moderation_queue.html:38
msgid "Authenticated?"
msgstr "Zalogowany?"
#: contrib/comments/templates/comments/moderation_queue.html:39
msgid "IP Address"
msgstr "Adres IP"
#: contrib/comments/templates/comments/moderation_queue.html:40
msgid "Date posted"
msgstr "Data dodania"
#: contrib/comments/templates/comments/moderation_queue.html:63
msgid "yes"
msgstr "tak"
#: contrib/comments/templates/comments/moderation_queue.html:63
msgid "no"
msgstr "nie"
#: contrib/comments/templates/comments/posted.html:4 #: contrib/comments/templates/comments/posted.html:4
msgid "Thanks for commenting" msgid "Thanks for commenting"
msgstr "Dziękujemy za dodanie komentarza" msgstr "Dziękujemy za dodanie komentarza"
@ -2599,6 +2592,10 @@ msgstr "Niepoprawna suma kontrolna numeru konta bankowego."
msgid "Enter a valid Finnish social security number." msgid "Enter a valid Finnish social security number."
msgstr "Wpis poprawny numer fińskiego ubezpieczenia socjalnego." msgstr "Wpis poprawny numer fińskiego ubezpieczenia socjalnego."
#: contrib/localflavor/fr/forms.py:30
msgid "Phone numbers must be in 0X XX XX XX XX format."
msgstr "Numery telefoniczne muszą być w formacie 0X XX XX XX XX."
#: contrib/localflavor/in_/forms.py:14 #: contrib/localflavor/in_/forms.py:14
msgid "Enter a zip code in the format XXXXXXX." msgid "Enter a zip code in the format XXXXXXX."
msgstr "Wpisz kod pocztowy w formacie XXXXXXX." msgstr "Wpisz kod pocztowy w formacie XXXXXXX."
@ -3944,86 +3941,86 @@ msgstr[2] ""
"Proszę podać poprawne identyfikatory %(self)s. Wartości %(value)r są " "Proszę podać poprawne identyfikatory %(self)s. Wartości %(value)r są "
"niepoprawne." "niepoprawne."
#: forms/fields.py:54 #: forms/fields.py:53
msgid "This field is required." msgid "This field is required."
msgstr "To pole jest wymagane." msgstr "To pole jest wymagane."
#: forms/fields.py:55 #: forms/fields.py:54
msgid "Enter a valid value." msgid "Enter a valid value."
msgstr "Wpisz poprawną wartość." msgstr "Wpisz poprawną wartość."
#: forms/fields.py:138 #: forms/fields.py:137
#, python-format #, python-format
msgid "Ensure this value has at most %(max)d characters (it has %(length)d)." msgid "Ensure this value has at most %(max)d characters (it has %(length)d)."
msgstr "" msgstr ""
"Upewnij się, że ta wartość ma co najwyżej %(max)d znaków (ma długość %" "Upewnij się, że ta wartość ma co najwyżej %(max)d znaków (ma długość %"
"(length)d)." "(length)d)."
#: forms/fields.py:139 #: forms/fields.py:138
#, python-format #, python-format
msgid "Ensure this value has at least %(min)d characters (it has %(length)d)." msgid "Ensure this value has at least %(min)d characters (it has %(length)d)."
msgstr "" msgstr ""
"Upewnij się, że ta wartość ma co najmniej %(min)d znaków (ma długość %" "Upewnij się, że ta wartość ma co najmniej %(min)d znaków (ma długość %"
"(length)d)." "(length)d)."
#: forms/fields.py:166 #: forms/fields.py:165
msgid "Enter a whole number." msgid "Enter a whole number."
msgstr "Wpisz liczbę całkowitą." msgstr "Wpisz liczbę całkowitą."
#: forms/fields.py:167 forms/fields.py:196 forms/fields.py:225 #: forms/fields.py:166 forms/fields.py:195 forms/fields.py:224
#, python-format #, python-format
msgid "Ensure this value is less than or equal to %s." msgid "Ensure this value is less than or equal to %s."
msgstr "Upewnij się, że ta wartość jest mniejsza lub równa %s." msgstr "Upewnij się, że ta wartość jest mniejsza lub równa %s."
#: forms/fields.py:168 forms/fields.py:197 forms/fields.py:226 #: forms/fields.py:167 forms/fields.py:196 forms/fields.py:225
#, python-format #, python-format
msgid "Ensure this value is greater than or equal to %s." msgid "Ensure this value is greater than or equal to %s."
msgstr "Upewnij się, że ta wartość jest większa lub równa %s." msgstr "Upewnij się, że ta wartość jest większa lub równa %s."
#: forms/fields.py:195 forms/fields.py:224 #: forms/fields.py:194 forms/fields.py:223
msgid "Enter a number." msgid "Enter a number."
msgstr "Wpisz liczbę." msgstr "Wpisz liczbę."
#: forms/fields.py:227 #: forms/fields.py:226
#, python-format #, python-format
msgid "Ensure that there are no more than %s digits in total." msgid "Ensure that there are no more than %s digits in total."
msgstr "Upewnij się, że jest nie więcej niż %s cyfr." msgstr "Upewnij się, że jest nie więcej niż %s cyfr."
#: forms/fields.py:228 #: forms/fields.py:227
#, python-format #, python-format
msgid "Ensure that there are no more than %s decimal places." msgid "Ensure that there are no more than %s decimal places."
msgstr "Upewnij się, że jest nie więcej niż %s miejsc po przecinku." msgstr "Upewnij się, że jest nie więcej niż %s miejsc po przecinku."
#: forms/fields.py:229 #: forms/fields.py:228
#, python-format #, python-format
msgid "Ensure that there are no more than %s digits before the decimal point." msgid "Ensure that there are no more than %s digits before the decimal point."
msgstr "Upewnij się, że jest nie więcej niż %s miejsc przed przecinkiem." msgstr "Upewnij się, że jest nie więcej niż %s miejsc przed przecinkiem."
#: forms/fields.py:288 forms/fields.py:863 #: forms/fields.py:287 forms/fields.py:862
msgid "Enter a valid date." msgid "Enter a valid date."
msgstr "Wpisz poprawną datę." msgstr "Wpisz poprawną datę."
#: forms/fields.py:322 forms/fields.py:864 #: forms/fields.py:321 forms/fields.py:863
msgid "Enter a valid time." msgid "Enter a valid time."
msgstr "Wpisz poprawną godzinę." msgstr "Wpisz poprawną godzinę."
#: forms/fields.py:361 #: forms/fields.py:360
msgid "Enter a valid date/time." msgid "Enter a valid date/time."
msgstr "Wpisz poprawną datę/godzinę." msgstr "Wpisz poprawną datę/godzinę."
#: forms/fields.py:447 #: forms/fields.py:446
msgid "No file was submitted. Check the encoding type on the form." msgid "No file was submitted. Check the encoding type on the form."
msgstr "Nie wysłano żadnego pliku. Sprawdź typ kodowania formularza." msgstr "Nie wysłano żadnego pliku. Sprawdź typ kodowania formularza."
#: forms/fields.py:448 #: forms/fields.py:447
msgid "No file was submitted." msgid "No file was submitted."
msgstr "Żaden plik nie został przesłany." msgstr "Żaden plik nie został przesłany."
#: forms/fields.py:449 #: forms/fields.py:448
msgid "The submitted file is empty." msgid "The submitted file is empty."
msgstr "Wysłany plik jest pusty." msgstr "Wysłany plik jest pusty."
#: forms/fields.py:450 #: forms/fields.py:449
#, python-format #, python-format
msgid "" msgid ""
"Ensure this filename has at most %(max)d characters (it has %(length)d)." "Ensure this filename has at most %(max)d characters (it has %(length)d)."
@ -4031,7 +4028,7 @@ msgstr ""
"Upewnij się, że nazwa tego pliku ma co najwyżej %(max)d znaków (ma długość %" "Upewnij się, że nazwa tego pliku ma co najwyżej %(max)d znaków (ma długość %"
"(length)d)." "(length)d)."
#: forms/fields.py:483 #: forms/fields.py:482
msgid "" msgid ""
"Upload a valid image. The file you uploaded was either not an image or a " "Upload a valid image. The file you uploaded was either not an image or a "
"corrupted image." "corrupted image."
@ -4039,29 +4036,29 @@ msgstr ""
"Wgraj poprawny plik graficzny. Ten, który został wgrany, nie jest obrazem, " "Wgraj poprawny plik graficzny. Ten, który został wgrany, nie jest obrazem, "
"albo jest uszkodzony." "albo jest uszkodzony."
#: forms/fields.py:544 #: forms/fields.py:543
msgid "Enter a valid URL." msgid "Enter a valid URL."
msgstr "Wpisz poprawny URL." msgstr "Wpisz poprawny URL."
#: forms/fields.py:545 #: forms/fields.py:544
msgid "This URL appears to be a broken link." msgid "This URL appears to be a broken link."
msgstr "Ten odnośnik jest nieprawidłowy." msgstr "Ten odnośnik jest nieprawidłowy."
#: forms/fields.py:625 forms/fields.py:703 #: forms/fields.py:624 forms/fields.py:702
#, python-format #, python-format
msgid "Select a valid choice. %(value)s is not one of the available choices." msgid "Select a valid choice. %(value)s is not one of the available choices."
msgstr "" msgstr ""
"Wybierz poprawną wartość. %(value)s nie jest jednym z dostępnych wyborów." "Wybierz poprawną wartość. %(value)s nie jest jednym z dostępnych wyborów."
#: forms/fields.py:704 forms/fields.py:765 forms/models.py:1003 #: forms/fields.py:703 forms/fields.py:764 forms/models.py:999
msgid "Enter a list of values." msgid "Enter a list of values."
msgstr "Podaj listę wartości." msgstr "Podaj listę wartości."
#: forms/fields.py:892 #: forms/fields.py:891
msgid "Enter a valid IPv4 address." msgid "Enter a valid IPv4 address."
msgstr "Wprowadź poprawny adres IPv4." msgstr "Wprowadź poprawny adres IPv4."
#: forms/fields.py:902 #: forms/fields.py:901
msgid "" msgid ""
"Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens." "Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."
msgstr "To pole może zawierać jedynie litery, cyfry, podkreślenia i myślniki." msgstr "To pole może zawierać jedynie litery, cyfry, podkreślenia i myślniki."
@ -4070,29 +4067,29 @@ msgstr "To pole może zawierać jedynie litery, cyfry, podkreślenia i myślniki
msgid "Order" msgid "Order"
msgstr "Porządek" msgstr "Porządek"
#: forms/models.py:367 #: forms/models.py:363
#, python-format #, python-format
msgid "%(field_name)s must be unique for %(date_field)s %(lookup)s." msgid "%(field_name)s must be unique for %(date_field)s %(lookup)s."
msgstr "" msgstr ""
"Wartości w %(field_name)s muszą być unikalne dla wyszukiwań %(lookup)s w %" "Wartości w %(field_name)s muszą być unikalne dla wyszukiwań %(lookup)s w %"
"(date_field)s" "(date_field)s"
#: forms/models.py:381 forms/models.py:389 #: forms/models.py:377 forms/models.py:385
#, python-format #, python-format
msgid "%(model_name)s with this %(field_label)s already exists." msgid "%(model_name)s with this %(field_label)s already exists."
msgstr "%(field_label)s już istnieje w %(model_name)s." msgstr "%(field_label)s już istnieje w %(model_name)s."
#: forms/models.py:594 #: forms/models.py:590
#, python-format #, python-format
msgid "Please correct the duplicate data for %(field)s." msgid "Please correct the duplicate data for %(field)s."
msgstr "Popraw zduplikowane dane w %(field)s." msgstr "Popraw zduplikowane dane w %(field)s."
#: forms/models.py:598 #: forms/models.py:594
#, python-format #, python-format
msgid "Please correct the duplicate data for %(field)s, which must be unique." msgid "Please correct the duplicate data for %(field)s, which must be unique."
msgstr "Popraw zduplikowane dane w %(field)s, które wymaga unikalności." msgstr "Popraw zduplikowane dane w %(field)s, które wymaga unikalności."
#: forms/models.py:604 #: forms/models.py:600
#, python-format #, python-format
msgid "" msgid ""
"Please correct the duplicate data for %(field_name)s which must be unique " "Please correct the duplicate data for %(field_name)s which must be unique "
@ -4101,24 +4098,24 @@ msgstr ""
"Popraw zduplikowane dane w %(field_name)s, które wymaga unikalności dla %" "Popraw zduplikowane dane w %(field_name)s, które wymaga unikalności dla %"
"(lookup)s w polu %(date_field)s." "(lookup)s w polu %(date_field)s."
#: forms/models.py:612 #: forms/models.py:608
msgid "Please correct the duplicate values below." msgid "Please correct the duplicate values below."
msgstr "Popraw poniższe zduplikowane wartości." msgstr "Popraw poniższe zduplikowane wartości."
#: forms/models.py:867 #: forms/models.py:863
msgid "The inline foreign key did not match the parent instance primary key." msgid "The inline foreign key did not match the parent instance primary key."
msgstr "Osadzony klucz obcy nie pasuje do klucza głównego obiektu rodzica." msgstr "Osadzony klucz obcy nie pasuje do klucza głównego obiektu rodzica."
#: forms/models.py:930 #: forms/models.py:926
msgid "Select a valid choice. That choice is not one of the available choices." msgid "Select a valid choice. That choice is not one of the available choices."
msgstr "Wybierz poprawną wartość. Podana nie jest jednym z dostępnych wyborów." msgstr "Wybierz poprawną wartość. Podana nie jest jednym z dostępnych wyborów."
#: forms/models.py:1004 #: forms/models.py:1000
#, python-format #, python-format
msgid "Select a valid choice. %s is not one of the available choices." msgid "Select a valid choice. %s is not one of the available choices."
msgstr "Wybierz poprawną wartość. %s nie jest jednym z dostępnych wyborów." msgstr "Wybierz poprawną wartość. %s nie jest jednym z dostępnych wyborów."
#: forms/models.py:1006 #: forms/models.py:1002
#, python-format #, python-format
msgid "\"%s\" is not a valid value for a primary key." msgid "\"%s\" is not a valid value for a primary key."
msgstr "\"%s\" nie jest poprawną wartością klucza głównego." msgstr "\"%s\" nie jest poprawną wartością klucza głównego."
@ -4444,3 +4441,27 @@ msgstr "%(verbose_name)s zostało pomyślnie zmienione."
#, python-format #, python-format
msgid "The %(verbose_name)s was deleted." msgid "The %(verbose_name)s was deleted."
msgstr "%(verbose_name)s zostało usunięte." msgstr "%(verbose_name)s zostało usunięte."
#~ msgid "Comment moderation queue"
#~ msgstr "Kolejka moderacji komentarzy"
#~ msgid "No comments to moderate"
#~ msgstr "Żaden komentarz nie oczekuje na akceptację"
#~ msgid "Email"
#~ msgstr "E-mail"
#~ msgid "Authenticated?"
#~ msgstr "Zalogowany?"
#~ msgid "IP Address"
#~ msgstr "Adres IP"
#~ msgid "Date posted"
#~ msgstr "Data dodania"
#~ msgid "yes"
#~ msgstr "tak"
#~ msgid "no"
#~ msgstr "nie"

View File

@ -65,6 +65,7 @@ TEMPLATE_LOADERS = (
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
) )

View File

@ -53,7 +53,7 @@
vertical-align: middle; vertical-align: middle;
} }
#changelist table thead th:first-child { #changelist table thead th.action-checkbox-column {
width: 1.5em; width: 1.5em;
text-align: center; text-align: center;
} }

View File

@ -6,6 +6,7 @@ from django.contrib.contenttypes.models import ContentType
from django.contrib.admin import widgets from django.contrib.admin import widgets
from django.contrib.admin import helpers from django.contrib.admin import helpers
from django.contrib.admin.util import unquote, flatten_fieldsets, get_deleted_objects, model_ngettext, model_format_dict from django.contrib.admin.util import unquote, flatten_fieldsets, get_deleted_objects, model_ngettext, model_format_dict
from django.views.decorators.csrf import csrf_protect
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import models, transaction from django.db import models, transaction
from django.db.models.fields import BLANK_CHOICE_DASH from django.db.models.fields import BLANK_CHOICE_DASH
@ -152,8 +153,9 @@ class BaseModelAdmin(object):
""" """
Get a form Field for a ManyToManyField. Get a form Field for a ManyToManyField.
""" """
# If it uses an intermediary model, don't show field in admin. # If it uses an intermediary model that isn't auto created, don't show
if db_field.rel.through is not None: # a field in admin.
if not db_field.rel.through._meta.auto_created:
return None return None
if db_field.name in self.raw_id_fields: if db_field.name in self.raw_id_fields:
@ -701,6 +703,8 @@ class ModelAdmin(BaseModelAdmin):
else: else:
return HttpResponseRedirect(".") return HttpResponseRedirect(".")
@csrf_protect
@transaction.commit_on_success
def add_view(self, request, form_url='', extra_context=None): def add_view(self, request, form_url='', extra_context=None):
"The 'add' admin view for this model." "The 'add' admin view for this model."
model = self.model model = self.model
@ -782,8 +786,9 @@ class ModelAdmin(BaseModelAdmin):
} }
context.update(extra_context or {}) context.update(extra_context or {})
return self.render_change_form(request, context, form_url=form_url, add=True) return self.render_change_form(request, context, form_url=form_url, add=True)
add_view = transaction.commit_on_success(add_view)
@csrf_protect
@transaction.commit_on_success
def change_view(self, request, object_id, extra_context=None): def change_view(self, request, object_id, extra_context=None):
"The 'change' admin view for this model." "The 'change' admin view for this model."
model = self.model model = self.model
@ -871,8 +876,8 @@ class ModelAdmin(BaseModelAdmin):
} }
context.update(extra_context or {}) context.update(extra_context or {})
return self.render_change_form(request, context, change=True, obj=obj) return self.render_change_form(request, context, change=True, obj=obj)
change_view = transaction.commit_on_success(change_view)
@csrf_protect
def changelist_view(self, request, extra_context=None): def changelist_view(self, request, extra_context=None):
"The 'change list' admin view for this model." "The 'change list' admin view for this model."
from django.contrib.admin.views.main import ChangeList, ERROR_FLAG from django.contrib.admin.views.main import ChangeList, ERROR_FLAG
@ -985,6 +990,7 @@ class ModelAdmin(BaseModelAdmin):
'admin/change_list.html' 'admin/change_list.html'
], context, context_instance=context_instance) ], context, context_instance=context_instance)
@csrf_protect
def delete_view(self, request, object_id, extra_context=None): def delete_view(self, request, object_id, extra_context=None):
"The 'delete' admin view for this model." "The 'delete' admin view for this model."
opts = self.model._meta opts = self.model._meta

View File

@ -3,6 +3,7 @@ from django import http, template
from django.contrib.admin import ModelAdmin from django.contrib.admin import ModelAdmin
from django.contrib.admin import actions from django.contrib.admin import actions
from django.contrib.auth import authenticate, login from django.contrib.auth import authenticate, login
from django.views.decorators.csrf import csrf_protect
from django.db.models.base import ModelBase from django.db.models.base import ModelBase
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
@ -186,11 +187,17 @@ class AdminSite(object):
return view(request, *args, **kwargs) return view(request, *args, **kwargs)
if not cacheable: if not cacheable:
inner = never_cache(inner) inner = never_cache(inner)
# We add csrf_protect here so this function can be used as a utility
# function for any view, without having to repeat 'csrf_protect'.
inner = csrf_protect(inner)
return update_wrapper(inner, view) return update_wrapper(inner, view)
def get_urls(self): def get_urls(self):
from django.conf.urls.defaults import patterns, url, include from django.conf.urls.defaults import patterns, url, include
if settings.DEBUG:
self.check_dependencies()
def wrap(view, cacheable=False): def wrap(view, cacheable=False):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
return self.admin_view(view, cacheable)(*args, **kwargs) return self.admin_view(view, cacheable)(*args, **kwargs)

View File

@ -15,7 +15,7 @@
</div> </div>
{% endif %}{% endblock %} {% endif %}{% endblock %}
{% block content %}<div id="content-main"> {% block content %}<div id="content-main">
<form action="{{ form_url }}" method="post" id="{{ opts.module_name }}_form">{% block form_top %}{% endblock %} <form action="{{ form_url }}" method="post" id="{{ opts.module_name }}_form">{% csrf_token %}{% block form_top %}{% endblock %}
<div> <div>
{% if is_popup %}<input type="hidden" name="_popup" value="1" />{% endif %} {% if is_popup %}<input type="hidden" name="_popup" value="1" />{% endif %}
{% if form.errors %} {% if form.errors %}

View File

@ -29,7 +29,7 @@
</ul> </ul>
{% endif %}{% endif %} {% endif %}{% endif %}
{% endblock %} {% endblock %}
<form {% if has_file_field %}enctype="multipart/form-data" {% endif %}action="{{ form_url }}" method="post" id="{{ opts.module_name }}_form">{% block form_top %}{% endblock %} <form {% if has_file_field %}enctype="multipart/form-data" {% endif %}action="{{ form_url }}" method="post" id="{{ opts.module_name }}_form">{% csrf_token %}{% block form_top %}{% endblock %}
<div> <div>
{% if is_popup %}<input type="hidden" name="_popup" value="1" />{% endif %} {% if is_popup %}<input type="hidden" name="_popup" value="1" />{% endif %}
{% if save_on_top %}{% submit_row %}{% endif %} {% if save_on_top %}{% submit_row %}{% endif %}

View File

@ -68,7 +68,7 @@
{% endif %} {% endif %}
{% endblock %} {% endblock %}
<form action="" method="post"{% if cl.formset.is_multipart %} enctype="multipart/form-data"{% endif %}> <form action="" method="post"{% if cl.formset.is_multipart %} enctype="multipart/form-data"{% endif %}>{% csrf_token %}
{% if cl.formset %} {% if cl.formset %}
{{ cl.formset.management_form }} {{ cl.formset.management_form }}
{% endif %} {% endif %}

View File

@ -22,7 +22,7 @@
{% else %} {% else %}
<p>{% blocktrans with object as escaped_object %}Are you sure you want to delete the {{ object_name }} "{{ escaped_object }}"? All of the following related items will be deleted:{% endblocktrans %}</p> <p>{% blocktrans with object as escaped_object %}Are you sure you want to delete the {{ object_name }} "{{ escaped_object }}"? All of the following related items will be deleted:{% endblocktrans %}</p>
<ul>{{ deleted_objects|unordered_list }}</ul> <ul>{{ deleted_objects|unordered_list }}</ul>
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
<div> <div>
<input type="hidden" name="post" value="yes" /> <input type="hidden" name="post" value="yes" />
<input type="submit" value="{% trans "Yes, I'm sure" %}" /> <input type="submit" value="{% trans "Yes, I'm sure" %}" />

View File

@ -23,7 +23,7 @@
{% for deleteable_object in deletable_objects %} {% for deleteable_object in deletable_objects %}
<ul>{{ deleteable_object|unordered_list }}</ul> <ul>{{ deleteable_object|unordered_list }}</ul>
{% endfor %} {% endfor %}
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
<div> <div>
{% for obj in queryset %} {% for obj in queryset %}
<input type="hidden" name="{{ action_checkbox_name }}" value="{{ obj.pk }}" /> <input type="hidden" name="{{ action_checkbox_name }}" value="{{ obj.pk }}" />

View File

@ -14,7 +14,7 @@
<p class="errornote">{{ error_message }}</p> <p class="errornote">{{ error_message }}</p>
{% endif %} {% endif %}
<div id="content-main"> <div id="content-main">
<form action="{{ app_path }}" method="post" id="login-form"> <form action="{{ app_path }}" method="post" id="login-form">{% csrf_token %}
<div class="form-row"> <div class="form-row">
<label for="id_username">{% trans 'Username:' %}</label> <input type="text" name="username" id="id_username" /> <label for="id_username">{% trans 'Username:' %}</label> <input type="text" name="username" id="id_username" />
</div> </div>

View File

@ -4,7 +4,7 @@
<div id="content-main"> <div id="content-main">
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
{% if form.errors %} {% if form.errors %}
<p class="errornote">Your template had {{ form.errors|length }} error{{ form.errors|pluralize }}:</p> <p class="errornote">Your template had {{ form.errors|length }} error{{ form.errors|pluralize }}:</p>

View File

@ -11,7 +11,7 @@
<p>{% trans "Please enter your old password, for security's sake, and then enter your new password twice so we can verify you typed it in correctly." %}</p> <p>{% trans "Please enter your old password, for security's sake, and then enter your new password twice so we can verify you typed it in correctly." %}</p>
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
{{ form.old_password.errors }} {{ form.old_password.errors }}
<p class="aligned wide"><label for="id_old_password">{% trans 'Old password:' %}</label>{{ form.old_password }}</p> <p class="aligned wide"><label for="id_old_password">{% trans 'Old password:' %}</label>{{ form.old_password }}</p>

View File

@ -13,7 +13,7 @@
<p>{% trans "Please enter your new password twice so we can verify you typed it in correctly." %}</p> <p>{% trans "Please enter your new password twice so we can verify you typed it in correctly." %}</p>
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
{{ form.new_password1.errors }} {{ form.new_password1.errors }}
<p class="aligned wide"><label for="id_new_password1">{% trans 'New password:' %}</label>{{ form.new_password1 }}</p> <p class="aligned wide"><label for="id_new_password1">{% trans 'New password:' %}</label>{{ form.new_password1 }}</p>
{{ form.new_password2.errors }} {{ form.new_password2.errors }}

View File

@ -11,7 +11,7 @@
<p>{% trans "Forgotten your password? Enter your e-mail address below, and we'll e-mail instructions for setting a new one." %}</p> <p>{% trans "Forgotten your password? Enter your e-mail address below, and we'll e-mail instructions for setting a new one." %}</p>
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
{{ form.email.errors }} {{ form.email.errors }}
<p><label for="id_email">{% trans 'E-mail address:' %}</label> {{ form.email }} <input type="submit" value="{% trans 'Reset my password' %}" /></p> <p><label for="id_email">{% trans 'E-mail address:' %}</label> {{ form.email }} <input type="submit" value="{% trans 'Reset my password' %}" /></p>
</form> </form>

View File

@ -106,6 +106,11 @@ def result_headers(cl):
else: else:
header = field_name header = field_name
header = header.replace('_', ' ') header = header.replace('_', ' ')
# if the field is the action checkbox: no sorting and special class
if field_name == 'action_checkbox':
yield {"text": header,
"class_attrib": mark_safe(' class="action-checkbox-column"')}
continue
# It is a non-field, but perhaps one that is sortable # It is a non-field, but perhaps one that is sortable
admin_order_field = getattr(attr, "admin_order_field", None) admin_order_field = getattr(attr, "admin_order_field", None)

View File

@ -149,12 +149,16 @@ def validate(cls, model):
validate_inline(inline, cls, model) validate_inline(inline, cls, model)
def validate_inline(cls, parent, parent_model): def validate_inline(cls, parent, parent_model):
# model is already verified to exist and be a Model # model is already verified to exist and be a Model
if cls.fk_name: # default value is None if cls.fk_name: # default value is None
f = get_field(cls, cls.model, cls.model._meta, 'fk_name', cls.fk_name) f = get_field(cls, cls.model, cls.model._meta, 'fk_name', cls.fk_name)
if not isinstance(f, models.ForeignKey): if not isinstance(f, models.ForeignKey):
raise ImproperlyConfigured("'%s.fk_name is not an instance of " raise ImproperlyConfigured("'%s.fk_name is not an instance of "
"models.ForeignKey." % cls.__name__) "models.ForeignKey." % cls.__name__)
fk = _get_foreign_key(parent_model, cls.model, fk_name=cls.fk_name, can_fail=True)
# extra = 3 # extra = 3
# max_num = 0 # max_num = 0
for attr in ('extra', 'max_num'): for attr in ('extra', 'max_num'):
@ -169,7 +173,6 @@ def validate_inline(cls, parent, parent_model):
# exclude # exclude
if hasattr(cls, 'exclude') and cls.exclude: if hasattr(cls, 'exclude') and cls.exclude:
fk = _get_foreign_key(parent_model, cls.model, can_fail=True)
if fk and fk.name in cls.exclude: if fk and fk.name in cls.exclude:
raise ImproperlyConfigured("%s cannot exclude the field " raise ImproperlyConfigured("%s cannot exclude the field "
"'%s' - this is the foreign key to the parent model " "'%s' - this is the foreign key to the parent model "
@ -193,6 +196,11 @@ def validate_base(cls, model):
check_isseq(cls, 'fields', cls.fields) check_isseq(cls, 'fields', cls.fields)
for field in cls.fields: for field in cls.fields:
check_formfield(cls, model, opts, 'fields', field) check_formfield(cls, model, opts, 'fields', field)
f = get_field(cls, model, opts, 'fields', field)
if isinstance(f, models.ManyToManyField) and not f.rel.through._meta.auto_created:
raise ImproperlyConfigured("'%s.fields' can't include the ManyToManyField "
"field '%s' because '%s' manually specifies "
"a 'through' model." % (cls.__name__, field, field))
if cls.fieldsets: if cls.fieldsets:
raise ImproperlyConfigured('Both fieldsets and fields are specified in %s.' % cls.__name__) raise ImproperlyConfigured('Both fieldsets and fields are specified in %s.' % cls.__name__)
if len(cls.fields) > len(set(cls.fields)): if len(cls.fields) > len(set(cls.fields)):
@ -211,11 +219,28 @@ def validate_base(cls, model):
raise ImproperlyConfigured("'fields' key is required in " raise ImproperlyConfigured("'fields' key is required in "
"%s.fieldsets[%d][1] field options dict." "%s.fieldsets[%d][1] field options dict."
% (cls.__name__, idx)) % (cls.__name__, idx))
for fields in fieldset[1]['fields']:
# The entry in fields might be a tuple. If it is a standalone
# field, make it into a tuple to make processing easier.
if type(fields) != tuple:
fields = (fields,)
for field in fields:
check_formfield(cls, model, opts, "fieldsets[%d][1]['fields']" % idx, field)
try:
f = opts.get_field(field)
if isinstance(f, models.ManyToManyField) and not f.rel.through._meta.auto_created:
raise ImproperlyConfigured("'%s.fieldsets[%d][1]['fields']' "
"can't include the ManyToManyField field '%s' because "
"'%s' manually specifies a 'through' model." % (
cls.__name__, idx, field, field))
except models.FieldDoesNotExist:
# If we can't find a field on the model that matches,
# it could be an extra field on the form.
pass
flattened_fieldsets = flatten_fieldsets(cls.fieldsets) flattened_fieldsets = flatten_fieldsets(cls.fieldsets)
if len(flattened_fieldsets) > len(set(flattened_fieldsets)): if len(flattened_fieldsets) > len(set(flattened_fieldsets)):
raise ImproperlyConfigured('There are duplicate field(s) in %s.fieldsets' % cls.__name__) raise ImproperlyConfigured('There are duplicate field(s) in %s.fieldsets' % cls.__name__)
for field in flattened_fieldsets:
check_formfield(cls, model, opts, "fieldsets[%d][1]['fields']" % idx, field)
# form # form
if hasattr(cls, 'form') and not issubclass(cls.form, BaseModelForm): if hasattr(cls, 'form') and not issubclass(cls.form, BaseModelForm):

View File

@ -2,7 +2,7 @@ from datetime import datetime
from django.conf import settings from django.conf import settings
from django.contrib.auth.backends import RemoteUserBackend from django.contrib.auth.backends import RemoteUserBackend
from django.contrib.auth.models import AnonymousUser, User from django.contrib.auth.models import User
from django.test import TestCase from django.test import TestCase
@ -30,15 +30,15 @@ class RemoteUserTest(TestCase):
num_users = User.objects.count() num_users = User.objects.count()
response = self.client.get('/remote_user/') response = self.client.get('/remote_user/')
self.assert_(isinstance(response.context['user'], AnonymousUser)) self.assert_(response.context['user'].is_anonymous())
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)
response = self.client.get('/remote_user/', REMOTE_USER=None) response = self.client.get('/remote_user/', REMOTE_USER=None)
self.assert_(isinstance(response.context['user'], AnonymousUser)) self.assert_(response.context['user'].is_anonymous())
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)
response = self.client.get('/remote_user/', REMOTE_USER='') response = self.client.get('/remote_user/', REMOTE_USER='')
self.assert_(isinstance(response.context['user'], AnonymousUser)) self.assert_(response.context['user'].is_anonymous())
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)
def test_unknown_user(self): def test_unknown_user(self):
@ -115,7 +115,7 @@ class RemoteUserNoCreateTest(RemoteUserTest):
def test_unknown_user(self): def test_unknown_user(self):
num_users = User.objects.count() num_users = User.objects.count()
response = self.client.get('/remote_user/', REMOTE_USER='newuser') response = self.client.get('/remote_user/', REMOTE_USER='newuser')
self.assert_(isinstance(response.context['user'], AnonymousUser)) self.assert_(response.context['user'].is_anonymous())
self.assertEqual(User.objects.count(), num_users) self.assertEqual(User.objects.count(), num_users)

View File

@ -4,6 +4,7 @@ from django.contrib.auth.decorators import login_required
from django.contrib.auth.forms import AuthenticationForm from django.contrib.auth.forms import AuthenticationForm
from django.contrib.auth.forms import PasswordResetForm, SetPasswordForm, PasswordChangeForm from django.contrib.auth.forms import PasswordResetForm, SetPasswordForm, PasswordChangeForm
from django.contrib.auth.tokens import default_token_generator from django.contrib.auth.tokens import default_token_generator
from django.views.decorators.csrf import csrf_protect
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.shortcuts import render_to_response, get_object_or_404 from django.shortcuts import render_to_response, get_object_or_404
from django.contrib.sites.models import Site, RequestSite from django.contrib.sites.models import Site, RequestSite
@ -14,11 +15,15 @@ from django.utils.translation import ugettext as _
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.views.decorators.cache import never_cache from django.views.decorators.cache import never_cache
def login(request, template_name='registration/login.html', redirect_field_name=REDIRECT_FIELD_NAME): @csrf_protect
@never_cache
def login(request, template_name='registration/login.html',
redirect_field_name=REDIRECT_FIELD_NAME,
authentication_form=AuthenticationForm):
"Displays the login form and handles the login action." "Displays the login form and handles the login action."
redirect_to = request.REQUEST.get(redirect_field_name, '') redirect_to = request.REQUEST.get(redirect_field_name, '')
if request.method == "POST": if request.method == "POST":
form = AuthenticationForm(data=request.POST) form = authentication_form(data=request.POST)
if form.is_valid(): if form.is_valid():
# Light security check -- make sure redirect_to isn't garbage. # Light security check -- make sure redirect_to isn't garbage.
if not redirect_to or '//' in redirect_to or ' ' in redirect_to: if not redirect_to or '//' in redirect_to or ' ' in redirect_to:
@ -29,7 +34,7 @@ def login(request, template_name='registration/login.html', redirect_field_name=
request.session.delete_test_cookie() request.session.delete_test_cookie()
return HttpResponseRedirect(redirect_to) return HttpResponseRedirect(redirect_to)
else: else:
form = AuthenticationForm(request) form = authentication_form(request)
request.session.set_test_cookie() request.session.set_test_cookie()
if Site._meta.installed: if Site._meta.installed:
current_site = Site.objects.get_current() current_site = Site.objects.get_current()
@ -41,7 +46,6 @@ def login(request, template_name='registration/login.html', redirect_field_name=
'site': current_site, 'site': current_site,
'site_name': current_site.name, 'site_name': current_site.name,
}, context_instance=RequestContext(request)) }, context_instance=RequestContext(request))
login = never_cache(login)
def logout(request, next_page=None, template_name='registration/logged_out.html', redirect_field_name=REDIRECT_FIELD_NAME): def logout(request, next_page=None, template_name='registration/logged_out.html', redirect_field_name=REDIRECT_FIELD_NAME):
"Logs out the user and displays 'You are logged out' message." "Logs out the user and displays 'You are logged out' message."
@ -78,6 +82,7 @@ def redirect_to_login(next, login_url=None, redirect_field_name=REDIRECT_FIELD_N
# prompts for a new password # prompts for a new password
# - password_reset_complete shows a success message for the above # - password_reset_complete shows a success message for the above
@csrf_protect
def password_reset(request, is_admin_site=False, template_name='registration/password_reset_form.html', def password_reset(request, is_admin_site=False, template_name='registration/password_reset_form.html',
email_template_name='registration/password_reset_email.html', email_template_name='registration/password_reset_email.html',
password_reset_form=PasswordResetForm, token_generator=default_token_generator, password_reset_form=PasswordResetForm, token_generator=default_token_generator,
@ -107,6 +112,7 @@ def password_reset(request, is_admin_site=False, template_name='registration/pas
def password_reset_done(request, template_name='registration/password_reset_done.html'): def password_reset_done(request, template_name='registration/password_reset_done.html'):
return render_to_response(template_name, context_instance=RequestContext(request)) return render_to_response(template_name, context_instance=RequestContext(request))
# Doesn't need csrf_protect since no-one can guess the URL
def password_reset_confirm(request, uidb36=None, token=None, template_name='registration/password_reset_confirm.html', def password_reset_confirm(request, uidb36=None, token=None, template_name='registration/password_reset_confirm.html',
token_generator=default_token_generator, set_password_form=SetPasswordForm, token_generator=default_token_generator, set_password_form=SetPasswordForm,
post_reset_redirect=None): post_reset_redirect=None):
@ -144,21 +150,22 @@ def password_reset_complete(request, template_name='registration/password_reset_
return render_to_response(template_name, context_instance=RequestContext(request, return render_to_response(template_name, context_instance=RequestContext(request,
{'login_url': settings.LOGIN_URL})) {'login_url': settings.LOGIN_URL}))
@csrf_protect
@login_required
def password_change(request, template_name='registration/password_change_form.html', def password_change(request, template_name='registration/password_change_form.html',
post_change_redirect=None): post_change_redirect=None, password_change_form=PasswordChangeForm):
if post_change_redirect is None: if post_change_redirect is None:
post_change_redirect = reverse('django.contrib.auth.views.password_change_done') post_change_redirect = reverse('django.contrib.auth.views.password_change_done')
if request.method == "POST": if request.method == "POST":
form = PasswordChangeForm(request.user, request.POST) form = password_change_form(user=request.user, data=request.POST)
if form.is_valid(): if form.is_valid():
form.save() form.save()
return HttpResponseRedirect(post_change_redirect) return HttpResponseRedirect(post_change_redirect)
else: else:
form = PasswordChangeForm(request.user) form = password_change_form(user=request.user)
return render_to_response(template_name, { return render_to_response(template_name, {
'form': form, 'form': form,
}, context_instance=RequestContext(request)) }, context_instance=RequestContext(request))
password_change = login_required(password_change)
def password_change_done(request, template_name='registration/password_change_done.html'): def password_change_done(request, template_name='registration/password_change_done.html'):
return render_to_response(template_name, context_instance=RequestContext(request)) return render_to_response(template_name, context_instance=RequestContext(request))

View File

@ -1,7 +1,8 @@
from django.contrib import admin from django.contrib import admin
from django.contrib.comments.models import Comment from django.contrib.comments.models import Comment
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _, ungettext
from django.contrib.comments import get_model from django.contrib.comments import get_model
from django.contrib.comments.views.moderation import perform_flag, perform_approve, perform_delete
class CommentsAdmin(admin.ModelAdmin): class CommentsAdmin(admin.ModelAdmin):
fieldsets = ( fieldsets = (
@ -22,6 +23,44 @@ class CommentsAdmin(admin.ModelAdmin):
ordering = ('-submit_date',) ordering = ('-submit_date',)
raw_id_fields = ('user',) raw_id_fields = ('user',)
search_fields = ('comment', 'user__username', 'user_name', 'user_email', 'user_url', 'ip_address') search_fields = ('comment', 'user__username', 'user_name', 'user_email', 'user_url', 'ip_address')
actions = ["flag_comments", "approve_comments", "remove_comments"]
def get_actions(self, request):
actions = super(CommentsAdmin, self).get_actions(request)
# Only superusers should be able to delete the comments from the DB.
if not request.user.is_superuser:
actions.pop('delete_selected')
if not request.user.has_perm('comments.can_moderate'):
actions.pop('approve_comments')
actions.pop('remove_comments')
return actions
def flag_comments(self, request, queryset):
self._bulk_flag(request, queryset, perform_flag, _("flagged"))
flag_comments.short_description = _("Flag selected comments")
def approve_comments(self, request, queryset):
self._bulk_flag(request, queryset, perform_approve, _('approved'))
approve_comments.short_description = _("Approve selected comments")
def remove_comments(self, request, queryset):
self._bulk_flag(request, queryset, perform_delete, _('removed'))
remove_comments.short_description = _("Remove selected comments")
def _bulk_flag(self, request, queryset, action, description):
"""
Flag, approve, or remove some comments from an admin action. Actually
calls the `action` argument to perform the heavy lifting.
"""
n_comments = 0
for comment in queryset:
action(request, comment)
n_comments += 1
msg = ungettext(u'1 comment was successfully %(action)s.',
u'%(count)s comments were successfully %(action)s.',
n_comments)
self.message_user(request, msg % {'count': n_comments, 'action': description})
# Only register the default admin if the model is the built-in comment model # Only register the default admin if the model is the built-in comment model
# (this won't be true if there's a custom comment app). # (this won't be true if there's a custom comment app).

View File

@ -6,7 +6,7 @@
{% block content %} {% block content %}
<h1>{% trans "Really make this comment public?" %}</h1> <h1>{% trans "Really make this comment public?" %}</h1>
<blockquote>{{ comment|linebreaks }}</blockquote> <blockquote>{{ comment|linebreaks }}</blockquote>
<form action="." method="post"> <form action="." method="post">{% csrf_token %}
{% if next %}<input type="hidden" name="next" value="{{ next }}" id="next" />{% endif %} {% if next %}<input type="hidden" name="next" value="{{ next }}" id="next" />{% endif %}
<p class="submit"> <p class="submit">
<input type="submit" name="submit" value="{% trans "Approve" %}" /> or <a href="{{ comment.get_absolute_url }}">cancel</a> <input type="submit" name="submit" value="{% trans "Approve" %}" /> or <a href="{{ comment.get_absolute_url }}">cancel</a>

View File

@ -6,7 +6,7 @@
{% block content %} {% block content %}
<h1>{% trans "Really remove this comment?" %}</h1> <h1>{% trans "Really remove this comment?" %}</h1>
<blockquote>{{ comment|linebreaks }}</blockquote> <blockquote>{{ comment|linebreaks }}</blockquote>
<form action="." method="post"> <form action="." method="post">{% csrf_token %}
{% if next %}<input type="hidden" name="next" value="{{ next }}" id="next" />{% endif %} {% if next %}<input type="hidden" name="next" value="{{ next }}" id="next" />{% endif %}
<p class="submit"> <p class="submit">
<input type="submit" name="submit" value="{% trans "Remove" %}" /> or <a href="{{ comment.get_absolute_url }}">cancel</a> <input type="submit" name="submit" value="{% trans "Remove" %}" /> or <a href="{{ comment.get_absolute_url }}">cancel</a>

View File

@ -6,7 +6,7 @@
{% block content %} {% block content %}
<h1>{% trans "Really flag this comment?" %}</h1> <h1>{% trans "Really flag this comment?" %}</h1>
<blockquote>{{ comment|linebreaks }}</blockquote> <blockquote>{{ comment|linebreaks }}</blockquote>
<form action="." method="post"> <form action="." method="post">{% csrf_token %}
{% if next %}<input type="hidden" name="next" value="{{ next }}" id="next" />{% endif %} {% if next %}<input type="hidden" name="next" value="{{ next }}" id="next" />{% endif %}
<p class="submit"> <p class="submit">
<input type="submit" name="submit" value="{% trans "Flag" %}" /> or <a href="{{ comment.get_absolute_url }}">cancel</a> <input type="submit" name="submit" value="{% trans "Flag" %}" /> or <a href="{{ comment.get_absolute_url }}">cancel</a>

View File

@ -1,5 +1,5 @@
{% load comments i18n %} {% load comments i18n %}
<form action="{% comment_form_target %}" method="post"> <form action="{% comment_form_target %}" method="post">{% csrf_token %}
{% if next %}<input type="hidden" name="next" value="{{ next }}" />{% endif %} {% if next %}<input type="hidden" name="next" value="{{ next }}" />{% endif %}
{% for field in form %} {% for field in form %}
{% if field.is_hidden %} {% if field.is_hidden %}

View File

@ -1,75 +0,0 @@
{% extends "admin/change_list.html" %}
{% load adminmedia i18n %}
{% block title %}{% trans "Comment moderation queue" %}{% endblock %}
{% block extrahead %}
{{ block.super }}
<style type="text/css" media="screen">
p#nocomments { font-size: 200%; text-align: center; border: 1px #ccc dashed; padding: 4em; }
td.actions { width: 11em; }
td.actions form { display: inline; }
td.actions form input.submit { width: 5em; padding: 2px 4px; margin-right: 4px;}
td.actions form input.approve { background: green; color: white; }
td.actions form input.remove { background: red; color: white; }
</style>
{% endblock %}
{% block branding %}
<h1 id="site-name">{% trans "Comment moderation queue" %}</h1>
{% endblock %}
{% block breadcrumbs %}{% endblock %}
{% block content %}
{% if empty %}
<p id="nocomments">{% trans "No comments to moderate" %}.</p>
{% else %}
<div id="content-main">
<div class="module" id="changelist">
<table cellspacing="0">
<thead>
<tr>
<th>{% trans "Action" %}</th>
<th>{% trans "Name" %}</th>
<th>{% trans "Comment" %}</th>
<th>{% trans "Email" %}</th>
<th>{% trans "URL" %}</th>
<th>{% trans "Authenticated?" %}</th>
<th>{% trans "IP Address" %}</th>
<th class="sorted desc">{% trans "Date posted" %}</th>
</tr>
</thead>
<tbody>
{% for comment in comments %}
<tr class="{% cycle 'row1' 'row2' %}">
<td class="actions">
<form action="{% url comments-approve comment.pk %}" method="post">
<input type="hidden" name="next" value="{% url comments-moderation-queue %}" />
<input class="approve submit" type="submit" name="submit" value="{% trans "Approve" %}" />
</form>
<form action="{% url comments-delete comment.pk %}" method="post">
<input type="hidden" name="next" value="{% url comments-moderation-queue %}" />
<input class="remove submit" type="submit" name="submit" value="{% trans "Remove" %}" />
</form>
</td>
<td>{{ comment.name }}</td>
<td>{{ comment.comment|truncatewords:"50" }}</td>
<td>{{ comment.email }}</td>
<td>{{ comment.url }}</td>
<td>
<img
src="{% admin_media_prefix %}img/admin/icon-{% if comment.user %}yes{% else %}no{% endif %}.gif"
alt="{% if comment.user %}{% trans "yes" %}{% else %}{% trans "no" %}{% endif %}"
/>
</td>
<td>{{ comment.ip_address }}</td>
<td>{{ comment.submit_date|date:"F j, P" }}</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
</div>
{% endif %}
{% endblock %}

View File

@ -5,7 +5,7 @@
{% block content %} {% block content %}
{% load comments %} {% load comments %}
<form action="{% comment_form_target %}" method="post"> <form action="{% comment_form_target %}" method="post">{% csrf_token %}
{% if next %}<input type="hidden" name="next" value="{{ next }}" />{% endif %} {% if next %}<input type="hidden" name="next" value="{{ next }}" />{% endif %}
{% if form.errors %} {% if form.errors %}
<h1>{% blocktrans count form.errors|length as counter %}Please correct the error below{% plural %}Please correct the errors below{% endblocktrans %}</h1> <h1>{% blocktrans count form.errors|length as counter %}Please correct the error below{% plural %}Please correct the errors below{% endblocktrans %}</h1>

View File

@ -7,7 +7,6 @@ urlpatterns = patterns('django.contrib.comments.views',
url(r'^flagged/$', 'moderation.flag_done', name='comments-flag-done'), url(r'^flagged/$', 'moderation.flag_done', name='comments-flag-done'),
url(r'^delete/(\d+)/$', 'moderation.delete', name='comments-delete'), url(r'^delete/(\d+)/$', 'moderation.delete', name='comments-delete'),
url(r'^deleted/$', 'moderation.delete_done', name='comments-delete-done'), url(r'^deleted/$', 'moderation.delete_done', name='comments-delete-done'),
url(r'^moderate/$', 'moderation.moderation_queue', name='comments-moderation-queue'),
url(r'^approve/(\d+)/$', 'moderation.approve', name='comments-approve'), url(r'^approve/(\d+)/$', 'moderation.approve', name='comments-approve'),
url(r'^approved/$', 'moderation.approve_done', name='comments-approve-done'), url(r'^approved/$', 'moderation.approve_done', name='comments-approve-done'),
) )

View File

@ -10,6 +10,7 @@ from django.utils.html import escape
from django.views.decorators.http import require_POST from django.views.decorators.http import require_POST
from django.contrib import comments from django.contrib import comments
from django.contrib.comments import signals from django.contrib.comments import signals
from django.views.decorators.csrf import csrf_protect
class CommentPostBadRequest(http.HttpResponseBadRequest): class CommentPostBadRequest(http.HttpResponseBadRequest):
""" """
@ -22,6 +23,8 @@ class CommentPostBadRequest(http.HttpResponseBadRequest):
if settings.DEBUG: if settings.DEBUG:
self.content = render_to_string("comments/400-debug.html", {"why": why}) self.content = render_to_string("comments/400-debug.html", {"why": why})
@csrf_protect
@require_POST
def post_comment(request, next=None): def post_comment(request, next=None):
""" """
Post a comment. Post a comment.
@ -116,8 +119,6 @@ def post_comment(request, next=None):
return next_redirect(data, next, comment_done, c=comment._get_pk_val()) return next_redirect(data, next, comment_done, c=comment._get_pk_val())
post_comment = require_POST(post_comment)
comment_done = confirmation_view( comment_done = confirmation_view(
template = "comments/posted.html", template = "comments/posted.html",
doc = """Display a "comment was posted" success page.""" doc = """Display a "comment was posted" success page."""

View File

@ -3,12 +3,12 @@ from django.conf import settings
from django.shortcuts import get_object_or_404, render_to_response from django.shortcuts import get_object_or_404, render_to_response
from django.contrib.auth.decorators import login_required, permission_required from django.contrib.auth.decorators import login_required, permission_required
from utils import next_redirect, confirmation_view from utils import next_redirect, confirmation_view
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
from django.contrib import comments from django.contrib import comments
from django.contrib.comments import signals from django.contrib.comments import signals
from django.views.decorators.csrf import csrf_protect
#@login_required @csrf_protect
@login_required
def flag(request, comment_id, next=None): def flag(request, comment_id, next=None):
""" """
Flags a comment. Confirmation on GET, action on POST. Flags a comment. Confirmation on GET, action on POST.
@ -22,18 +22,7 @@ def flag(request, comment_id, next=None):
# Flag on POST # Flag on POST
if request.method == 'POST': if request.method == 'POST':
flag, created = comments.models.CommentFlag.objects.get_or_create( perform_flag(request, comment)
comment = comment,
user = request.user,
flag = comments.models.CommentFlag.SUGGEST_REMOVAL
)
signals.comment_was_flagged.send(
sender = comment.__class__,
comment = comment,
flag = flag,
created = created,
request = request,
)
return next_redirect(request.POST.copy(), next, flag_done, c=comment.pk) return next_redirect(request.POST.copy(), next, flag_done, c=comment.pk)
# Render a form on GET # Render a form on GET
@ -42,9 +31,9 @@ def flag(request, comment_id, next=None):
{'comment': comment, "next": next}, {'comment': comment, "next": next},
template.RequestContext(request) template.RequestContext(request)
) )
flag = login_required(flag)
#@permission_required("comments.delete_comment") @csrf_protect
@permission_required("comments.can_moderate")
def delete(request, comment_id, next=None): def delete(request, comment_id, next=None):
""" """
Deletes a comment. Confirmation on GET, action on POST. Requires the "can Deletes a comment. Confirmation on GET, action on POST. Requires the "can
@ -60,20 +49,7 @@ def delete(request, comment_id, next=None):
# Delete on POST # Delete on POST
if request.method == 'POST': if request.method == 'POST':
# Flag the comment as deleted instead of actually deleting it. # Flag the comment as deleted instead of actually deleting it.
flag, created = comments.models.CommentFlag.objects.get_or_create( perform_delete(request, comment)
comment = comment,
user = request.user,
flag = comments.models.CommentFlag.MODERATOR_DELETION
)
comment.is_removed = True
comment.save()
signals.comment_was_flagged.send(
sender = comment.__class__,
comment = comment,
flag = flag,
created = created,
request = request,
)
return next_redirect(request.POST.copy(), next, delete_done, c=comment.pk) return next_redirect(request.POST.copy(), next, delete_done, c=comment.pk)
# Render a form on GET # Render a form on GET
@ -82,9 +58,9 @@ def delete(request, comment_id, next=None):
{'comment': comment, "next": next}, {'comment': comment, "next": next},
template.RequestContext(request) template.RequestContext(request)
) )
delete = permission_required("comments.can_moderate")(delete)
#@permission_required("comments.can_moderate") @csrf_protect
@permission_required("comments.can_moderate")
def approve(request, comment_id, next=None): def approve(request, comment_id, next=None):
""" """
Approve a comment (that is, mark it as public and non-removed). Confirmation Approve a comment (that is, mark it as public and non-removed). Confirmation
@ -100,23 +76,7 @@ def approve(request, comment_id, next=None):
# Delete on POST # Delete on POST
if request.method == 'POST': if request.method == 'POST':
# Flag the comment as approved. # Flag the comment as approved.
flag, created = comments.models.CommentFlag.objects.get_or_create( perform_approve(request, comment)
comment = comment,
user = request.user,
flag = comments.models.CommentFlag.MODERATOR_APPROVAL,
)
comment.is_removed = False
comment.is_public = True
comment.save()
signals.comment_was_flagged.send(
sender = comment.__class__,
comment = comment,
flag = flag,
created = created,
request = request,
)
return next_redirect(request.POST.copy(), next, approve_done, c=comment.pk) return next_redirect(request.POST.copy(), next, approve_done, c=comment.pk)
# Render a form on GET # Render a form on GET
@ -126,69 +86,64 @@ def approve(request, comment_id, next=None):
template.RequestContext(request) template.RequestContext(request)
) )
approve = permission_required("comments.can_moderate")(approve) # The following functions actually perform the various flag/aprove/delete
# actions. They've been broken out into seperate functions to that they
# may be called from admin actions.
def perform_flag(request, comment):
#@permission_required("comments.can_moderate")
def moderation_queue(request):
""" """
Displays a list of unapproved comments to be approved. Actually perform the flagging of a comment from a request.
Templates: `comments/moderation_queue.html`
Context:
comments
Comments to be approved (paginated).
empty
Is the comment list empty?
is_paginated
Is there more than one page?
results_per_page
Number of comments per page
has_next
Is there a next page?
has_previous
Is there a previous page?
page
The current page number
next
The next page number
pages
Number of pages
hits
Total number of comments
page_range
Range of page numbers
""" """
qs = comments.get_model().objects.filter(is_public=False, is_removed=False) flag, created = comments.models.CommentFlag.objects.get_or_create(
paginator = Paginator(qs, 100) comment = comment,
user = request.user,
flag = comments.models.CommentFlag.SUGGEST_REMOVAL
)
signals.comment_was_flagged.send(
sender = comment.__class__,
comment = comment,
flag = flag,
created = created,
request = request,
)
try: def perform_delete(request, comment):
page = int(request.GET.get("page", 1)) flag, created = comments.models.CommentFlag.objects.get_or_create(
except ValueError: comment = comment,
raise Http404 user = request.user,
flag = comments.models.CommentFlag.MODERATOR_DELETION
)
comment.is_removed = True
comment.save()
signals.comment_was_flagged.send(
sender = comment.__class__,
comment = comment,
flag = flag,
created = created,
request = request,
)
try:
comments_per_page = paginator.page(page)
except InvalidPage:
raise Http404
return render_to_response("comments/moderation_queue.html", { def perform_approve(request, comment):
'comments' : comments_per_page.object_list, flag, created = comments.models.CommentFlag.objects.get_or_create(
'empty' : page == 1 and paginator.count == 0, comment = comment,
'is_paginated': paginator.num_pages > 1, user = request.user,
'results_per_page': 100, flag = comments.models.CommentFlag.MODERATOR_APPROVAL,
'has_next': comments_per_page.has_next(), )
'has_previous': comments_per_page.has_previous(),
'page': page,
'next': page + 1,
'previous': page - 1,
'pages': paginator.num_pages,
'hits' : paginator.count,
'page_range' : paginator.page_range
}, context_instance=template.RequestContext(request))
moderation_queue = permission_required("comments.can_moderate")(moderation_queue) comment.is_removed = False
comment.is_public = True
comment.save()
signals.comment_was_flagged.send(
sender = comment.__class__,
comment = comment,
flag = flag,
created = created,
request = request,
)
# Confirmation views.
flag_done = confirmation_view( flag_done = confirmation_view(
template = "comments/flagged.html", template = "comments/flagged.html",

View File

@ -105,8 +105,6 @@ class GenericRelation(RelatedField, Field):
limit_choices_to=kwargs.pop('limit_choices_to', None), limit_choices_to=kwargs.pop('limit_choices_to', None),
symmetrical=kwargs.pop('symmetrical', True)) symmetrical=kwargs.pop('symmetrical', True))
# By its very nature, a GenericRelation doesn't create a table.
self.creates_table = False
# Override content-type/object-id field names on the related class # Override content-type/object-id field names on the related class
self.object_id_field_name = kwargs.pop("object_id_field", "object_id") self.object_id_field_name = kwargs.pop("object_id_field", "object_id")

View File

@ -1,160 +1,7 @@
""" from django.middleware.csrf import CsrfMiddleware, CsrfViewMiddleware, CsrfResponseMiddleware
Cross Site Request Forgery Middleware. from django.views.decorators.csrf import csrf_exempt, csrf_view_exempt, csrf_response_exempt
This module provides a middleware that implements protection import warnings
against request forgeries from other sites. warnings.warn("This import for CSRF functionality is deprecated. Please use django.middleware.csrf for the middleware and django.views.decorators.csrf for decorators.",
""" PendingDeprecationWarning
)
import re
import itertools
try:
from functools import wraps
except ImportError:
from django.utils.functional import wraps # Python 2.3, 2.4 fallback.
from django.conf import settings
from django.http import HttpResponseForbidden
from django.utils.hashcompat import md5_constructor
from django.utils.safestring import mark_safe
_ERROR_MSG = mark_safe('<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en"><body><h1>403 Forbidden</h1><p>Cross Site Request Forgery detected. Request aborted.</p></body></html>')
_POST_FORM_RE = \
re.compile(r'(<form\W[^>]*\bmethod\s*=\s*(\'|"|)POST(\'|"|)\b[^>]*>)', re.IGNORECASE)
_HTML_TYPES = ('text/html', 'application/xhtml+xml')
def _make_token(session_id):
return md5_constructor(settings.SECRET_KEY + session_id).hexdigest()
class CsrfViewMiddleware(object):
"""
Middleware that requires a present and correct csrfmiddlewaretoken
for POST requests that have an active session.
"""
def process_view(self, request, callback, callback_args, callback_kwargs):
if request.method == 'POST':
if getattr(callback, 'csrf_exempt', False):
return None
if request.is_ajax():
return None
try:
session_id = request.COOKIES[settings.SESSION_COOKIE_NAME]
except KeyError:
# No session, no check required
return None
csrf_token = _make_token(session_id)
# check incoming token
try:
request_csrf_token = request.POST['csrfmiddlewaretoken']
except KeyError:
return HttpResponseForbidden(_ERROR_MSG)
if request_csrf_token != csrf_token:
return HttpResponseForbidden(_ERROR_MSG)
return None
class CsrfResponseMiddleware(object):
"""
Middleware that post-processes a response to add a
csrfmiddlewaretoken if the response/request have an active
session.
"""
def process_response(self, request, response):
if getattr(response, 'csrf_exempt', False):
return response
csrf_token = None
try:
# This covers a corner case in which the outgoing response
# both contains a form and sets a session cookie. This
# really should not be needed, since it is best if views
# that create a new session (login pages) also do a
# redirect, as is done by all such view functions in
# Django.
cookie = response.cookies[settings.SESSION_COOKIE_NAME]
csrf_token = _make_token(cookie.value)
except KeyError:
# Normal case - look for existing session cookie
try:
session_id = request.COOKIES[settings.SESSION_COOKIE_NAME]
csrf_token = _make_token(session_id)
except KeyError:
# no incoming or outgoing cookie
pass
if csrf_token is not None and \
response['Content-Type'].split(';')[0] in _HTML_TYPES:
# ensure we don't add the 'id' attribute twice (HTML validity)
idattributes = itertools.chain(("id='csrfmiddlewaretoken'",),
itertools.repeat(''))
def add_csrf_field(match):
"""Returns the matched <form> tag plus the added <input> element"""
return mark_safe(match.group() + "<div style='display:none;'>" + \
"<input type='hidden' " + idattributes.next() + \
" name='csrfmiddlewaretoken' value='" + csrf_token + \
"' /></div>")
# Modify any POST forms
response.content = _POST_FORM_RE.sub(add_csrf_field, response.content)
return response
class CsrfMiddleware(CsrfViewMiddleware, CsrfResponseMiddleware):
"""Django middleware that adds protection against Cross Site
Request Forgeries by adding hidden form fields to POST forms and
checking requests for the correct value.
In the list of middlewares, SessionMiddleware is required, and
must come after this middleware. CsrfMiddleWare must come after
compression middleware.
If a session ID cookie is present, it is hashed with the
SECRET_KEY setting to create an authentication token. This token
is added to all outgoing POST forms and is expected on all
incoming POST requests that have a session ID cookie.
If you are setting cookies directly, instead of using Django's
session framework, this middleware will not work.
CsrfMiddleWare is composed of two middleware, CsrfViewMiddleware
and CsrfResponseMiddleware which can be used independently.
"""
pass
def csrf_response_exempt(view_func):
"""
Modifies a view function so that its response is exempt
from the post-processing of the CSRF middleware.
"""
def wrapped_view(*args, **kwargs):
resp = view_func(*args, **kwargs)
resp.csrf_exempt = True
return resp
return wraps(view_func)(wrapped_view)
def csrf_view_exempt(view_func):
"""
Marks a view function as being exempt from CSRF view protection.
"""
# We could just do view_func.csrf_exempt = True, but decorators
# are nicer if they don't have side-effects, so we return a new
# function.
def wrapped_view(*args, **kwargs):
return view_func(*args, **kwargs)
wrapped_view.csrf_exempt = True
return wraps(view_func)(wrapped_view)
def csrf_exempt(view_func):
"""
Marks a view function as being exempt from the CSRF checks
and post processing.
This is the same as using both the csrf_view_exempt and
csrf_response_exempt decorators.
"""
return csrf_response_exempt(csrf_view_exempt(view_func))

View File

@ -1,144 +0,0 @@
# -*- coding: utf-8 -*-
from django.test import TestCase
from django.http import HttpRequest, HttpResponse, HttpResponseForbidden
from django.contrib.csrf.middleware import CsrfMiddleware, _make_token, csrf_exempt
from django.conf import settings
def post_form_response():
resp = HttpResponse(content="""
<html><body><form method="POST"><input type="text" /></form></body></html>
""", mimetype="text/html")
return resp
def test_view(request):
return post_form_response()
class CsrfMiddlewareTest(TestCase):
_session_id = "1"
def _get_GET_no_session_request(self):
return HttpRequest()
def _get_GET_session_request(self):
req = self._get_GET_no_session_request()
req.COOKIES[settings.SESSION_COOKIE_NAME] = self._session_id
return req
def _get_POST_session_request(self):
req = self._get_GET_session_request()
req.method = "POST"
return req
def _get_POST_no_session_request(self):
req = self._get_GET_no_session_request()
req.method = "POST"
return req
def _get_POST_session_request_with_token(self):
req = self._get_POST_session_request()
req.POST['csrfmiddlewaretoken'] = _make_token(self._session_id)
return req
def _get_post_form_response(self):
return post_form_response()
def _get_new_session_response(self):
resp = self._get_post_form_response()
resp.cookies[settings.SESSION_COOKIE_NAME] = self._session_id
return resp
def _check_token_present(self, response):
self.assertContains(response, "name='csrfmiddlewaretoken' value='%s'" % _make_token(self._session_id))
def get_view(self):
return test_view
# Check the post processing
def test_process_response_no_session(self):
"""
Check the post-processor does nothing if no session active
"""
req = self._get_GET_no_session_request()
resp = self._get_post_form_response()
resp_content = resp.content # needed because process_response modifies resp
resp2 = CsrfMiddleware().process_response(req, resp)
self.assertEquals(resp_content, resp2.content)
def test_process_response_existing_session(self):
"""
Check that the token is inserted if there is an existing session
"""
req = self._get_GET_session_request()
resp = self._get_post_form_response()
resp_content = resp.content # needed because process_response modifies resp
resp2 = CsrfMiddleware().process_response(req, resp)
self.assertNotEqual(resp_content, resp2.content)
self._check_token_present(resp2)
def test_process_response_new_session(self):
"""
Check that the token is inserted if there is a new session being started
"""
req = self._get_GET_no_session_request() # no session in request
resp = self._get_new_session_response() # but new session started
resp_content = resp.content # needed because process_response modifies resp
resp2 = CsrfMiddleware().process_response(req, resp)
self.assertNotEqual(resp_content, resp2.content)
self._check_token_present(resp2)
def test_process_response_exempt_view(self):
"""
Check that no post processing is done for an exempt view
"""
req = self._get_POST_session_request()
resp = csrf_exempt(self.get_view())(req)
resp_content = resp.content
resp2 = CsrfMiddleware().process_response(req, resp)
self.assertEquals(resp_content, resp2.content)
# Check the request processing
def test_process_request_no_session(self):
"""
Check that if no session is present, the middleware does nothing.
to the incoming request.
"""
req = self._get_POST_no_session_request()
req2 = CsrfMiddleware().process_view(req, self.get_view(), (), {})
self.assertEquals(None, req2)
def test_process_request_session_no_token(self):
"""
Check that if a session is present but no token, we get a 'forbidden'
"""
req = self._get_POST_session_request()
req2 = CsrfMiddleware().process_view(req, self.get_view(), (), {})
self.assertEquals(HttpResponseForbidden, req2.__class__)
def test_process_request_session_and_token(self):
"""
Check that if a session is present and a token, the middleware lets it through
"""
req = self._get_POST_session_request_with_token()
req2 = CsrfMiddleware().process_view(req, self.get_view(), (), {})
self.assertEquals(None, req2)
def test_process_request_session_no_token_exempt_view(self):
"""
Check that if a session is present and no token, but the csrf_exempt
decorator has been applied to the view, the middleware lets it through
"""
req = self._get_POST_session_request()
req2 = CsrfMiddleware().process_view(req, csrf_exempt(self.get_view()), (), {})
self.assertEquals(None, req2)
def test_ajax_exemption(self):
"""
Check that AJAX requests are automatically exempted.
"""
req = self._get_POST_session_request()
req.META['HTTP_X_REQUESTED_WITH'] = 'XMLHttpRequest'
req2 = CsrfMiddleware().process_view(req, self.get_view(), (), {})
self.assertEquals(None, req2)

View File

@ -4,7 +4,7 @@
{% if form.errors %}<h1>Please correct the following errors</h1>{% else %}<h1>Submit</h1>{% endif %} {% if form.errors %}<h1>Please correct the following errors</h1>{% else %}<h1>Submit</h1>{% endif %}
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
<table> <table>
{{ form }} {{ form }}
</table> </table>

View File

@ -15,7 +15,7 @@
<p>Security hash: {{ hash_value }}</p> <p>Security hash: {{ hash_value }}</p>
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
{% for field in form %}{{ field.as_hidden }} {% for field in form %}{{ field.as_hidden }}
{% endfor %} {% endfor %}
<input type="hidden" name="{{ stage_field }}" value="2" /> <input type="hidden" name="{{ stage_field }}" value="2" />
@ -25,7 +25,7 @@
<h1>Or edit it again</h1> <h1>Or edit it again</h1>
<form action="" method="post"> <form action="" method="post">{% csrf_token %}
<table> <table>
{{ form }} {{ form }}
</table> </table>

View File

@ -147,15 +147,18 @@ class WizardPageTwoForm(forms.Form):
class WizardClass(wizard.FormWizard): class WizardClass(wizard.FormWizard):
def render_template(self, *args, **kw): def render_template(self, *args, **kw):
return "" return http.HttpResponse("")
def done(self, request, cleaned_data): def done(self, request, cleaned_data):
return http.HttpResponse(success_string) return http.HttpResponse(success_string)
class DummyRequest(object): class DummyRequest(http.HttpRequest):
def __init__(self, POST=None): def __init__(self, POST=None):
super(DummyRequest, self).__init__()
self.method = POST and "POST" or "GET" self.method = POST and "POST" or "GET"
self.POST = POST if POST is not None:
self.POST.update(POST)
self._dont_enforce_csrf_checks = True
class WizardTests(TestCase): class WizardTests(TestCase):
def test_step_starts_at_zero(self): def test_step_starts_at_zero(self):

View File

@ -14,6 +14,7 @@ from django.template.context import RequestContext
from django.utils.hashcompat import md5_constructor from django.utils.hashcompat import md5_constructor
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.contrib.formtools.utils import security_hash from django.contrib.formtools.utils import security_hash
from django.views.decorators.csrf import csrf_protect
class FormWizard(object): class FormWizard(object):
# Dictionary of extra template context variables. # Dictionary of extra template context variables.
@ -44,6 +45,7 @@ class FormWizard(object):
# hook methods might alter self.form_list. # hook methods might alter self.form_list.
return len(self.form_list) return len(self.form_list)
@csrf_protect
def __call__(self, request, *args, **kwargs): def __call__(self, request, *args, **kwargs):
""" """
Main method that does all the hard work, conforming to the Django view Main method that does all the hard work, conforming to the Django view

View File

@ -18,18 +18,21 @@ SpatialBackend = BaseSpatialBackend(name='postgis', postgis=True,
distance_spheroid=DISTANCE_SPHEROID, distance_spheroid=DISTANCE_SPHEROID,
envelope=ENVELOPE, envelope=ENVELOPE,
extent=EXTENT, extent=EXTENT,
extent3d=EXTENT3D,
gis_terms=POSTGIS_TERMS, gis_terms=POSTGIS_TERMS,
geojson=ASGEOJSON, geojson=ASGEOJSON,
gml=ASGML, gml=ASGML,
intersection=INTERSECTION, intersection=INTERSECTION,
kml=ASKML, kml=ASKML,
length=LENGTH, length=LENGTH,
length3d=LENGTH3D,
length_spheroid=LENGTH_SPHEROID, length_spheroid=LENGTH_SPHEROID,
make_line=MAKE_LINE, make_line=MAKE_LINE,
mem_size=MEM_SIZE, mem_size=MEM_SIZE,
num_geom=NUM_GEOM, num_geom=NUM_GEOM,
num_points=NUM_POINTS, num_points=NUM_POINTS,
perimeter=PERIMETER, perimeter=PERIMETER,
perimeter3d=PERIMETER3D,
point_on_surface=POINT_ON_SURFACE, point_on_surface=POINT_ON_SURFACE,
scale=SCALE, scale=SCALE,
select=GEOM_SELECT, select=GEOM_SELECT,

View File

@ -2,7 +2,7 @@
This object provides quoting for GEOS geometries into PostgreSQL/PostGIS. This object provides quoting for GEOS geometries into PostgreSQL/PostGIS.
""" """
from django.contrib.gis.db.backend.postgis.query import GEOM_FROM_WKB from django.contrib.gis.db.backend.postgis.query import GEOM_FROM_EWKB
from psycopg2 import Binary from psycopg2 import Binary
from psycopg2.extensions import ISQLQuote from psycopg2.extensions import ISQLQuote
@ -11,7 +11,7 @@ class PostGISAdaptor(object):
"Initializes on the geometry." "Initializes on the geometry."
# Getting the WKB (in string form, to allow easy pickling of # Getting the WKB (in string form, to allow easy pickling of
# the adaptor) and the SRID from the geometry. # the adaptor) and the SRID from the geometry.
self.wkb = str(geom.wkb) self.ewkb = str(geom.ewkb)
self.srid = geom.srid self.srid = geom.srid
def __conform__(self, proto): def __conform__(self, proto):
@ -30,7 +30,7 @@ class PostGISAdaptor(object):
def getquoted(self): def getquoted(self):
"Returns a properly quoted string for use in PostgreSQL/PostGIS." "Returns a properly quoted string for use in PostgreSQL/PostGIS."
# Want to use WKB, so wrap with psycopg2 Binary() to quote properly. # Want to use WKB, so wrap with psycopg2 Binary() to quote properly.
return "%s(%s, %s)" % (GEOM_FROM_WKB, Binary(self.wkb), self.srid or -1) return "%s(E%s)" % (GEOM_FROM_EWKB, Binary(self.ewkb))
def prepare_database_save(self, unused): def prepare_database_save(self, unused):
return self return self

View File

@ -63,17 +63,21 @@ if MAJOR_VERSION >= 1:
DISTANCE_SPHERE = get_func('distance_sphere') DISTANCE_SPHERE = get_func('distance_sphere')
DISTANCE_SPHEROID = get_func('distance_spheroid') DISTANCE_SPHEROID = get_func('distance_spheroid')
ENVELOPE = get_func('Envelope') ENVELOPE = get_func('Envelope')
EXTENT = get_func('extent') EXTENT = get_func('Extent')
EXTENT3D = get_func('Extent3D')
GEOM_FROM_TEXT = get_func('GeomFromText') GEOM_FROM_TEXT = get_func('GeomFromText')
GEOM_FROM_EWKB = get_func('GeomFromEWKB')
GEOM_FROM_WKB = get_func('GeomFromWKB') GEOM_FROM_WKB = get_func('GeomFromWKB')
INTERSECTION = get_func('Intersection') INTERSECTION = get_func('Intersection')
LENGTH = get_func('Length') LENGTH = get_func('Length')
LENGTH3D = get_func('Length3D')
LENGTH_SPHEROID = get_func('length_spheroid') LENGTH_SPHEROID = get_func('length_spheroid')
MAKE_LINE = get_func('MakeLine') MAKE_LINE = get_func('MakeLine')
MEM_SIZE = get_func('mem_size') MEM_SIZE = get_func('mem_size')
NUM_GEOM = get_func('NumGeometries') NUM_GEOM = get_func('NumGeometries')
NUM_POINTS = get_func('npoints') NUM_POINTS = get_func('npoints')
PERIMETER = get_func('Perimeter') PERIMETER = get_func('Perimeter')
PERIMETER3D = get_func('Perimeter3D')
POINT_ON_SURFACE = get_func('PointOnSurface') POINT_ON_SURFACE = get_func('PointOnSurface')
SCALE = get_func('Scale') SCALE = get_func('Scale')
SNAP_TO_GRID = get_func('SnapToGrid') SNAP_TO_GRID = get_func('SnapToGrid')

View File

@ -24,6 +24,9 @@ class Collect(GeoAggregate):
class Extent(GeoAggregate): class Extent(GeoAggregate):
name = 'Extent' name = 'Extent'
class Extent3D(GeoAggregate):
name = 'Extent3D'
class MakeLine(GeoAggregate): class MakeLine(GeoAggregate):
name = 'MakeLine' name = 'MakeLine'

View File

@ -34,6 +34,9 @@ class GeoManager(Manager):
def extent(self, *args, **kwargs): def extent(self, *args, **kwargs):
return self.get_query_set().extent(*args, **kwargs) return self.get_query_set().extent(*args, **kwargs)
def extent3d(self, *args, **kwargs):
return self.get_query_set().extent3d(*args, **kwargs)
def geojson(self, *args, **kwargs): def geojson(self, *args, **kwargs):
return self.get_query_set().geojson(*args, **kwargs) return self.get_query_set().geojson(*args, **kwargs)

View File

@ -110,6 +110,14 @@ class GeoQuerySet(QuerySet):
""" """
return self._spatial_aggregate(aggregates.Extent, **kwargs) return self._spatial_aggregate(aggregates.Extent, **kwargs)
def extent3d(self, **kwargs):
"""
Returns the aggregate extent, in 3D, of the features in the
GeoQuerySet. It is returned as a 6-tuple, comprising:
(xmin, ymin, zmin, xmax, ymax, zmax).
"""
return self._spatial_aggregate(aggregates.Extent3D, **kwargs)
def geojson(self, precision=8, crs=False, bbox=False, **kwargs): def geojson(self, precision=8, crs=False, bbox=False, **kwargs):
""" """
Returns a GeoJSON representation of the geomtry field in a `geojson` Returns a GeoJSON representation of the geomtry field in a `geojson`
@ -524,12 +532,14 @@ class GeoQuerySet(QuerySet):
else: else:
dist_att = Distance.unit_attname(geo_field.units_name) dist_att = Distance.unit_attname(geo_field.units_name)
# Shortcut booleans for what distance function we're using. # Shortcut booleans for what distance function we're using and
# whether the geometry field is 3D.
distance = func == 'distance' distance = func == 'distance'
length = func == 'length' length = func == 'length'
perimeter = func == 'perimeter' perimeter = func == 'perimeter'
if not (distance or length or perimeter): if not (distance or length or perimeter):
raise ValueError('Unknown distance function: %s' % func) raise ValueError('Unknown distance function: %s' % func)
geom_3d = geo_field.dim == 3
# The field's get_db_prep_lookup() is used to get any # The field's get_db_prep_lookup() is used to get any
# extra distance parameters. Here we set up the # extra distance parameters. Here we set up the
@ -604,7 +614,7 @@ class GeoQuerySet(QuerySet):
# some error checking is required. # some error checking is required.
if not isinstance(geo_field, PointField): if not isinstance(geo_field, PointField):
raise ValueError('Spherical distance calculation only supported on PointFields.') raise ValueError('Spherical distance calculation only supported on PointFields.')
if not str(SpatialBackend.Geometry(buffer(params[0].wkb)).geom_type) == 'Point': if not str(SpatialBackend.Geometry(buffer(params[0].ewkb)).geom_type) == 'Point':
raise ValueError('Spherical distance calculation only supported with Point Geometry parameters') raise ValueError('Spherical distance calculation only supported with Point Geometry parameters')
# The `function` procedure argument needs to be set differently for # The `function` procedure argument needs to be set differently for
# geodetic distance calculations. # geodetic distance calculations.
@ -617,9 +627,16 @@ class GeoQuerySet(QuerySet):
elif length or perimeter: elif length or perimeter:
procedure_fmt = '%(geo_col)s' procedure_fmt = '%(geo_col)s'
if geodetic and length: if geodetic and length:
# There's no `length_sphere` # There's no `length_sphere`, and `length_spheroid` also
# works on 3D geometries.
procedure_fmt += ',%(spheroid)s' procedure_fmt += ',%(spheroid)s'
procedure_args.update({'function' : SpatialBackend.length_spheroid, 'spheroid' : where[1]}) procedure_args.update({'function' : SpatialBackend.length_spheroid, 'spheroid' : where[1]})
elif geom_3d and SpatialBackend.postgis:
# Use 3D variants of perimeter and length routines on PostGIS.
if perimeter:
procedure_args.update({'function' : SpatialBackend.perimeter3d})
elif length:
procedure_args.update({'function' : SpatialBackend.length3d})
# Setting up the settings for `_spatial_attribute`. # Setting up the settings for `_spatial_attribute`.
s = {'select_field' : DistanceField(dist_att), s = {'select_field' : DistanceField(dist_att),

View File

@ -11,6 +11,9 @@ geo_template = '%(function)s(%(field)s)'
def convert_extent(box): def convert_extent(box):
raise NotImplementedError('Aggregate extent not implemented for this spatial backend.') raise NotImplementedError('Aggregate extent not implemented for this spatial backend.')
def convert_extent3d(box):
raise NotImplementedError('Aggregate 3D extent not implemented for this spatial backend.')
def convert_geom(wkt, geo_field): def convert_geom(wkt, geo_field):
raise NotImplementedError('Aggregate method not implemented for this spatial backend.') raise NotImplementedError('Aggregate method not implemented for this spatial backend.')
@ -23,6 +26,14 @@ if SpatialBackend.postgis:
xmax, ymax = map(float, ur.split()) xmax, ymax = map(float, ur.split())
return (xmin, ymin, xmax, ymax) return (xmin, ymin, xmax, ymax)
def convert_extent3d(box3d):
# Box text will be something like "BOX3D(-90.0 30.0 1, -85.0 40.0 2)";
# parsing out and returning as a 4-tuple.
ll, ur = box3d[6:-1].split(',')
xmin, ymin, zmin = map(float, ll.split())
xmax, ymax, zmax = map(float, ur.split())
return (xmin, ymin, zmin, xmax, ymax, zmax)
def convert_geom(hex, geo_field): def convert_geom(hex, geo_field):
if hex: return SpatialBackend.Geometry(hex) if hex: return SpatialBackend.Geometry(hex)
else: return None else: return None
@ -94,7 +105,7 @@ class Collect(GeoAggregate):
sql_function = SpatialBackend.collect sql_function = SpatialBackend.collect
class Extent(GeoAggregate): class Extent(GeoAggregate):
is_extent = True is_extent = '2D'
sql_function = SpatialBackend.extent sql_function = SpatialBackend.extent
if SpatialBackend.oracle: if SpatialBackend.oracle:
@ -102,6 +113,10 @@ if SpatialBackend.oracle:
Extent.conversion_class = GeomField Extent.conversion_class = GeomField
Extent.sql_template = '%(function)s(%(field)s)' Extent.sql_template = '%(function)s(%(field)s)'
class Extent3D(GeoAggregate):
is_extent = '3D'
sql_function = SpatialBackend.extent3d
class MakeLine(GeoAggregate): class MakeLine(GeoAggregate):
conversion_class = GeomField conversion_class = GeomField
sql_function = SpatialBackend.make_line sql_function = SpatialBackend.make_line

View File

@ -262,7 +262,10 @@ class GeoQuery(sql.Query):
""" """
if isinstance(aggregate, self.aggregates_module.GeoAggregate): if isinstance(aggregate, self.aggregates_module.GeoAggregate):
if aggregate.is_extent: if aggregate.is_extent:
return self.aggregates_module.convert_extent(value) if aggregate.is_extent == '3D':
return self.aggregates_module.convert_extent3d(value)
else:
return self.aggregates_module.convert_extent(value)
else: else:
return self.aggregates_module.convert_geom(value, aggregate.source) return self.aggregates_module.convert_geom(value, aggregate.source)
else: else:

View File

@ -179,10 +179,17 @@ class OGRGeometry(GDALBase):
"Returns 0 for points, 1 for lines, and 2 for surfaces." "Returns 0 for points, 1 for lines, and 2 for surfaces."
return capi.get_dims(self.ptr) return capi.get_dims(self.ptr)
@property def _get_coord_dim(self):
def coord_dim(self):
"Returns the coordinate dimension of the Geometry." "Returns the coordinate dimension of the Geometry."
return capi.get_coord_dims(self.ptr) return capi.get_coord_dim(self.ptr)
def _set_coord_dim(self, dim):
"Sets the coordinate dimension of this Geometry."
if not dim in (2, 3):
raise ValueError('Geometry dimension must be either 2 or 3')
capi.set_coord_dim(self.ptr, dim)
coord_dim = property(_get_coord_dim, _set_coord_dim)
@property @property
def geom_count(self): def geom_count(self):
@ -207,13 +214,7 @@ class OGRGeometry(GDALBase):
@property @property
def geom_type(self): def geom_type(self):
"Returns the Type for this Geometry." "Returns the Type for this Geometry."
try: return OGRGeomType(capi.get_geom_type(self.ptr))
return OGRGeomType(capi.get_geom_type(self.ptr))
except OGRException:
# VRT datasources return an invalid geometry type
# number, but a valid name -- we'll try that instead.
# See: http://trac.osgeo.org/gdal/ticket/2491
return OGRGeomType(capi.get_geom_name(self.ptr))
@property @property
def geom_name(self): def geom_name(self):
@ -249,11 +250,15 @@ class OGRGeometry(GDALBase):
def _set_srs(self, srs): def _set_srs(self, srs):
"Sets the SpatialReference for this geometry." "Sets the SpatialReference for this geometry."
# Do not have to clone the `SpatialReference` object pointer because
# when it is assigned to this `OGRGeometry` it's internal OGR
# reference count is incremented, and will likewise be released
# (decremented) when this geometry's destructor is called.
if isinstance(srs, SpatialReference): if isinstance(srs, SpatialReference):
srs_ptr = srs_api.clone_srs(srs.ptr) srs_ptr = srs.ptr
elif isinstance(srs, (int, long, basestring)): elif isinstance(srs, (int, long, basestring)):
sr = SpatialReference(srs) sr = SpatialReference(srs)
srs_ptr = srs_api.clone_srs(sr.ptr) srs_ptr = sr.ptr
else: else:
raise TypeError('Cannot assign spatial reference with object of type: %s' % type(srs)) raise TypeError('Cannot assign spatial reference with object of type: %s' % type(srs))
capi.assign_srs(self.ptr, srs_ptr) capi.assign_srs(self.ptr, srs_ptr)
@ -363,6 +368,16 @@ class OGRGeometry(GDALBase):
klone = self.clone() klone = self.clone()
klone.transform(coord_trans) klone.transform(coord_trans)
return klone return klone
# Have to get the coordinate dimension of the original geometry
# so it can be used to reset the transformed geometry's dimension
# afterwards. This is done because of GDAL bug (in versions prior
# to 1.7) that turns geometries 3D after transformation, see:
# http://trac.osgeo.org/gdal/changeset/17792
orig_dim = self.coord_dim
# Depending on the input type, use the appropriate OGR routine
# to perform the transformation.
if isinstance(coord_trans, CoordTransform): if isinstance(coord_trans, CoordTransform):
capi.geom_transform(self.ptr, coord_trans.ptr) capi.geom_transform(self.ptr, coord_trans.ptr)
elif isinstance(coord_trans, SpatialReference): elif isinstance(coord_trans, SpatialReference):
@ -373,6 +388,10 @@ class OGRGeometry(GDALBase):
else: else:
raise TypeError('Transform only accepts CoordTransform, SpatialReference, string, and integer objects.') raise TypeError('Transform only accepts CoordTransform, SpatialReference, string, and integer objects.')
# Setting with original dimension, see comment above.
if self.coord_dim != orig_dim:
self.coord_dim = orig_dim
def transform_to(self, srs): def transform_to(self, srs):
"For backwards-compatibility." "For backwards-compatibility."
self.transform(srs) self.transform(srs)
@ -659,4 +678,11 @@ GEO_CLASSES = {1 : Point,
6 : MultiPolygon, 6 : MultiPolygon,
7 : GeometryCollection, 7 : GeometryCollection,
101: LinearRing, 101: LinearRing,
1 + OGRGeomType.wkb25bit : Point,
2 + OGRGeomType.wkb25bit : LineString,
3 + OGRGeomType.wkb25bit : Polygon,
4 + OGRGeomType.wkb25bit : MultiPoint,
5 + OGRGeomType.wkb25bit : MultiLineString,
6 + OGRGeomType.wkb25bit : MultiPolygon,
7 + OGRGeomType.wkb25bit : GeometryCollection,
} }

View File

@ -4,6 +4,8 @@ from django.contrib.gis.gdal.error import OGRException
class OGRGeomType(object): class OGRGeomType(object):
"Encapulates OGR Geometry Types." "Encapulates OGR Geometry Types."
wkb25bit = -2147483648
# Dictionary of acceptable OGRwkbGeometryType s and their string names. # Dictionary of acceptable OGRwkbGeometryType s and their string names.
_types = {0 : 'Unknown', _types = {0 : 'Unknown',
1 : 'Point', 1 : 'Point',
@ -15,6 +17,13 @@ class OGRGeomType(object):
7 : 'GeometryCollection', 7 : 'GeometryCollection',
100 : 'None', 100 : 'None',
101 : 'LinearRing', 101 : 'LinearRing',
1 + wkb25bit: 'Point25D',
2 + wkb25bit: 'LineString25D',
3 + wkb25bit: 'Polygon25D',
4 + wkb25bit: 'MultiPoint25D',
5 + wkb25bit : 'MultiLineString25D',
6 + wkb25bit : 'MultiPolygon25D',
7 + wkb25bit : 'GeometryCollection25D',
} }
# Reverse type dictionary, keyed by lower-case of the name. # Reverse type dictionary, keyed by lower-case of the name.
_str_types = dict([(v.lower(), k) for k, v in _types.items()]) _str_types = dict([(v.lower(), k) for k, v in _types.items()])
@ -68,7 +77,7 @@ class OGRGeomType(object):
@property @property
def django(self): def django(self):
"Returns the Django GeometryField for this OGR Type." "Returns the Django GeometryField for this OGR Type."
s = self.name s = self.name.replace('25D', '')
if s in ('LinearRing', 'None'): if s in ('LinearRing', 'None'):
return None return None
elif s == 'Unknown': elif s == 'Unknown':

View File

@ -1,5 +1,5 @@
# Needed ctypes routines # Needed ctypes routines
from ctypes import byref from ctypes import c_double, byref
# Other GDAL imports. # Other GDAL imports.
from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.base import GDALBase
@ -7,11 +7,12 @@ from django.contrib.gis.gdal.envelope import Envelope, OGREnvelope
from django.contrib.gis.gdal.error import OGRException, OGRIndexError, SRSException from django.contrib.gis.gdal.error import OGRException, OGRIndexError, SRSException
from django.contrib.gis.gdal.feature import Feature from django.contrib.gis.gdal.feature import Feature
from django.contrib.gis.gdal.field import OGRFieldTypes from django.contrib.gis.gdal.field import OGRFieldTypes
from django.contrib.gis.gdal.geometries import OGRGeomType from django.contrib.gis.gdal.geomtype import OGRGeomType
from django.contrib.gis.gdal.geometries import OGRGeometry
from django.contrib.gis.gdal.srs import SpatialReference from django.contrib.gis.gdal.srs import SpatialReference
# GDAL ctypes function prototypes. # GDAL ctypes function prototypes.
from django.contrib.gis.gdal.prototypes import ds as capi, srs as srs_api from django.contrib.gis.gdal.prototypes import ds as capi, geom as geom_api, srs as srs_api
# For more information, see the OGR C API source code: # For more information, see the OGR C API source code:
# http://www.gdal.org/ogr/ogr__api_8h.html # http://www.gdal.org/ogr/ogr__api_8h.html
@ -156,6 +157,29 @@ class Layer(GDALBase):
return [capi.get_field_precision(capi.get_field_defn(self._ldefn, i)) return [capi.get_field_precision(capi.get_field_defn(self._ldefn, i))
for i in xrange(self.num_fields)] for i in xrange(self.num_fields)]
def _get_spatial_filter(self):
try:
return OGRGeometry(geom_api.clone_geom(capi.get_spatial_filter(self.ptr)))
except OGRException:
return None
def _set_spatial_filter(self, filter):
if isinstance(filter, OGRGeometry):
capi.set_spatial_filter(self.ptr, filter.ptr)
elif isinstance(filter, (tuple, list)):
if not len(filter) == 4:
raise ValueError('Spatial filter list/tuple must have 4 elements.')
# Map c_double onto params -- if a bad type is passed in it
# will be caught here.
xmin, ymin, xmax, ymax = map(c_double, filter)
capi.set_spatial_filter_rect(self.ptr, xmin, ymin, xmax, ymax)
elif filter is None:
capi.set_spatial_filter(self.ptr, None)
else:
raise TypeError('Spatial filter must be either an OGRGeometry instance, a 4-tuple, or None.')
spatial_filter = property(_get_spatial_filter, _set_spatial_filter)
#### Layer Methods #### #### Layer Methods ####
def get_fields(self, field_name): def get_fields(self, field_name):
""" """

View File

@ -3,7 +3,7 @@
related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*, related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*,
OGR_Fld_* routines are relevant here. OGR_Fld_* routines are relevant here.
""" """
from ctypes import c_char_p, c_int, c_long, c_void_p, POINTER from ctypes import c_char_p, c_double, c_int, c_long, c_void_p, POINTER
from django.contrib.gis.gdal.envelope import OGREnvelope from django.contrib.gis.gdal.envelope import OGREnvelope
from django.contrib.gis.gdal.libgdal import lgdal from django.contrib.gis.gdal.libgdal import lgdal
from django.contrib.gis.gdal.prototypes.generation import \ from django.contrib.gis.gdal.prototypes.generation import \
@ -38,6 +38,9 @@ get_layer_srs = srs_output(lgdal.OGR_L_GetSpatialRef, [c_void_p])
get_next_feature = voidptr_output(lgdal.OGR_L_GetNextFeature, [c_void_p]) get_next_feature = voidptr_output(lgdal.OGR_L_GetNextFeature, [c_void_p])
reset_reading = void_output(lgdal.OGR_L_ResetReading, [c_void_p], errcheck=False) reset_reading = void_output(lgdal.OGR_L_ResetReading, [c_void_p], errcheck=False)
test_capability = int_output(lgdal.OGR_L_TestCapability, [c_void_p, c_char_p]) test_capability = int_output(lgdal.OGR_L_TestCapability, [c_void_p, c_char_p])
get_spatial_filter = geom_output(lgdal.OGR_L_GetSpatialFilter, [c_void_p])
set_spatial_filter = void_output(lgdal.OGR_L_SetSpatialFilter, [c_void_p, c_void_p], errcheck=False)
set_spatial_filter_rect = void_output(lgdal.OGR_L_SetSpatialFilterRect, [c_void_p, c_double, c_double, c_double, c_double], errcheck=False)
### Feature Definition Routines ### ### Feature Definition Routines ###
get_fd_geom_type = int_output(lgdal.OGR_FD_GetGeomType, [c_void_p]) get_fd_geom_type = int_output(lgdal.OGR_FD_GetGeomType, [c_void_p])

View File

@ -83,7 +83,8 @@ get_geom_srs = srs_output(lgdal.OGR_G_GetSpatialReference, [c_void_p])
get_area = double_output(lgdal.OGR_G_GetArea, [c_void_p]) get_area = double_output(lgdal.OGR_G_GetArea, [c_void_p])
get_centroid = void_output(lgdal.OGR_G_Centroid, [c_void_p, c_void_p]) get_centroid = void_output(lgdal.OGR_G_Centroid, [c_void_p, c_void_p])
get_dims = int_output(lgdal.OGR_G_GetDimension, [c_void_p]) get_dims = int_output(lgdal.OGR_G_GetDimension, [c_void_p])
get_coord_dims = int_output(lgdal.OGR_G_GetCoordinateDimension, [c_void_p]) get_coord_dim = int_output(lgdal.OGR_G_GetCoordinateDimension, [c_void_p])
set_coord_dim = void_output(lgdal.OGR_G_SetCoordinateDimension, [c_void_p, c_int], errcheck=False)
get_geom_count = int_output(lgdal.OGR_G_GetGeometryCount, [c_void_p]) get_geom_count = int_output(lgdal.OGR_G_GetGeometryCount, [c_void_p])
get_geom_name = const_string_output(lgdal.OGR_G_GetGeometryName, [c_void_p]) get_geom_name = const_string_output(lgdal.OGR_G_GetGeometryName, [c_void_p])

View File

@ -1,13 +1,11 @@
import os, os.path, unittest import os, os.path, unittest
from django.contrib.gis.gdal import DataSource, Envelope, OGRException, OGRIndexError from django.contrib.gis.gdal import DataSource, Envelope, OGRGeometry, OGRException, OGRIndexError
from django.contrib.gis.gdal.field import OFTReal, OFTInteger, OFTString from django.contrib.gis.gdal.field import OFTReal, OFTInteger, OFTString
from django.contrib import gis from django.contrib import gis
# Path for SHP files # Path for SHP files
data_path = os.path.join(os.path.dirname(gis.__file__), 'tests' + os.sep + 'data') data_path = os.path.join(os.path.dirname(gis.__file__), 'tests' + os.sep + 'data')
def get_ds_file(name, ext): def get_ds_file(name, ext):
return os.sep.join([data_path, name, name + '.%s' % ext]) return os.sep.join([data_path, name, name + '.%s' % ext])
# Test SHP data source object # Test SHP data source object
@ -25,7 +23,7 @@ ds_list = (TestDS('test_point', nfeat=5, nfld=3, geom='POINT', gtype=1, driver='
srs_wkt='GEOGCS["GCS_WGS_1984",DATUM["WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]', srs_wkt='GEOGCS["GCS_WGS_1984",DATUM["WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]',
field_values={'dbl' : [float(i) for i in range(1, 6)], 'int' : range(1, 6), 'str' : [str(i) for i in range(1, 6)]}, field_values={'dbl' : [float(i) for i in range(1, 6)], 'int' : range(1, 6), 'str' : [str(i) for i in range(1, 6)]},
fids=range(5)), fids=range(5)),
TestDS('test_vrt', ext='vrt', nfeat=3, nfld=3, geom='POINT', gtype=1, driver='VRT', TestDS('test_vrt', ext='vrt', nfeat=3, nfld=3, geom='POINT', gtype='Point25D', driver='VRT',
fields={'POINT_X' : OFTString, 'POINT_Y' : OFTString, 'NUM' : OFTString}, # VRT uses CSV, which all types are OFTString. fields={'POINT_X' : OFTString, 'POINT_Y' : OFTString, 'NUM' : OFTString}, # VRT uses CSV, which all types are OFTString.
extent=(1.0, 2.0, 100.0, 523.5), # Min/Max from CSV extent=(1.0, 2.0, 100.0, 523.5), # Min/Max from CSV
field_values={'POINT_X' : ['1.0', '5.0', '100.0'], 'POINT_Y' : ['2.0', '23.0', '523.5'], 'NUM' : ['5', '17', '23']}, field_values={'POINT_X' : ['1.0', '5.0', '100.0'], 'POINT_Y' : ['2.0', '23.0', '523.5'], 'NUM' : ['5', '17', '23']},
@ -191,6 +189,40 @@ class DataSourceTest(unittest.TestCase):
if hasattr(source, 'srs_wkt'): if hasattr(source, 'srs_wkt'):
self.assertEqual(source.srs_wkt, g.srs.wkt) self.assertEqual(source.srs_wkt, g.srs.wkt)
def test06_spatial_filter(self):
"Testing the Layer.spatial_filter property."
ds = DataSource(get_ds_file('cities', 'shp'))
lyr = ds[0]
# When not set, it should be None.
self.assertEqual(None, lyr.spatial_filter)
# Must be set a/an OGRGeometry or 4-tuple.
self.assertRaises(TypeError, lyr._set_spatial_filter, 'foo')
# Setting the spatial filter with a tuple/list with the extent of
# a buffer centering around Pueblo.
self.assertRaises(ValueError, lyr._set_spatial_filter, range(5))
filter_extent = (-105.609252, 37.255001, -103.609252, 39.255001)
lyr.spatial_filter = (-105.609252, 37.255001, -103.609252, 39.255001)
self.assertEqual(OGRGeometry.from_bbox(filter_extent), lyr.spatial_filter)
feats = [feat for feat in lyr]
self.assertEqual(1, len(feats))
self.assertEqual('Pueblo', feats[0].get('Name'))
# Setting the spatial filter with an OGRGeometry for buffer centering
# around Houston.
filter_geom = OGRGeometry('POLYGON((-96.363151 28.763374,-94.363151 28.763374,-94.363151 30.763374,-96.363151 30.763374,-96.363151 28.763374))')
lyr.spatial_filter = filter_geom
self.assertEqual(filter_geom, lyr.spatial_filter)
feats = [feat for feat in lyr]
self.assertEqual(1, len(feats))
self.assertEqual('Houston', feats[0].get('Name'))
# Clearing the spatial filter by setting it to None. Now
# should indicate that there are 3 features in the Layer.
lyr.spatial_filter = None
self.assertEqual(3, len(lyr))
def suite(): def suite():
s = unittest.TestSuite() s = unittest.TestSuite()

View File

@ -46,6 +46,13 @@ class OGRGeomTest(unittest.TestCase):
self.assertEqual(0, gt.num) self.assertEqual(0, gt.num)
self.assertEqual('Unknown', gt.name) self.assertEqual('Unknown', gt.name)
def test00b_geomtype_25d(self):
"Testing OGRGeomType object with 25D types."
wkb25bit = OGRGeomType.wkb25bit
self.failUnless(OGRGeomType(wkb25bit + 1) == 'Point25D')
self.failUnless(OGRGeomType('MultiLineString25D') == (5 + wkb25bit))
self.assertEqual('GeometryCollectionField', OGRGeomType('GeometryCollection25D').django)
def test01a_wkt(self): def test01a_wkt(self):
"Testing WKT output." "Testing WKT output."
for g in wkt_out: for g in wkt_out:
@ -319,6 +326,18 @@ class OGRGeomTest(unittest.TestCase):
self.assertAlmostEqual(trans.x, p.x, prec) self.assertAlmostEqual(trans.x, p.x, prec)
self.assertAlmostEqual(trans.y, p.y, prec) self.assertAlmostEqual(trans.y, p.y, prec)
def test09c_transform_dim(self):
"Testing coordinate dimension is the same on transformed geometries."
ls_orig = OGRGeometry('LINESTRING(-104.609 38.255)', 4326)
ls_trans = OGRGeometry('LINESTRING(992385.4472045 481455.4944650)', 2774)
prec = 3
ls_orig.transform(ls_trans.srs)
# Making sure the coordinate dimension is still 2D.
self.assertEqual(2, ls_orig.coord_dim)
self.assertAlmostEqual(ls_trans.x[0], ls_orig.x[0], prec)
self.assertAlmostEqual(ls_trans.y[0], ls_orig.y[0], prec)
def test10_difference(self): def test10_difference(self):
"Testing difference()." "Testing difference()."
for i in xrange(len(topology_geoms)): for i in xrange(len(topology_geoms)):
@ -406,6 +425,17 @@ class OGRGeomTest(unittest.TestCase):
xmax, ymax = max(x), max(y) xmax, ymax = max(x), max(y)
self.assertEqual((xmin, ymin, xmax, ymax), poly.extent) self.assertEqual((xmin, ymin, xmax, ymax), poly.extent)
def test16_25D(self):
"Testing 2.5D geometries."
pnt_25d = OGRGeometry('POINT(1 2 3)')
self.assertEqual('Point25D', pnt_25d.geom_type.name)
self.assertEqual(3.0, pnt_25d.z)
self.assertEqual(3, pnt_25d.coord_dim)
ls_25d = OGRGeometry('LINESTRING(1 1 1,2 2 2,3 3 3)')
self.assertEqual('LineString25D', ls_25d.geom_type.name)
self.assertEqual([1.0, 2.0, 3.0], ls_25d.z)
self.assertEqual(3, ls_25d.coord_dim)
def suite(): def suite():
s = unittest.TestSuite() s = unittest.TestSuite()
s.addTest(unittest.makeSuite(OGRGeomTest)) s.addTest(unittest.makeSuite(OGRGeomTest))

View File

@ -357,26 +357,46 @@ class GEOSGeometry(GEOSBase, ListMixin):
#### Output Routines #### #### Output Routines ####
@property @property
def ewkt(self): def ewkt(self):
"Returns the EWKT (WKT + SRID) of the Geometry." """
Returns the EWKT (WKT + SRID) of the Geometry. Note that Z values
are *not* included in this representation because GEOS does not yet
support serializing them.
"""
if self.get_srid(): return 'SRID=%s;%s' % (self.srid, self.wkt) if self.get_srid(): return 'SRID=%s;%s' % (self.srid, self.wkt)
else: return self.wkt else: return self.wkt
@property @property
def wkt(self): def wkt(self):
"Returns the WKT (Well-Known Text) of the Geometry." "Returns the WKT (Well-Known Text) representation of this Geometry."
return wkt_w.write(self) return wkt_w.write(self)
@property @property
def hex(self): def hex(self):
""" """
Returns the HEX of the Geometry -- please note that the SRID is not Returns the WKB of this Geometry in hexadecimal form. Please note
included in this representation, because the GEOS C library uses that the SRID and Z values are not included in this representation
-1 by default, even if the SRID is set. because it is not a part of the OGC specification (use the `hexewkb`
property instead).
""" """
# A possible faster, all-python, implementation: # A possible faster, all-python, implementation:
# str(self.wkb).encode('hex') # str(self.wkb).encode('hex')
return wkb_w.write_hex(self) return wkb_w.write_hex(self)
@property
def hexewkb(self):
"""
Returns the EWKB of this Geometry in hexadecimal form. This is an
extension of the WKB specification that includes SRID and Z values
that are a part of this geometry.
"""
if self.hasz:
if not GEOS_PREPARE:
# See: http://trac.osgeo.org/geos/ticket/216
raise GEOSException('Upgrade GEOS to 3.1 to get valid 3D HEXEWKB.')
return ewkb_w3d.write_hex(self)
else:
return ewkb_w.write_hex(self)
@property @property
def json(self): def json(self):
""" """
@ -391,9 +411,28 @@ class GEOSGeometry(GEOSBase, ListMixin):
@property @property
def wkb(self): def wkb(self):
"Returns the WKB of the Geometry as a buffer." """
Returns the WKB (Well-Known Binary) representation of this Geometry
as a Python buffer. SRID and Z values are not included, use the
`ewkb` property instead.
"""
return wkb_w.write(self) return wkb_w.write(self)
@property
def ewkb(self):
"""
Return the EWKB representation of this Geometry as a Python buffer.
This is an extension of the WKB specification that includes any SRID
and Z values that are a part of this geometry.
"""
if self.hasz:
if not GEOS_PREPARE:
# See: http://trac.osgeo.org/geos/ticket/216
raise GEOSException('Upgrade GEOS to 3.1 to get valid 3D EWKB.')
return ewkb_w3d.write(self)
else:
return ewkb_w.write(self)
@property @property
def kml(self): def kml(self):
"Returns the KML representation of this Geometry." "Returns the KML representation of this Geometry."
@ -617,7 +656,7 @@ GEOS_CLASSES = {0 : Point,
} }
# Similarly, import the GEOS I/O instances here to avoid conflicts. # Similarly, import the GEOS I/O instances here to avoid conflicts.
from django.contrib.gis.geos.io import wkt_r, wkt_w, wkb_r, wkb_w from django.contrib.gis.geos.io import wkt_r, wkt_w, wkb_r, wkb_w, ewkb_w, ewkb_w3d
# If supported, import the PreparedGeometry class. # If supported, import the PreparedGeometry class.
if GEOS_PREPARE: if GEOS_PREPARE:

View File

@ -14,19 +14,19 @@ class IOBase(GEOSBase):
"Base class for GEOS I/O objects." "Base class for GEOS I/O objects."
def __init__(self): def __init__(self):
# Getting the pointer with the constructor. # Getting the pointer with the constructor.
self.ptr = self.constructor() self.ptr = self._constructor()
def __del__(self): def __del__(self):
# Cleaning up with the appropriate destructor. # Cleaning up with the appropriate destructor.
if self._ptr: self.destructor(self._ptr) if self._ptr: self._destructor(self._ptr)
### WKT Reading and Writing objects ### ### WKT Reading and Writing objects ###
# Non-public class for internal use because its `read` method returns # Non-public class for internal use because its `read` method returns
# _pointers_ instead of a GEOSGeometry object. # _pointers_ instead of a GEOSGeometry object.
class _WKTReader(IOBase): class _WKTReader(IOBase):
constructor = capi.wkt_reader_create _constructor = capi.wkt_reader_create
destructor = capi.wkt_reader_destroy _destructor = capi.wkt_reader_destroy
ptr_type = capi.WKT_READ_PTR ptr_type = capi.WKT_READ_PTR
def read(self, wkt): def read(self, wkt):
@ -39,8 +39,8 @@ class WKTReader(_WKTReader):
return GEOSGeometry(super(WKTReader, self).read(wkt)) return GEOSGeometry(super(WKTReader, self).read(wkt))
class WKTWriter(IOBase): class WKTWriter(IOBase):
constructor = capi.wkt_writer_create _constructor = capi.wkt_writer_create
destructor = capi.wkt_writer_destroy _destructor = capi.wkt_writer_destroy
ptr_type = capi.WKT_WRITE_PTR ptr_type = capi.WKT_WRITE_PTR
def write(self, geom): def write(self, geom):
@ -51,8 +51,8 @@ class WKTWriter(IOBase):
# Non-public class for the same reason as _WKTReader above. # Non-public class for the same reason as _WKTReader above.
class _WKBReader(IOBase): class _WKBReader(IOBase):
constructor = capi.wkb_reader_create _constructor = capi.wkb_reader_create
destructor = capi.wkb_reader_destroy _destructor = capi.wkb_reader_destroy
ptr_type = capi.WKB_READ_PTR ptr_type = capi.WKB_READ_PTR
def read(self, wkb): def read(self, wkb):
@ -71,8 +71,8 @@ class WKBReader(_WKBReader):
return GEOSGeometry(super(WKBReader, self).read(wkb)) return GEOSGeometry(super(WKBReader, self).read(wkb))
class WKBWriter(IOBase): class WKBWriter(IOBase):
constructor = capi.wkb_writer_create _constructor = capi.wkb_writer_create
destructor = capi.wkb_writer_destroy _destructor = capi.wkb_writer_destroy
ptr_type = capi.WKB_WRITE_PTR ptr_type = capi.WKB_WRITE_PTR
def write(self, geom): def write(self, geom):
@ -121,3 +121,10 @@ wkt_r = _WKTReader()
wkt_w = WKTWriter() wkt_w = WKTWriter()
wkb_r = _WKBReader() wkb_r = _WKBReader()
wkb_w = WKBWriter() wkb_w = WKBWriter()
# These instances are for writing EWKB in 2D and 3D.
ewkb_w = WKBWriter()
ewkb_w.srid = True
ewkb_w3d = WKBWriter()
ewkb_w3d.srid = True
ewkb_w3d.outdim = 3

View File

@ -62,17 +62,16 @@ def string_from_geom(func):
### ctypes prototypes ### ### ctypes prototypes ###
# Deprecated creation routines from WKB, HEX, WKT # Deprecated creation and output routines from WKB, HEX, WKT
from_hex = bin_constructor(lgeos.GEOSGeomFromHEX_buf) from_hex = bin_constructor(lgeos.GEOSGeomFromHEX_buf)
from_wkb = bin_constructor(lgeos.GEOSGeomFromWKB_buf) from_wkb = bin_constructor(lgeos.GEOSGeomFromWKB_buf)
from_wkt = geom_output(lgeos.GEOSGeomFromWKT, [c_char_p]) from_wkt = geom_output(lgeos.GEOSGeomFromWKT, [c_char_p])
# Output routines
to_hex = bin_output(lgeos.GEOSGeomToHEX_buf) to_hex = bin_output(lgeos.GEOSGeomToHEX_buf)
to_wkb = bin_output(lgeos.GEOSGeomToWKB_buf) to_wkb = bin_output(lgeos.GEOSGeomToWKB_buf)
to_wkt = string_from_geom(lgeos.GEOSGeomToWKT) to_wkt = string_from_geom(lgeos.GEOSGeomToWKT)
# The GEOS geometry type, typeid, num_coordites and number of geometries # The GEOS geometry type, typeid, num_coordinates and number of geometries
geos_normalize = int_from_geom(lgeos.GEOSNormalize) geos_normalize = int_from_geom(lgeos.GEOSNormalize)
geos_type = string_from_geom(lgeos.GEOSGeomType) geos_type = string_from_geom(lgeos.GEOSGeomType)
geos_typeid = int_from_geom(lgeos.GEOSGeomTypeId) geos_typeid = int_from_geom(lgeos.GEOSGeomTypeId)

View File

@ -71,6 +71,49 @@ class GEOSTest(unittest.TestCase):
geom = fromstr(g.wkt) geom = fromstr(g.wkt)
self.assertEqual(g.hex, geom.hex) self.assertEqual(g.hex, geom.hex)
def test01b_hexewkb(self):
"Testing (HEX)EWKB output."
from binascii import a2b_hex
pnt_2d = Point(0, 1, srid=4326)
pnt_3d = Point(0, 1, 2, srid=4326)
# OGC-compliant HEX will not have SRID nor Z value.
self.assertEqual(ogc_hex, pnt_2d.hex)
self.assertEqual(ogc_hex, pnt_3d.hex)
# HEXEWKB should be appropriate for its dimension -- have to use an
# a WKBWriter w/dimension set accordingly, else GEOS will insert
# garbage into 3D coordinate if there is none. Also, GEOS has a
# a bug in versions prior to 3.1 that puts the X coordinate in
# place of Z; an exception should be raised on those versions.
self.assertEqual(hexewkb_2d, pnt_2d.hexewkb)
if GEOS_PREPARE:
self.assertEqual(hexewkb_3d, pnt_3d.hexewkb)
self.assertEqual(True, GEOSGeometry(hexewkb_3d).hasz)
else:
try:
hexewkb = pnt_3d.hexewkb
except GEOSException:
pass
else:
self.fail('Should have raised GEOSException.')
# Same for EWKB.
self.assertEqual(buffer(a2b_hex(hexewkb_2d)), pnt_2d.ewkb)
if GEOS_PREPARE:
self.assertEqual(buffer(a2b_hex(hexewkb_3d)), pnt_3d.ewkb)
else:
try:
ewkb = pnt_3d.ewkb
except GEOSException:
pass
else:
self.fail('Should have raised GEOSException')
# Redundant sanity check.
self.assertEqual(4326, GEOSGeometry(hexewkb_2d).srid)
def test01c_kml(self): def test01c_kml(self):
"Testing KML output." "Testing KML output."
for tg in wkt_out: for tg in wkt_out:

View File

@ -9,9 +9,10 @@ def geo_suite():
some backends). some backends).
""" """
from django.conf import settings from django.conf import settings
from django.contrib.gis.geos import GEOS_PREPARE
from django.contrib.gis.gdal import HAS_GDAL from django.contrib.gis.gdal import HAS_GDAL
from django.contrib.gis.utils import HAS_GEOIP from django.contrib.gis.utils import HAS_GEOIP
from django.contrib.gis.tests.utils import mysql from django.contrib.gis.tests.utils import postgis, mysql
# The test suite. # The test suite.
s = unittest.TestSuite() s = unittest.TestSuite()
@ -32,6 +33,10 @@ def geo_suite():
if not mysql: if not mysql:
test_apps.append('distapp') test_apps.append('distapp')
# Only PostGIS using GEOS 3.1+ can support 3D so far.
if postgis and GEOS_PREPARE:
test_apps.append('geo3d')
if HAS_GDAL: if HAS_GDAL:
# These tests require GDAL. # These tests require GDAL.
test_suite_names.extend(['test_spatialrefsys', 'test_geoforms']) test_suite_names.extend(['test_spatialrefsys', 'test_geoforms'])
@ -164,20 +169,3 @@ def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[], suite=
# Returning the total failures and errors # Returning the total failures and errors
return len(result.failures) + len(result.errors) return len(result.failures) + len(result.errors)
# Class for creating a fake module with a run method. This is for the
# GEOS and GDAL tests that were moved to their respective modules.
class _DeprecatedTestModule(object):
def __init__(self, mod_name):
self.mod_name = mod_name
def run(self):
from warnings import warn
warn('This test module is deprecated because it has moved to ' \
'`django.contrib.gis.%s.tests` and will disappear in 1.2.' %
self.mod_name, DeprecationWarning)
tests = import_module('django.contrib.gis.%s.tests' % self.mod_name)
tests.run()
test_geos = _DeprecatedTestModule('geos')
test_gdal = _DeprecatedTestModule('gdal')

View File

@ -1,7 +1,7 @@
<OGRVRTDataSource> <OGRVRTDataSource>
<OGRVRTLayer name="test_vrt"> <OGRVRTLayer name="test_vrt">
<SrcDataSource relativeToVRT="1">test_vrt.csv</SrcDataSource> <SrcDataSource relativeToVRT="1">test_vrt.csv</SrcDataSource>
<GeometryType>wkbPoint</GeometryType> <GeometryType>wkbPoint25D</GeometryType>
<GeometryField encoding="PointFromColumns" x="POINT_X" y="POINT_Y" z="NUM"/> <GeometryField encoding="PointFromColumns" x="POINT_X" y="POINT_Y" z="NUM"/>
</OGRVRTLayer> </OGRVRTLayer>
</OGRVRTDataSource> </OGRVRTDataSource>

View File

@ -0,0 +1,69 @@
from django.contrib.gis.db import models
class City3D(models.Model):
name = models.CharField(max_length=30)
point = models.PointField(dim=3)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class Interstate2D(models.Model):
name = models.CharField(max_length=30)
line = models.LineStringField(srid=4269)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class Interstate3D(models.Model):
name = models.CharField(max_length=30)
line = models.LineStringField(dim=3, srid=4269)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class InterstateProj2D(models.Model):
name = models.CharField(max_length=30)
line = models.LineStringField(srid=32140)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class InterstateProj3D(models.Model):
name = models.CharField(max_length=30)
line = models.LineStringField(dim=3, srid=32140)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class Polygon2D(models.Model):
name = models.CharField(max_length=30)
poly = models.PolygonField(srid=32140)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class Polygon3D(models.Model):
name = models.CharField(max_length=30)
poly = models.PolygonField(dim=3, srid=32140)
objects = models.GeoManager()
def __unicode__(self):
return self.name
class Point2D(models.Model):
point = models.PointField()
objects = models.GeoManager()
class Point3D(models.Model):
point = models.PointField(dim=3)
objects = models.GeoManager()
class MultiPoint3D(models.Model):
mpoint = models.MultiPointField(dim=3)
objects = models.GeoManager()

View File

@ -0,0 +1,234 @@
import os, re, unittest
from django.contrib.gis.db.models import Union, Extent3D
from django.contrib.gis.geos import GEOSGeometry, Point, Polygon
from django.contrib.gis.utils import LayerMapping, LayerMapError
from models import City3D, Interstate2D, Interstate3D, \
InterstateProj2D, InterstateProj3D, \
Point2D, Point3D, MultiPoint3D, Polygon2D, Polygon3D
data_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', 'data'))
city_file = os.path.join(data_path, 'cities', 'cities.shp')
vrt_file = os.path.join(data_path, 'test_vrt', 'test_vrt.vrt')
# The coordinates of each city, with Z values corresponding to their
# altitude in meters.
city_data = (
('Houston', (-95.363151, 29.763374, 18)),
('Dallas', (-96.801611, 32.782057, 147)),
('Oklahoma City', (-97.521157, 34.464642, 380)),
('Wellington', (174.783117, -41.315268, 14)),
('Pueblo', (-104.609252, 38.255001, 1433)),
('Lawrence', (-95.235060, 38.971823, 251)),
('Chicago', (-87.650175, 41.850385, 181)),
('Victoria', (-123.305196, 48.462611, 15)),
)
# Reference mapping of city name to its altitude (Z value).
city_dict = dict((name, coords) for name, coords in city_data)
# 3D freeway data derived from the National Elevation Dataset:
# http://seamless.usgs.gov/products/9arc.php
interstate_data = (
('I-45',
'LINESTRING(-95.3708481 29.7765870 11.339,-95.3694580 29.7787980 4.536,-95.3690305 29.7797359 9.762,-95.3691886 29.7812450 12.448,-95.3696447 29.7850144 10.457,-95.3702511 29.7868518 9.418,-95.3706724 29.7881286 14.858,-95.3711632 29.7896157 15.386,-95.3714525 29.7936267 13.168,-95.3717848 29.7955007 15.104,-95.3717719 29.7969804 16.516,-95.3717305 29.7982117 13.923,-95.3717254 29.8000778 14.385,-95.3719875 29.8013539 15.160,-95.3720575 29.8026785 15.544,-95.3721321 29.8040912 14.975,-95.3722074 29.8050998 15.688,-95.3722779 29.8060430 16.099,-95.3733818 29.8076750 15.197,-95.3741563 29.8103686 17.268,-95.3749458 29.8129927 19.857,-95.3763564 29.8144557 15.435)',
( 11.339, 4.536, 9.762, 12.448, 10.457, 9.418, 14.858,
15.386, 13.168, 15.104, 16.516, 13.923, 14.385, 15.16 ,
15.544, 14.975, 15.688, 16.099, 15.197, 17.268, 19.857,
15.435),
),
)
# Bounding box polygon for inner-loop of Houston (in projected coordinate
# system 32140), with elevation values from the National Elevation Dataset
# (see above).
bbox_wkt = 'POLYGON((941527.97 4225693.20,962596.48 4226349.75,963152.57 4209023.95,942051.75 4208366.38,941527.97 4225693.20))'
bbox_z = (21.71, 13.21, 9.12, 16.40, 21.71)
def gen_bbox():
bbox_2d = GEOSGeometry(bbox_wkt, srid=32140)
bbox_3d = Polygon(tuple((x, y, z) for (x, y), z in zip(bbox_2d[0].coords, bbox_z)), srid=32140)
return bbox_2d, bbox_3d
class Geo3DTest(unittest.TestCase):
"""
Only a subset of the PostGIS routines are 3D-enabled, and this TestCase
tries to test the features that can handle 3D and that are also
available within GeoDjango. For more information, see the PostGIS docs
on the routines that support 3D:
http://postgis.refractions.net/documentation/manual-1.4/ch08.html#PostGIS_3D_Functions
"""
def test01_3d(self):
"Test the creation of 3D models."
# 3D models for the rest of the tests will be populated in here.
# For each 3D data set create model (and 2D version if necessary),
# retrieve, and assert geometry is in 3D and contains the expected
# 3D values.
for name, pnt_data in city_data:
x, y, z = pnt_data
pnt = Point(x, y, z, srid=4326)
City3D.objects.create(name=name, point=pnt)
city = City3D.objects.get(name=name)
self.failUnless(city.point.hasz)
self.assertEqual(z, city.point.z)
# Interstate (2D / 3D and Geographic/Projected variants)
for name, line, exp_z in interstate_data:
line_3d = GEOSGeometry(line, srid=4269)
# Using `hex` attribute because it omits 3D.
line_2d = GEOSGeometry(line_3d.hex, srid=4269)
# Creating a geographic and projected version of the
# interstate in both 2D and 3D.
Interstate3D.objects.create(name=name, line=line_3d)
InterstateProj3D.objects.create(name=name, line=line_3d)
Interstate2D.objects.create(name=name, line=line_2d)
InterstateProj2D.objects.create(name=name, line=line_2d)
# Retrieving and making sure it's 3D and has expected
# Z values -- shouldn't change because of coordinate system.
interstate = Interstate3D.objects.get(name=name)
interstate_proj = InterstateProj3D.objects.get(name=name)
for i in [interstate, interstate_proj]:
self.failUnless(i.line.hasz)
self.assertEqual(exp_z, tuple(i.line.z))
# Creating 3D Polygon.
bbox2d, bbox3d = gen_bbox()
Polygon2D.objects.create(name='2D BBox', poly=bbox2d)
Polygon3D.objects.create(name='3D BBox', poly=bbox3d)
p3d = Polygon3D.objects.get(name='3D BBox')
self.failUnless(p3d.poly.hasz)
self.assertEqual(bbox3d, p3d.poly)
def test01a_3d_layermapping(self):
"Testing LayerMapping on 3D models."
from models import Point2D, Point3D
point_mapping = {'point' : 'POINT'}
mpoint_mapping = {'mpoint' : 'MULTIPOINT'}
# The VRT is 3D, but should still be able to map sans the Z.
lm = LayerMapping(Point2D, vrt_file, point_mapping, transform=False)
lm.save()
self.assertEqual(3, Point2D.objects.count())
# The city shapefile is 2D, and won't be able to fill the coordinates
# in the 3D model -- thus, a LayerMapError is raised.
self.assertRaises(LayerMapError, LayerMapping,
Point3D, city_file, point_mapping, transform=False)
# 3D model should take 3D data just fine.
lm = LayerMapping(Point3D, vrt_file, point_mapping, transform=False)
lm.save()
self.assertEqual(3, Point3D.objects.count())
# Making sure LayerMapping.make_multi works right, by converting
# a Point25D into a MultiPoint25D.
lm = LayerMapping(MultiPoint3D, vrt_file, mpoint_mapping, transform=False)
lm.save()
self.assertEqual(3, MultiPoint3D.objects.count())
def test02a_kml(self):
"Test GeoQuerySet.kml() with Z values."
h = City3D.objects.kml(precision=6).get(name='Houston')
# KML should be 3D.
# `SELECT ST_AsKML(point, 6) FROM geo3d_city3d WHERE name = 'Houston';`
ref_kml_regex = re.compile(r'^<Point><coordinates>-95.363\d+,29.763\d+,18</coordinates></Point>$')
self.failUnless(ref_kml_regex.match(h.kml))
def test02b_geojson(self):
"Test GeoQuerySet.geojson() with Z values."
h = City3D.objects.geojson(precision=6).get(name='Houston')
# GeoJSON should be 3D
# `SELECT ST_AsGeoJSON(point, 6) FROM geo3d_city3d WHERE name='Houston';`
ref_json_regex = re.compile(r'^{"type":"Point","coordinates":\[-95.363151,29.763374,18(\.0+)?\]}$')
self.failUnless(ref_json_regex.match(h.geojson))
def test03a_union(self):
"Testing the Union aggregate of 3D models."
# PostGIS query that returned the reference EWKT for this test:
# `SELECT ST_AsText(ST_Union(point)) FROM geo3d_city3d;`
ref_ewkt = 'SRID=4326;MULTIPOINT(-123.305196 48.462611 15,-104.609252 38.255001 1433,-97.521157 34.464642 380,-96.801611 32.782057 147,-95.363151 29.763374 18,-95.23506 38.971823 251,-87.650175 41.850385 181,174.783117 -41.315268 14)'
ref_union = GEOSGeometry(ref_ewkt)
union = City3D.objects.aggregate(Union('point'))['point__union']
self.failUnless(union.hasz)
self.assertEqual(ref_union, union)
def test03b_extent(self):
"Testing the Extent3D aggregate for 3D models."
# `SELECT ST_Extent3D(point) FROM geo3d_city3d;`
ref_extent3d = (-123.305196, -41.315268, 14,174.783117, 48.462611, 1433)
extent1 = City3D.objects.aggregate(Extent3D('point'))['point__extent3d']
extent2 = City3D.objects.extent3d()
def check_extent3d(extent3d, tol=6):
for ref_val, ext_val in zip(ref_extent3d, extent3d):
self.assertAlmostEqual(ref_val, ext_val, tol)
for e3d in [extent1, extent2]:
check_extent3d(e3d)
def test04_perimeter(self):
"Testing GeoQuerySet.perimeter() on 3D fields."
# Reference query for values below:
# `SELECT ST_Perimeter3D(poly), ST_Perimeter2D(poly) FROM geo3d_polygon3d;`
ref_perim_3d = 76859.2620451
ref_perim_2d = 76859.2577803
tol = 6
self.assertAlmostEqual(ref_perim_2d,
Polygon2D.objects.perimeter().get(name='2D BBox').perimeter.m,
tol)
self.assertAlmostEqual(ref_perim_3d,
Polygon3D.objects.perimeter().get(name='3D BBox').perimeter.m,
tol)
def test05_length(self):
"Testing GeoQuerySet.length() on 3D fields."
# ST_Length_Spheroid Z-aware, and thus does not need to use
# a separate function internally.
# `SELECT ST_Length_Spheroid(line, 'SPHEROID["GRS 1980",6378137,298.257222101]')
# FROM geo3d_interstate[2d|3d];`
tol = 3
ref_length_2d = 4368.1721949481
ref_length_3d = 4368.62547052088
self.assertAlmostEqual(ref_length_2d,
Interstate2D.objects.length().get(name='I-45').length.m,
tol)
self.assertAlmostEqual(ref_length_3d,
Interstate3D.objects.length().get(name='I-45').length.m,
tol)
# Making sure `ST_Length3D` is used on for a projected
# and 3D model rather than `ST_Length`.
# `SELECT ST_Length(line) FROM geo3d_interstateproj2d;`
ref_length_2d = 4367.71564892392
# `SELECT ST_Length3D(line) FROM geo3d_interstateproj3d;`
ref_length_3d = 4368.16897234101
self.assertAlmostEqual(ref_length_2d,
InterstateProj2D.objects.length().get(name='I-45').length.m,
tol)
self.assertAlmostEqual(ref_length_3d,
InterstateProj3D.objects.length().get(name='I-45').length.m,
tol)
def test06_scale(self):
"Testing GeoQuerySet.scale() on Z values."
# Mapping of City name to reference Z values.
zscales = (-3, 4, 23)
for zscale in zscales:
for city in City3D.objects.scale(1.0, 1.0, zscale):
self.assertEqual(city_dict[city.name][2] * zscale, city.scale.z)
def test07_translate(self):
"Testing GeoQuerySet.translate() on Z values."
ztranslations = (5.23, 23, -17)
for ztrans in ztranslations:
for city in City3D.objects.translate(0, 0, ztrans):
self.assertEqual(city_dict[city.name][2] + ztrans, city.translate.z)
def suite():
s = unittest.TestSuite()
s.addTest(unittest.makeSuite(Geo3DTest))
return s

View File

@ -0,0 +1 @@
# Create your views here.

View File

@ -28,8 +28,11 @@ class GeoRegressionTests(unittest.TestCase):
kmz = render_to_kmz('gis/kml/placemarks.kml', {'places' : places}) kmz = render_to_kmz('gis/kml/placemarks.kml', {'places' : places})
@no_spatialite @no_spatialite
@no_mysql
def test03_extent(self): def test03_extent(self):
"Testing `extent` on a table with a single point, see #11827." "Testing `extent` on a table with a single point, see #11827."
pnt = City.objects.get(name='Pueblo').point pnt = City.objects.get(name='Pueblo').point
ref_ext = (pnt.x, pnt.y, pnt.x, pnt.y) ref_ext = (pnt.x, pnt.y, pnt.x, pnt.y)
self.assertEqual(ref_ext, City.objects.filter(name='Pueblo').extent()) extent = City.objects.filter(name='Pueblo').extent()
for ref_val, val in zip(ref_ext, extent):
self.assertAlmostEqual(ref_val, val, 4)

View File

@ -171,3 +171,10 @@ json_geoms = (TestGeom('POINT(100 0)', json='{ "type": "Point", "coordinates": [
not_equal=True, not_equal=True,
), ),
) )
# For testing HEX(EWKB).
ogc_hex = '01010000000000000000000000000000000000F03F'
# `SELECT ST_AsHEXEWKB(ST_GeomFromText('POINT(0 1)', 4326));`
hexewkb_2d = '0101000020E61000000000000000000000000000000000F03F'
# `SELECT ST_AsHEXEWKB(ST_GeomFromEWKT('SRID=4326;POINT(0 1 2)'));`
hexewkb_3d = '01010000A0E61000000000000000000000000000000000F03F0000000000000040'

View File

@ -29,6 +29,20 @@ class Interstate(models.Model):
path = models.LineStringField() path = models.LineStringField()
objects = models.GeoManager() objects = models.GeoManager()
# Same as `City` above, but for testing model inheritance.
class CityBase(models.Model):
name = models.CharField(max_length=25)
population = models.IntegerField()
density = models.DecimalField(max_digits=7, decimal_places=1)
point = models.PointField()
objects = models.GeoManager()
class ICity1(CityBase):
dt = models.DateField()
class ICity2(ICity1):
dt_time = models.DateTimeField(auto_now=True)
# Mapping dictionaries for the models above. # Mapping dictionaries for the models above.
co_mapping = {'name' : 'Name', co_mapping = {'name' : 'Name',
'state' : {'name' : 'State'}, # ForeignKey's use another mapping dictionary for the _related_ Model (State in this case). 'state' : {'name' : 'State'}, # ForeignKey's use another mapping dictionary for the _related_ Model (State in this case).

View File

@ -1,7 +1,7 @@
import os, unittest import os, unittest
from copy import copy from copy import copy
from decimal import Decimal from decimal import Decimal
from models import City, County, CountyFeat, Interstate, State, city_mapping, co_mapping, cofeat_mapping, inter_mapping from models import City, County, CountyFeat, Interstate, ICity1, ICity2, State, city_mapping, co_mapping, cofeat_mapping, inter_mapping
from django.contrib.gis.db.backend import SpatialBackend from django.contrib.gis.db.backend import SpatialBackend
from django.contrib.gis.utils.layermapping import LayerMapping, LayerMapError, InvalidDecimal, MissingForeignKey from django.contrib.gis.utils.layermapping import LayerMapping, LayerMapError, InvalidDecimal, MissingForeignKey
from django.contrib.gis.gdal import DataSource from django.contrib.gis.gdal import DataSource
@ -242,6 +242,26 @@ class LayerMapTest(unittest.TestCase):
lm.save(step=st, strict=True) lm.save(step=st, strict=True)
self.county_helper(county_feat=False) self.county_helper(county_feat=False)
def test06_model_inheritance(self):
"Tests LayerMapping on inherited models. See #12093."
icity_mapping = {'name' : 'Name',
'population' : 'Population',
'density' : 'Density',
'point' : 'POINT',
'dt' : 'Created',
}
# Parent model has geometry field.
lm1 = LayerMapping(ICity1, city_shp, icity_mapping)
lm1.save()
# Grandparent has geometry field.
lm2 = LayerMapping(ICity2, city_shp, icity_mapping)
lm2.save()
self.assertEqual(6, ICity1.objects.count())
self.assertEqual(3, ICity2.objects.count())
def suite(): def suite():
s = unittest.TestSuite() s = unittest.TestSuite()
s.addTest(unittest.makeSuite(LayerMapTest)) s.addTest(unittest.makeSuite(LayerMapTest))

View File

@ -10,7 +10,7 @@ if HAS_GDAL:
try: try:
# LayerMapping requires DJANGO_SETTINGS_MODULE to be set, # LayerMapping requires DJANGO_SETTINGS_MODULE to be set,
# so this needs to be in try/except. # so this needs to be in try/except.
from django.contrib.gis.utils.layermapping import LayerMapping from django.contrib.gis.utils.layermapping import LayerMapping, LayerMapError
except: except:
pass pass

View File

@ -133,6 +133,9 @@ class LayerMapping(object):
MULTI_TYPES = {1 : OGRGeomType('MultiPoint'), MULTI_TYPES = {1 : OGRGeomType('MultiPoint'),
2 : OGRGeomType('MultiLineString'), 2 : OGRGeomType('MultiLineString'),
3 : OGRGeomType('MultiPolygon'), 3 : OGRGeomType('MultiPolygon'),
OGRGeomType('Point25D').num : OGRGeomType('MultiPoint25D'),
OGRGeomType('LineString25D').num : OGRGeomType('MultiLineString25D'),
OGRGeomType('Polygon25D').num : OGRGeomType('MultiPolygon25D'),
} }
# Acceptable Django field types and corresponding acceptable OGR # Acceptable Django field types and corresponding acceptable OGR
@ -282,19 +285,28 @@ class LayerMapping(object):
if self.geom_field: if self.geom_field:
raise LayerMapError('LayerMapping does not support more than one GeometryField per model.') raise LayerMapError('LayerMapping does not support more than one GeometryField per model.')
# Getting the coordinate dimension of the geometry field.
coord_dim = model_field.dim
try: try:
gtype = OGRGeomType(ogr_name) if coord_dim == 3:
gtype = OGRGeomType(ogr_name + '25D')
else:
gtype = OGRGeomType(ogr_name)
except OGRException: except OGRException:
raise LayerMapError('Invalid mapping for GeometryField "%s".' % field_name) raise LayerMapError('Invalid mapping for GeometryField "%s".' % field_name)
# Making sure that the OGR Layer's Geometry is compatible. # Making sure that the OGR Layer's Geometry is compatible.
ltype = self.layer.geom_type ltype = self.layer.geom_type
if not (gtype == ltype or self.make_multi(ltype, model_field)): if not (ltype.name.startswith(gtype.name) or self.make_multi(ltype, model_field)):
raise LayerMapError('Invalid mapping geometry; model has %s, feature has %s.' % (fld_name, gtype)) raise LayerMapError('Invalid mapping geometry; model has %s%s, layer is %s.' %
(fld_name, (coord_dim == 3 and '(dim=3)') or '', ltype))
# Setting the `geom_field` attribute w/the name of the model field # Setting the `geom_field` attribute w/the name of the model field
# that is a Geometry. # that is a Geometry. Also setting the coordinate dimension
# attribute.
self.geom_field = field_name self.geom_field = field_name
self.coord_dim = coord_dim
fields_val = model_field fields_val = model_field
elif isinstance(model_field, models.ForeignKey): elif isinstance(model_field, models.ForeignKey):
if isinstance(ogr_name, dict): if isinstance(ogr_name, dict):
@ -482,6 +494,10 @@ class LayerMapping(object):
if necessary (for example if the model field is MultiPolygonField while if necessary (for example if the model field is MultiPolygonField while
the mapped shapefile only contains Polygons). the mapped shapefile only contains Polygons).
""" """
# Downgrade a 3D geom to a 2D one, if necessary.
if self.coord_dim != geom.coord_dim:
geom.coord_dim = self.coord_dim
if self.make_multi(geom.geom_type, model_field): if self.make_multi(geom.geom_type, model_field):
# Constructing a multi-geometry type to contain the single geometry # Constructing a multi-geometry type to contain the single geometry
multi_type = self.MULTI_TYPES[geom.geom_type.num] multi_type = self.MULTI_TYPES[geom.geom_type.num]
@ -514,16 +530,26 @@ class LayerMapping(object):
def geometry_column(self): def geometry_column(self):
"Returns the GeometryColumn model associated with the geographic column." "Returns the GeometryColumn model associated with the geographic column."
from django.contrib.gis.models import GeometryColumns from django.contrib.gis.models import GeometryColumns
# Getting the GeometryColumn object. # Use the `get_field_by_name` on the model's options so that we
# get the correct model if there's model inheritance -- otherwise
# the returned model is None.
opts = self.model._meta
fld, model, direct, m2m = opts.get_field_by_name(self.geom_field)
if model is None: model = self.model
# Trying to get the `GeometryColumns` object that corresponds to the
# the geometry field.
try: try:
db_table = self.model._meta.db_table db_table = model._meta.db_table
geo_col = self.geom_field geo_col = fld.column
if SpatialBackend.oracle: if SpatialBackend.oracle:
# Making upper case for Oracle. # Making upper case for Oracle.
db_table = db_table.upper() db_table = db_table.upper()
geo_col = geo_col.upper() geo_col = geo_col.upper()
gc_kwargs = {GeometryColumns.table_name_col() : db_table,
GeometryColumns.geom_col_name() : geo_col, gc_kwargs = { GeometryColumns.table_name_col() : db_table,
GeometryColumns.geom_col_name() : geo_col,
} }
return GeometryColumns.objects.get(**gc_kwargs) return GeometryColumns.objects.get(**gc_kwargs)
except Exception, msg: except Exception, msg:

View File

@ -8,6 +8,8 @@ RequestContext.
""" """
from django.conf import settings from django.conf import settings
from django.middleware.csrf import get_token
from django.utils.functional import lazy, memoize, SimpleLazyObject
def auth(request): def auth(request):
""" """
@ -17,17 +19,46 @@ def auth(request):
If there is no 'user' attribute in the request, uses AnonymousUser (from If there is no 'user' attribute in the request, uses AnonymousUser (from
django.contrib.auth). django.contrib.auth).
""" """
if hasattr(request, 'user'): # If we access request.user, request.session is accessed, which results in
user = request.user # 'Vary: Cookie' being sent in every request that uses this context
else: # processor, which can easily be every request on a site if
from django.contrib.auth.models import AnonymousUser # TEMPLATE_CONTEXT_PROCESSORS has this context processor added. This kills
user = AnonymousUser() # the ability to cache. So, we carefully ensure these attributes are lazy.
# We don't use django.utils.functional.lazy() for User, because that
# requires knowing the class of the object we want to proxy, which could
# break with custom auth backends. LazyObject is a less complete but more
# flexible solution that is a good enough wrapper for 'User'.
def get_user():
if hasattr(request, 'user'):
return request.user
else:
from django.contrib.auth.models import AnonymousUser
return AnonymousUser()
return { return {
'user': user, 'user': SimpleLazyObject(get_user),
'messages': user.get_and_delete_messages(), 'messages': lazy(memoize(lambda: get_user().get_and_delete_messages(), {}, 0), list)(),
'perms': PermWrapper(user), 'perms': lazy(lambda: PermWrapper(get_user()), PermWrapper)(),
} }
def csrf(request):
"""
Context processor that provides a CSRF token, or the string 'NOTPROVIDED' if
it has not been provided by either a view decorator or the middleware
"""
def _get_val():
token = get_token(request)
if token is None:
# In order to be able to provide debugging info in the
# case of misconfiguration, we use a sentinel value
# instead of returning an empty dict.
return 'NOTPROVIDED'
else:
return token
_get_val = lazy(_get_val, str)
return {'csrf_token': _get_val() }
def debug(request): def debug(request):
"Returns context variables helpful for debugging." "Returns context variables helpful for debugging."
context_extras = {} context_extras = {}

View File

@ -118,10 +118,6 @@ class Storage(object):
""" """
raise NotImplementedError() raise NotImplementedError()
# Needed by django.utils.functional.LazyObject (via DefaultStorage).
def get_all_members(self):
return self.__members__
class FileSystemStorage(Storage): class FileSystemStorage(Storage):
""" """
Standard filesystem storage Standard filesystem storage

View File

@ -68,6 +68,9 @@ class BaseHandler(object):
from django.core import exceptions, urlresolvers from django.core import exceptions, urlresolvers
from django.conf import settings from django.conf import settings
# Reset the urlconf for this thread.
urlresolvers.set_urlconf(None)
# Apply request middleware # Apply request middleware
for middleware_method in self._request_middleware: for middleware_method in self._request_middleware:
response = middleware_method(request) response = middleware_method(request)
@ -77,61 +80,69 @@ class BaseHandler(object):
# Get urlconf from request object, if available. Otherwise use default. # Get urlconf from request object, if available. Otherwise use default.
urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF) urlconf = getattr(request, "urlconf", settings.ROOT_URLCONF)
# Set the urlconf for this thread to the one specified above.
urlresolvers.set_urlconf(urlconf)
resolver = urlresolvers.RegexURLResolver(r'^/', urlconf) resolver = urlresolvers.RegexURLResolver(r'^/', urlconf)
try: try:
callback, callback_args, callback_kwargs = resolver.resolve(
request.path_info)
# Apply view middleware
for middleware_method in self._view_middleware:
response = middleware_method(request, callback, callback_args, callback_kwargs)
if response:
return response
try: try:
response = callback(request, *callback_args, **callback_kwargs) callback, callback_args, callback_kwargs = resolver.resolve(
except Exception, e: request.path_info)
# If the view raised an exception, run it through exception
# middleware, and if the exception middleware returns a # Apply view middleware
# response, use that. Otherwise, reraise the exception. for middleware_method in self._view_middleware:
for middleware_method in self._exception_middleware: response = middleware_method(request, callback, callback_args, callback_kwargs)
response = middleware_method(request, e)
if response: if response:
return response return response
raise
# Complain if the view returned None (a common error).
if response is None:
try: try:
view_name = callback.func_name # If it's a function response = callback(request, *callback_args, **callback_kwargs)
except AttributeError: except Exception, e:
view_name = callback.__class__.__name__ + '.__call__' # If it's a class # If the view raised an exception, run it through exception
raise ValueError, "The view %s.%s didn't return an HttpResponse object." % (callback.__module__, view_name) # middleware, and if the exception middleware returns a
# response, use that. Otherwise, reraise the exception.
for middleware_method in self._exception_middleware:
response = middleware_method(request, e)
if response:
return response
raise
return response # Complain if the view returned None (a common error).
except http.Http404, e: if response is None:
if settings.DEBUG:
from django.views import debug
return debug.technical_404_response(request, e)
else:
try:
callback, param_dict = resolver.resolve404()
return callback(request, **param_dict)
except:
try: try:
return self.handle_uncaught_exception(request, resolver, sys.exc_info()) view_name = callback.func_name # If it's a function
finally: except AttributeError:
receivers = signals.got_request_exception.send(sender=self.__class__, request=request) view_name = callback.__class__.__name__ + '.__call__' # If it's a class
except exceptions.PermissionDenied: raise ValueError, "The view %s.%s didn't return an HttpResponse object." % (callback.__module__, view_name)
return http.HttpResponseForbidden('<h1>Permission denied</h1>')
except SystemExit: return response
# Allow sys.exit() to actually exit. See tickets #1023 and #4701 except http.Http404, e:
raise if settings.DEBUG:
except: # Handle everything else, including SuspiciousOperation, etc. from django.views import debug
# Get the exception info now, in case another exception is thrown later. return debug.technical_404_response(request, e)
exc_info = sys.exc_info() else:
receivers = signals.got_request_exception.send(sender=self.__class__, request=request) try:
return self.handle_uncaught_exception(request, resolver, exc_info) callback, param_dict = resolver.resolve404()
return callback(request, **param_dict)
except:
try:
return self.handle_uncaught_exception(request, resolver, sys.exc_info())
finally:
receivers = signals.got_request_exception.send(sender=self.__class__, request=request)
except exceptions.PermissionDenied:
return http.HttpResponseForbidden('<h1>Permission denied</h1>')
except SystemExit:
# Allow sys.exit() to actually exit. See tickets #1023 and #4701
raise
except: # Handle everything else, including SuspiciousOperation, etc.
# Get the exception info now, in case another exception is thrown later.
exc_info = sys.exc_info()
receivers = signals.got_request_exception.send(sender=self.__class__, request=request)
return self.handle_uncaught_exception(request, resolver, exc_info)
finally:
# Reset URLconf for this thread on the way out for complete
# isolation of request.urlconf
urlresolvers.set_urlconf(None)
def handle_uncaught_exception(self, request, resolver, exc_info): def handle_uncaught_exception(self, request, resolver, exc_info):
""" """

View File

@ -0,0 +1,110 @@
"""
Tools for sending email.
"""
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.utils.importlib import import_module
# Imported for backwards compatibility, and for the sake
# of a cleaner namespace. These symbols used to be in
# django/core/mail.py before the introduction of email
# backends and the subsequent reorganization (See #10355)
from django.core.mail.utils import CachedDnsName, DNS_NAME
from django.core.mail.message import \
EmailMessage, EmailMultiAlternatives, \
SafeMIMEText, SafeMIMEMultipart, \
DEFAULT_ATTACHMENT_MIME_TYPE, make_msgid, \
BadHeaderError, forbid_multi_line_headers
from django.core.mail.backends.smtp import EmailBackend as _SMTPConnection
def get_connection(backend=None, fail_silently=False, **kwds):
"""Load an e-mail backend and return an instance of it.
If backend is None (default) settings.EMAIL_BACKEND is used.
Both fail_silently and other keyword arguments are used in the
constructor of the backend.
"""
path = backend or settings.EMAIL_BACKEND
try:
mod = import_module(path)
except ImportError, e:
raise ImproperlyConfigured(('Error importing email backend %s: "%s"'
% (path, e)))
try:
cls = getattr(mod, 'EmailBackend')
except AttributeError:
raise ImproperlyConfigured(('Module "%s" does not define a '
'"EmailBackend" class' % path))
return cls(fail_silently=fail_silently, **kwds)
def send_mail(subject, message, from_email, recipient_list,
fail_silently=False, auth_user=None, auth_password=None,
connection=None):
"""
Easy wrapper for sending a single message to a recipient list. All members
of the recipient list will see the other recipients in the 'To' field.
If auth_user is None, the EMAIL_HOST_USER setting is used.
If auth_password is None, the EMAIL_HOST_PASSWORD setting is used.
Note: The API for this method is frozen. New code wanting to extend the
functionality should use the EmailMessage class directly.
"""
connection = connection or get_connection(username=auth_user,
password=auth_password,
fail_silently=fail_silently)
return EmailMessage(subject, message, from_email, recipient_list,
connection=connection).send()
def send_mass_mail(datatuple, fail_silently=False, auth_user=None,
auth_password=None, connection=None):
"""
Given a datatuple of (subject, message, from_email, recipient_list), sends
each message to each recipient list. Returns the number of e-mails sent.
If from_email is None, the DEFAULT_FROM_EMAIL setting is used.
If auth_user and auth_password are set, they're used to log in.
If auth_user is None, the EMAIL_HOST_USER setting is used.
If auth_password is None, the EMAIL_HOST_PASSWORD setting is used.
Note: The API for this method is frozen. New code wanting to extend the
functionality should use the EmailMessage class directly.
"""
connection = connection or get_connection(username=auth_user,
password=auth_password,
fail_silently=fail_silently)
messages = [EmailMessage(subject, message, sender, recipient)
for subject, message, sender, recipient in datatuple]
return connection.send_messages(messages)
def mail_admins(subject, message, fail_silently=False, connection=None):
"""Sends a message to the admins, as defined by the ADMINS setting."""
if not settings.ADMINS:
return
EmailMessage(settings.EMAIL_SUBJECT_PREFIX + subject, message,
settings.SERVER_EMAIL, [a[1] for a in settings.ADMINS],
connection=connection).send(fail_silently=fail_silently)
def mail_managers(subject, message, fail_silently=False, connection=None):
"""Sends a message to the managers, as defined by the MANAGERS setting."""
if not settings.MANAGERS:
return
EmailMessage(settings.EMAIL_SUBJECT_PREFIX + subject, message,
settings.SERVER_EMAIL, [a[1] for a in settings.MANAGERS],
connection=connection).send(fail_silently=fail_silently)
class SMTPConnection(_SMTPConnection):
def __init__(self, *args, **kwds):
import warnings
warnings.warn(
'mail.SMTPConnection is deprecated; use mail.get_connection() instead.',
DeprecationWarning
)
super(SMTPConnection, self).__init__(*args, **kwds)

View File

@ -0,0 +1 @@
# Mail backends shipped with Django.

View File

@ -0,0 +1,39 @@
"""Base email backend class."""
class BaseEmailBackend(object):
"""
Base class for email backend implementations.
Subclasses must at least overwrite send_messages().
"""
def __init__(self, fail_silently=False, **kwargs):
self.fail_silently = fail_silently
def open(self):
"""Open a network connection.
This method can be overwritten by backend implementations to
open a network connection.
It's up to the backend implementation to track the status of
a network connection if it's needed by the backend.
This method can be called by applications to force a single
network connection to be used when sending mails. See the
send_messages() method of the SMTP backend for a reference
implementation.
The default implementation does nothing.
"""
pass
def close(self):
"""Close a network connection."""
pass
def send_messages(self, email_messages):
"""
Sends one or more EmailMessage objects and returns the number of email
messages sent.
"""
raise NotImplementedError

View File

@ -0,0 +1,37 @@
"""
Email backend that writes messages to console instead of sending them.
"""
import sys
import threading
from django.core.mail.backends.base import BaseEmailBackend
class EmailBackend(BaseEmailBackend):
def __init__(self, *args, **kwargs):
self.stream = kwargs.pop('stream', sys.stdout)
self._lock = threading.RLock()
super(EmailBackend, self).__init__(*args, **kwargs)
def send_messages(self, email_messages):
"""Write all messages to the stream in a thread-safe way."""
if not email_messages:
return
self._lock.acquire()
try:
# The try-except is nested to allow for
# Python 2.4 support (Refs #12147)
try:
stream_created = self.open()
for message in email_messages:
self.stream.write('%s\n' % message.message().as_string())
self.stream.write('-'*79)
self.stream.write('\n')
self.stream.flush() # flush after each message
if stream_created:
self.close()
except:
if not self.fail_silently:
raise
finally:
self._lock.release()
return len(email_messages)

View File

@ -0,0 +1,9 @@
"""
Dummy email backend that does nothing.
"""
from django.core.mail.backends.base import BaseEmailBackend
class EmailBackend(BaseEmailBackend):
def send_messages(self, email_messages):
return len(email_messages)

View File

@ -0,0 +1,59 @@
"""Email backend that writes messages to a file."""
import datetime
import os
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.mail.backends.console import EmailBackend as ConsoleEmailBackend
class EmailBackend(ConsoleEmailBackend):
def __init__(self, *args, **kwargs):
self._fname = None
if 'file_path' in kwargs:
self.file_path = kwargs.pop('file_path')
else:
self.file_path = getattr(settings, 'EMAIL_FILE_PATH',None)
# Make sure self.file_path is a string.
if not isinstance(self.file_path, basestring):
raise ImproperlyConfigured('Path for saving emails is invalid: %r' % self.file_path)
self.file_path = os.path.abspath(self.file_path)
# Make sure that self.file_path is an directory if it exists.
if os.path.exists(self.file_path) and not os.path.isdir(self.file_path):
raise ImproperlyConfigured('Path for saving email messages exists, but is not a directory: %s' % self.file_path)
# Try to create it, if it not exists.
elif not os.path.exists(self.file_path):
try:
os.makedirs(self.file_path)
except OSError, err:
raise ImproperlyConfigured('Could not create directory for saving email messages: %s (%s)' % (self.file_path, err))
# Make sure that self.file_path is writable.
if not os.access(self.file_path, os.W_OK):
raise ImproperlyConfigured('Could not write to directory: %s' % self.file_path)
# Finally, call super().
# Since we're using the console-based backend as a base,
# force the stream to be None, so we don't default to stdout
kwargs['stream'] = None
super(EmailBackend, self).__init__(*args, **kwargs)
def _get_filename(self):
"""Return a unique file name."""
if self._fname is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
fname = "%s-%s.log" % (timestamp, abs(id(self)))
self._fname = os.path.join(self.file_path, fname)
return self._fname
def open(self):
if self.stream is None:
self.stream = open(self._get_filename(), 'a')
return True
return False
def close(self):
try:
if self.stream is not None:
self.stream.close()
finally:
self.stream = None

View File

@ -0,0 +1,24 @@
"""
Backend for test environment.
"""
from django.core import mail
from django.core.mail.backends.base import BaseEmailBackend
class EmailBackend(BaseEmailBackend):
"""A email backend for use during test sessions.
The test connection stores email messages in a dummy outbox,
rather than sending them out on the wire.
The dummy outbox is accessible through the outbox instance attribute.
"""
def __init__(self, *args, **kwargs):
super(EmailBackend, self).__init__(*args, **kwargs)
if not hasattr(mail, 'outbox'):
mail.outbox = []
def send_messages(self, messages):
"""Redirect messages to the dummy outbox"""
mail.outbox.extend(messages)
return len(messages)

View File

@ -0,0 +1,106 @@
"""SMTP email backend class."""
import smtplib
import socket
import threading
from django.conf import settings
from django.core.mail.backends.base import BaseEmailBackend
from django.core.mail.utils import DNS_NAME
class EmailBackend(BaseEmailBackend):
"""
A wrapper that manages the SMTP network connection.
"""
def __init__(self, host=None, port=None, username=None, password=None,
use_tls=None, fail_silently=False, **kwargs):
super(EmailBackend, self).__init__(fail_silently=fail_silently)
self.host = host or settings.EMAIL_HOST
self.port = port or settings.EMAIL_PORT
self.username = username or settings.EMAIL_HOST_USER
self.password = password or settings.EMAIL_HOST_PASSWORD
if use_tls is None:
self.use_tls = settings.EMAIL_USE_TLS
else:
self.use_tls = use_tls
self.connection = None
self._lock = threading.RLock()
def open(self):
"""
Ensures we have a connection to the email server. Returns whether or
not a new connection was required (True or False).
"""
if self.connection:
# Nothing to do if the connection is already open.
return False
try:
# If local_hostname is not specified, socket.getfqdn() gets used.
# For performance, we use the cached FQDN for local_hostname.
self.connection = smtplib.SMTP(self.host, self.port,
local_hostname=DNS_NAME.get_fqdn())
if self.use_tls:
self.connection.ehlo()
self.connection.starttls()
self.connection.ehlo()
if self.username and self.password:
self.connection.login(self.username, self.password)
return True
except:
if not self.fail_silently:
raise
def close(self):
"""Closes the connection to the email server."""
try:
try:
self.connection.quit()
except socket.sslerror:
# This happens when calling quit() on a TLS connection
# sometimes.
self.connection.close()
except:
if self.fail_silently:
return
raise
finally:
self.connection = None
def send_messages(self, email_messages):
"""
Sends one or more EmailMessage objects and returns the number of email
messages sent.
"""
if not email_messages:
return
self._lock.acquire()
try:
new_conn_created = self.open()
if not self.connection:
# We failed silently on open().
# Trying to send would be pointless.
return
num_sent = 0
for message in email_messages:
sent = self._send(message)
if sent:
num_sent += 1
if new_conn_created:
self.close()
finally:
self._lock.release()
return num_sent
def _send(self, email_message):
"""A helper method that does the actual sending."""
if not email_message.recipients():
return False
try:
self.connection.sendmail(email_message.from_email,
email_message.recipients(),
email_message.message().as_string())
except:
if not self.fail_silently:
raise
return False
return True

View File

@ -1,21 +1,16 @@
"""
Tools for sending email.
"""
import mimetypes import mimetypes
import os import os
import smtplib
import socket
import time
import random import random
import time
from email import Charset, Encoders from email import Charset, Encoders
from email.MIMEText import MIMEText from email.MIMEText import MIMEText
from email.MIMEMultipart import MIMEMultipart from email.MIMEMultipart import MIMEMultipart
from email.MIMEBase import MIMEBase from email.MIMEBase import MIMEBase
from email.Header import Header from email.Header import Header
from email.Utils import formatdate, parseaddr, formataddr from email.Utils import formatdate, getaddresses, formataddr
from django.conf import settings from django.conf import settings
from django.core.mail.utils import DNS_NAME
from django.utils.encoding import smart_str, force_unicode from django.utils.encoding import smart_str, force_unicode
# Don't BASE64-encode UTF-8 messages so that we avoid unwanted attention from # Don't BASE64-encode UTF-8 messages so that we avoid unwanted attention from
@ -26,18 +21,10 @@ Charset.add_charset('utf-8', Charset.SHORTEST, Charset.QP, 'utf-8')
# and cannot be guessed). # and cannot be guessed).
DEFAULT_ATTACHMENT_MIME_TYPE = 'application/octet-stream' DEFAULT_ATTACHMENT_MIME_TYPE = 'application/octet-stream'
# Cache the hostname, but do it lazily: socket.getfqdn() can take a couple of
# seconds, which slows down the restart of the server.
class CachedDnsName(object):
def __str__(self):
return self.get_fqdn()
def get_fqdn(self): class BadHeaderError(ValueError):
if not hasattr(self, '_fqdn'): pass
self._fqdn = socket.getfqdn()
return self._fqdn
DNS_NAME = CachedDnsName()
# Copied from Python standard library, with the following modifications: # Copied from Python standard library, with the following modifications:
# * Used cached hostname for performance. # * Used cached hostname for performance.
@ -66,8 +53,6 @@ def make_msgid(idstring=None):
msgid = '<%s.%s.%s%s@%s>' % (utcdate, pid, randint, idstring, idhost) msgid = '<%s.%s.%s%s@%s>' % (utcdate, pid, randint, idstring, idhost)
return msgid return msgid
class BadHeaderError(ValueError):
pass
def forbid_multi_line_headers(name, val): def forbid_multi_line_headers(name, val):
"""Forbids multi-line headers, to prevent header injection.""" """Forbids multi-line headers, to prevent header injection."""
@ -79,8 +64,7 @@ def forbid_multi_line_headers(name, val):
except UnicodeEncodeError: except UnicodeEncodeError:
if name.lower() in ('to', 'from', 'cc'): if name.lower() in ('to', 'from', 'cc'):
result = [] result = []
for item in val.split(', '): for nm, addr in getaddresses((val,)):
nm, addr = parseaddr(item)
nm = str(Header(nm, settings.DEFAULT_CHARSET)) nm = str(Header(nm, settings.DEFAULT_CHARSET))
result.append(formataddr((nm, str(addr)))) result.append(formataddr((nm, str(addr))))
val = ', '.join(result) val = ', '.join(result)
@ -91,104 +75,18 @@ def forbid_multi_line_headers(name, val):
val = Header(val) val = Header(val)
return name, val return name, val
class SafeMIMEText(MIMEText): class SafeMIMEText(MIMEText):
def __setitem__(self, name, val): def __setitem__(self, name, val):
name, val = forbid_multi_line_headers(name, val) name, val = forbid_multi_line_headers(name, val)
MIMEText.__setitem__(self, name, val) MIMEText.__setitem__(self, name, val)
class SafeMIMEMultipart(MIMEMultipart): class SafeMIMEMultipart(MIMEMultipart):
def __setitem__(self, name, val): def __setitem__(self, name, val):
name, val = forbid_multi_line_headers(name, val) name, val = forbid_multi_line_headers(name, val)
MIMEMultipart.__setitem__(self, name, val) MIMEMultipart.__setitem__(self, name, val)
class SMTPConnection(object):
"""
A wrapper that manages the SMTP network connection.
"""
def __init__(self, host=None, port=None, username=None, password=None,
use_tls=None, fail_silently=False):
self.host = host or settings.EMAIL_HOST
self.port = port or settings.EMAIL_PORT
self.username = username or settings.EMAIL_HOST_USER
self.password = password or settings.EMAIL_HOST_PASSWORD
self.use_tls = (use_tls is not None) and use_tls or settings.EMAIL_USE_TLS
self.fail_silently = fail_silently
self.connection = None
def open(self):
"""
Ensures we have a connection to the email server. Returns whether or
not a new connection was required (True or False).
"""
if self.connection:
# Nothing to do if the connection is already open.
return False
try:
# If local_hostname is not specified, socket.getfqdn() gets used.
# For performance, we use the cached FQDN for local_hostname.
self.connection = smtplib.SMTP(self.host, self.port,
local_hostname=DNS_NAME.get_fqdn())
if self.use_tls:
self.connection.ehlo()
self.connection.starttls()
self.connection.ehlo()
if self.username and self.password:
self.connection.login(self.username, self.password)
return True
except:
if not self.fail_silently:
raise
def close(self):
"""Closes the connection to the email server."""
try:
try:
self.connection.quit()
except socket.sslerror:
# This happens when calling quit() on a TLS connection
# sometimes.
self.connection.close()
except:
if self.fail_silently:
return
raise
finally:
self.connection = None
def send_messages(self, email_messages):
"""
Sends one or more EmailMessage objects and returns the number of email
messages sent.
"""
if not email_messages:
return
new_conn_created = self.open()
if not self.connection:
# We failed silently on open(). Trying to send would be pointless.
return
num_sent = 0
for message in email_messages:
sent = self._send(message)
if sent:
num_sent += 1
if new_conn_created:
self.close()
return num_sent
def _send(self, email_message):
"""A helper method that does the actual sending."""
if not email_message.recipients():
return False
try:
self.connection.sendmail(email_message.from_email,
email_message.recipients(),
email_message.message().as_string())
except:
if not self.fail_silently:
raise
return False
return True
class EmailMessage(object): class EmailMessage(object):
""" """
@ -199,14 +97,14 @@ class EmailMessage(object):
encoding = None # None => use settings default encoding = None # None => use settings default
def __init__(self, subject='', body='', from_email=None, to=None, bcc=None, def __init__(self, subject='', body='', from_email=None, to=None, bcc=None,
connection=None, attachments=None, headers=None): connection=None, attachments=None, headers=None):
""" """
Initialize a single email message (which can be sent to multiple Initialize a single email message (which can be sent to multiple
recipients). recipients).
All strings used to create the message can be unicode strings (or UTF-8 All strings used to create the message can be unicode strings
bytestrings). The SafeMIMEText class will handle any necessary encoding (or UTF-8 bytestrings). The SafeMIMEText class will handle any
conversions. necessary encoding conversions.
""" """
if to: if to:
assert not isinstance(to, basestring), '"to" argument must be a list or tuple' assert not isinstance(to, basestring), '"to" argument must be a list or tuple'
@ -226,8 +124,9 @@ class EmailMessage(object):
self.connection = connection self.connection = connection
def get_connection(self, fail_silently=False): def get_connection(self, fail_silently=False):
from django.core.mail import get_connection
if not self.connection: if not self.connection:
self.connection = SMTPConnection(fail_silently=fail_silently) self.connection = get_connection(fail_silently=fail_silently)
return self.connection return self.connection
def message(self): def message(self):
@ -332,6 +231,7 @@ class EmailMessage(object):
filename=filename) filename=filename)
return attachment return attachment
class EmailMultiAlternatives(EmailMessage): class EmailMultiAlternatives(EmailMessage):
""" """
A version of EmailMessage that makes it easy to send multipart/alternative A version of EmailMessage that makes it easy to send multipart/alternative
@ -371,56 +271,3 @@ class EmailMultiAlternatives(EmailMessage):
for alternative in self.alternatives: for alternative in self.alternatives:
msg.attach(self._create_mime_attachment(*alternative)) msg.attach(self._create_mime_attachment(*alternative))
return msg return msg
def send_mail(subject, message, from_email, recipient_list,
fail_silently=False, auth_user=None, auth_password=None):
"""
Easy wrapper for sending a single message to a recipient list. All members
of the recipient list will see the other recipients in the 'To' field.
If auth_user is None, the EMAIL_HOST_USER setting is used.
If auth_password is None, the EMAIL_HOST_PASSWORD setting is used.
Note: The API for this method is frozen. New code wanting to extend the
functionality should use the EmailMessage class directly.
"""
connection = SMTPConnection(username=auth_user, password=auth_password,
fail_silently=fail_silently)
return EmailMessage(subject, message, from_email, recipient_list,
connection=connection).send()
def send_mass_mail(datatuple, fail_silently=False, auth_user=None,
auth_password=None):
"""
Given a datatuple of (subject, message, from_email, recipient_list), sends
each message to each recipient list. Returns the number of e-mails sent.
If from_email is None, the DEFAULT_FROM_EMAIL setting is used.
If auth_user and auth_password are set, they're used to log in.
If auth_user is None, the EMAIL_HOST_USER setting is used.
If auth_password is None, the EMAIL_HOST_PASSWORD setting is used.
Note: The API for this method is frozen. New code wanting to extend the
functionality should use the EmailMessage class directly.
"""
connection = SMTPConnection(username=auth_user, password=auth_password,
fail_silently=fail_silently)
messages = [EmailMessage(subject, message, sender, recipient)
for subject, message, sender, recipient in datatuple]
return connection.send_messages(messages)
def mail_admins(subject, message, fail_silently=False):
"""Sends a message to the admins, as defined by the ADMINS setting."""
if not settings.ADMINS:
return
EmailMessage(settings.EMAIL_SUBJECT_PREFIX + subject, message,
settings.SERVER_EMAIL, [a[1] for a in settings.ADMINS]
).send(fail_silently=fail_silently)
def mail_managers(subject, message, fail_silently=False):
"""Sends a message to the managers, as defined by the MANAGERS setting."""
if not settings.MANAGERS:
return
EmailMessage(settings.EMAIL_SUBJECT_PREFIX + subject, message,
settings.SERVER_EMAIL, [a[1] for a in settings.MANAGERS]
).send(fail_silently=fail_silently)

19
django/core/mail/utils.py Normal file
View File

@ -0,0 +1,19 @@
"""
Email message and email sending related helper functions.
"""
import socket
# Cache the hostname, but do it lazily: socket.getfqdn() can take a couple of
# seconds, which slows down the restart of the server.
class CachedDnsName(object):
def __str__(self):
return self.get_fqdn()
def get_fqdn(self):
if not hasattr(self, '_fqdn'):
self._fqdn = socket.getfqdn()
return self._fqdn
DNS_NAME = CachedDnsName()

View File

@ -299,7 +299,7 @@ class ManagementUtility(object):
# subcommand # subcommand
if cword == 1: if cword == 1:
print ' '.join(filter(lambda x: x.startswith(curr), subcommands)) print ' '.join(sorted(filter(lambda x: x.startswith(curr), subcommands)))
# subcommand options # subcommand options
# special case: the 'help' subcommand has no options # special case: the 'help' subcommand has no options
elif cwords[0] in subcommands and cwords[0] != 'help': elif cwords[0] in subcommands and cwords[0] != 'help':
@ -328,7 +328,7 @@ class ManagementUtility(object):
options = filter(lambda (x, v): x not in prev_opts, options) options = filter(lambda (x, v): x not in prev_opts, options)
# filter options by current input # filter options by current input
options = [(k, v) for k, v in options if k.startswith(curr)] options = sorted([(k, v) for k, v in options if k.startswith(curr)])
for option in options: for option in options:
opt_label = option[0] opt_label = option[0]
# append '=' to options which require args # append '=' to options which require args

View File

@ -31,10 +31,7 @@ class Command(NoArgsCommand):
self.style = no_style() self.style = no_style()
if not options['database']: connection = connections[options["database"]]
dbs = connections
else:
dbs = [options['database']]
# Import the 'management' module within each installed app, to register # Import the 'management' module within each installed app, to register
# dispatcher events. # dispatcher events.
@ -55,85 +52,92 @@ class Command(NoArgsCommand):
if not msg.startswith('No module named') or 'management' not in msg: if not msg.startswith('No module named') or 'management' not in msg:
raise raise
for db in dbs: cursor = connection.cursor()
connection = connections[db]
cursor = connection.cursor()
# Get a list of already installed *models* so that references work right. # Get a list of already installed *models* so that references work right.
tables = connection.introspection.table_names() tables = connection.introspection.table_names()
seen_models = connection.introspection.installed_models(tables) seen_models = connection.introspection.installed_models(tables)
created_models = set() created_models = set()
pending_references = {} pending_references = {}
# Create the tables for each model # Create the tables for each model
for app in models.get_apps(): for app in models.get_apps():
app_name = app.__name__.split('.')[-2] app_name = app.__name__.split('.')[-2]
model_list = models.get_models(app) model_list = models.get_models(app, include_auto_created=True)
for model in model_list: for model in model_list:
# Create the model's database table, if it doesn't already exist. # Create the model's database table, if it doesn't already exist.
if verbosity >= 2: if verbosity >= 2:
print "Processing %s.%s model" % (app_name, model._meta.object_name) print "Processing %s.%s model" % (app_name, model._meta.object_name)
if connection.introspection.table_name_converter(model._meta.db_table) in tables: opts = model._meta
continue if (connection.introspection.table_name_converter(opts.db_table) in tables or
sql, references = connection.creation.sql_create_model(model, self.style, seen_models) (opts.auto_created and
seen_models.add(model) connection.introspection.table_name_converter(opts.auto_created._meta.db_table) in tables)):
created_models.add(model) continue
for refto, refs in references.items(): sql, references = connection.creation.sql_create_model(model, self.style, seen_models)
pending_references.setdefault(refto, []).extend(refs) seen_models.add(model)
if refto in seen_models: created_models.add(model)
sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references)) for refto, refs in references.items():
sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references)) pending_references.setdefault(refto, []).extend(refs)
if verbosity >= 1 and sql: if refto in seen_models:
print "Creating table %s" % model._meta.db_table sql.extend(connection.creation.sql_for_pending_references(refto, self.style, pending_references))
for statement in sql: sql.extend(connection.creation.sql_for_pending_references(model, self.style, pending_references))
cursor.execute(statement) if verbosity >= 1 and sql:
tables.append(connection.introspection.table_name_converter(model._meta.db_table)) print "Creating table %s" % model._meta.db_table
for statement in sql:
cursor.execute(statement)
tables.append(connection.introspection.table_name_converter(model._meta.db_table))
# Create the m2m tables. This must be done after all tables have been created
# to ensure that all referred tables will exist.
for app in models.get_apps():
app_name = app.__name__.split('.')[-2]
model_list = models.get_models(app)
for model in model_list:
if model in created_models:
sql = connection.creation.sql_for_many_to_many(model, self.style)
if sql:
if verbosity >= 2:
print "Creating many-to-many tables for %s.%s model" % (app_name, model._meta.object_name)
for statement in sql:
cursor.execute(statement)
transaction.commit_unless_managed(using=db) transaction.commit_unless_managed()
# Send the post_syncdb signal, so individual apps can do whatever they need # Send the post_syncdb signal, so individual apps can do whatever they need
# to do at this point. # to do at this point.
emit_post_sync_signal(created_models, verbosity, interactive, db) emit_post_sync_signal(created_models, verbosity, interactive)
# The connection may have been closed by a syncdb handler. # The connection may have been closed by a syncdb handler.
cursor = connection.cursor() cursor = connection.cursor()
# Install custom SQL for the app (but only if this # Install custom SQL for the app (but only if this
# is a model we've just created) # is a model we've just created)
for app in models.get_apps(): for app in models.get_apps():
app_name = app.__name__.split('.')[-2] app_name = app.__name__.split('.')[-2]
for model in models.get_models(app): for model in models.get_models(app):
if model in created_models: if model in created_models:
custom_sql = custom_sql_for_model(model, self.style, connection) custom_sql = custom_sql_for_model(model, self.style)
if custom_sql: if custom_sql:
if verbosity >= 1: if verbosity >= 1:
print "Installing custom SQL for %s.%s model" % (app_name, model._meta.object_name) print "Installing custom SQL for %s.%s model" % (app_name, model._meta.object_name)
try: try:
for sql in custom_sql: for sql in custom_sql:
cursor.execute(sql) cursor.execute(sql)
except Exception, e: except Exception, e:
sys.stderr.write("Failed to install custom SQL for %s.%s model: %s\n" % \ sys.stderr.write("Failed to install custom SQL for %s.%s model: %s\n" % \
(app_name, model._meta.object_name, e)) (app_name, model._meta.object_name, e))
if show_traceback: if show_traceback:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
transaction.rollback_unless_managed(using=db) transaction.rollback_unless_managed()
else: else:
transaction.commit_unless_managed(using=db) transaction.commit_unless_managed()
else:
if verbosity >= 2:
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)
# Install SQL indicies for all newly created models
for app in models.get_apps():
app_name = app.__name__.split('.')[-2]
for model in models.get_models(app):
if model in created_models:
index_sql = connection.creation.sql_indexes_for_model(model, self.style)
if index_sql:
if verbosity >= 1:
print "Installing index for %s.%s model" % (app_name, model._meta.object_name)
try:
for sql in index_sql:
cursor.execute(sql)
except Exception, e:
sys.stderr.write("Failed to install index for %s.%s model: %s\n" % \
(app_name, model._meta.object_name, e))
transaction.rollback_unless_managed()
else: else:
if verbosity >= 2: if verbosity >= 2:
print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name) print "No custom SQL for %s.%s model" % (app_name, model._meta.object_name)

View File

@ -28,7 +28,7 @@ def sql_create(app, style, connection):
# We trim models from the current app so that the sqlreset command does not # We trim models from the current app so that the sqlreset command does not
# generate invalid SQL (leaving models out of known_models is harmless, so # generate invalid SQL (leaving models out of known_models is harmless, so
# we can be conservative). # we can be conservative).
app_models = models.get_models(app) app_models = models.get_models(app, include_auto_created=True)
final_output = [] final_output = []
tables = connection.introspection.table_names() tables = connection.introspection.table_names()
known_models = set([model for model in connection.introspection.installed_models(tables) if model not in app_models]) known_models = set([model for model in connection.introspection.installed_models(tables) if model not in app_models])
@ -45,10 +45,6 @@ def sql_create(app, style, connection):
# Keep track of the fact that we've created the table for this model. # Keep track of the fact that we've created the table for this model.
known_models.add(model) known_models.add(model)
# Create the many-to-many join tables.
for model in app_models:
final_output.extend(connection.creation.sql_for_many_to_many(model, style))
# Handle references to tables that are from other apps # Handle references to tables that are from other apps
# but don't exist physically. # but don't exist physically.
not_installed_models = set(pending_references.keys()) not_installed_models = set(pending_references.keys())
@ -84,7 +80,7 @@ def sql_delete(app, style, connection):
to_delete = set() to_delete = set()
references_to_delete = {} references_to_delete = {}
app_models = models.get_models(app) app_models = models.get_models(app, include_auto_created=True)
for model in app_models: for model in app_models:
if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names: if cursor and connection.introspection.table_name_converter(model._meta.db_table) in table_names:
# The table exists, so it needs to be dropped # The table exists, so it needs to be dropped
@ -99,13 +95,6 @@ def sql_delete(app, style, connection):
if connection.introspection.table_name_converter(model._meta.db_table) in table_names: if connection.introspection.table_name_converter(model._meta.db_table) in table_names:
output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style)) output.extend(connection.creation.sql_destroy_model(model, references_to_delete, style))
# Output DROP TABLE statements for many-to-many tables.
for model in app_models:
opts = model._meta
for f in opts.local_many_to_many:
if cursor and connection.introspection.table_name_converter(f.m2m_db_table()) in table_names:
output.extend(connection.creation.sql_destroy_many_to_many(model, f, style))
# Close database connection explicitly, in case this output is being piped # Close database connection explicitly, in case this output is being piped
# directly into a database client, to avoid locking issues. # directly into a database client, to avoid locking issues.
if cursor: if cursor:

View File

@ -79,27 +79,28 @@ def get_validation_errors(outfile, app=None):
rel_opts = f.rel.to._meta rel_opts = f.rel.to._meta
rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name() rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name()
rel_query_name = f.related_query_name() rel_query_name = f.related_query_name()
for r in rel_opts.fields: if not f.rel.is_hidden():
if r.name == rel_name: for r in rel_opts.fields:
e.add(opts, "Accessor for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_name:
if r.name == rel_query_name: e.add(opts, "Accessor for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
e.add(opts, "Reverse query name for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_query_name:
for r in rel_opts.local_many_to_many: e.add(opts, "Reverse query name for field '%s' clashes with field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.name == rel_name: for r in rel_opts.local_many_to_many:
e.add(opts, "Accessor for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_name:
if r.name == rel_query_name: e.add(opts, "Accessor for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
e.add(opts, "Reverse query name for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name)) if r.name == rel_query_name:
for r in rel_opts.get_all_related_many_to_many_objects(): e.add(opts, "Reverse query name for field '%s' clashes with m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.name, f.name))
if r.get_accessor_name() == rel_name: for r in rel_opts.get_all_related_many_to_many_objects():
e.add(opts, "Accessor for field '%s' clashes with related m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
if r.get_accessor_name() == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with related m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
for r in rel_opts.get_all_related_objects():
if r.field is not f:
if r.get_accessor_name() == rel_name: if r.get_accessor_name() == rel_name:
e.add(opts, "Accessor for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name)) e.add(opts, "Accessor for field '%s' clashes with related m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
if r.get_accessor_name() == rel_query_name: if r.get_accessor_name() == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name)) e.add(opts, "Reverse query name for field '%s' clashes with related m2m field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
for r in rel_opts.get_all_related_objects():
if r.field is not f:
if r.get_accessor_name() == rel_name:
e.add(opts, "Accessor for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
if r.get_accessor_name() == rel_query_name:
e.add(opts, "Reverse query name for field '%s' clashes with related field '%s.%s'. Add a related_name argument to the definition for '%s'." % (f.name, rel_opts.object_name, r.get_accessor_name(), f.name))
seen_intermediary_signatures = [] seen_intermediary_signatures = []
for i, f in enumerate(opts.local_many_to_many): for i, f in enumerate(opts.local_many_to_many):
@ -117,48 +118,80 @@ def get_validation_errors(outfile, app=None):
if f.unique: if f.unique:
e.add(opts, "ManyToManyFields cannot be unique. Remove the unique argument on '%s'." % f.name) e.add(opts, "ManyToManyFields cannot be unique. Remove the unique argument on '%s'." % f.name)
if getattr(f.rel, 'through', None) is not None: if f.rel.through is not None and not isinstance(f.rel.through, basestring):
if hasattr(f.rel, 'through_model'): from_model, to_model = cls, f.rel.to
from_model, to_model = cls, f.rel.to if from_model == to_model and f.rel.symmetrical and not f.rel.through._meta.auto_created:
if from_model == to_model and f.rel.symmetrical: e.add(opts, "Many-to-many fields with intermediate tables cannot be symmetrical.")
e.add(opts, "Many-to-many fields with intermediate tables cannot be symmetrical.") seen_from, seen_to, seen_self = False, False, 0
seen_from, seen_to, seen_self = False, False, 0 for inter_field in f.rel.through._meta.fields:
for inter_field in f.rel.through_model._meta.fields: rel_to = getattr(inter_field.rel, 'to', None)
rel_to = getattr(inter_field.rel, 'to', None) if from_model == to_model: # relation to self
if from_model == to_model: # relation to self if rel_to == from_model:
if rel_to == from_model: seen_self += 1
seen_self += 1 if seen_self > 2:
if seen_self > 2: e.add(opts, "Intermediary model %s has more than "
e.add(opts, "Intermediary model %s has more than two foreign keys to %s, which is ambiguous and is not permitted." % (f.rel.through_model._meta.object_name, from_model._meta.object_name)) "two foreign keys to %s, which is ambiguous "
else: "and is not permitted." % (
if rel_to == from_model: f.rel.through._meta.object_name,
if seen_from: from_model._meta.object_name
e.add(opts, "Intermediary model %s has more than one foreign key to %s, which is ambiguous and is not permitted." % (f.rel.through_model._meta.object_name, from_model._meta.object_name)) )
else: )
seen_from = True
elif rel_to == to_model:
if seen_to:
e.add(opts, "Intermediary model %s has more than one foreign key to %s, which is ambiguous and is not permitted." % (f.rel.through_model._meta.object_name, rel_to._meta.object_name))
else:
seen_to = True
if f.rel.through_model not in models.get_models():
e.add(opts, "'%s' specifies an m2m relation through model %s, which has not been installed." % (f.name, f.rel.through))
signature = (f.rel.to, cls, f.rel.through_model)
if signature in seen_intermediary_signatures:
e.add(opts, "The model %s has two manually-defined m2m relations through the model %s, which is not permitted. Please consider using an extra field on your intermediary model instead." % (cls._meta.object_name, f.rel.through_model._meta.object_name))
else: else:
seen_intermediary_signatures.append(signature) if rel_to == from_model:
seen_related_fk, seen_this_fk = False, False if seen_from:
for field in f.rel.through_model._meta.fields: e.add(opts, "Intermediary model %s has more "
if field.rel: "than one foreign key to %s, which is "
if not seen_related_fk and field.rel.to == f.rel.to: "ambiguous and is not permitted." % (
seen_related_fk = True f.rel.through._meta.object_name,
elif field.rel.to == cls: from_model._meta.object_name
seen_this_fk = True )
if not seen_related_fk or not seen_this_fk: )
e.add(opts, "'%s' has a manually-defined m2m relation through model %s, which does not have foreign keys to %s and %s" % (f.name, f.rel.through, f.rel.to._meta.object_name, cls._meta.object_name)) else:
seen_from = True
elif rel_to == to_model:
if seen_to:
e.add(opts, "Intermediary model %s has more "
"than one foreign key to %s, which is "
"ambiguous and is not permitted." % (
f.rel.through._meta.object_name,
rel_to._meta.object_name
)
)
else:
seen_to = True
if f.rel.through not in models.get_models(include_auto_created=True):
e.add(opts, "'%s' specifies an m2m relation through model "
"%s, which has not been installed." % (f.name, f.rel.through)
)
signature = (f.rel.to, cls, f.rel.through)
if signature in seen_intermediary_signatures:
e.add(opts, "The model %s has two manually-defined m2m "
"relations through the model %s, which is not "
"permitted. Please consider using an extra field on "
"your intermediary model instead." % (
cls._meta.object_name,
f.rel.through._meta.object_name
)
)
else: else:
e.add(opts, "'%s' specifies an m2m relation through model %s, which has not been installed" % (f.name, f.rel.through)) seen_intermediary_signatures.append(signature)
seen_related_fk, seen_this_fk = False, False
for field in f.rel.through._meta.fields:
if field.rel:
if not seen_related_fk and field.rel.to == f.rel.to:
seen_related_fk = True
elif field.rel.to == cls:
seen_this_fk = True
if not seen_related_fk or not seen_this_fk:
e.add(opts, "'%s' has a manually-defined m2m relation "
"through model %s, which does not have foreign keys "
"to %s and %s" % (f.name, f.rel.through._meta.object_name,
f.rel.to._meta.object_name, cls._meta.object_name)
)
elif isinstance(f.rel.through, basestring):
e.add(opts, "'%s' specifies an m2m relation through model %s, "
"which has not been installed" % (f.name, f.rel.through)
)
rel_opts = f.rel.to._meta rel_opts = f.rel.to._meta
rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name() rel_name = RelatedObject(f.rel.to, cls, f).get_accessor_name()

View File

@ -56,7 +56,7 @@ class Serializer(base.Serializer):
self._current[field.name] = smart_unicode(related, strings_only=True) self._current[field.name] = smart_unicode(related, strings_only=True)
def handle_m2m_field(self, obj, field): def handle_m2m_field(self, obj, field):
if field.creates_table: if field.rel.through._meta.auto_created:
self._current[field.name] = [smart_unicode(related._get_pk_val(), strings_only=True) self._current[field.name] = [smart_unicode(related._get_pk_val(), strings_only=True)
for related in getattr(obj, field.name).iterator()] for related in getattr(obj, field.name).iterator()]

View File

@ -98,7 +98,7 @@ class Serializer(base.Serializer):
serialized as references to the object's PK (i.e. the related *data* serialized as references to the object's PK (i.e. the related *data*
is not dumped, just the relation). is not dumped, just the relation).
""" """
if field.creates_table: if field.rel.through._meta.auto_created:
self._start_relational_field(field) self._start_relational_field(field)
for relobj in getattr(obj, field.name).iterator(): for relobj in getattr(obj, field.name).iterator():
self.xml.addQuickElement("object", attrs={"pk" : smart_unicode(relobj._get_pk_val())}) self.xml.addQuickElement("object", attrs={"pk" : smart_unicode(relobj._get_pk_val())})
@ -233,4 +233,3 @@ def getInnerText(node):
else: else:
pass pass
return u"".join(inner_text) return u"".join(inner_text)

View File

@ -10,6 +10,7 @@ a string) and returns a tuple in this format:
import re import re
from django.http import Http404 from django.http import Http404
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, ViewDoesNotExist from django.core.exceptions import ImproperlyConfigured, ViewDoesNotExist
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from django.utils.encoding import iri_to_uri, force_unicode, smart_str from django.utils.encoding import iri_to_uri, force_unicode, smart_str
@ -32,6 +33,9 @@ _callable_cache = {} # Maps view and url pattern names to their view functions.
# be empty. # be empty.
_prefixes = {} _prefixes = {}
# Overridden URLconfs for each thread are stored here.
_urlconfs = {}
class Resolver404(Http404): class Resolver404(Http404):
pass pass
@ -300,9 +304,13 @@ class RegexURLResolver(object):
"arguments '%s' not found." % (lookup_view_s, args, kwargs)) "arguments '%s' not found." % (lookup_view_s, args, kwargs))
def resolve(path, urlconf=None): def resolve(path, urlconf=None):
if urlconf is None:
urlconf = get_urlconf()
return get_resolver(urlconf).resolve(path) return get_resolver(urlconf).resolve(path)
def reverse(viewname, urlconf=None, args=None, kwargs=None, prefix=None, current_app=None): def reverse(viewname, urlconf=None, args=None, kwargs=None, prefix=None, current_app=None):
if urlconf is None:
urlconf = get_urlconf()
resolver = get_resolver(urlconf) resolver = get_resolver(urlconf)
args = args or [] args = args or []
kwargs = kwargs or {} kwargs = kwargs or {}
@ -370,3 +378,26 @@ def get_script_prefix():
instance is normally going to be a lot cleaner). instance is normally going to be a lot cleaner).
""" """
return _prefixes.get(currentThread(), u'/') return _prefixes.get(currentThread(), u'/')
def set_urlconf(urlconf_name):
"""
Sets the URLconf for the current thread (overriding the default one in
settings). Set to None to revert back to the default.
"""
thread = currentThread()
if urlconf_name:
_urlconfs[thread] = urlconf_name
else:
# faster than wrapping in a try/except
if thread in _urlconfs:
del _urlconfs[thread]
def get_urlconf(default=None):
"""
Returns the root URLconf to use for the current thread if it has been
changed from the default one.
"""
thread = currentThread()
if thread in _urlconfs:
return _urlconfs[thread]
return default

View File

@ -3,11 +3,6 @@ import types
import sys import sys
import os import os
from itertools import izip from itertools import izip
try:
set
except NameError:
from sets import Set as set # Python 2.3 fallback.
import django.db.models.manager # Imported to register signal handler. import django.db.models.manager # Imported to register signal handler.
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned, FieldError from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned, FieldError
from django.db.models.fields import AutoField, FieldDoesNotExist from django.db.models.fields import AutoField, FieldDoesNotExist
@ -22,7 +17,6 @@ from django.utils.functional import curry
from django.utils.encoding import smart_str, force_unicode, smart_unicode from django.utils.encoding import smart_str, force_unicode, smart_unicode
from django.conf import settings from django.conf import settings
class ModelBase(type): class ModelBase(type):
""" """
Metaclass for all models. Metaclass for all models.
@ -236,7 +230,6 @@ class ModelBase(type):
signals.class_prepared.send(sender=cls) signals.class_prepared.send(sender=cls)
class Model(object): class Model(object):
__metaclass__ = ModelBase __metaclass__ = ModelBase
_deferred = False _deferred = False
@ -300,7 +293,14 @@ class Model(object):
if rel_obj is None and field.null: if rel_obj is None and field.null:
val = None val = None
else: else:
val = kwargs.pop(field.attname, field.get_default()) try:
val = kwargs.pop(field.attname)
except KeyError:
# This is done with an exception rather than the
# default argument on pop because we don't want
# get_default() to be evaluated, and then not used.
# Refs #12057.
val = field.get_default()
else: else:
val = field.get_default() val = field.get_default()
if is_related_object: if is_related_object:
@ -352,21 +352,30 @@ class Model(object):
only module-level classes can be pickled by the default path. only module-level classes can be pickled by the default path.
""" """
data = self.__dict__ data = self.__dict__
if not self._deferred: model = self.__class__
return (self.__class__, (), data) # The obvious thing to do here is to invoke super().__reduce__()
# for the non-deferred case. Don't do that.
# On Python 2.4, there is something wierd with __reduce__,
# and as a result, the super call will cause an infinite recursion.
# See #10547 and #12121.
defers = [] defers = []
pk_val = None pk_val = None
for field in self._meta.fields: if self._deferred:
if isinstance(self.__class__.__dict__.get(field.attname), from django.db.models.query_utils import deferred_class_factory
DeferredAttribute): factory = deferred_class_factory
defers.append(field.attname) for field in self._meta.fields:
if pk_val is None: if isinstance(self.__class__.__dict__.get(field.attname),
# The pk_val and model values are the same for all DeferredAttribute):
# DeferredAttribute classes, so we only need to do this defers.append(field.attname)
# once. if pk_val is None:
obj = self.__class__.__dict__[field.attname] # The pk_val and model values are the same for all
model = obj.model_ref() # DeferredAttribute classes, so we only need to do this
return (model_unpickle, (model, defers), data) # once.
obj = self.__class__.__dict__[field.attname]
model = obj.model_ref()
else:
factory = simple_class_factory
return (model_unpickle, (model, defers, factory), data)
def _get_pk_val(self, meta=None): def _get_pk_val(self, meta=None):
if not meta: if not meta:
@ -430,7 +439,7 @@ class Model(object):
else: else:
meta = cls._meta meta = cls._meta
if origin: if origin and not meta.auto_created:
signals.pre_save.send(sender=origin, instance=self, raw=raw) signals.pre_save.send(sender=origin, instance=self, raw=raw)
# If we are in a raw save, save the object exactly as presented. # If we are in a raw save, save the object exactly as presented.
@ -469,7 +478,7 @@ class Model(object):
if pk_set: if pk_set:
# Determine whether a record with the primary key already exists. # Determine whether a record with the primary key already exists.
if (force_update or (not force_insert and if (force_update or (not force_insert and
manager.using(using).filter(pk=pk_val).extra(select={'a': 1}).values('a').order_by())): manager.using(using).filter(pk=pk_val).exists())):
# It does already exist, so do an UPDATE. # It does already exist, so do an UPDATE.
if force_update or non_pks: if force_update or non_pks:
values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks] values = [(f, None, (raw and getattr(self, f.attname) or f.pre_save(self, False))) for f in non_pks]
@ -505,7 +514,7 @@ class Model(object):
setattr(self, meta.pk.attname, result) setattr(self, meta.pk.attname, result)
transaction.commit_unless_managed(using=using) transaction.commit_unless_managed(using=using)
if origin: if origin and not meta.auto_created:
signals.post_save.send(sender=origin, instance=self, signals.post_save.send(sender=origin, instance=self,
created=(not record_exists), raw=raw) created=(not record_exists), raw=raw)
@ -542,7 +551,12 @@ class Model(object):
rel_descriptor = cls.__dict__[rel_opts_name] rel_descriptor = cls.__dict__[rel_opts_name]
break break
else: else:
raise AssertionError("Should never get here.") # in the case of a hidden fkey just skip it, it'll get
# processed as an m2m
if not related.field.rel.is_hidden():
raise AssertionError("Should never get here.")
else:
continue
delete_qs = rel_descriptor.delete_manager(self).all() delete_qs = rel_descriptor.delete_manager(self).all()
for sub_obj in delete_qs: for sub_obj in delete_qs:
sub_obj._collect_sub_objects(seen_objs, self.__class__, related.field.null) sub_obj._collect_sub_objects(seen_objs, self.__class__, related.field.null)
@ -653,12 +667,20 @@ def get_absolute_url(opts, func, self, *args, **kwargs):
class Empty(object): class Empty(object):
pass pass
def model_unpickle(model, attrs): def simple_class_factory(model, attrs):
"""Used to unpickle Models without deferred fields.
We need to do this the hard way, rather than just using
the default __reduce__ implementation, because of a
__deepcopy__ problem in Python 2.4
"""
return model
def model_unpickle(model, attrs, factory):
""" """
Used to unpickle Model subclasses with deferred fields. Used to unpickle Model subclasses with deferred fields.
""" """
from django.db.models.query_utils import deferred_class_factory cls = factory(model, attrs)
cls = deferred_class_factory(model, attrs)
return cls.__new__(cls) return cls.__new__(cls)
model_unpickle.__safe_for_unpickle__ = True model_unpickle.__safe_for_unpickle__ = True

View File

@ -58,6 +58,10 @@ def add_lazy_relation(cls, field, relation, operation):
# If we can't split, assume a model in current app # If we can't split, assume a model in current app
app_label = cls._meta.app_label app_label = cls._meta.app_label
model_name = relation model_name = relation
except AttributeError:
# If it doesn't have a split it's actually a model class
app_label = relation._meta.app_label
model_name = relation._meta.object_name
# Try to look up the related model, and if it's already loaded resolve the # Try to look up the related model, and if it's already loaded resolve the
# string right away. If get_model returns None, it means that the related # string right away. If get_model returns None, it means that the related
@ -96,7 +100,7 @@ class RelatedField(object):
self.rel.related_name = self.rel.related_name % {'class': cls.__name__.lower()} self.rel.related_name = self.rel.related_name % {'class': cls.__name__.lower()}
other = self.rel.to other = self.rel.to
if isinstance(other, basestring): if isinstance(other, basestring) or other._meta.pk is None:
def resolve_related_class(field, model, cls): def resolve_related_class(field, model, cls):
field.rel.to = model field.rel.to = model
field.do_related_class(model, cls) field.do_related_class(model, cls)
@ -401,22 +405,22 @@ class ForeignRelatedObjectsDescriptor(object):
return manager return manager
def create_many_related_manager(superclass, through=False): def create_many_related_manager(superclass, rel=False):
"""Creates a manager that subclasses 'superclass' (which is a Manager) """Creates a manager that subclasses 'superclass' (which is a Manager)
and adds behavior for many-to-many related objects.""" and adds behavior for many-to-many related objects."""
through = rel.through
class ManyRelatedManager(superclass): class ManyRelatedManager(superclass):
def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None, def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
join_table=None, source_col_name=None, target_col_name=None): join_table=None, source_field_name=None, target_field_name=None):
super(ManyRelatedManager, self).__init__() super(ManyRelatedManager, self).__init__()
self.core_filters = core_filters self.core_filters = core_filters
self.model = model self.model = model
self.symmetrical = symmetrical self.symmetrical = symmetrical
self.instance = instance self.instance = instance
self.join_table = join_table self.source_field_name = source_field_name
self.source_col_name = source_col_name self.target_field_name = target_field_name
self.target_col_name = target_col_name
self.through = through self.through = through
self._pk_val = self.instance._get_pk_val() self._pk_val = self.instance.pk
if self._pk_val is None: if self._pk_val is None:
raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__)
@ -425,36 +429,37 @@ def create_many_related_manager(superclass, through=False):
# If the ManyToMany relation has an intermediary model, # If the ManyToMany relation has an intermediary model,
# the add and remove methods do not exist. # the add and remove methods do not exist.
if through is None: if rel.through._meta.auto_created:
def add(self, *objs): def add(self, *objs):
self._add_items(self.source_col_name, self.target_col_name, *objs) self._add_items(self.source_field_name, self.target_field_name, *objs)
# If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table
if self.symmetrical: if self.symmetrical:
self._add_items(self.target_col_name, self.source_col_name, *objs) self._add_items(self.target_field_name, self.source_field_name, *objs)
add.alters_data = True add.alters_data = True
def remove(self, *objs): def remove(self, *objs):
self._remove_items(self.source_col_name, self.target_col_name, *objs) self._remove_items(self.source_field_name, self.target_field_name, *objs)
# If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table # If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table
if self.symmetrical: if self.symmetrical:
self._remove_items(self.target_col_name, self.source_col_name, *objs) self._remove_items(self.target_field_name, self.source_field_name, *objs)
remove.alters_data = True remove.alters_data = True
def clear(self): def clear(self):
self._clear_items(self.source_col_name) self._clear_items(self.source_field_name)
# If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table
if self.symmetrical: if self.symmetrical:
self._clear_items(self.target_col_name) self._clear_items(self.target_field_name)
clear.alters_data = True clear.alters_data = True
def create(self, **kwargs): def create(self, **kwargs):
# This check needs to be done here, since we can't later remove this # This check needs to be done here, since we can't later remove this
# from the method lookup table, as we do with add and remove. # from the method lookup table, as we do with add and remove.
if through is not None: if not rel.through._meta.auto_created:
raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through opts = through._meta
raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
new_obj = super(ManyRelatedManager, self).create(**kwargs) new_obj = super(ManyRelatedManager, self).create(**kwargs)
self.add(new_obj) self.add(new_obj)
return new_obj return new_obj
@ -470,43 +475,38 @@ def create_many_related_manager(superclass, through=False):
return obj, created return obj, created
get_or_create.alters_data = True get_or_create.alters_data = True
def _add_items(self, source_col_name, target_col_name, *objs): def _add_items(self, source_field_name, target_field_name, *objs):
# join_table: name of the m2m link table # join_table: name of the m2m link table
# source_col_name: the PK colname in join_table for the source object # source_field_name: the PK fieldname in join_table for the source object
# target_col_name: the PK colname in join_table for the target object # target_col_name: the PK fieldname in join_table for the target object
# *objs - objects to add. Either object instances, or primary keys of object instances. # *objs - objects to add. Either object instances, or primary keys of object instances.
# If there aren't any objects, there is nothing to do. # If there aren't any objects, there is nothing to do.
from django.db.models import Model
if objs: if objs:
from django.db.models.base import Model
# Check that all the objects are of the right type
new_ids = set() new_ids = set()
for obj in objs: for obj in objs:
if isinstance(obj, self.model): if isinstance(obj, self.model):
new_ids.add(obj._get_pk_val()) new_ids.add(obj.pk)
elif isinstance(obj, Model): elif isinstance(obj, Model):
raise TypeError, "'%s' instance expected" % self.model._meta.object_name raise TypeError, "'%s' instance expected" % self.model._meta.object_name
else: else:
new_ids.add(obj) new_ids.add(obj)
# Add the newly created or already existing objects to the join table. vals = self.through._default_manager.values_list(target_field_name, flat=True)
# First find out which items are already added, to avoid adding them twice vals = vals.filter(**{
cursor = connection.cursor() source_field_name: self._pk_val,
cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \ '%s__in' % target_field_name: new_ids,
(target_col_name, self.join_table, source_col_name, })
target_col_name, ",".join(['%s'] * len(new_ids))), vals = set(vals)
[self._pk_val] + list(new_ids))
existing_ids = set([row[0] for row in cursor.fetchall()])
# Add the ones that aren't there already # Add the ones that aren't there already
for obj_id in (new_ids - existing_ids): for obj_id in (new_ids - vals):
cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ self.through._default_manager.create(**{
(self.join_table, source_col_name, target_col_name), '%s_id' % source_field_name: self._pk_val,
[self._pk_val, obj_id]) '%s_id' % target_field_name: obj_id,
# FIXME, once this isn't in related.py it should conditionally })
# use the right DB.
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
def _remove_items(self, source_col_name, target_col_name, *objs): def _remove_items(self, source_field_name, target_field_name, *objs):
# source_col_name: the PK colname in join_table for the source object # source_col_name: the PK colname in join_table for the source object
# target_col_name: the PK colname in join_table for the target object # target_col_name: the PK colname in join_table for the target object
# *objs - objects to remove # *objs - objects to remove
@ -517,26 +517,20 @@ def create_many_related_manager(superclass, through=False):
old_ids = set() old_ids = set()
for obj in objs: for obj in objs:
if isinstance(obj, self.model): if isinstance(obj, self.model):
old_ids.add(obj._get_pk_val()) old_ids.add(obj.pk)
else: else:
old_ids.add(obj) old_ids.add(obj)
# Remove the specified objects from the join table # Remove the specified objects from the join table
cursor = connection.cursor() self.through._default_manager.filter(**{
cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \ source_field_name: self._pk_val,
(self.join_table, source_col_name, '%s__in' % target_field_name: old_ids
target_col_name, ",".join(['%s'] * len(old_ids))), }).delete()
[self._pk_val] + list(old_ids))
# TODO
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
def _clear_items(self, source_col_name): def _clear_items(self, source_field_name):
# source_col_name: the PK colname in join_table for the source object # source_col_name: the PK colname in join_table for the source object
cursor = connection.cursor() self.through._default_manager.filter(**{
cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ source_field_name: self._pk_val
(self.join_table, source_col_name), }).delete()
[self._pk_val])
# TODO
transaction.commit_unless_managed(using=DEFAULT_DB_ALIAS)
return ManyRelatedManager return ManyRelatedManager
@ -558,17 +552,15 @@ class ManyRelatedObjectsDescriptor(object):
# model's default manager. # model's default manager.
rel_model = self.related.model rel_model = self.related.model
superclass = rel_model._default_manager.__class__ superclass = rel_model._default_manager.__class__
RelatedManager = create_many_related_manager(superclass, self.related.field.rel.through) RelatedManager = create_many_related_manager(superclass, self.related.field.rel)
qn = connection.ops.quote_name
manager = RelatedManager( manager = RelatedManager(
model=rel_model, model=rel_model,
core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()}, core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()},
instance=instance, instance=instance,
symmetrical=False, symmetrical=False,
join_table=qn(self.related.field.m2m_db_table()), source_field_name=self.related.field.m2m_reverse_field_name(),
source_col_name=qn(self.related.field.m2m_reverse_name()), target_field_name=self.related.field.m2m_field_name()
target_col_name=qn(self.related.field.m2m_column_name())
) )
return manager return manager
@ -577,9 +569,9 @@ class ManyRelatedObjectsDescriptor(object):
if instance is None: if instance is None:
raise AttributeError, "Manager must be accessed via instance" raise AttributeError, "Manager must be accessed via instance"
through = getattr(self.related.field.rel, 'through', None) if not self.related.field.rel.through._meta.auto_created:
if through is not None: opts = self.related.field.rel.through._meta
raise AttributeError, "Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through raise AttributeError, "Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
manager = self.__get__(instance) manager = self.__get__(instance)
manager.clear() manager.clear()
@ -595,6 +587,13 @@ class ReverseManyRelatedObjectsDescriptor(object):
def __init__(self, m2m_field): def __init__(self, m2m_field):
self.field = m2m_field self.field = m2m_field
def _through(self):
# through is provided so that you have easy access to the through
# model (Book.authors.through) for inlines, etc. This is done as
# a property to ensure that the fully resolved value is returned.
return self.field.rel.through
through = property(_through)
def __get__(self, instance, instance_type=None): def __get__(self, instance, instance_type=None):
if instance is None: if instance is None:
return self return self
@ -603,17 +602,15 @@ class ReverseManyRelatedObjectsDescriptor(object):
# model's default manager. # model's default manager.
rel_model=self.field.rel.to rel_model=self.field.rel.to
superclass = rel_model._default_manager.__class__ superclass = rel_model._default_manager.__class__
RelatedManager = create_many_related_manager(superclass, self.field.rel.through) RelatedManager = create_many_related_manager(superclass, self.field.rel)
qn = connection.ops.quote_name
manager = RelatedManager( manager = RelatedManager(
model=rel_model, model=rel_model,
core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()}, core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()},
instance=instance, instance=instance,
symmetrical=(self.field.rel.symmetrical and isinstance(instance, rel_model)), symmetrical=(self.field.rel.symmetrical and isinstance(instance, rel_model)),
join_table=qn(self.field.m2m_db_table()), source_field_name=self.field.m2m_field_name(),
source_col_name=qn(self.field.m2m_column_name()), target_field_name=self.field.m2m_reverse_field_name()
target_col_name=qn(self.field.m2m_reverse_name())
) )
return manager return manager
@ -622,9 +619,9 @@ class ReverseManyRelatedObjectsDescriptor(object):
if instance is None: if instance is None:
raise AttributeError, "Manager must be accessed via instance" raise AttributeError, "Manager must be accessed via instance"
through = getattr(self.field.rel, 'through', None) if not self.field.rel.through._meta.auto_created:
if through is not None: opts = self.field.rel.through._meta
raise AttributeError, "Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through raise AttributeError, "Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)
manager = self.__get__(instance) manager = self.__get__(instance)
manager.clear() manager.clear()
@ -646,6 +643,10 @@ class ManyToOneRel(object):
self.multiple = True self.multiple = True
self.parent_link = parent_link self.parent_link = parent_link
def is_hidden(self):
"Should the related object be hidden?"
return self.related_name and self.related_name[-1] == '+'
def get_related_field(self): def get_related_field(self):
""" """
Returns the Field in the 'to' object to which this relationship is Returns the Field in the 'to' object to which this relationship is
@ -677,6 +678,10 @@ class ManyToManyRel(object):
self.multiple = True self.multiple = True
self.through = through self.through = through
def is_hidden(self):
"Should the related object be hidden?"
return self.related_name and self.related_name[-1] == '+'
def get_related_field(self): def get_related_field(self):
""" """
Returns the field in the to' object to which this relationship is tied Returns the field in the to' object to which this relationship is tied
@ -694,7 +699,10 @@ class ForeignKey(RelatedField, Field):
assert isinstance(to, basestring), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT) assert isinstance(to, basestring), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT)
else: else:
assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name) assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name)
to_field = to_field or to._meta.pk.name # For backwards compatibility purposes, we need to *try* and set
# the to_field during FK construction. It won't be guaranteed to
# be correct until contribute_to_class is called. Refs #12190.
to_field = to_field or (to._meta.pk and to._meta.pk.name)
kwargs['verbose_name'] = kwargs.get('verbose_name', None) kwargs['verbose_name'] = kwargs.get('verbose_name', None)
kwargs['rel'] = rel_class(to, to_field, kwargs['rel'] = rel_class(to, to_field,
@ -748,7 +756,12 @@ class ForeignKey(RelatedField, Field):
cls._meta.duplicate_targets[self.column] = (target, "o2m") cls._meta.duplicate_targets[self.column] = (target, "o2m")
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) # Internal FK's - i.e., those with a related name ending with '+' -
# don't get a related descriptor.
if not self.rel.is_hidden():
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
if self.rel.field_name is None:
self.rel.field_name = cls._meta.pk.name
def formfield(self, **kwargs): def formfield(self, **kwargs):
defaults = { defaults = {
@ -795,6 +808,45 @@ class OneToOneField(ForeignKey):
return None return None
return super(OneToOneField, self).formfield(**kwargs) return super(OneToOneField, self).formfield(**kwargs)
def create_many_to_many_intermediary_model(field, klass):
from django.db import models
managed = True
if isinstance(field.rel.to, basestring) and field.rel.to != RECURSIVE_RELATIONSHIP_CONSTANT:
to = field.rel.to
to_model = field.rel.to
def set_managed(field, model, cls):
field.rel.through._meta.managed = model._meta.managed or cls._meta.managed
add_lazy_relation(klass, field, to_model, set_managed)
elif isinstance(field.rel.to, basestring):
to = klass._meta.object_name
to_model = klass
managed = klass._meta.managed
else:
to = field.rel.to._meta.object_name
to_model = field.rel.to
managed = klass._meta.managed or to_model._meta.managed
name = '%s_%s' % (klass._meta.object_name, field.name)
if field.rel.to == RECURSIVE_RELATIONSHIP_CONSTANT or field.rel.to == klass._meta.object_name:
from_ = 'from_%s' % to.lower()
to = 'to_%s' % to.lower()
else:
from_ = klass._meta.object_name.lower()
to = to.lower()
meta = type('Meta', (object,), {
'db_table': field._get_m2m_db_table(klass._meta),
'managed': managed,
'auto_created': klass,
'app_label': klass._meta.app_label,
'unique_together': (from_, to)
})
# Construct and return the new class.
return type(name, (models.Model,), {
'Meta': meta,
'__module__': klass.__module__,
from_: models.ForeignKey(klass, related_name='%s+' % name),
to: models.ForeignKey(to_model, related_name='%s+' % name)
})
class ManyToManyField(RelatedField, Field): class ManyToManyField(RelatedField, Field):
def __init__(self, to, **kwargs): def __init__(self, to, **kwargs):
try: try:
@ -811,10 +863,7 @@ class ManyToManyField(RelatedField, Field):
self.db_table = kwargs.pop('db_table', None) self.db_table = kwargs.pop('db_table', None)
if kwargs['rel'].through is not None: if kwargs['rel'].through is not None:
self.creates_table = False
assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used." assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used."
else:
self.creates_table = True
Field.__init__(self, **kwargs) Field.__init__(self, **kwargs)
@ -827,62 +876,45 @@ class ManyToManyField(RelatedField, Field):
def _get_m2m_db_table(self, opts): def _get_m2m_db_table(self, opts):
"Function that can be curried to provide the m2m table name for this relation" "Function that can be curried to provide the m2m table name for this relation"
if self.rel.through is not None: if self.rel.through is not None:
return self.rel.through_model._meta.db_table return self.rel.through._meta.db_table
elif self.db_table: elif self.db_table:
return self.db_table return self.db_table
else: else:
return util.truncate_name('%s_%s' % (opts.db_table, self.name), return util.truncate_name('%s_%s' % (opts.db_table, self.name),
connection.ops.max_name_length()) connection.ops.max_name_length())
def _get_m2m_column_name(self, related): def _get_m2m_attr(self, related, attr):
"Function that can be curried to provide the source column name for the m2m table" "Function that can be curried to provide the source column name for the m2m table"
try: cache_attr = '_m2m_%s_cache' % attr
return self._m2m_column_name_cache if hasattr(self, cache_attr):
except: return getattr(self, cache_attr)
if self.rel.through is not None: for f in self.rel.through._meta.fields:
for f in self.rel.through_model._meta.fields: if hasattr(f,'rel') and f.rel and f.rel.to == related.model:
if hasattr(f,'rel') and f.rel and f.rel.to == related.model: setattr(self, cache_attr, getattr(f, attr))
self._m2m_column_name_cache = f.column return getattr(self, cache_attr)
break
# If this is an m2m relation to self, avoid the inevitable name clash
elif related.model == related.parent_model:
self._m2m_column_name_cache = 'from_' + related.model._meta.object_name.lower() + '_id'
else:
self._m2m_column_name_cache = related.model._meta.object_name.lower() + '_id'
# Return the newly cached value def _get_m2m_reverse_attr(self, related, attr):
return self._m2m_column_name_cache
def _get_m2m_reverse_name(self, related):
"Function that can be curried to provide the related column name for the m2m table" "Function that can be curried to provide the related column name for the m2m table"
try: cache_attr = '_m2m_reverse_%s_cache' % attr
return self._m2m_reverse_name_cache if hasattr(self, cache_attr):
except: return getattr(self, cache_attr)
if self.rel.through is not None: found = False
found = False for f in self.rel.through._meta.fields:
for f in self.rel.through_model._meta.fields: if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model:
if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model: if related.model == related.parent_model:
if related.model == related.parent_model: # If this is an m2m-intermediate to self,
# If this is an m2m-intermediate to self, # the first foreign key you find will be
# the first foreign key you find will be # the source column. Keep searching for
# the source column. Keep searching for # the second foreign key.
# the second foreign key. if found:
if found: setattr(self, cache_attr, getattr(f, attr))
self._m2m_reverse_name_cache = f.column break
break else:
else: found = True
found = True else:
else: setattr(self, cache_attr, getattr(f, attr))
self._m2m_reverse_name_cache = f.column break
break return getattr(self, cache_attr)
# If this is an m2m relation to self, avoid the inevitable name clash
elif related.model == related.parent_model:
self._m2m_reverse_name_cache = 'to_' + related.parent_model._meta.object_name.lower() + '_id'
else:
self._m2m_reverse_name_cache = related.parent_model._meta.object_name.lower() + '_id'
# Return the newly cached value
return self._m2m_reverse_name_cache
def isValidIDList(self, field_data, all_data): def isValidIDList(self, field_data, all_data):
"Validates that the value is a valid list of foreign keys" "Validates that the value is a valid list of foreign keys"
@ -924,10 +956,17 @@ class ManyToManyField(RelatedField, Field):
# specify *what* on my non-reversible relation?!"), so we set it up # specify *what* on my non-reversible relation?!"), so we set it up
# automatically. The funky name reduces the chance of an accidental # automatically. The funky name reduces the chance of an accidental
# clash. # clash.
if self.rel.symmetrical and self.rel.to == "self" and self.rel.related_name is None: if self.rel.symmetrical and (self.rel.to == "self" or self.rel.to == cls._meta.object_name):
self.rel.related_name = "%s_rel_+" % name self.rel.related_name = "%s_rel_+" % name
super(ManyToManyField, self).contribute_to_class(cls, name) super(ManyToManyField, self).contribute_to_class(cls, name)
# The intermediate m2m model is not auto created if:
# 1) There is a manually specified intermediate, or
# 2) The class owning the m2m field is abstract.
if not self.rel.through and not cls._meta.abstract:
self.rel.through = create_many_to_many_intermediary_model(self, cls)
# Add the descriptor for the m2m relation # Add the descriptor for the m2m relation
setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self)) setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self))
@ -938,11 +977,8 @@ class ManyToManyField(RelatedField, Field):
# work correctly. # work correctly.
if isinstance(self.rel.through, basestring): if isinstance(self.rel.through, basestring):
def resolve_through_model(field, model, cls): def resolve_through_model(field, model, cls):
field.rel.through_model = model field.rel.through = model
add_lazy_relation(cls, self, self.rel.through, resolve_through_model) add_lazy_relation(cls, self, self.rel.through, resolve_through_model)
elif self.rel.through:
self.rel.through_model = self.rel.through
self.rel.through = self.rel.through._meta.object_name
if isinstance(self.rel.to, basestring): if isinstance(self.rel.to, basestring):
target = self.rel.to target = self.rel.to
@ -951,15 +987,17 @@ class ManyToManyField(RelatedField, Field):
cls._meta.duplicate_targets[self.column] = (target, "m2m") cls._meta.duplicate_targets[self.column] = (target, "m2m")
def contribute_to_related_class(self, cls, related): def contribute_to_related_class(self, cls, related):
# m2m relations to self do not have a ManyRelatedObjectsDescriptor, # Internal M2Ms (i.e., those with a related name ending with '+')
# as it would be redundant - unless the field is non-symmetrical. # don't get a related descriptor.
if related.model != related.parent_model or not self.rel.symmetrical: if not self.rel.is_hidden():
# Add the descriptor for the m2m relation
setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related)) setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related))
# Set up the accessors for the column names on the m2m table # Set up the accessors for the column names on the m2m table
self.m2m_column_name = curry(self._get_m2m_column_name, related) self.m2m_column_name = curry(self._get_m2m_attr, related, 'column')
self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related) self.m2m_reverse_name = curry(self._get_m2m_reverse_attr, related, 'column')
self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
def set_attributes_from_rel(self): def set_attributes_from_rel(self):
pass pass

View File

@ -131,19 +131,25 @@ class AppCache(object):
self._populate() self._populate()
return self.app_errors return self.app_errors
def get_models(self, app_mod=None): def get_models(self, app_mod=None, include_auto_created=False):
""" """
Given a module containing models, returns a list of the models. Given a module containing models, returns a list of the models.
Otherwise returns a list of all installed models. Otherwise returns a list of all installed models.
By default, auto-created models (i.e., m2m models without an
explicit intermediate table) are not included. However, if you
specify include_auto_created=True, they will be.
""" """
self._populate() self._populate()
if app_mod: if app_mod:
return self.app_models.get(app_mod.__name__.split('.')[-2], SortedDict()).values() model_list = self.app_models.get(app_mod.__name__.split('.')[-2], SortedDict()).values()
else: else:
model_list = [] model_list = []
for app_entry in self.app_models.itervalues(): for app_entry in self.app_models.itervalues():
model_list.extend(app_entry.values()) model_list.extend(app_entry.values())
return model_list if not include_auto_created:
return filter(lambda o: not o._meta.auto_created, model_list)
return model_list
def get_model(self, app_label, model_name, seed_cache=True): def get_model(self, app_label, model_name, seed_cache=True):
""" """

View File

@ -1,5 +1,4 @@
import copy import copy
from django.db.models.query import QuerySet, EmptyQuerySet, insert_query from django.db.models.query import QuerySet, EmptyQuerySet, insert_query
from django.db.models import signals from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
@ -176,6 +175,9 @@ class Manager(object):
def using(self, *args, **kwargs): def using(self, *args, **kwargs):
return self.get_query_set().using(*args, **kwargs) return self.get_query_set().using(*args, **kwargs)
def exists(self, *args, **kwargs):
return self.get_query_set().exists(*args, **kwargs)
def _insert(self, values, **kwargs): def _insert(self, values, **kwargs):
return insert_query(self.model, values, **kwargs) return insert_query(self.model, values, **kwargs)

View File

@ -21,7 +21,7 @@ get_verbose_name = lambda class_name: re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|
DEFAULT_NAMES = ('verbose_name', 'db_table', 'ordering', DEFAULT_NAMES = ('verbose_name', 'db_table', 'ordering',
'unique_together', 'permissions', 'get_latest_by', 'unique_together', 'permissions', 'get_latest_by',
'order_with_respect_to', 'app_label', 'db_tablespace', 'order_with_respect_to', 'app_label', 'db_tablespace',
'abstract', 'managed', 'proxy', 'using') 'abstract', 'managed', 'proxy', 'using', 'auto_created')
class Options(object): class Options(object):
def __init__(self, meta, app_label=None): def __init__(self, meta, app_label=None):
@ -48,6 +48,7 @@ class Options(object):
self.parents = SortedDict() self.parents = SortedDict()
self.duplicate_targets = {} self.duplicate_targets = {}
self.using = None self.using = None
self.auto_created = False
# To handle various inheritance situations, we need to track where # To handle various inheritance situations, we need to track where
# managers came from (concrete or abstract base classes). # managers came from (concrete or abstract base classes).

View File

@ -2,11 +2,6 @@
The main QuerySet implementation. This provides the public API for the ORM. The main QuerySet implementation. This provides the public API for the ORM.
""" """
try:
set
except NameError:
from sets import Set as set # Python 2.3 fallback
from copy import deepcopy from copy import deepcopy
from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS from django.db import connections, transaction, IntegrityError, DEFAULT_DB_ALIAS
@ -15,7 +10,6 @@ from django.db.models.fields import DateField
from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory
from django.db.models import signals, sql from django.db.models import signals, sql
# Used to control how many objects are worked with at once in some cases (e.g. # Used to control how many objects are worked with at once in some cases (e.g.
# when deleting objects). # when deleting objects).
CHUNK_SIZE = 100 CHUNK_SIZE = 100
@ -453,6 +447,11 @@ class QuerySet(object):
return query.execute_sql(None) return query.execute_sql(None)
_update.alters_data = True _update.alters_data = True
def exists(self):
if self._result_cache is None:
return self.query.has_results()
return bool(self._result_cache)
################################################## ##################################################
# PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS # # PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
################################################## ##################################################
@ -1086,7 +1085,8 @@ def delete_objects(seen_objs, using):
# Pre-notify all instances to be deleted. # Pre-notify all instances to be deleted.
for pk_val, instance in items: for pk_val, instance in items:
signals.pre_delete.send(sender=cls, instance=instance) if not cls._meta.auto_created:
signals.pre_delete.send(sender=cls, instance=instance)
pk_list = [pk for pk,instance in items] pk_list = [pk for pk,instance in items]
del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection) del_query = connection.ops.query_class(sql.Query, sql.DeleteQuery)(cls, connection)
@ -1120,7 +1120,8 @@ def delete_objects(seen_objs, using):
if field.rel and field.null and field.rel.to in seen_objs: if field.rel and field.null and field.rel.to in seen_objs:
setattr(instance, field.attname, None) setattr(instance, field.attname, None)
signals.post_delete.send(sender=cls, instance=instance) if not cls._meta.auto_created:
signals.post_delete.send(sender=cls, instance=instance)
setattr(instance, cls._meta.pk.attname, None) setattr(instance, cls._meta.pk.attname, None)
if forced_managed: if forced_managed:

View File

@ -8,7 +8,6 @@ all about the internals of models in order to get the information it needs.
""" """
from copy import deepcopy from copy import deepcopy
from django.utils.tree import Node from django.utils.tree import Node
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.utils.encoding import force_unicode from django.utils.encoding import force_unicode
@ -24,11 +23,6 @@ from django.core.exceptions import FieldError
from datastructures import EmptyResultSet, Empty, MultiJoin from datastructures import EmptyResultSet, Empty, MultiJoin
from constants import * from constants import *
try:
set
except NameError:
from sets import Set as set # Python 2.3 fallback
__all__ = ['Query'] __all__ = ['Query']
class Query(object): class Query(object):
@ -386,6 +380,16 @@ class Query(object):
return number return number
def has_results(self):
q = self.clone()
q.add_extra({'a': 1}, None, None, None, None, None)
q.add_fields(())
q.set_extra_mask(('a',))
q.set_aggregate_mask(())
q.clear_ordering()
q.set_limits(high=1)
return bool(q.execute_sql(SINGLE))
def as_sql(self, with_limits=True, with_col_aliases=False): def as_sql(self, with_limits=True, with_col_aliases=False):
""" """
Creates the SQL for this query. Returns the SQL string and list of Creates the SQL for this query. Returns the SQL string and list of

View File

@ -421,7 +421,7 @@ class DateQuery(Query):
self.select = [select] self.select = [select]
self.select_fields = [None] self.select_fields = [None]
self.select_related = False # See #7097. self.select_related = False # See #7097.
self.extra = {} self.set_extra_mask([])
self.distinct = True self.distinct = True
self.order_by = order == 'ASC' and [1] or [-1] self.order_by = order == 'ASC' and [1] or [-1]

View File

@ -319,9 +319,7 @@ class BaseModelForm(BaseForm):
if self.instance.pk is not None: if self.instance.pk is not None:
qs = qs.exclude(pk=self.instance.pk) qs = qs.exclude(pk=self.instance.pk)
# This cute trick with extra/values is the most efficient way to if qs.exists():
# tell if a particular query returns any results.
if qs.extra(select={'a': 1}).values('a').order_by():
if len(unique_check) == 1: if len(unique_check) == 1:
self._errors[unique_check[0]] = ErrorList([self.unique_error_message(unique_check)]) self._errors[unique_check[0]] = ErrorList([self.unique_error_message(unique_check)])
else: else:
@ -354,9 +352,7 @@ class BaseModelForm(BaseForm):
if self.instance.pk is not None: if self.instance.pk is not None:
qs = qs.exclude(pk=self.instance.pk) qs = qs.exclude(pk=self.instance.pk)
# This cute trick with extra/values is the most efficient way to if qs.exists():
# tell if a particular query returns any results.
if qs.extra(select={'a': 1}).values('a').order_by():
self._errors[field] = ErrorList([ self._errors[field] = ErrorList([
self.date_error_message(lookup_type, field, unique_for) self.date_error_message(lookup_type, field, unique_for)
]) ])
@ -476,6 +472,7 @@ class BaseModelFormSet(BaseFormSet):
pk_field = self.model._meta.pk pk_field = self.model._meta.pk
pk = pk_field.get_db_prep_lookup('exact', pk, pk = pk_field.get_db_prep_lookup('exact', pk,
connection=self.get_queryset().query.connection) connection=self.get_queryset().query.connection)
pk = pk_field.get_db_prep_lookup('exact', pk)
if isinstance(pk, list): if isinstance(pk, list):
pk = pk[0] pk = pk[0]
kwargs['instance'] = self._existing_object(pk) kwargs['instance'] = self._existing_object(pk)
@ -710,7 +707,7 @@ class BaseInlineFormSet(BaseModelFormSet):
save_as_new=False, prefix=None): save_as_new=False, prefix=None):
from django.db.models.fields.related import RelatedObject from django.db.models.fields.related import RelatedObject
if instance is None: if instance is None:
self.instance = self.model() self.instance = self.fk.rel.to()
else: else:
self.instance = instance self.instance = instance
self.save_as_new = save_as_new self.save_as_new = save_as_new

Some files were not shown because too many files have changed in this diff Show More