Many-to-many intersection in sqlalchemy
I have a Character class with a .tags attribute; the .tags attribute is a list of Tag objects. in a many-to-many relationship. I'm trying to write a query that 开发者_如何学JAVAwill find all pairs of characters that don't have the same name that have at least one tag in common; how would I go about doing this?
You go about this as following:
- think of an SQL query which will give you the desired result
- create a corresponding SA query
The SQL query (with WITH
clause on SQL Server for the sake of test data) is as below (obviously your table and column names might be different):
WITH t_character (id, name)
AS ( SELECT 1, "ch-1"
UNION SELECT 2, "ch-2"
UNION SELECT 3, "ch-3"
UNION SELECT 4, "ch-4"
)
, t_tag (id, name)
AS ( SELECT 1, "tag-1"
UNION SELECT 2, "tag-2"
UNION SELECT 3, "tag-3"
)
, t_character_tag (character_id, tag_id)
AS ( SELECT 1, 1
UNION SELECT 2, 1
UNION SELECT 2, 2
UNION SELECT 3, 1
UNION SELECT 3, 2
UNION SELECT 3, 3
UNION SELECT 4, 3
)
-- the result should contain pairs (1, 2), (1, 3), (2, 3) again (2, 3), and (3, 4)
SELECT DISTINCT -- will filter out duplicates
c1.id, c2.id
FROM t_character c1
INNER JOIN t_character c2
ON c1.id < c2.id -- all pairs without duplicates
INNER JOIN t_character_tag r1
ON r1.character_id = c1.id
INNER JOIN t_character_tag r2
ON r2.character_id = c2.id
WHERE r1.tag_id = r2.tag_id
ORDER BY c1.id, c2.id
The complete sample code with the query you need is below:
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Table
from sqlalchemy.orm import relationship, scoped_session, sessionmaker, aliased
from sqlalchemy.ext.declarative import declarative_base
# Configure test database for SA
engine = create_engine("sqlite:///:memory:", echo=False)
session = scoped_session(sessionmaker(bind=engine, autoflush=False))
class Base(object):
""" Just a helper base class to set properties on object creation.
Also provides a convenient default __repr__() function, but be aware that
also relationships are printed, which might result in loading relations.
"""
def __init__(self, **kwargs):
for k,v in kwargs.items():
setattr(self, k, v)
def __repr__(self):
return "<%s(%s)>" % (self.__class__.__name__,
", ".join("%s=%r" % (k, self.__dict__[k])
for k in sorted(self.__dict__) if "_sa_" != k[:4] and "_backref_" != k[:9])
)
Base = declarative_base(cls=Base)
t_character_tag = Table(
"t_character_tag", Base.metadata,
Column("character_id", Integer, ForeignKey("t_character.id")),
Column("tag_id", Integer, ForeignKey("t_tag.id"))
)
class Character(Base):
__tablename__ = u"t_character"
id = Column(Integer, primary_key=True)
name = Column(String)
tags = relationship("Tag", secondary=t_character_tag, backref="characters")
class Tag(Base):
__tablename__ = u"t_tag"
id = Column(Integer, primary_key=True)
name = Column(String)
# create db schema
Base.metadata.create_all(engine)
# 0. create test data
ch1 = Character(id=1, name="ch-1")
ch2 = Character(id=2, name="ch-2")
ch3 = Character(id=3, name="ch-3")
ch4 = Character(id=4, name="ch-4")
ta1 = Tag(id=1, name="tag-1")
ta2 = Tag(id=2, name="tag-2")
ta3 = Tag(id=3, name="tag-3")
ch1.tags.append(ta1)
ch2.tags.append(ta1); ch2.tags.append(ta2);
ch3.tags.append(ta1); ch3.tags.append(ta2); ch3.tags.append(ta3);
ch4.tags.append(ta3)
session.add_all((ch1, ch2, ch3, ch4,))
session.commit()
# 1. some data checks
session.expunge_all()
assert len(session.query(Character).all()) == 4
assert session.query(Tag).get(2).name == "tag-2"
assert len(session.query(Character).get(3).tags) == 3
# 2. create a final query (THE ANSWER TO THE QUESTION):
session.expunge_all()
t_c1 = aliased(Character)
t_c2 = aliased(Character)
t_t1 = aliased(Tag)
t_t2 = aliased(Tag)
q =(session.query(t_c1, t_c2).
join((t_c2, t_c1.id < t_c2.id)).
join((t_t1, t_c1.tags)).
join((t_t2, t_c2.tags)).
filter(t_t1.id == t_t2.id).
filter(t_c1.name != t_c2.name). # if tag name is unique, this can be dropped
order_by(t_c1.id).
order_by(t_c2.id)
)
q = q.distinct() # filter out duplicates
res = [_r for _r in q.all()]
assert len(res) == 4
for _r in res:
print _r
精彩评论