[eb20dc8] | 1 | import pkgutil |
---|
| 2 | import importlib |
---|
| 3 | import logging |
---|
| 4 | from bisect import insort, bisect_right |
---|
| 5 | |
---|
| 6 | from pyramid_zodbconn import get_connection |
---|
| 7 | from pyramid.scripting import prepare |
---|
| 8 | from pyramid.paster import bootstrap |
---|
| 9 | from pyramid.events import subscriber, ApplicationCreated |
---|
| 10 | |
---|
| 11 | from transaction import commit |
---|
| 12 | |
---|
| 13 | # Cache for maximum versions of migrations for packages |
---|
| 14 | MAX_CACHE = {} |
---|
| 15 | log = logging.getLogger(__name__) |
---|
| 16 | |
---|
| 17 | |
---|
| 18 | def get_indexes(package_name): |
---|
| 19 | """ |
---|
| 20 | Gets a sorted list of migrations in a package |
---|
| 21 | """ |
---|
| 22 | package = importlib.import_module(package_name) |
---|
| 23 | indexes = [] |
---|
| 24 | for (loader, name, ispkg) in pkgutil.iter_modules(package.__path__): |
---|
| 25 | try: |
---|
| 26 | insort(indexes, int(name)) |
---|
| 27 | except ValueError: |
---|
| 28 | continue |
---|
| 29 | return indexes |
---|
| 30 | |
---|
| 31 | |
---|
| 32 | def get_max(package_name): |
---|
| 33 | """ |
---|
| 34 | Get the maximum version for a package of migrations |
---|
| 35 | """ |
---|
| 36 | if package_name in MAX_CACHE: |
---|
| 37 | return MAX_CACHE[package_name] |
---|
| 38 | |
---|
| 39 | indexes = get_indexes(package_name) |
---|
| 40 | retval = 0 |
---|
| 41 | if len(indexes) > 0: |
---|
| 42 | retval = indexes[-1] |
---|
| 43 | |
---|
| 44 | MAX_CACHE[package_name] = retval |
---|
| 45 | return retval |
---|
| 46 | |
---|
| 47 | |
---|
| 48 | def set_max_version(zodb_root, package_name): |
---|
| 49 | """ |
---|
| 50 | Set the version to the maximum in the zodb_root using the given package |
---|
| 51 | name |
---|
| 52 | """ |
---|
| 53 | zodb_root = set_version(zodb_root, get_max(package_name)) |
---|
| 54 | return zodb_root |
---|
| 55 | |
---|
| 56 | |
---|
| 57 | def reset_version(request, version): |
---|
| 58 | """ |
---|
| 59 | Forces the database version to a specific one and commits |
---|
| 60 | """ |
---|
| 61 | dbroot = get_connection(request).root() |
---|
| 62 | set_version(dbroot, version) |
---|
| 63 | commit() |
---|
| 64 | |
---|
| 65 | |
---|
| 66 | def set_version(zodb_root, version): |
---|
| 67 | """ |
---|
| 68 | Sets the version |
---|
| 69 | """ |
---|
| 70 | zodb_root['database_version'] = version |
---|
| 71 | return zodb_root |
---|
| 72 | |
---|
| 73 | |
---|
| 74 | def run_migrations(request, root, package_name): |
---|
| 75 | """ |
---|
| 76 | Run migrations from a package_name |
---|
| 77 | """ |
---|
| 78 | indexes = get_indexes(package_name) |
---|
| 79 | dbroot = get_connection(request).root() |
---|
| 80 | current = dbroot.get('database_version', 0) |
---|
| 81 | migrations_to_apply = indexes[bisect_right(indexes, current):] |
---|
| 82 | if len(migrations_to_apply) == 0: |
---|
| 83 | log.info("Your database is in the latest version: '%i'. No migrations" |
---|
| 84 | " will be applied." % current) |
---|
| 85 | else: |
---|
| 86 | log.info("Starting migrations from %i" % current) |
---|
| 87 | for index in migrations_to_apply: |
---|
| 88 | migration = importlib.import_module("%s.%s" % (package_name, index)) |
---|
| 89 | if not hasattr(migration, 'migrate'): |
---|
| 90 | log.error('No migrate method found for %s' % migration.__name__) |
---|
| 91 | return False |
---|
| 92 | log.info("Running migration %i" % index) |
---|
| 93 | if hasattr(migration.migrate, '__doc__'): |
---|
| 94 | doc = migration.migrate.__doc__.strip().split('\n') |
---|
| 95 | description = [] |
---|
| 96 | for line in doc: |
---|
| 97 | line = line.strip() |
---|
| 98 | if line == '': |
---|
| 99 | break |
---|
| 100 | description.append(line) |
---|
| 101 | first_paragraph = ' '.join(description) |
---|
| 102 | log.info('"""%s"""' % first_paragraph) |
---|
| 103 | migration.migrate(root) |
---|
| 104 | set_version(dbroot, index) |
---|
| 105 | commit() |
---|
| 106 | |
---|
| 107 | |
---|
| 108 | def closer_wrapper(env): |
---|
| 109 | closer = env['closer'] |
---|
| 110 | registry = env['registry'] |
---|
| 111 | root_factory = env['root_factory'] |
---|
| 112 | package = registry.settings.get( |
---|
| 113 | 'migrations_package', root_factory.__module__ + '.migrations') |
---|
| 114 | # Run migrations |
---|
| 115 | try: |
---|
| 116 | run_migrations(env['request'], env['root'], package) |
---|
| 117 | finally: |
---|
| 118 | closer() |
---|
| 119 | |
---|
| 120 | |
---|
| 121 | def output(*args): # pragma: nocover |
---|
| 122 | print(args) |
---|
| 123 | |
---|
| 124 | |
---|
| 125 | def command_line_main(argv): |
---|
| 126 | """ |
---|
| 127 | migrates the given database (Command line) |
---|
| 128 | |
---|
| 129 | This command will use the bootstrap system from pyramid to start the |
---|
| 130 | appllication without webserver. |
---|
| 131 | |
---|
| 132 | Use this function to create a command line caller for the migrations. Pass |
---|
| 133 | the sys.argv as parameter. |
---|
| 134 | """ |
---|
| 135 | # Arguments |
---|
| 136 | args = argv[1:] |
---|
| 137 | if not len(args) == 1: |
---|
| 138 | output('You must provide at least one argument: configuration file') |
---|
| 139 | return 1 |
---|
| 140 | config_uri = args[0] |
---|
| 141 | # Prepare application |
---|
| 142 | env = bootstrap(config_uri) |
---|
| 143 | closer_wrapper(env) |
---|
| 144 | return 0 |
---|
| 145 | |
---|
| 146 | |
---|
| 147 | @subscriber(ApplicationCreated) |
---|
| 148 | def application_created(event): |
---|
| 149 | app = event.app |
---|
| 150 | env = prepare(registry=app.registry) |
---|
| 151 | closer_wrapper(env) |
---|
| 152 | |
---|
| 153 | |
---|
| 154 | def includeme(config): |
---|
| 155 | config.scan('ow.migrate') |
---|