Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion test/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import base64
import datetime
import errno
import os
import pathlib
import tempfile
import unittest
from unittest import mock

from khard.contacts import Contact, multi_property_key
from khard.contacts import Contact, atomic_write, multi_property_key


class ContactFormatDateObject(unittest.TestCase):
Expand Down Expand Up @@ -80,3 +84,114 @@ def test_all_strings_are_sorted_before_dicts(self) -> None:
my_list = ["a", {"c": "d"}, "e", {"f": "g"}]
my_list.sort(key=multi_property_key) # type: ignore[arg-type]
self.assertEqual(my_list, ["a", "e", {"c": "d"}, {"f": "g"}])


class AtomicWrite(unittest.TestCase):
"""Tests for the atomic_write functionself.

These tests have been migrated from the original atomicwrites repository.
"""

def setUp(self) -> None:
self.t = tempfile.TemporaryDirectory()
self.tmpdir = pathlib.Path(self.t.name)
return super().setUp()

def tearDown(self) -> None:
self.t.cleanup()
return super().tearDown()

def assertFileContents(self, file: pathlib.Path, contents: str) -> None:
"""Assert the file contents of the given file"""
with file.open() as f:
return self.assertEqual(f.read(), contents)

@staticmethod
def write(path: pathlib.Path, contents: str) -> None:
"""Write a string to a file"""
with path.open("w") as f:
f.write(contents)

def test_atomic_write(self) -> None:
fname = self.tmpdir / 'ha'
for i in range(2):
with atomic_write(str(fname), overwrite=True) as f:
f.write('hoho')

with self.assertRaises(OSError) as excinfo:
with atomic_write(str(fname), overwrite=False) as f:
f.write('haha')

self.assertEqual(excinfo.exception.errno, errno.EEXIST)

self.assertFileContents(fname, 'hoho')
self.assertEqual(len(list(self.tmpdir.iterdir())), 1)


def test_teardown(self) -> None:
fname = self.tmpdir / 'ha'
with self.assertRaises(AssertionError):
with atomic_write(str(fname), overwrite=True):
self.fail("This code should not be reached")

self.assertFalse(any(self.tmpdir.iterdir()))


def test_replace_simultaneously_created_file(self) -> None:
fname = self.tmpdir / 'ha'
with atomic_write(str(fname), overwrite=True) as f:
f.write('hoho')
self.write(fname, 'harhar')
self.assertFileContents(fname, 'harhar')
self.assertFileContents(fname, 'hoho')
self.assertEqual(len(list(self.tmpdir.iterdir())), 1)


def test_dont_remove_simultaneously_created_file(self) -> None:
fname = self.tmpdir / 'ha'
with self.assertRaises(OSError) as excinfo:
with atomic_write(str(fname), overwrite=False) as f:
f.write('hoho')
self.write(fname, 'harhar')
self.assertFileContents(fname, 'harhar')

self.assertEqual(excinfo.exception.errno, errno.EEXIST)
self.assertFileContents(fname, 'harhar')
self.assertEqual(len(list(self.tmpdir.iterdir())), 1)


def test_open_reraise(self) -> None:
"""Verify that nested exceptions during rollback do not overwrite the
initial exception that triggered a rollback."""
fname = self.tmpdir / 'ha'
with self.assertRaises(AssertionError):
with atomic_write(str(fname), overwrite=False):
# Mess with internals; find and remove the temp file used by
# atomic_write internally. We're testing that the initial
# AssertionError triggered below is propagated up the stack,
# not the second exception triggered during commit.
tmp = next(self.tmpdir.iterdir())
tmp.unlink()
# Now trigger our own exception.
self.fail("Intentional failure for testing purposes")


def test_atomic_write_in_cwd(self) -> None:
orig_curdir = os.getcwd()
try:
os.chdir(str(self.tmpdir))
fname = 'ha'
for i in range(2):
with atomic_write(fname, overwrite=True) as f:
f.write('hoho')

with self.assertRaises(OSError) as excinfo:
with atomic_write(fname, overwrite=False) as f:
f.write('haha')

self.assertEqual(excinfo.exception.errno, errno.EEXIST)

self.assertFileContents(pathlib.Path(fname), 'hoho')
self.assertEqual(len(list(self.tmpdir.iterdir())), 1)
finally:
os.chdir(orig_curdir)