1 | import pkgutil |
---|
2 | import importlib |
---|
3 | import logging |
---|
4 | from bisect import insort, bisect_right |
---|
5 | |
---|
6 | from pyramid_zodbconn import get_connection |
---|
7 | from pyramid.scripting import prepare |
---|
8 | from pyramid.paster import bootstrap |
---|
9 | from pyramid.events import subscriber, ApplicationCreated |
---|
10 | |
---|
11 | from transaction import commit |
---|
12 | |
---|
13 | # Cache for maximum versions of migrations for packages |
---|
14 | MAX_CACHE = {} |
---|
15 | log = logging.getLogger(__name__) |
---|
16 | |
---|
17 | |
---|
18 | def get_indexes(package_name): |
---|
19 | """ |
---|
20 | Gets a sorted list of migrations in a package |
---|
21 | """ |
---|
22 | package = importlib.import_module(package_name) |
---|
23 | indexes = [] |
---|
24 | for (loader, name, ispkg) in pkgutil.iter_modules(package.__path__): |
---|
25 | try: |
---|
26 | insort(indexes, int(name)) |
---|
27 | except ValueError: |
---|
28 | continue |
---|
29 | return indexes |
---|
30 | |
---|
31 | |
---|
32 | def get_max(package_name): |
---|
33 | """ |
---|
34 | Get the maximum version for a package of migrations |
---|
35 | """ |
---|
36 | if package_name in MAX_CACHE: |
---|
37 | return MAX_CACHE[package_name] |
---|
38 | |
---|
39 | indexes = get_indexes(package_name) |
---|
40 | retval = 0 |
---|
41 | if len(indexes) > 0: |
---|
42 | retval = indexes[-1] |
---|
43 | |
---|
44 | MAX_CACHE[package_name] = retval |
---|
45 | return retval |
---|
46 | |
---|
47 | |
---|
48 | def set_max_version(zodb_root, package_name): |
---|
49 | """ |
---|
50 | Set the version to the maximum in the zodb_root using the given package |
---|
51 | name |
---|
52 | """ |
---|
53 | zodb_root = set_version(zodb_root, get_max(package_name)) |
---|
54 | return zodb_root |
---|
55 | |
---|
56 | |
---|
57 | def reset_version(request, version): |
---|
58 | """ |
---|
59 | Forces the database version to a specific one and commits |
---|
60 | """ |
---|
61 | dbroot = get_connection(request).root() |
---|
62 | set_version(dbroot, version) |
---|
63 | commit() |
---|
64 | |
---|
65 | |
---|
66 | def set_version(zodb_root, version): |
---|
67 | """ |
---|
68 | Sets the version |
---|
69 | """ |
---|
70 | zodb_root['database_version'] = version |
---|
71 | return zodb_root |
---|
72 | |
---|
73 | |
---|
74 | def run_migrations(request, root, package_name): |
---|
75 | """ |
---|
76 | Run migrations from a package_name |
---|
77 | """ |
---|
78 | indexes = get_indexes(package_name) |
---|
79 | dbroot = get_connection(request).root() |
---|
80 | current = dbroot.get('database_version', 0) |
---|
81 | migrations_to_apply = indexes[bisect_right(indexes, current):] |
---|
82 | if len(migrations_to_apply) == 0: |
---|
83 | log.info("Your database is in the latest version: '%i'. No migrations" |
---|
84 | " will be applied." % current) |
---|
85 | else: |
---|
86 | log.info("Starting migrations from %i" % current) |
---|
87 | for index in migrations_to_apply: |
---|
88 | migration = importlib.import_module("%s.%s" % (package_name, index)) |
---|
89 | if not hasattr(migration, 'migrate'): |
---|
90 | log.error('No migrate method found for %s' % migration.__name__) |
---|
91 | return False |
---|
92 | log.info("Running migration %i" % index) |
---|
93 | if hasattr(migration.migrate, '__doc__'): |
---|
94 | doc = migration.migrate.__doc__.strip().split('\n') |
---|
95 | description = [] |
---|
96 | for line in doc: |
---|
97 | line = line.strip() |
---|
98 | if line == '': |
---|
99 | break |
---|
100 | description.append(line) |
---|
101 | first_paragraph = ' '.join(description) |
---|
102 | log.info('"""%s"""' % first_paragraph) |
---|
103 | migration.migrate(root) |
---|
104 | set_version(dbroot, index) |
---|
105 | commit() |
---|
106 | |
---|
107 | |
---|
108 | def closer_wrapper(env): |
---|
109 | closer = env['closer'] |
---|
110 | registry = env['registry'] |
---|
111 | root_factory = env['root_factory'] |
---|
112 | package = registry.settings.get( |
---|
113 | 'migrations_package', root_factory.__module__ + '.migrations') |
---|
114 | # Run migrations |
---|
115 | try: |
---|
116 | run_migrations(env['request'], env['root'], package) |
---|
117 | finally: |
---|
118 | closer() |
---|
119 | |
---|
120 | |
---|
121 | def output(*args): # pragma: nocover |
---|
122 | print(args) |
---|
123 | |
---|
124 | |
---|
125 | def command_line_main(argv): |
---|
126 | """ |
---|
127 | migrates the given database (Command line) |
---|
128 | |
---|
129 | This command will use the bootstrap system from pyramid to start the |
---|
130 | appllication without webserver. |
---|
131 | |
---|
132 | Use this function to create a command line caller for the migrations. Pass |
---|
133 | the sys.argv as parameter. |
---|
134 | """ |
---|
135 | # Arguments |
---|
136 | args = argv[1:] |
---|
137 | if not len(args) == 1: |
---|
138 | output('You must provide at least one argument: configuration file') |
---|
139 | return 1 |
---|
140 | config_uri = args[0] |
---|
141 | # Prepare application |
---|
142 | env = bootstrap(config_uri) |
---|
143 | closer_wrapper(env) |
---|
144 | return 0 |
---|
145 | |
---|
146 | |
---|
147 | @subscriber(ApplicationCreated) |
---|
148 | def application_created(event): |
---|
149 | app = event.app |
---|
150 | env = prepare(registry=app.registry) |
---|
151 | closer_wrapper(env) |
---|
152 | |
---|
153 | |
---|
154 | def includeme(config): |
---|
155 | config.scan('ow.migrate') |
---|