Skip to content

Commit 7267569

Browse files
jackton1pre-commit-ci[bot]Lint Action
authored
Resolve bug with reseeding migrations. (#136)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lint Action <github-action[bot]@github.com>
1 parent da2f1e1 commit 7267569

File tree

3 files changed

+33
-66
lines changed

3 files changed

+33
-66
lines changed

migration_fixer/management/commands/makemigrations.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
from django.core.management.commands.makemigrations import Command as BaseCommand
1212
from django.db import DEFAULT_DB_ALIAS, connections, router
1313
from django.db.migrations.loader import MigrationLoader
14-
from git import InvalidGitRepositoryError, Repo
14+
from git import GitCommandError, InvalidGitRepositoryError, Repo
1515

1616
from migration_fixer.utils import (
1717
fix_numbered_migration,
1818
get_filename,
1919
get_migration_module_path,
2020
migration_sorter,
2121
no_translations,
22-
sibling_nodes,
2322
)
2423

2524

@@ -127,11 +126,19 @@ def handle(self, *app_labels, **options):
127126
force=self.force_update,
128127
)
129128
else:
130-
remote = self.repo.remotes[self.remote]
131-
remote.fetch(
132-
f"{self.default_branch}:{self.default_branch}",
133-
force=self.force_update,
134-
)
129+
try:
130+
remote = self.repo.remotes[self.remote]
131+
remote.fetch(
132+
f"{self.default_branch}:{self.default_branch}",
133+
force=self.force_update,
134+
)
135+
except GitCommandError as e: # pragma: no cover
136+
raise CommandError(
137+
self.style.ERROR(
138+
f"Unable to fetch {self.remote} branch "
139+
f"'{self.default_branch}': {e.stderr}",
140+
),
141+
)
135142

136143
if self.verbosity >= 2:
137144
self.stdout.write(
@@ -170,13 +177,9 @@ def handle(self, *app_labels, **options):
170177
):
171178
loader.check_consistent_history(connection)
172179

173-
conflicts = {
174-
app_name: sibling_nodes(loader.graph, app_name)
175-
for app_name in loader.detect_conflicts()
176-
}
180+
conflict_leaf_nodes = loader.detect_conflicts()
177181

178-
for app_label in conflicts:
179-
conflict = conflicts[app_label]
182+
for app_label, leaf_nodes in conflict_leaf_nodes.items():
180183
migration_module, _ = loader.migrations_module(app_label)
181184
migration_path = get_migration_module_path(migration_module)
182185

@@ -202,43 +205,24 @@ def handle(self, *app_labels, **options):
202205
)
203206
]
204207

205-
# Only consider files from the current conflict.
206-
conflict_base = [
207-
get_filename(path)
208-
for path in changed_files
209-
if get_filename(path) in conflict
210-
][0]
211-
212208
sorted_changed_files = sorted(
213209
changed_files,
214210
key=partial(migration_sorter, app_label=app_label),
215211
)
216212

217-
changed_files = [
218-
path
219-
for path in sorted_changed_files
220-
if (
221-
int(get_filename(path).split("_")[0])
222-
>= int(conflict_base.split("_")[0])
223-
)
224-
]
225-
226213
# Local migration
227214
local_filenames = [
228-
get_filename(p) for p in changed_files
215+
get_filename(p) for p in sorted_changed_files
229216
]
230-
if self.verbosity >= 2:
231-
self.stdout.write(
232-
f"Retrieving the last migration on: {self.default_branch}"
233-
)
234217

235-
last_remote = [
218+
# Calculate the last changed file on the default branch
219+
conflict_bases = [
236220
name
237-
for name in conflict
221+
for name in leaf_nodes
238222
if name not in local_filenames
239223
]
240224

241-
if not last_remote: # pragma: no cover
225+
if not conflict_bases: # pragma: no cover
242226
raise CommandError(
243227
self.style.ERROR(
244228
f"Unable to determine the last migration on: "
@@ -248,12 +232,14 @@ def handle(self, *app_labels, **options):
248232
)
249233
)
250234

251-
last_remote_filename, *rest = last_remote
252-
changed_files = changed_files or [
253-
f"{fname}.py" for fname in rest
254-
]
235+
conflict_base = conflict_bases[0]
236+
237+
if self.verbosity >= 2:
238+
self.stdout.write(
239+
f"Retrieving the last migration on: {self.default_branch}"
240+
)
255241

256-
seed_split = last_remote_filename.split("_")
242+
seed_split = conflict_base.split("_")
257243

258244
if (
259245
seed_split
@@ -269,8 +255,8 @@ def handle(self, *app_labels, **options):
269255
app_label=app_label,
270256
migration_path=migration_path,
271257
seed=int(seed_split[0]),
272-
start_name=last_remote_filename,
273-
changed_files=changed_files,
258+
start_name=conflict_base,
259+
changed_files=sorted_changed_files,
274260
writer=(
275261
lambda m: self.stdout.write(m)
276262
if self.verbosity >= 2
@@ -279,7 +265,7 @@ def handle(self, *app_labels, **options):
279265
)
280266
else: # pragma: no cover
281267
raise ValueError(
282-
f"Unable to fix migration: {last_remote_filename}. \n"
268+
f"Unable to fix migration: {conflict_base}. \n"
283269
f"NOTE: It needs to begin with a number. eg. 0001_*",
284270
)
285271
except (ValueError, IndexError, TypeError) as e:

migration_fixer/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from importlib import import_module
44
from itertools import count
55
from pathlib import Path
6-
from typing import Callable, List, Optional
7-
8-
from django.db.migrations.graph import MigrationGraph
6+
from typing import Callable, List
97

108
DEFAULT_TIMEOUT = 120
119
MIGRATION_REGEX = "\\((?P<comma>['\"]){app_label}(['\"]),\\s(['\"])(?P<conflict_migration>.*)(['\"])\\),"
@@ -140,20 +138,3 @@ def get_migration_module_path(migration_module_path: str) -> Path:
140138
raise
141139

142140
return Path(os.path.dirname(os.path.abspath(migration_module.__file__)))
143-
144-
145-
def sibling_nodes(graph: MigrationGraph, app_name: Optional[str] = None) -> List[str]:
146-
"""
147-
Return all sibling nodes that have the same parent
148-
- it's usually the result of a VCS merge and needs some user input.
149-
"""
150-
siblings = set()
151-
152-
for node in graph.nodes:
153-
if len(graph.node_map[node].children) > 1 and (
154-
not app_name or app_name == node[0]
155-
):
156-
for child in graph.node_map[node].children:
157-
siblings.add(child[-1])
158-
159-
return sorted(siblings)

0 commit comments

Comments
 (0)