开发者

Many-to-many intersection in sqlalchemy

I have a Character class with a .tags attribute; the .tags attribute is a list of Tag objects. in a many-to-many relationship. I'm trying to write a query that 开发者_如何学JAVAwill find all pairs of characters that don't have the same name that have at least one tag in common; how would I go about doing this?


You go about this as following:

  1. think of an SQL query which will give you the desired result
  2. create a corresponding SA query

The SQL query (with WITH clause on SQL Server for the sake of test data) is as below (obviously your table and column names might be different):

WITH t_character (id, name)
AS (    SELECT  1, "ch-1"
UNION   SELECT  2, "ch-2"
UNION   SELECT  3, "ch-3"
UNION   SELECT  4, "ch-4"
)
, t_tag (id, name)
AS (    SELECT  1, "tag-1"
UNION   SELECT  2, "tag-2"
UNION   SELECT  3, "tag-3"
)
, t_character_tag (character_id, tag_id)
AS (    SELECT  1, 1
UNION   SELECT  2, 1
UNION   SELECT  2, 2
UNION   SELECT  3, 1
UNION   SELECT  3, 2
UNION   SELECT  3, 3
UNION   SELECT  4, 3
)
-- the result should contain pairs (1, 2), (1, 3), (2, 3) again (2, 3), and (3, 4)
SELECT      DISTINCT -- will filter out duplicates
            c1.id, c2.id
FROM        t_character c1
INNER JOIN  t_character c2
        ON  c1.id < c2.id -- all pairs without duplicates
INNER JOIN  t_character_tag r1
        ON  r1.character_id = c1.id
INNER JOIN  t_character_tag r2
        ON  r2.character_id = c2.id
WHERE       r1.tag_id = r2.tag_id
ORDER BY    c1.id, c2.id

The complete sample code with the query you need is below:

from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Table
from sqlalchemy.orm import relationship, scoped_session, sessionmaker, aliased
from sqlalchemy.ext.declarative import declarative_base

# Configure test database for SA
engine = create_engine("sqlite:///:memory:", echo=False)
session = scoped_session(sessionmaker(bind=engine, autoflush=False))

class Base(object):
    """ Just a helper base class to set properties on object creation.
    Also provides a convenient default __repr__() function, but be aware that 
    also relationships are printed, which might result in loading relations.
    """
    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

    def __repr__(self):
        return "<%s(%s)>" % (self.__class__.__name__, 
            ", ".join("%s=%r" % (k, self.__dict__[k]) 
                for k in sorted(self.__dict__) if "_sa_" != k[:4] and "_backref_" != k[:9])
            )
Base = declarative_base(cls=Base)

t_character_tag = Table(
    "t_character_tag", Base.metadata,
    Column("character_id", Integer, ForeignKey("t_character.id")),
    Column("tag_id", Integer, ForeignKey("t_tag.id"))
    )

class Character(Base):
    __tablename__ = u"t_character"
    id = Column(Integer, primary_key=True)
    name = Column(String)
    tags = relationship("Tag", secondary=t_character_tag, backref="characters")

class Tag(Base):
    __tablename__ = u"t_tag"
    id = Column(Integer, primary_key=True)
    name = Column(String)

# create db schema
Base.metadata.create_all(engine)


# 0. create test data
ch1 = Character(id=1, name="ch-1")
ch2 = Character(id=2, name="ch-2")
ch3 = Character(id=3, name="ch-3")
ch4 = Character(id=4, name="ch-4")
ta1 = Tag(id=1, name="tag-1")
ta2 = Tag(id=2, name="tag-2")
ta3 = Tag(id=3, name="tag-3")
ch1.tags.append(ta1)
ch2.tags.append(ta1); ch2.tags.append(ta2);
ch3.tags.append(ta1); ch3.tags.append(ta2); ch3.tags.append(ta3);
ch4.tags.append(ta3)
session.add_all((ch1, ch2, ch3, ch4,))
session.commit()

# 1. some data checks
session.expunge_all()
assert len(session.query(Character).all()) == 4
assert session.query(Tag).get(2).name == "tag-2"
assert len(session.query(Character).get(3).tags) == 3

# 2. create a final query (THE ANSWER TO THE QUESTION):
session.expunge_all()
t_c1 = aliased(Character)
t_c2 = aliased(Character)
t_t1 = aliased(Tag)
t_t2 = aliased(Tag)
q =(session.query(t_c1, t_c2).
    join((t_c2, t_c1.id < t_c2.id)).
    join((t_t1, t_c1.tags)).
    join((t_t2, t_c2.tags)).
    filter(t_t1.id == t_t2.id).
    filter(t_c1.name != t_c2.name). # if tag name is unique, this can be dropped
    order_by(t_c1.id).
    order_by(t_c2.id)
    )
q = q.distinct() # filter out duplicates
res = [_r for _r in q.all()]
assert len(res) == 4
for _r in res:
    print _r
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜