'python failing __eq__ test
I'm new to python and I'm learning to use pytest
. I have a class defined as:
class Matrix:
def __init__(self, *rows):
row_length = len(rows[0])
for row in rows:
# TODO skip first
if len(row) != row_length:
raise SystemError("Rows does not have equal length")
self._rows = [*rows]
def __eq__(self, other):
return isinstance(self, other.__class__) and \
all([x == y for x, y in zip_longest(self._rows, other._rows)])
# other methods omitted for simplicity...
and I wrote a test for __eq__(self, other)
like this:
def test_eq():
m1 = Matrix([[1,2,3],[4,5,6]])
m2 = Matrix([1,2,3],[4,5,6])
m3 = Matrix([1,2,3],[5,4,6])
assert m1 == m2
assert m2 == m1
assert m2 != m3
Which should pass because m1
and m2
has the same rows, and m3
has a difference in second row. However when I run this test I have the output:
def test_eq():
m1 = Matrix([[1,2,3],[4,5,6]])
m2 = Matrix([1,2,3],[4,5,6])
m3 = Matrix([1,2,3],[5,4,6])
> assert m1 == m2
E assert <exercises.matrix.Matrix object at 0x10ccd67d0> == <exercises.matrix.Matrix object at 0x10ccd6810>
What am I missing here? I'm using Python 3.7.4 and pytest version 5.1.2. Thanks in advance for your comments/answers
NOTE: I changed the implementation based on ggorlen answer, but I'm having a similar issue.
Solution 1:[1]
The lines in your comparison should be something like:
for i, i_row in enumerate(self._rows):
if i_row != other._rows[i]:
return False
But this still won't return the correct result if other
has more rows than self
, so:
def __eq__(self, other):
return isinstance(self, other.__class__) and \
len(other._rows) == len(self._rows) and \
all([x == y for x, y in zip(self._rows, other._rows)])
The property is called _rows
, and we need to use []
to index into a list, not parentheses.
A potentially faster version that can bail early on failed comparisons is:
def __eq__(self, other):
if isinstance(self, other.__class__) and \
len(other._rows) == len(self._rows):
for i, row in enumerate(self._rows):
if row != other._rows[i]:
return False
return True
return False
In your test, you might have a typo:
m1 = Matrix([[1,2,3],[4,5,6]]) # <-- this matrix has an extra `[]` wrapper
m2 = Matrix([1,2,3],[4,5,6]) # <-- but this one just uses flat lists
so these matrices will not be equal.
Minor suggestions:
- Raise a
ValueError
orArgumentError
instead of aSystemError
on bad parameters. - Consider using Numpy.matrix instead of rolling your own matrix.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |