diff --git a/django/core/management/commands/loaddata.py b/django/core/management/commands/loaddata.py index 021c6b7ee1..20428f9f10 100644 --- a/django/core/management/commands/loaddata.py +++ b/django/core/management/commands/loaddata.py @@ -84,6 +84,33 @@ class Command(BaseCommand): if transaction.get_autocommit(self.using): connections[self.using].close() + @cached_property + def compression_formats(self): + """A dict mapping format names to (open function, mode arg) tuples.""" + # Forcing binary mode may be revisited after dropping Python 2 support (see #22399) + compression_formats = { + None: (open, 'rb'), + 'gz': (gzip.GzipFile, 'rb'), + 'zip': (SingleZipReader, 'r'), + 'stdin': (lambda *args: sys.stdin, None), + } + if has_bz2: + compression_formats['bz2'] = (bz2.BZ2File, 'r') + if has_lzma: + compression_formats['lzma'] = (lzma.LZMAFile, 'r') + compression_formats['xz'] = (lzma.LZMAFile, 'r') + return compression_formats + + def reset_sequences(self, connection, models): + """Reset database sequences for the given connection and models.""" + sequence_sql = connection.ops.sequence_reset_sql(no_style(), models) + if sequence_sql: + if self.verbosity >= 2: + self.stdout.write('Resetting sequences') + with connection.cursor() as cursor: + for line in sequence_sql: + cursor.execute(line) + def loaddata(self, fixture_labels): connection = connections[self.using] @@ -94,18 +121,6 @@ class Command(BaseCommand): self.models = set() self.serialization_formats = serializers.get_public_serializer_formats() - # Forcing binary mode may be revisited after dropping Python 2 support (see #22399) - self.compression_formats = { - None: (open, 'rb'), - 'gz': (gzip.GzipFile, 'rb'), - 'zip': (SingleZipReader, 'r'), - 'stdin': (lambda *args: sys.stdin, None), - } - if has_bz2: - self.compression_formats['bz2'] = (bz2.BZ2File, 'r') - if has_lzma: - self.compression_formats['lzma'] = (lzma.LZMAFile, 'r') - self.compression_formats['xz'] = (lzma.LZMAFile, 'r') # Django's test suite repeatedly tries to load initial_data fixtures # from apps that don't have any fixtures. Because disabling constraint @@ -136,13 +151,7 @@ class Command(BaseCommand): # If we found even one object in a fixture, we need to reset the # database sequences. if self.loaded_object_count > 0: - sequence_sql = connection.ops.sequence_reset_sql(no_style(), self.models) - if sequence_sql: - if self.verbosity >= 2: - self.stdout.write('Resetting sequences') - with connection.cursor() as cursor: - for line in sequence_sql: - cursor.execute(line) + self.reset_sequences(connection, self.models) if self.verbosity >= 1: if self.fixture_object_count == self.loaded_object_count: @@ -156,6 +165,31 @@ class Command(BaseCommand): % (self.loaded_object_count, self.fixture_object_count, self.fixture_count) ) + def save_obj(self, obj): + """Save an object if permitted.""" + if ( + obj.object._meta.app_config in self.excluded_apps or + type(obj.object) in self.excluded_models + ): + return False + saved = False + if router.allow_migrate_model(self.using, obj.object.__class__): + saved = True + self.models.add(obj.object.__class__) + try: + obj.save(using=self.using) + # psycopg2 raises ValueError if data contains NUL chars. + except (DatabaseError, IntegrityError, ValueError) as e: + e.args = ('Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s' % { + 'object_label': obj.object._meta.label, + 'pk': obj.object.pk, + 'error_msg': e, + },) + raise + if obj.deferred_fields: + self.objs_with_deferred_fields.append(obj) + return saved + def load_label(self, fixture_label): """Load fixtures files for a given label.""" show_progress = self.verbosity >= 3 @@ -179,29 +213,13 @@ class Command(BaseCommand): for obj in objects: objects_in_fixture += 1 - if (obj.object._meta.app_config in self.excluded_apps or - type(obj.object) in self.excluded_models): - continue - if router.allow_migrate_model(self.using, obj.object.__class__): + if self.save_obj(obj): loaded_objects_in_fixture += 1 - self.models.add(obj.object.__class__) - try: - obj.save(using=self.using) - # psycopg2 raises ValueError if data contains NUL chars. - except (DatabaseError, IntegrityError, ValueError) as e: - e.args = ("Could not load %(object_label)s(pk=%(pk)s): %(error_msg)s" % { - 'object_label': obj.object._meta.label, - 'pk': obj.object.pk, - 'error_msg': e, - },) - raise if show_progress: self.stdout.write( '\rProcessed %i object(s).' % loaded_objects_in_fixture, ending='' ) - if obj.deferred_fields: - self.objs_with_deferred_fields.append(obj) except Exception as e: if not isinstance(e, CommandError): e.args = ("Problem installing fixture '%s': %s" % (fixture_file, e),) @@ -221,20 +239,7 @@ class Command(BaseCommand): RuntimeWarning ) - @functools.lru_cache(maxsize=None) - def find_fixtures(self, fixture_label): - """Find fixture files for a given label.""" - if fixture_label == READ_STDIN: - return [(READ_STDIN, None, READ_STDIN)] - - fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label) - databases = [self.using, None] - cmp_fmts = list(self.compression_formats) if cmp_fmt is None else [cmp_fmt] - ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt] - - if self.verbosity >= 2: - self.stdout.write("Loading '%s' fixtures..." % fixture_name) - + def get_fixture_name_and_dirs(self, fixture_name): dirname, basename = os.path.split(fixture_name) if os.path.isabs(fixture_name): fixture_dirs = [dirname] @@ -242,25 +247,48 @@ class Command(BaseCommand): fixture_dirs = self.fixture_dirs if os.path.sep in os.path.normpath(fixture_name): fixture_dirs = [os.path.join(dir_, dirname) for dir_ in fixture_dirs] - fixture_name = basename + return basename, fixture_dirs - suffixes = ( - '.'.join(ext for ext in combo if ext) - for combo in product(databases, ser_fmts, cmp_fmts) - ) - targets = {'.'.join((fixture_name, suffix)) for suffix in suffixes} + def get_targets(self, fixture_name, ser_fmt, cmp_fmt): + databases = [self.using, None] + cmp_fmts = self.compression_formats if cmp_fmt is None else [cmp_fmt] + ser_fmts = self.serialization_formats if ser_fmt is None else [ser_fmt] + return { + '%s.%s' % ( + fixture_name, + '.'.join([ext for ext in combo if ext]), + ) for combo in product(databases, ser_fmts, cmp_fmts) + } + def find_fixture_files_in_dir(self, fixture_dir, fixture_name, targets): + fixture_files_in_dir = [] + path = os.path.join(fixture_dir, fixture_name) + for candidate in glob.iglob(glob.escape(path) + '*'): + if os.path.basename(candidate) in targets: + # Save the fixture_dir and fixture_name for future error + # messages. + fixture_files_in_dir.append((candidate, fixture_dir, fixture_name)) + return fixture_files_in_dir + + @functools.lru_cache(maxsize=None) + def find_fixtures(self, fixture_label): + """Find fixture files for a given label.""" + if fixture_label == READ_STDIN: + return [(READ_STDIN, None, READ_STDIN)] + + fixture_name, ser_fmt, cmp_fmt = self.parse_name(fixture_label) + if self.verbosity >= 2: + self.stdout.write("Loading '%s' fixtures..." % fixture_name) + + fixture_name, fixture_dirs = self.get_fixture_name_and_dirs(fixture_name) + targets = self.get_targets(fixture_name, ser_fmt, cmp_fmt) fixture_files = [] for fixture_dir in fixture_dirs: if self.verbosity >= 2: self.stdout.write("Checking %s for fixtures..." % humanize(fixture_dir)) - fixture_files_in_dir = [] - path = os.path.join(fixture_dir, fixture_name) - for candidate in glob.iglob(glob.escape(path) + '*'): - if os.path.basename(candidate) in targets: - # Save the fixture_dir and fixture_name for future error messages. - fixture_files_in_dir.append((candidate, fixture_dir, fixture_name)) - + fixture_files_in_dir = self.find_fixture_files_in_dir( + fixture_dir, fixture_name, targets, + ) if self.verbosity >= 2 and not fixture_files_in_dir: self.stdout.write("No fixture '%s' in %s." % (fixture_name, humanize(fixture_dir)))