'python failing __eq__ test

I'm new to python and I'm learning to use pytest. I have a class defined as:

class Matrix:

    def __init__(self, *rows):
        row_length = len(rows[0])
        for row in rows:
            # TODO skip first
            if len(row) != row_length:
                raise SystemError("Rows does not have equal length")

        self._rows = [*rows]

    def __eq__(self, other):
        return isinstance(self, other.__class__) and \
               all([x == y for x, y in zip_longest(self._rows, other._rows)])

    # other methods omitted for simplicity...

and I wrote a test for __eq__(self, other) like this:

def test_eq():
    m1 = Matrix([[1,2,3],[4,5,6]])
    m2 = Matrix([1,2,3],[4,5,6])
    m3 = Matrix([1,2,3],[5,4,6])
    assert m1 == m2
    assert m2 == m1
    assert m2 != m3

Which should pass because m1 and m2 has the same rows, and m3 has a difference in second row. However when I run this test I have the output:

    def test_eq():
        m1 = Matrix([[1,2,3],[4,5,6]])
        m2 = Matrix([1,2,3],[4,5,6])
        m3 = Matrix([1,2,3],[5,4,6])
>       assert m1 == m2
E       assert <exercises.matrix.Matrix object at 0x10ccd67d0> == <exercises.matrix.Matrix object at 0x10ccd6810>

What am I missing here? I'm using Python 3.7.4 and pytest version 5.1.2. Thanks in advance for your comments/answers


NOTE: I changed the implementation based on ggorlen answer, but I'm having a similar issue.



Solution 1:[1]

The lines in your comparison should be something like:

for i, i_row in enumerate(self._rows):
    if i_row != other._rows[i]:
        return False

But this still won't return the correct result if other has more rows than self, so:

def __eq__(self, other):
    return isinstance(self, other.__class__) and \
           len(other._rows) == len(self._rows) and \
           all([x == y for x, y in zip(self._rows, other._rows)])

The property is called _rows, and we need to use [] to index into a list, not parentheses.

A potentially faster version that can bail early on failed comparisons is:

def __eq__(self, other):
    if isinstance(self, other.__class__) and \
      len(other._rows) == len(self._rows):
        for i, row in enumerate(self._rows):
            if row != other._rows[i]:
                return False

        return True

    return False

In your test, you might have a typo:

m1 = Matrix([[1,2,3],[4,5,6]]) # <-- this matrix has an extra `[]` wrapper
m2 = Matrix([1,2,3],[4,5,6])   # <-- but this one just uses flat lists

so these matrices will not be equal.


Minor suggestions:

  • Raise a ValueError or ArgumentError instead of a SystemError on bad parameters.
  • Consider using Numpy.matrix instead of rolling your own matrix.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1