From de32fe83a2e4a20887972c69a0693b94eb25a88b Mon Sep 17 00:00:00 2001 From: William Schwartz Date: Wed, 30 Dec 2020 11:32:46 -0600 Subject: [PATCH] Fixed #32317 -- Refactored loaddata command to make it extensible. Moved deeply nested blocks out of inner loops to improve readability and maintainability. Thanks to Mariusz Felisiak, Shreyas Ravi, and Paolo Melchiorre for feedback. --- django/core/management/commands/loaddata.py | 154 ++++++++++++-------- 1 file changed, 91 insertions(+), 63 deletions(-) diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index 021c6b7ee1..20428f9f10 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -84,6 +84,33 @@ class Command(BaseCommand): if transaction.get_autocommit(self.using): connections[self.using].close() + @cached_property + def compression_formats(self): + """A dict mapping format names to (open function, mode arg) tuples.""" + # Forcing binary mode may be revisited after dropping Python 2 support (see #22399) + compression_formats = { + None: (open, 'rb'), + 'gz': (gzip.GzipFile, 'rb'), + 'zip': (SingleZipReader, 'r'), + 'stdin': (lambda *args: sys.stdin, None), + } + if has_bz2: + compression_formats['bz2'] = (bz2.BZ2File, 'r') + if has_lzma: + compression_formats['lzma'] = (lzma.LZMAFile, 'r') + compression_formats['xz'] = (lzma.LZMAFile, 'r') + return compression_formats + + def reset_sequences(self, connection, models): + """Reset database sequences for the given connection and models.""" + sequence_sql = connection.ops.sequence_reset_sql(no_style(), models) + if sequence_sql: + if self.verbosity >= 2: + self.stdout.write('Resetting sequences') + with connection.cursor() as cursor: + for line in sequence_sql: + cursor.execute(line) + def loaddata(self, fixture_labels): connection = connections[self.using] @@ -94,18 +121,6 @@ class Command(BaseCommand): self.models = set() self.serialization_formats = serializers.get_public_serializer_formats() - # Forcing binary mode may be revisited after dropping Python 2 support (see #22399) - self.compression_formats = { - None: (open, 'rb'), - 'gz': (gzip.GzipFile, 'rb'), - 'zip': (SingleZipReader, 'r'), - 'stdin': (lambda *args: sys.stdin, None), - } - if has_bz2: - self.compression_formats['bz2'] = (bz2.BZ2File, 'r') - if has_lzma: - self.compression_formats['lzma'] = (lzma.LZMAFile, 'r') - self.compression_formats['xz'] = (lzma.LZMAFile, 'r') # Django's test suite repeatedly tries to load initial_data fixtures # from apps that don't have any fixtures. Because disabling constraint @@ -136,13 +151,7 @@ class Command(BaseCommand): # If we found even one object in a fixture, we need to reset the # database sequences. if self.loaded_object_count > 0: - sequence_sql = connection.ops.sequence_reset_sql(no_style(), self.models) - if sequence_sql: - if self.verbosity >= 2: - self.stdout.write('Resetting sequences') - with connection.cursor() as cursor: - for line in sequence_sql: - cursor.execute(line) + self.reset_sequences(connection, self.models) if self.verbosity >= 1: if self.fixture_object_count == self.loaded_object_count: @@ -156,6 +165,31 @@ class Command(BaseCommand): % (self.loaded_object_count, self.fixture_object_count, self.fixture_count) ) + def save_obj(self, obj): + """Save an object if permitted.""" + if ( + obj.object._meta.app_config in self.excluded_apps or + type(obj.object) in self.excluded_models + ): + return False + saved = False + if router.allow_migrate_model(self.using, obj.object.__class__): + saved = True + self.models.add(obj.object.__class__) + try: + obj.save(using=self.using) + # psycopg2 raises ValueError if data contains NUL chars. + except (DatabaseError, IntegrityError, ValueError) as e: + e.args = ('Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s' % { + 'object_label': obj.object._meta.label, + 'pk': obj.object.pk, + 'error_msg': e, + },) + raise + if obj.deferred_fields: + self.objs_with_deferred_fields.append(obj) + return saved + def load_label(self, fixture_label): """Load fixtures files for a given label.""" show_progress = self.verbosity >= 3 @@ -179,29 +213,13 @@ class Command(BaseCommand): for obj in objects: objects_in_fixture += 1 - if (obj.object._meta.app_config in self.excluded_apps or - type(obj.object) in self.excluded_models): - continue - if router.allow_migrate_model(self.using, obj.object.__class__): + if self.save_obj(obj): loaded_objects_in_fixture += 1 - self.models.add(obj.object.__class__) - try: - obj.save(using=self.using) - # psycopg2 raises ValueError if data contains NUL chars. - except (DatabaseError, IntegrityError, ValueError) as e: - e.args = ("Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s" % { - 'object_label': obj.object._meta.label, - 'pk': obj.object.pk, - 'error_msg': e, - },) - raise if show_progress: self.stdout.write( '\rProcessed %i object(s).' % loaded_objects_in_fixture, ending='' ) - if obj.deferred_fields: - self.objs_with_deferred_fields.append(obj) except Exception as e: if not isinstance(e, CommandError): e.args = ("Problem installing fixture '%s': %s" % (fixture_file, e),) @@ -221,20 +239,7 @@ class Command(BaseCommand): RuntimeWarning ) - @functools.lru_cache(maxsize=None) - def find_fixtures(self, fixture_label): - """Find fixture files for a given label.""" - if fixture_label == READ_STDIN: - return [(READ_STDIN, None, READ_STDIN)] - - fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label) - databases = [self.using, None] - cmp_fmts = list(self.compression_formats) if cmp_fmt is None else [cmp_fmt] - ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt] - - if self.verbosity >= 2: - self.stdout.write("Loading '%s' fixtures..." % fixture_name) - + def get_fixture_name_and_dirs(self, fixture_name): dirname, basename = os.path.split(fixture_name) if os.path.isabs(fixture_name): fixture_dirs = [dirname] @@ -242,25 +247,48 @@ class Command(BaseCommand): fixture_dirs = self.fixture_dirs if os.path.sep in os.path.normpath(fixture_name): fixture_dirs = [os.path.join(dir_, dirname) for dir_ in fixture_dirs] - fixture_name = basename + return basename, fixture_dirs - suffixes = ( - '.'.join(ext for ext in combo if ext) - for combo in product(databases, ser_fmts, cmp_fmts) - ) - targets = {'.'.join((fixture_name, suffix)) for suffix in suffixes} + def get_targets(self, fixture_name, ser_fmt, cmp_fmt): + databases = [self.using, None] + cmp_fmts = self.compression_formats if cmp_fmt is None else [cmp_fmt] + ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt] + return { + '%s.%s' % ( + fixture_name, + '.'.join([ext for ext in combo if ext]), + ) for combo in product(databases, ser_fmts, cmp_fmts) + } + def find_fixture_files_in_dir(self, fixture_dir, fixture_name, targets): + fixture_files_in_dir = [] + path = os.path.join(fixture_dir, fixture_name) + for candidate in glob.iglob(glob.escape(path) + '*'): + if os.path.basename(candidate) in targets: + # Save the fixture_dir and fixture_name for future error + # messages. + fixture_files_in_dir.append((candidate, fixture_dir, fixture_name)) + return fixture_files_in_dir + + @functools.lru_cache(maxsize=None) + def find_fixtures(self, fixture_label): + """Find fixture files for a given label.""" + if fixture_label == READ_STDIN: + return [(READ_STDIN, None, READ_STDIN)] + + fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label) + if self.verbosity >= 2: + self.stdout.write("Loading '%s' fixtures..." % fixture_name) + + fixture_name, fixture_dirs = self.get_fixture_name_and_dirs(fixture_name) + targets = self.get_targets(fixture_name, ser_fmt, cmp_fmt) fixture_files = [] for fixture_dir in fixture_dirs: if self.verbosity >= 2: self.stdout.write("Checking %s for fixtures..." % humanize(fixture_dir)) - fixture_files_in_dir = [] - path = os.path.join(fixture_dir, fixture_name) - for candidate in glob.iglob(glob.escape(path) + '*'): - if os.path.basename(candidate) in targets: - # Save the fixture_dir and fixture_name for future error messages. - fixture_files_in_dir.append((candidate, fixture_dir, fixture_name)) - + fixture_files_in_dir = self.find_fixture_files_in_dir( + fixture_dir, fixture_name, targets, + ) if self.verbosity >= 2 and not fixture_files_in_dir: self.stdout.write("No fixture '%s' in %s." % (fixture_name, humanize(fixture_dir)))