Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions closuretree/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
from django.utils.six import with_metaclass
import sys

try:
from django.db.transaction import atomic
except ImportError:
# On django < 1.6, instead of atomic() use commit_on_success()
from django.db.transaction import commit_on_success as atomic

def _closure_model_unicode(self):
"""__unicode__ implementation for the dynamically created
<Model>Closure model.
Expand Down Expand Up @@ -116,6 +122,17 @@ def __setattr__(self, name, value):
self._closure_change_init()
super(ClosureModel, self).__setattr__(name, value)

def save_base(self, *args, **kwargs):
"""Wrap the saving operation of this model and that of the closure
table in one transaction.
"""
# the superclass save_base() sends post_save() signal.
# post_save() is then received by closure_model_save() function,
# which saves closure model.
# we're going to wrap this series of operations in a transaction.
with atomic():
super(ClosureModel, self).save_base(*args, **kwargs)

@classmethod
def _toplevel(cls):
"""Find the top level of the chain we're in.
Expand Down
61 changes: 60 additions & 1 deletion closuretree/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
# pylint: disable=

from django import VERSION
from django.test import TestCase
from django.test import TestCase, TransactionTestCase
from django.db import models
from django.db.utils import IntegrityError
from django.core.management import call_command
from closuretree.models import ClosureModel
import uuid

Expand Down Expand Up @@ -118,6 +120,63 @@ def test_deletion(self):
self.b.delete()
self.failUnlessEqual(self.closure_model.objects.count(), 2)

class DirtyParentTestCase(TransactionTestCase):

normal_model = TC
closure_model = TCClosure

def setUp(self):
self.a = self.normal_model.objects.create(name="a")
self.b = self.normal_model.objects.create(name="b")
self.c = self.normal_model.objects.create(name="c")
self.d = self.normal_model.objects.create(name="d")

if VERSION < (1, 5):
def _post_teardown(self):
"""Support testing on older versions of Django.

We need to manually flush the DB to run the tests successfully in
Django < 1.5"""
super(DirtyParentTestCase, self)._post_teardown()
call_command('flush', interactive=False)

def test_selfreferencing_parent(self):
"""Tests that instances with self-referencing parents are not saved"""
self.b.parent2 = self.a
self.b.save()
self.c.parent2 = self.b
self.c.save()

# this is the prerequisite of this test:
# closure model contains 7 rows before we save dirty data
self.assertEqual(self.closure_model.objects.count(), 7)

# dirty data
self.b.parent2 = self.b
# save dirty data
self.assertRaises(IntegrityError, self.b.save)

# closure model contains the same number of data as before
self.assertEqual(self.closure_model.objects.count(), 7)

def test_parent_referencing_child(self):
"""Tests instances with circular-referencing parents are not saved"""
self.b.parent2 = self.a
self.b.save()
self.c.parent2 = self.b
self.c.save()

# this is the prerequisite of this test:
# closure model contains 7 rows before we save dirty data
self.assertEqual(self.closure_model.objects.count(), 7)

# dirty data
self.b.parent2 = self.c
# save dirty data
self.assertRaises(IntegrityError, self.b.save)

# closure model contains the same number of data as before
self.assertEqual(self.closure_model.objects.count(), 7)

if VERSION >= (1, 8):
class UUIDTC(ClosureModel):
Expand Down