diff --git a/src/murfey/cli/inject_spa_processing.py b/src/murfey/cli/inject_spa_processing.py index 69f350c6a..9a55a2ef9 100644 --- a/src/murfey/cli/inject_spa_processing.py +++ b/src/murfey/cli/inject_spa_processing.py @@ -13,7 +13,6 @@ from murfey.util.config import get_machine_config, get_microscope, get_security_config from murfey.util.db import ( AutoProcProgram, - ClassificationFeedbackParameters, ClientEnvironment, DataCollection, DataCollectionGroup, @@ -136,15 +135,11 @@ def run(): .where(AutoProcProgram.pj_id == ProcessingJob.id) .where(ProcessingJob.recipe == "em-spa-preprocess") ).one() - params = murfey_db.exec( - select(SPARelionParameters, ClassificationFeedbackParameters) - .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) + proc_params = murfey_db.exec( + select(SPARelionParameters).where( + SPARelionParameters.pj_id == collected_ids[2].id + ) ).one() - proc_params: dict | None = dict(params[0]) - feedback_params = params[1] - if feedback_params.picker_murfey_id is None: - raise ValueError("No ISPyB picker ID was found") except sqlalchemy.exc.NoResultFound: proc_params = None @@ -196,7 +191,6 @@ def run(): "ft_bin": proc_params["motion_corr_binning"], "fm_dose": proc_params["dose_per_frame"], "gain_ref": proc_params["gain_ref"], - "picker_uuid": feedback_params.picker_murfey_id, "session_id": args.session_id, "particle_diameter": proc_params["particle_diameter"] or 0, "fm_int_file": args.eer_fractionation_file, diff --git a/src/murfey/server/api/auth.py b/src/murfey/server/api/auth.py index b0d5f292e..8c8988bf2 100644 --- a/src/murfey/server/api/auth.py +++ b/src/murfey/server/api/auth.py @@ -110,6 +110,8 @@ async def submit_to_auth_endpoint( Helper function to forward incoming requests to an authentication server to verify that they are allowed to inspect the """ + if security_config.auth_type == "none": + return {"valid": True} # Forward only essentials auth-related headers headers = { @@ -189,6 +191,9 @@ async def validate_instrument_token( """ Used by the backend routers to check the incoming instrument server token. """ + if security_config.instrument_auth_type == "none": + return None + try: # Validate using auth URL if provided if security_config.instrument_auth_url: diff --git a/src/murfey/server/api/session_info.py b/src/murfey/server/api/session_info.py index cd52ee011..d4f60a4bb 100644 --- a/src/murfey/server/api/session_info.py +++ b/src/murfey/server/api/session_info.py @@ -131,7 +131,7 @@ def all_visit_info( @router.get("/sessions/{session_id}/rsyncers", response_model=List[RsyncInstance]) -def get_rsyncers_for_client(session_id: MurfeySessionID, db=murfey_db): +def get_rsyncers_for_client(session_id: MurfeySessionID, db: Session = murfey_db): rsync_instances = db.exec( select(RsyncInstance).where(RsyncInstance.session_id == session_id) ) @@ -144,7 +144,9 @@ class SessionClients(BaseModel): @router.get("/sessions/{session_id}") -async def get_session(session_id: MurfeySessionID, db=murfey_db) -> SessionClients: +async def get_session( + session_id: MurfeySessionID, db: Session = murfey_db +) -> SessionClients: session = db.exec(select(Session).where(Session.id == session_id)).one() clients = db.exec( select(ClientEnvironment).where(ClientEnvironment.session_id == session_id) @@ -153,7 +155,7 @@ async def get_session(session_id: MurfeySessionID, db=murfey_db) -> SessionClien @router.get("/sessions") -async def get_sessions(db=murfey_db): +async def get_sessions(db: Session = murfey_db): sessions = db.exec(select(Session)).all() clients = db.exec(select(ClientEnvironment)).all() res = [] @@ -176,7 +178,7 @@ def create_session( visit: str, name: str, visit_end_time: VisitEndTime, - db=murfey_db, + db: Session = murfey_db, ) -> int: s = Session( name=name, @@ -196,7 +198,7 @@ def create_session( @router.post("/sessions/{session_id}") def update_session( - session_id: MurfeySessionID, process: bool = True, db=murfey_db + session_id: MurfeySessionID, process: bool = True, db: Session = murfey_db ) -> None: session = db.exec(select(Session).where(Session.id == session_id)).one() session.process = process @@ -206,13 +208,13 @@ def update_session( @router.delete("/sessions/{session_id}") -def remove_session(session_id: MurfeySessionID, db=murfey_db): +def remove_session(session_id: MurfeySessionID, db: Session = murfey_db): remove_session_by_id(session_id, db) @router.get("/instruments/{instrument_name}/visits/{visit_name}/sessions") def get_sessions_with_visit( - instrument_name: MurfeyInstrumentName, visit_name: str, db=murfey_db + instrument_name: MurfeyInstrumentName, visit_name: str, db: Session = murfey_db ) -> List[Session]: sessions = db.exec( select(Session) @@ -224,7 +226,7 @@ def get_sessions_with_visit( @router.get("/instruments/{instrument_name}/sessions") async def get_sessions_by_instrument_name( - instrument_name: MurfeyInstrumentName, db=murfey_db + instrument_name: MurfeyInstrumentName, db: Session = murfey_db ) -> List[Session]: sessions = db.exec( select(Session).where(Session.instrument_name == instrument_name) @@ -234,7 +236,7 @@ async def get_sessions_by_instrument_name( @router.get("/sessions/{session_id}/data_collection_groups") def get_dc_groups( - session_id: MurfeySessionID, db=murfey_db + session_id: MurfeySessionID, db: Session = murfey_db ) -> Dict[str, DataCollectionGroup]: data_collection_groups = db.exec( select(DataCollectionGroup).where(DataCollectionGroup.session_id == session_id) @@ -244,7 +246,7 @@ def get_dc_groups( @router.get("/sessions/{session_id}/data_collection_groups/{dcgid}/data_collections") def get_data_collections( - session_id: MurfeySessionID, dcgid: int, db=murfey_db + session_id: MurfeySessionID, dcgid: int, db: Session = murfey_db ) -> List[DataCollection]: data_collections = db.exec( select(DataCollection).where(DataCollection.dcg_id == dcgid) @@ -253,7 +255,7 @@ def get_data_collections( @router.get("/clients") -async def get_clients(db=murfey_db): +async def get_clients(db: Session = murfey_db): clients = db.exec(select(ClientEnvironment)).all() return clients @@ -264,7 +266,7 @@ class CurrentGainRef(BaseModel): @router.put("/sessions/{session_id}/current_gain_ref") def update_current_gain_ref( - session_id: MurfeySessionID, new_gain_ref: CurrentGainRef, db=murfey_db + session_id: MurfeySessionID, new_gain_ref: CurrentGainRef, db: Session = murfey_db ): session = db.exec(select(Session).where(Session.id == session_id)).one() session.current_gain_ref = new_gain_ref.path @@ -387,7 +389,7 @@ class ProcessingDetails(BaseModel): @spa_router.get("/sessions/{session_id}/spa_processing_parameters") def get_spa_proc_param_details( - session_id: MurfeySessionID, db=murfey_db + session_id: MurfeySessionID, db: Session = murfey_db ) -> Optional[List[ProcessingDetails]]: params = db.exec( select( @@ -436,7 +438,7 @@ def _parse(ps, i, dcg_id): "/sessions/{session_id}/data_collection_groups/{dcgid}/grid_squares/{gsid}/foil_holes/{fhid}/num_movies" ) def get_number_of_movies_from_foil_hole( - session_id: int, dcgid: int, gsid: int, fhid: int, db=murfey_db + session_id: int, dcgid: int, gsid: int, fhid: int, db: Session = murfey_db ) -> int: movies = db.exec( select(Movie, FoilHole, GridSquare, DataCollectionGroup) @@ -452,13 +454,13 @@ def get_number_of_movies_from_foil_hole( @spa_router.get("/sessions/{session_id}/grid_squares") -def get_grid_squares(session_id: MurfeySessionID, db=murfey_db): +def get_grid_squares(session_id: MurfeySessionID, db: Session = murfey_db): return _get_grid_squares(session_id, db) @spa_router.get("/sessions/{session_id}/data_collection_groups/{dcgid}/grid_squares") def get_grid_squares_from_dcg( - session_id: MurfeySessionID, dcgid: int, db=murfey_db + session_id: MurfeySessionID, dcgid: int, db: Session = murfey_db ) -> List[GridSquare]: return _get_grid_squares_from_dcg(session_id, dcgid, db) @@ -467,14 +469,14 @@ def get_grid_squares_from_dcg( "/sessions/{session_id}/data_collection_groups/{dcgid}/grid_squares/{gsid}/foil_holes" ) def get_foil_holes_from_grid_square( - session_id: MurfeySessionID, dcgid: int, gsid: int, db=murfey_db + session_id: MurfeySessionID, dcgid: int, gsid: int, db: Session = murfey_db ) -> List[FoilHole]: return _get_foil_holes_from_grid_square(session_id, dcgid, gsid, db) @spa_router.get("/sessions/{session_id}/foil_hole/{fh_name}") def get_foil_hole( - session_id: MurfeySessionID, fh_name: int, db=murfey_db + session_id: MurfeySessionID, fh_name: int, db: Session = murfey_db ) -> Dict[str, int]: return _get_foil_hole(session_id, fh_name, db) @@ -488,7 +490,7 @@ def get_foil_hole( @tomo_router.get("/sessions/{session_id}/tilt_series/{tilt_series_tag}/tilts") def get_tilts( - session_id: MurfeySessionID, tilt_series_tag: str, db=murfey_db + session_id: MurfeySessionID, tilt_series_tag: str, db: Session = murfey_db ) -> Dict[str, List[str]]: res = db.exec( select(TiltSeries, Tilt) @@ -513,7 +515,7 @@ def get_tilts( @correlative_router.get("/sessions/{session_id}/upstream_visits") -async def find_upstream_visits(session_id: MurfeySessionID, db=murfey_db): +async def find_upstream_visits(session_id: MurfeySessionID, db: Session = murfey_db): return _find_upstream_visits(session_id=session_id, db=db) @@ -524,7 +526,7 @@ async def gather_upstream_files( visit_name: str, session_id: MurfeySessionID, upstream_file_request: UpstreamFileRequestInfo, - db=murfey_db, + db: Session = murfey_db, ): return _gather_upstream_files( session_id=session_id, @@ -541,7 +543,7 @@ async def get_upstream_file( visit_name: str, session_id: MurfeySessionID, upstream_file_path: Path, - db=murfey_db, + db: Session = murfey_db, ): upstream_file = _get_upstream_file(upstream_file_path) return ( @@ -552,14 +554,18 @@ async def get_upstream_file( @correlative_router.get( "/visits/{visit_name}/sessions/{session_id}/upstream_tiff_paths" ) -async def gather_upstream_tiffs(visit_name: str, session_id: int, db=murfey_db): +async def gather_upstream_tiffs( + visit_name: str, session_id: int, db: Session = murfey_db +): return _gather_upstream_tiffs(visit_name=visit_name, session_id=session_id, db=db) @correlative_router.get( "/visits/{visit_name}/sessions/{session_id}/upstream_tiff/{tiff_path:path}" ) -async def get_tiff_file(visit_name: str, session_id: int, tiff_path: str, db=murfey_db): +async def get_tiff_file( + visit_name: str, session_id: int, tiff_path: str, db: Session = murfey_db +): tiff_file = _get_tiff_file( visit_name=visit_name, session_id=session_id, tiff_path=tiff_path, db=db ) diff --git a/src/murfey/server/api/workflow.py b/src/murfey/server/api/workflow.py index 467989a85..07294a0ee 100644 --- a/src/murfey/server/api/workflow.py +++ b/src/murfey/server/api/workflow.py @@ -15,10 +15,12 @@ BLSubSample, ) from pydantic import BaseModel -from sqlalchemy.exc import OperationalError +from sqlalchemy.exc import NoResultFound, OperationalError from sqlmodel import col, select from werkzeug.utils import secure_filename +import murfey.server + try: from smartem_backend.api_client import SmartEMAPIClient from smartem_common.schemas import ( @@ -33,7 +35,6 @@ SMARTEM_ACTIVE = False import murfey.server.prometheus as prom -from murfey.server import _transport_object from murfey.server.api.auth import ( MurfeySessionIDInstrument as MurfeySessionID, validate_instrument_token, @@ -52,7 +53,6 @@ from murfey.util.config import get_machine_config from murfey.util.db import ( AutoProcProgram, - ClassificationFeedbackParameters, DataCollection, DataCollectionGroup, FoilHole, @@ -112,7 +112,7 @@ def register_dc_group( visit_name: str, session_id: MurfeySessionID, dcg_params: DCGroupParameters, - db=murfey_db, + db: Session = murfey_db, ): ispyb_proposal_code = visit_name[:2] ispyb_proposal_number = visit_name.split("-")[0][2:] @@ -178,10 +178,10 @@ def register_dc_group( if smartem_grid_uuid: dcg_instance.smartem_grid_uuid = smartem_grid_uuid - if _transport_object: + if murfey.server._transport_object: if dcg_instance.atlas_id is not None: - _transport_object.send( - _transport_object.feedback_queue, + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, { "register": "atlas_update", "tag": dcg_instance.tag, @@ -194,7 +194,7 @@ def register_dc_group( }, ) else: - atlas_id_response = _transport_object.do_insert_atlas( + atlas_id_response = murfey.server._transport_object.do_insert_atlas( Atlas( dataCollectionGroupId=dcg_instance.id, atlasImage=dcg_params.atlas, @@ -228,9 +228,9 @@ def register_dc_group( # Case where we switch from atlas to processing original_tag = dcg_murfey[0].tag dcg_murfey[0].tag = dcg_params.tag or dcg_murfey[0].tag - if _transport_object: - _transport_object.send( - _transport_object.feedback_queue, + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, { "register": "experiment_type_update", "experiment_type_id": dcg_params.experiment_type_id, @@ -257,9 +257,9 @@ def register_dc_group( "atlas_pixel_size": dcg_params.atlas_pixel_size, } - if _transport_object: - _transport_object.send( - _transport_object.feedback_queue, + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, { "register": "data_collection_group", **dcg_parameters, @@ -303,7 +303,10 @@ class DCParameters(BaseModel): @router.post("/visits/{visit_name}/sessions/{session_id}/start_data_collection") def start_dc( - visit_name: str, session_id: MurfeySessionID, dc_params: DCParameters, db=murfey_db + visit_name: str, + session_id: MurfeySessionID, + dc_params: DCParameters, + db: Session = murfey_db, ): ispyb_proposal_code = visit_name[:2] ispyb_proposal_number = visit_name.split("-")[0][2:] @@ -345,9 +348,9 @@ def start_dc( "session_id": session_id, } - if _transport_object: - _transport_object.send( - _transport_object.feedback_queue, + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, { "register": "data_collection", **dc_parameters, @@ -375,7 +378,7 @@ def register_proc( visit_name: str, session_id: MurfeySessionID, proc_params: ProcessingJobParameters, - db=murfey_db, + db: Session = murfey_db, ): proc_parameters: dict = { "session_id": session_id, @@ -408,9 +411,9 @@ def register_proc( ) proc_parameters["job_parameters"] = job_parameters - if _transport_object: - _transport_object.send( - _transport_object.feedback_queue, + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, {"register": "processing_job", **proc_parameters}, ) return proc_params @@ -425,7 +428,9 @@ def register_proc( @spa_router.post("/sessions/{session_id}/spa_processing_parameters") def register_spa_proc_params( - session_id: MurfeySessionID, proc_params: ProcessingParametersSPA, db=murfey_db + session_id: MurfeySessionID, + proc_params: ProcessingParametersSPA, + db: Session = murfey_db, ): session_processing_parameters = db.exec( select(SessionProcessingParameters).where( @@ -445,8 +450,16 @@ def register_spa_proc_params( **dict(proc_params), "session_id": session_id, } - if _transport_object: - _transport_object.send(_transport_object.feedback_queue, zocalo_message) + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, zocalo_message + ) + else: + logger.error( + f"Pre-processing was requested for {session_id} " + "but no Zocalo transport object was found" + ) + return proc_params class Tag(BaseModel): @@ -455,15 +468,17 @@ class Tag(BaseModel): @spa_router.post("/visits/{visit_name}/sessions/{session_id}/flush_spa_processing") def flush_spa_processing( - visit_name: str, session_id: MurfeySessionID, tag: Tag, db=murfey_db + visit_name: str, session_id: MurfeySessionID, tag: Tag, db: Session = murfey_db ): zocalo_message = { "register": "spa.flush_spa_preprocess", "session_id": session_id, "tag": tag.tag, } - if _transport_object: - _transport_object.send(_transport_object.feedback_queue, zocalo_message) + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, zocalo_message + ) return @@ -490,7 +505,7 @@ async def request_spa_preprocessing( visit_name: str, session_id: MurfeySessionID, proc_file: SPAProcessFile, - db=murfey_db, + db: Session = murfey_db, ): instrument_name = ( db.exec(select(Session).where(Session.id == session_id)).one().instrument_name @@ -509,13 +524,17 @@ async def request_spa_preprocessing( .where(AutoProcProgram.pj_id == ProcessingJob.id) .where(ProcessingJob.recipe == "em-spa-preprocess") ).one() - params = db.exec( - select(SPARelionParameters, ClassificationFeedbackParameters) - .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) - ).one() - proc_params: Optional[dict] = dict(params[0]) - feedback_params = params[1] + # SPARelionParameters is an ORM row, but the recipe parameters below are + # read by key, so convert to a dict. (April refactor dropped this + # dict() wrap and left the dict-style access, raising "'SPARelionParameters' + # object is not subscriptable" on every micrograph.) + proc_params: Optional[dict] = dict( + db.exec( + select(SPARelionParameters).where( + SPARelionParameters.pj_id == collected_ids[2].id + ) + ).one() + ) except sqlalchemy.exc.NoResultFound: proc_params = None try: @@ -530,6 +549,9 @@ async def request_spa_preprocessing( .one()[0] .id ) + except NoResultFound: + logger.warning("No foil hole ID found") + foil_hole_id = None except Exception as e: logger.warning( f"Foil hole ID not found for foil hole {sanitise(str(proc_file.foil_hole_id))}: {e}", @@ -540,10 +562,6 @@ async def request_spa_preprocessing( detached_ids = [c.id for c in collected_ids] murfey_ids = _murfey_id(detached_ids[3], db, number=2, close=False) - - if feedback_params.picker_murfey_id is None: - feedback_params.picker_murfey_id = murfey_ids[1] - db.add(feedback_params) movie = Movie( murfey_id=murfey_ids[0], data_collection_id=detached_ids[1], @@ -642,12 +660,11 @@ async def request_spa_preprocessing( ), }, } - # log.info(f"Sending Zocalo message {zocalo_message}") - if _transport_object: + if murfey.server._transport_object: zocalo_message["parameters"]["feedback_queue"] = ( - _transport_object.feedback_queue + murfey.server._transport_object.feedback_queue ) - _transport_object.send("processing_recipe", zocalo_message) + murfey.server._transport_object.send("processing_recipe", zocalo_message) else: logger.error( f"Pre-processing was requested for {sanitise(Path(proc_file.path).name)} " @@ -681,7 +698,9 @@ async def request_spa_preprocessing( @tomo_router.post("/sessions/{session_id}/tomography_processing_parameters") def register_tomo_proc_params( - session_id: MurfeySessionID, proc_params: ProcessingParametersTomo, db=murfey_db + session_id: MurfeySessionID, + proc_params: ProcessingParametersTomo, + db: Session = murfey_db, ): session_processing_parameters = db.exec( select(SessionProcessingParameters).where( @@ -700,8 +719,10 @@ def register_tomo_proc_params( **dict(proc_params), "session_id": session_id, } - if _transport_object: - _transport_object.send(_transport_object.feedback_queue, zocalo_message) + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, zocalo_message + ) class Source(BaseModel): @@ -712,7 +733,10 @@ class Source(BaseModel): "/visits/{visit_name}/sessions/{session_id}/flush_tomography_processing" ) def flush_tomography_processing( - visit_name: str, session_id: MurfeySessionID, rsync_source: Source, db=murfey_db + visit_name: str, + session_id: MurfeySessionID, + rsync_source: Source, + db: Session = murfey_db, ): zocalo_message = { "register": "flush_tomography_preprocess", @@ -720,8 +744,10 @@ def flush_tomography_processing( "visit_name": visit_name, "data_collection_group_tag": rsync_source.rsync_source, } - if _transport_object: - _transport_object.send(_transport_object.feedback_queue, zocalo_message) + if murfey.server._transport_object: + murfey.server._transport_object.send( + murfey.server._transport_object.feedback_queue, zocalo_message + ) return @@ -733,7 +759,7 @@ class TiltSeriesInfo(BaseModel): @tomo_router.post("/visits/{visit_name}/tilt_series") def register_tilt_series( - visit_name: str, tilt_series_info: TiltSeriesInfo, db=murfey_db + visit_name: str, tilt_series_info: TiltSeriesInfo, db: Session = murfey_db ): session_id = tilt_series_info.session_id if db.exec( @@ -762,7 +788,7 @@ class TiltSeriesGroupInfo(BaseModel): def register_tilt_series_length( session_id: int, tilt_series_group: TiltSeriesGroupInfo, - db=murfey_db, + db: Session = murfey_db, ): tilt_series_db = db.exec( select(TiltSeries) @@ -800,7 +826,7 @@ async def request_tomography_preprocessing( visit_name: str, session_id: MurfeySessionID, proc_file: TomoProcessFile, - db=murfey_db, + db: Session = murfey_db, ): instrument_name = ( db.exec(select(Session).where(Session.id == session_id)).one().instrument_name @@ -884,11 +910,11 @@ async def request_tomography_preprocessing( "fm_int_file": proc_file.eer_fractionation_file, }, } - if _transport_object: + if murfey.server._transport_object: zocalo_message["parameters"]["feedback_queue"] = ( - _transport_object.feedback_queue + murfey.server._transport_object.feedback_queue ) - _transport_object.send("processing_recipe", zocalo_message) + murfey.server._transport_object.send("processing_recipe", zocalo_message) else: logger.error( f"Pre-processing was requested for {sanitise(Path(proc_file.path).name)} " @@ -915,7 +941,7 @@ def register_completed_tilt_series( visit_name: str, session_id: MurfeySessionID, tilt_series_group: TiltSeriesGroupInfo, - db=murfey_db, + db: Session = murfey_db, ): tilt_series_db = db.exec( select(TiltSeries) @@ -1003,9 +1029,9 @@ def register_completed_tilt_series( "y_location": ts.y_location, }, } - if _transport_object: + if murfey.server._transport_object: logger.info(f"Sending Zocalo message for processing: {zocalo_message}") - _transport_object.send( + murfey.server._transport_object.send( "processing_recipe", zocalo_message, new_connection=True ) else: @@ -1017,7 +1043,7 @@ def register_completed_tilt_series( @tomo_router.post("/visits/{visit_name}/rerun_tilt_series") def register_tilt_series_for_rerun( - visit_name: str, tilt_series_info: TiltSeriesInfo, db=murfey_db + visit_name: str, tilt_series_info: TiltSeriesInfo, db: Session = murfey_db ): """Set processing to false for cases where an extra tilt is found for a series""" session_id = tilt_series_info.session_id @@ -1042,7 +1068,10 @@ class TiltInfo(BaseModel): @tomo_router.post("/visits/{visit_name}/sessions/{session_id}/tilt") async def register_tilt( - visit_name: str, session_id: MurfeySessionID, tilt_info: TiltInfo, db=murfey_db + visit_name: str, + session_id: MurfeySessionID, + tilt_info: TiltInfo, + db: Session = murfey_db, ): def _add_tilt(): tilt_series_id = ( @@ -1135,8 +1164,8 @@ def get_samples(visit_name: str, db=ispyb_db) -> List[Sample]: def register_sample_group(visit_name: str, db=ispyb_db) -> dict: proposal_id = get_proposal_id(visit_name[:2], visit_name.split("-")[0][2:], db=db) record = BLSampleGroup(proposalId=proposal_id) - if _transport_object: - return _transport_object.do_insert_sample_group(record) + if murfey.server._transport_object: + return murfey.server._transport_object.do_insert_sample_group(record) return {"success": False} @@ -1147,8 +1176,10 @@ class BLSampleParameters(BaseModel): @correlative_router.post("/visit/{visit_name}/sample") def register_sample(visit_name: str, sample_params: BLSampleParameters) -> dict: record = BLSample() - if _transport_object: - return _transport_object.do_insert_sample(record, sample_params.sample_group_id) + if murfey.server._transport_object: + return murfey.server._transport_object.do_insert_sample( + record, sample_params.sample_group_id + ) return {"success": False} @@ -1164,8 +1195,8 @@ def register_subsample( record = BLSubSample( blSampleId=subsample_params.sample_id, imgFilePath=subsample_params.image_path ) - if _transport_object: - return _transport_object.do_insert_subsample(record) + if murfey.server._transport_object: + return murfey.server._transport_object.do_insert_subsample(record) return {"success": False} @@ -1182,6 +1213,6 @@ def register_sample_image( blSampleId=sample_image_params.sample_id, imageFullPath=sample_image_params.image_path, ) - if _transport_object: - return _transport_object.do_insert_sample_image(record) + if murfey.server._transport_object: + return murfey.server._transport_object.do_insert_sample_image(record) return {"success": False} diff --git a/src/murfey/server/feedback.py b/src/murfey/server/feedback.py index 6765951c9..1f93d9337 100644 --- a/src/murfey/server/feedback.py +++ b/src/murfey/server/feedback.py @@ -12,22 +12,17 @@ import subprocess import time from datetime import datetime -from functools import partial -from importlib.metadata import ( - EntryPoint, # For type hinting only - entry_points, -) from pathlib import Path from typing import Dict, List, NamedTuple, Tuple import mrcfile import numpy as np +from gemmi import cif +from pipeliner.project_graph import ProjectGraph +from pipeliner.star_keys import GENERAL_BLOCK, JOB_COUNTER from sqlalchemy import func from sqlalchemy.exc import ( InvalidRequestError, - NoResultFound, - OperationalError, - PendingRollbackError, SQLAlchemyError, ) from sqlalchemy.orm.exc import ObjectDeletedError @@ -50,6 +45,126 @@ logger = logging.getLogger("murfey.server.feedback") +# The first job number available to dynamic SPA feedback jobs. Jobs 1..6 are +# the fixed preprocessing jobs (Import, MotionCorr, CtfFind, AutoPick, Extract, +# Select); Class2D and everything downstream start here. Used as a floor when +# allocating so a feedback job is never handed a preprocessing job's number even +# if it is scheduled before Extract/Select have been registered by the node +# creator (which only happens once compute has finished). +FIRST_FEEDBACK_JOB = 7 + + +def _current_pipeline_job_counter(visit_name: str) -> int: + """Return the next jobNNN Pipeliner will allocate for visit_name. + + Reads the JOB_COUNTER value from default_pipeline.star so that + SPA feedback decisions are anchored to Pipeliner's actual state instead + of an independent integer counter that drifts. + + Falls back to ``FIRST_FEEDBACK_JOB`` if the file is missing — this preserves + the previous behaviour for non Doppio runs. + + NOTE: this is a non-reserving read. It is only safe when the result is used + immediately and no other feedback job will be scheduled before the read + value is registered in the pipeline. For job *allocation* use + ``_reserve_pipeline_job_numbers`` instead, which advances the counter under + the project lock so the numbers cannot be reused. + """ + pipeline_file = Path(visit_name) / "default_pipeline.star" + if not pipeline_file.is_file(): + return FIRST_FEEDBACK_JOB + try: + dp = cif.read_file(str(pipeline_file)) + block = dp.find_block(GENERAL_BLOCK) + if block is None: + return FIRST_FEEDBACK_JOB + return int(block.find_value(JOB_COUNTER)) + except Exception: + logger.warning( + "Failed to read JOB_COUNTER from %s — falling back to legacy job number", + pipeline_file, + exc_info=True, + ) + return FIRST_FEEDBACK_JOB + + +def _reserve_pipeline_job_numbers(visit_name: str, n_jobs: int) -> int: + """Atomically reserve ``n_jobs`` job numbers and return the first. + + Opens default_pipeline.star read/write under the project's ``.relion_lock`` + and advances ``_rlnPipeLineJobCounter`` by ``n_jobs ``. Because the counter + is consumed *now* — rather than when the job later completes and is + registered by the node creator — two feedback jobs (or a feedback job and a + manually launched job) can no longer be handed the same number during the + window between scheduling a job and its registration. + + The reserved block must cover every job the scheduling step will create + (e.g. InitialModel + Class3D), so the next allocation starts strictly after + them. The node creator's ``adjust_job_counter`` keeps the on-disk counter to + ``max(disk, job_number + 1)``, so a correctly sized block leaves the + counter exactly where this function set it (no gaps, no double counting). + + Falls back to ``FIRST_FEEDBACK_JOB`` without reserving when the pipeline file + does not yet exist (non Doppio runs / before the first job is registered). + """ + project_dir = Path(visit_name) + pipeline_file = project_dir / "default_pipeline.star" + if not pipeline_file.is_file(): + return FIRST_FEEDBACK_JOB + if n_jobs < 1: + raise ValueError("Must reserve at least one job number") + try: + with ProjectGraph( + read_only=False, pipeline_dir=str(project_dir), name="default" + ) as project: + base = max(project.job_counter, FIRST_FEEDBACK_JOB) + project.job_counter = base + n_jobs + return base + except Exception: + logger.warning( + "Failed to reserve %d job number(s) in %s — falling back to a " + "non-reserving counter read", + n_jobs, + pipeline_file, + exc_info=True, + ) + return _current_pipeline_job_counter(visit_name) + + +def _reserve_2d_classification_jobs( + visit_name: str, feedback_params: db.ClassificationFeedbackParameters +) -> int: + """Reserve the Pipeliner jobs for one complete 2D batch. + + A complete batch runs Class2D, then the autoselect Select job, and (the + first time only) the shared combine Select job that all batches feed into. + With icebreaker enabled an extra IceBreaker job sits between Class2D and + autoselect. This reserves them up front, sets ``feedback_params.next_job`` to + the Class2D number, and fills in ``star_combination_job`` (the combine + number) the first time it is called. The autoselect job is always + ``star_combination_job - 1`` (see select_classes). + + Returns the reserved Class2D job number. + """ + # Class2D (+ IceBreaker) + autoselect Select + per_batch_jobs = 3 if default_spa_parameters.do_icebreaker_jobs else 2 + if not feedback_params.star_combination_job: + # Also reserve the one-off shared combine Select job, one after the + # autoselect job. + base = _reserve_pipeline_job_numbers(visit_name, per_batch_jobs + 1) + feedback_params.star_combination_job = base + per_batch_jobs + else: + base = _reserve_pipeline_job_numbers(visit_name, per_batch_jobs) + feedback_params.next_job = base + return base + + +def _visit_name_for_session(session_id: int, _db) -> str: + """Return the visit (project directory) for a Murfey session id.""" + session_row = _db.exec(select(db.Session).where(db.Session.id == session_id)).one() + return session_row.visit + + try: _url = url(get_security_config()) engine = create_engine(_url) @@ -322,10 +437,6 @@ def _get_spa_params( def _release_2d_hold(message: dict, _db): relion_params, feedback_params = _get_spa_params(message["program_id"], _db) - if not feedback_params.star_combination_job: - feedback_params.star_combination_job = feedback_params.next_job + ( - 3 if default_spa_parameters.do_icebreaker_jobs else 2 - ) pj_id = _pj_id(message["program_id"], _db, recipe="em-spa-class2d") if feedback_params.rerun_class2d: first_class2d = _db.exec( @@ -339,6 +450,16 @@ def _release_2d_hold(message: dict, _db): machine_config = get_machine_config(instrument_name=instrument_name)[ instrument_name ] + if first_class2d.complete and not feedback_params.star_combination_job: + # The held batch is now complete and will run the autoselect Select + # plus the one-off shared combine Select job. The Class2D job re-uses + # its existing (already reserved) directory (message["job_dir"]), so + # reserve only the trailing jobs: combine goes at the end and the + # autoselect job is combine - 1 (see select_classes). + visit_name = _visit_name_for_session(message["session_id"], _db) + trailing = 3 if default_spa_parameters.do_icebreaker_jobs else 2 + base = _reserve_pipeline_job_numbers(visit_name, trailing) + feedback_params.star_combination_job = base + (trailing - 1) zocalo_message: dict = { "parameters": { "particles_file": first_class2d.particles_file, @@ -373,10 +494,6 @@ def _release_2d_hold(message: dict, _db): }, "recipes": [machine_config.recipes.get("em-spa-class2d", "em-spa-class2d")], } - if first_class2d.complete: - feedback_params.next_job += ( - 4 if default_spa_parameters.do_icebreaker_jobs else 3 - ) feedback_params.rerun_class2d = False _db.add(feedback_params) if first_class2d.complete: @@ -586,7 +703,12 @@ def _register_incomplete_2d_batch(message: dict, _db): _db.commit() _db.close() return - feedback_params.next_job = 10 if default_spa_parameters.do_icebreaker_jobs else 7 + # Reserve the single Class2D job this incomplete batch will create. An + # incomplete batch runs Class2D only (no autoselect/combine), so one job is + # enough; reserving advances the Pipeliner counter now so the next batch + # cannot be handed the same number before this job is registered. + visit_name = _visit_name_for_session(message["session_id"], _db) + feedback_params.next_job = _reserve_pipeline_job_numbers(visit_name, 1) feedback_params.hold_class2d = True relion_options = dict(relion_params) other_options = dict(feedback_params) @@ -729,19 +851,10 @@ def _register_complete_2d_batch(message: dict, _db): murfey_ids, class2d_message["particles_file"], _app_id(pj_id, _db), _db ) elif not feedback_params.class_selection_score: - # For the first batch, start a container and set the database to wait - job_number_after_first_batch = ( - 10 if default_spa_parameters.do_icebreaker_jobs else 7 - ) - if ( - feedback_params.next_job is not None - and feedback_params.next_job < job_number_after_first_batch - ): - feedback_params.next_job = job_number_after_first_batch - if not feedback_params.star_combination_job: - feedback_params.star_combination_job = feedback_params.next_job + ( - 3 if default_spa_parameters.do_icebreaker_jobs else 2 - ) + # Reserve Class2D + autoselect (+ combine on the first batch) up front so + # the numbers cannot be reused before the jobs are registered. + visit_name = _visit_name_for_session(message["session_id"], _db) + _reserve_2d_classification_jobs(visit_name, feedback_params) if _db.exec( select(func.count(db.Class2DParameters.particles_file)) .where(db.Class2DParameters.pj_id == pj_id) @@ -809,14 +922,14 @@ def _register_complete_2d_batch(message: dict, _db): "processing_recipe", zocalo_message, new_connection=True ) feedback_params.hold_class2d = True - feedback_params.next_job += ( - 4 if default_spa_parameters.do_icebreaker_jobs else 3 - ) _db.add(feedback_params) _db.commit() _db.close() else: - # Send all other messages on to a container + # star_combination_job is already set by now, so this reserves just the + # Class2D + autoselect jobs for this batch. + visit_name = _visit_name_for_session(message["session_id"], _db) + _reserve_2d_classification_jobs(visit_name, feedback_params) if _db.exec( select(func.count(db.Class2DParameters.particles_file)) .where(db.Class2DParameters.pj_id == pj_id) @@ -883,9 +996,7 @@ def _register_complete_2d_batch(message: dict, _db): murfey.server._transport_object.send( "processing_recipe", zocalo_message, new_connection=True ) - feedback_params.next_job += ( - 3 if default_spa_parameters.do_icebreaker_jobs else 2 - ) + feedback_params.hold_class2d = False _db.add(feedback_params) _db.commit() _db.close() @@ -930,17 +1041,14 @@ def _flush_class2d( .where(db.Class2DParameters.pj_id == pj_id) .where(db.Class2DParameters.complete) ).all() - if not feedback_params.next_job: - feedback_params.next_job = ( - 10 if default_spa_parameters.do_icebreaker_jobs else 7 - ) - if not feedback_params.star_combination_job: - feedback_params.star_combination_job = feedback_params.next_job + ( - 3 if default_spa_parameters.do_icebreaker_jobs else 2 - ) + # Reserve each batch's Class2D + autoselect jobs (and the shared combine + # job the first time) as it is dispatched, so queued batches never share a + # number and the counter is advanced under the project lock. + visit_name = _visit_name_for_session(session_id, _db) for saved_message in class2d_db: # Send all held Class2D messages on with the selection score added _db.expunge(saved_message) + _reserve_2d_classification_jobs(visit_name, feedback_params) zocalo_message: dict = { "parameters": { "particles_file": saved_message.particles_file, @@ -973,9 +1081,6 @@ def _flush_class2d( murfey.server._transport_object.send( "processing_recipe", zocalo_message, new_connection=True ) - feedback_params.next_job += ( - 3 if default_spa_parameters.do_icebreaker_jobs else 2 - ) _db.delete(saved_message) _db.add(feedback_params) _db.commit() @@ -1199,10 +1304,12 @@ def _register_3d_batch(message: dict, _db): ) feedback_params.initial_model = str(rescaled_initial_model_path) other_options["initial_model"] = str(rescaled_initial_model_path) + # Reserve the InitialModel (base) + Class3D (base + 1) job up front so + # the Class3D number cannot be reused before the job is registered. + feedback_params.next_job = _reserve_pipeline_job_numbers(visit_name, 2) class3d_dir = ( f"{class3d_message['class3d_dir']}{(feedback_params.next_job + 1):03}" ) - feedback_params.next_job += 1 _db.add(feedback_params) _db.commit() @@ -1236,8 +1343,9 @@ def _register_3d_batch(message: dict, _db): _db.commit() _db.close() elif not feedback_params.initial_model: - # For the first batch, start a container and set the database to wait - next_job = feedback_params.next_job + # For the first batch, start a container and set the database to wait. + # Reserve the InitialModel (base) + Class3D (base + 1) jobs. + feedback_params.next_job = _reserve_pipeline_job_numbers(visit_name, 2) class3d_dir = ( f"{class3d_message['class3d_dir']}{(feedback_params.next_job + 1):03}" ) @@ -1257,8 +1365,6 @@ def _register_3d_batch(message: dict, _db): ) feedback_params.hold_class3d = True - next_job += 2 - feedback_params.next_job = next_job zocalo_message: dict = { "parameters": { "particles_file": class3d_message["particles_file"], @@ -1544,7 +1650,11 @@ def _register_refinement(message: dict, _db): .where(db.RefineParameters.tag == "symmetry") ).one() except SQLAlchemyError: - next_job = feedback_params.next_job + # Reserve the contiguous refinement block: re-extraction + # Select (base) + Extract (base + 1), Refine3D (base + 2), + # MaskCreate (base + 3) and PostProcess (base + 4). + visit_name = _visit_name_for_session(message["session_id"], _db) + feedback_params.next_job = _reserve_pipeline_job_numbers(visit_name, 5) refine_dir = f"{message['refine_dir']}{(feedback_params.next_job + 2):03}" refined_grp_uuid = _murfey_id(message["program_id"], _db)[0] refined_class_uuid = _murfey_id(message["program_id"], _db)[0] @@ -1585,14 +1695,6 @@ def _register_refinement(message: dict, _db): _db=_db, ) - if relion_options["symmetry"] == "C1": - # Extra Refine, Mask, PostProcess beyond for determined symmetry - next_job += 8 - else: - # Select and Extract particles, then Refine, Mask, PostProcess - next_job += 5 - feedback_params.next_job = next_job - zocalo_message: dict = { "parameters": { "refine_job_dir": refine_params.refine_dir, @@ -2197,12 +2299,20 @@ def feedback_listen(): channel_hint="", callback=None, sub_id=None ) ) - murfey.server._transport_object._connection_callback = partial( - murfey.server._transport_object.transport.subscribe, - murfey.server._transport_object.feedback_queue, - feedback_callback, - acknowledgement=True, - ) + # Re-subscription callback invoked by send() after reconnect() replaces + # _transport_object.transport. Resolve the transport at call time rather + # than capturing it now: a partial() bound to today's transport would + # re-subscribe on the old, closed connection and raise + # "add_callback_threadsafe() called on closed or closing connection". + def _resubscribe_feedback(): + transport_manager = murfey.server._transport_object + transport_manager.transport.subscribe( + transport_manager.feedback_queue, + feedback_callback, + acknowledgement=True, + ) + + murfey.server._transport_object._connection_callback = _resubscribe_feedback murfey.server._transport_object.transport.subscribe( murfey.server._transport_object.feedback_queue, feedback_callback, diff --git a/src/murfey/server/ispyb.py b/src/murfey/server/ispyb.py index 815f06354..5349e44d0 100644 --- a/src/murfey/server/ispyb.py +++ b/src/murfey/server/ispyb.py @@ -41,18 +41,22 @@ log = logging.getLogger("murfey.server.ispyb") security_config = get_security_config() -try: - ISPyBSession = sessionmaker( - bind=create_engine( - url(credentials=security_config.ispyb_credentials), - connect_args={"use_pure": True}, - pool_recycle=250, +if security_config.ispyb_credentials: + try: + ISPyBSession = sessionmaker( + bind=create_engine( + url(credentials=security_config.ispyb_credentials), + connect_args={"use_pure": True}, + pool_recycle=250, + ) ) - ) - log.info("Loaded ISPyB database session") -# Catch all errors associated with loading ISPyB database -except Exception: - log.error("Error loading ISPyB session", exc_info=True) + log.info("Loaded ISPyB database session") + # Catch all errors associated with loading ISPyB database + except Exception: + log.error("Error loading ISPyB session", exc_info=True) + ISPyBSession = lambda: None +else: + log.info("No ISPyB credentials set, using local database") ISPyBSession = lambda: None diff --git a/src/murfey/server/murfey_db.py b/src/murfey/server/murfey_db.py index 2afdb8544..3efd156f3 100644 --- a/src/murfey/server/murfey_db.py +++ b/src/murfey/server/murfey_db.py @@ -1,16 +1,38 @@ from __future__ import annotations +import sqlite3 from functools import partial import yaml from cryptography.fernet import Fernet from fastapi import Depends +from sqlalchemy import event +from sqlalchemy.engine import Engine from sqlalchemy.pool import NullPool from sqlmodel import Session, create_engine from murfey.util.config import Security, get_security_config +@event.listens_for(Engine, "connect") +def _configure_sqlite_connection(dbapi_connection, connection_record): + """Tune every SQLite connection; a no-op for Postgres. + + Doppio's feedback thread and micrograph watcher write concurrently, so the + SQLite defaults (``busy_timeout=0``, ``DELETE`` journal) make the second + writer fail instantly with "database is locked". WAL lets readers run + alongside a single writer, ``busy_timeout`` makes writers wait for the lock + instead of erroring, and ``synchronous=NORMAL`` (safe under WAL) skips an + fsync per commit. + """ + if isinstance(dbapi_connection, sqlite3.Connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.execute("PRAGMA busy_timeout=30000") + cursor.close() + + def url(security_config: Security | None = None) -> str: security_config = security_config or get_security_config() with open(security_config.murfey_db_credentials, "r") as stream: diff --git a/src/murfey/util/config.py b/src/murfey/util/config.py index fcf7ca243..dd879cc52 100644 --- a/src/murfey/util/config.py +++ b/src/murfey/util/config.py @@ -250,7 +250,7 @@ class Security(BaseModel): # Murfey server connection settings auth_url: str = "" - auth_type: Literal["password", "cookie"] = "password" + auth_type: Literal["password", "cookie", "none"] = "password" auth_algorithm: str = "" auth_key: str = "" cookie_key: str = "" diff --git a/src/murfey/util/db.py b/src/murfey/util/db.py index eb92f48fa..3771a4872 100644 --- a/src/murfey/util/db.py +++ b/src/murfey/util/db.py @@ -472,11 +472,6 @@ class MurfeyLedger(SQLModel, table=True): # type: ignore refine_parameters: Optional["RefineParameters"] = Relationship( back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} ) - classification_feedback_parameters: Optional["ClassificationFeedbackParameters"] = ( - Relationship( - back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} - ) - ) movies: Optional["Movie"] = Relationship( back_populates="murfey_ledger", sa_relationship_kwargs={"cascade": "delete"} ) @@ -708,13 +703,9 @@ class ClassificationFeedbackParameters(SQLModel, table=True): # type: ignore star_combination_job: int initial_model: str next_job: int - picker_murfey_id: Optional[int] = Field(default=None, foreign_key="murfeyledger.id") processing_job: Optional[ProcessingJob] = Relationship( back_populates="classification_feedback_parameters" ) - murfey_ledger: Optional[MurfeyLedger] = Relationship( - back_populates="classification_feedback_parameters" - ) class Class2DParameters(SQLModel, table=True): # type: ignore diff --git a/src/murfey/util/processing_params.py b/src/murfey/util/processing_params.py index 1ddb2b251..dbb660749 100644 --- a/src/murfey/util/processing_params.py +++ b/src/murfey/util/processing_params.py @@ -1,35 +1,94 @@ +import logging +import os from datetime import datetime from functools import lru_cache from pathlib import Path +from pipeliner.project_graph import ProjectGraph from pydantic import BaseModel from werkzeug.utils import secure_filename from murfey.util.config import MachineConfig, get_machine_config +logger = logging.getLogger("murfey.util.processing_params") + + +_DEFAULT_MOTIONCORR_FALLBACK = "job002" + + +@lru_cache(maxsize=16) +def _job_dir_for_alias_cached(visit_name: str, alias: str, mtime_ns: int) -> str | None: + """Read default_pipeline.star and return the jobNNN for the given alias. + + Returns None on any failure (missing file, graph read error, alias + not found). The mtime_ns argument is a cache key — when Pipeliner rewrites + default_pipeline.star its mtime changes and the next call falls through + to a fresh read. + """ + project_dir = Path(visit_name) + pipeline_file = project_dir / "default_pipeline.star" + if not pipeline_file.is_file(): + return None + try: + with ProjectGraph(pipeline_dir=project_dir, read_only=True) as graph: + for proc in graph.process_list: + proc_alias = getattr(proc, "alias", None) + if proc_alias and proc_alias.rstrip("/").endswith(alias): + # proc.name is e.g. "MotionCorr/job003/" + return Path(proc.name).name + except Exception: + logger.error( + "ProjectGraph read failed while looking up alias %r in %s", + alias, + pipeline_file, + exc_info=True, + ) + return None + return None + + +def _job_dir_for_alias(visit_name: str, alias: str) -> str: + """Return the Pipeliner jobNNN for alias in the given project. + + visit_name must be an path to the project directory. + Falls back to the positional default job002 and logs a warning so + drift from the live pipeline is visible in the logs instead of silent. + """ + project_dir = Path(visit_name).resolve() + pipeline_file = project_dir / "default_pipeline.star" + try: + mtime_ns = pipeline_file.stat().st_mtime_ns + except FileNotFoundError: + logger.warning( + "default_pipeline.star missing at %s — falling back to %s for alias %r", + pipeline_file, + _DEFAULT_MOTIONCORR_FALLBACK, + alias, + ) + return _DEFAULT_MOTIONCORR_FALLBACK + job_dir = _job_dir_for_alias_cached(str(project_dir), alias, mtime_ns) + if job_dir is None: + logger.warning( + "Alias %r not found in %s — falling back to %s", + alias, + pipeline_file, + _DEFAULT_MOTIONCORR_FALLBACK, + ) + return _DEFAULT_MOTIONCORR_FALLBACK + return job_dir + def motion_corrected_mrc( input_movie: Path, visit_name: str, machine_config: MachineConfig ): - parts = [secure_filename(p) for p in input_movie.parts] - visit_idx = parts.index(visit_name) - core = Path("/") / Path(*parts[: visit_idx + 1]) - ppath = Path("/") / Path(*parts) - if machine_config.process_multiple_datasets: - sub_dataset = ppath.relative_to(core).parts[0] - else: - sub_dataset = "" - extra_path = machine_config.processed_extra_directory + movie = os.path.basename(input_movie) + job_dir = _job_dir_for_alias(visit_name, "Live_motioncorr") mrc_out = ( - core - / machine_config.processed_directory_name - / sub_dataset - / extra_path + Path(visit_name) / "MotionCorr" - / "job002" + / job_dir / "Movies" - / ppath.parent.relative_to(core / sub_dataset) - / str(ppath.stem + "_motion_corrected.mrc") + / str(movie + "_motion_corrected.mrc") ) return Path("/".join(secure_filename(p) for p in mrc_out.parts)) diff --git a/src/murfey/workflows/register_data_collection.py b/src/murfey/workflows/register_data_collection.py index 0c61ba5d5..87de8a3f6 100644 --- a/src/murfey/workflows/register_data_collection.py +++ b/src/murfey/workflows/register_data_collection.py @@ -6,7 +6,7 @@ from sqlmodel.orm.session import Session as SQLModelSession import murfey.util.db as MurfeyDB -from murfey.server import _transport_object +import murfey.server from murfey.server.ispyb import ISPyBSession, get_session_id from murfey.util import sanitise @@ -15,7 +15,7 @@ def run(message: dict, murfey_db: SQLModelSession) -> dict[str, bool]: # Fail immediately if transport manager was not provided - if _transport_object is None: + if murfey.server._transport_object is None: logger.error("Unable to find transport manager") return {"success": False, "requeue": False} @@ -80,7 +80,7 @@ def run(message: dict, murfey_db: SQLModelSession) -> dict[str, bool]: axisEnd=message.get("axis_end"), numberOfImages=message.get("tilt_series_length"), ) - dcid = _transport_object.do_insert_data_collection( + dcid = murfey.server._transport_object.do_insert_data_collection( record, tag=( message.get("tag") diff --git a/src/murfey/workflows/register_data_collection_group.py b/src/murfey/workflows/register_data_collection_group.py index 0908e769b..dae2111a8 100644 --- a/src/murfey/workflows/register_data_collection_group.py +++ b/src/murfey/workflows/register_data_collection_group.py @@ -7,7 +7,7 @@ from sqlmodel import select from sqlmodel.orm.session import Session as SQLModelSession -from murfey.server import _transport_object +import murfey.server from murfey.server.ispyb import ISPyBSession, get_session_id from murfey.util.db import DataCollectionGroup @@ -16,7 +16,7 @@ def run(message: dict, murfey_db: SQLModelSession) -> dict[str, bool]: # Fail immediately if no transport wrapper is found - if _transport_object is None: + if murfey.server._transport_object is None: logger.error("Unable to find transport manager") return {"success": False, "requeue": False} @@ -50,9 +50,9 @@ def run(message: dict, murfey_db: SQLModelSession) -> dict[str, bool]: experimentTypeId=message["experiment_type_id"], ) - dcgid = _transport_object.do_insert_data_collection_group(record).get( - "return_value", None - ) + dcgid = murfey.server._transport_object.do_insert_data_collection_group( + record + ).get("return_value", None) if dcgid is None: time.sleep(2) @@ -75,9 +75,9 @@ def run(message: dict, murfey_db: SQLModelSession) -> dict[str, bool]: if color_flags := message.get("color_flags", {}): for col_name, value in color_flags.items(): setattr(atlas_record, col_name, value) - atlas_id = _transport_object.do_insert_atlas(atlas_record).get( - "return_value", None - ) + atlas_id = murfey.server._transport_object.do_insert_atlas( + atlas_record + ).get("return_value", None) murfey_dcg = DataCollectionGroup( id=dcgid, diff --git a/src/murfey/workflows/register_processing_job.py b/src/murfey/workflows/register_processing_job.py index 1bb2d5f52..b6ab7b5a5 100644 --- a/src/murfey/workflows/register_processing_job.py +++ b/src/murfey/workflows/register_processing_job.py @@ -7,7 +7,7 @@ import murfey.server.prometheus as prom import murfey.util.db as MurfeyDB -from murfey.server import _transport_object +import murfey.server from murfey.server.ispyb import ISPyBSession from murfey.util import sanitise @@ -16,7 +16,7 @@ def run(message: dict, murfey_db: SQLModelSession): # Faill immediately if not transport manager is set - if _transport_object is None: + if murfey.server._transport_object is None: logger.error("Unable to find transport manager") return {"success": False, "requeue": False} @@ -56,11 +56,11 @@ def run(message: dict, murfey_db: SQLModelSession): ISPyBDB.ProcessingJobParameter(parameterKey=k, parameterValue=v) for k, v in message["job_parameters"].items() ] - pid = _transport_object.do_create_ispyb_job( + pid = murfey.server._transport_object.do_create_ispyb_job( record, params=job_parameters ).get("return_value", None) else: - pid = _transport_object.do_create_ispyb_job(record).get( + pid = murfey.server._transport_object.do_create_ispyb_job(record).get( "return_value", None ) if pid is None: @@ -86,7 +86,7 @@ def run(message: dict, murfey_db: SQLModelSession): record = ISPyBDB.AutoProcProgram( processingJobId=pid, processingStartTime=datetime.now() ) - appid = _transport_object.do_update_processing_status(record).get( + appid = murfey.server._transport_object.do_update_processing_status(record).get( "return_value", None ) if appid is None: diff --git a/src/murfey/workflows/spa/flush_spa_preprocess.py b/src/murfey/workflows/spa/flush_spa_preprocess.py index e90a74bc1..27847aea0 100644 --- a/src/murfey/workflows/spa/flush_spa_preprocess.py +++ b/src/murfey/workflows/spa/flush_spa_preprocess.py @@ -18,13 +18,12 @@ except ImportError: SMARTEM_ACTIVE = False -from murfey.server import _transport_object +import murfey.server from murfey.server.feedback import _murfey_id from murfey.util import sanitise, secure_path from murfey.util.config import get_machine_config, get_microscope from murfey.util.db import ( AutoProcProgram, - ClassificationFeedbackParameters, DataCollection, DataCollectionGroup, FoilHole, @@ -96,17 +95,19 @@ def register_grid_square( ) grid_square.pixel_size = grid_square_params.pixel_size or grid_square.pixel_size grid_square.image = grid_square_params.image or grid_square.image - if _transport_object: - _transport_object.do_update_grid_square(grid_square.id, grid_square_params) + if murfey.server._transport_object: + murfey.server._transport_object.do_update_grid_square( + grid_square.id, grid_square_params + ) else: # No existing grid square in the murfey database - if _transport_object: + if murfey.server._transport_object: dcg = murfey_db.exec( select(DataCollectionGroup) .where(DataCollectionGroup.session_id == session_id) .where(DataCollectionGroup.tag == grid_square_params.tag) ).one() - gs_ispyb_response = _transport_object.do_insert_grid_square( + gs_ispyb_response = murfey.server._transport_object.do_insert_grid_square( dcg.atlas_id, gsid, grid_square_params ) else: @@ -274,14 +275,14 @@ def register_foil_hole( foil_hole_params.thumbnail_size_y or foil_hole.thumbnail_size_y ) or jpeg_size[1] foil_hole.pixel_size = foil_hole_params.pixel_size or foil_hole.pixel_size - if _transport_object and gs.readout_area_x: - _transport_object.do_update_foil_hole( + if murfey.server._transport_object and gs.readout_area_x: + murfey.server._transport_object.do_update_foil_hole( foil_hole.id, gs.thumbnail_size_x / gs.readout_area_x, foil_hole_params ) else: # No existing foil hole in the murfey database - if _transport_object: - fh_ispyb_response = _transport_object.do_insert_foil_hole( + if murfey.server._transport_object: + fh_ispyb_response = murfey.server._transport_object.do_insert_foil_hole( gs.id, gs.thumbnail_size_x / gs.readout_area_x if gs.readout_area_x else None, foil_hole_params, @@ -495,18 +496,11 @@ def flush_spa_preprocess(message: dict, murfey_db: Session) -> dict[str, bool]: .where(AutoProcProgram.pj_id == ProcessingJob.id) .where(ProcessingJob.recipe == recipe_name) ).one() - params = murfey_db.exec( - select(SPARelionParameters, ClassificationFeedbackParameters) - .where(SPARelionParameters.pj_id == collected_ids[2].id) - .where(ClassificationFeedbackParameters.pj_id == SPARelionParameters.pj_id) - ).one() - proc_params = params[0] - feedback_params = params[1] - if not proc_params: - logger.warning( - f"No SPA processing parameters found for client processing job ID {collected_ids[2].id}" + proc_params = murfey_db.exec( + select(SPARelionParameters).where( + SPARelionParameters.pj_id == collected_ids[2].id ) - return {"success": False, "requeue": False} + ).one() murfey_ids = _murfey_id( collected_ids[3].id, @@ -514,10 +508,6 @@ def flush_spa_preprocess(message: dict, murfey_db: Session) -> dict[str, bool]: number=2 * len(stashed_files), close=False, ) - if feedback_params.picker_murfey_id is None: - feedback_params.picker_murfey_id = murfey_ids[1] - murfey_db.add(feedback_params) - for i, f in enumerate(stashed_files): try: foil_hole_id = None @@ -586,11 +576,11 @@ def flush_spa_preprocess(message: dict, murfey_db: Session) -> dict[str, bool]: "foil_hole_id": foil_hole_id, }, } - if _transport_object: + if murfey.server._transport_object: zocalo_message["parameters"]["feedback_queue"] = ( - _transport_object.feedback_queue + murfey.server._transport_object.feedback_queue ) - _transport_object.send( + murfey.server._transport_object.send( "processing_recipe", zocalo_message, new_connection=True ) murfey_db.delete(f) diff --git a/src/murfey/workflows/spa/picking.py b/src/murfey/workflows/spa/picking.py index f7c915d56..da542d17a 100644 --- a/src/murfey/workflows/spa/picking.py +++ b/src/murfey/workflows/spa/picking.py @@ -5,8 +5,8 @@ from sqlalchemy import func from sqlmodel import Session, select +import murfey.server import murfey.server.prometheus as prom -from murfey.server import _transport_object from murfey.server.feedback import ( _app_id, _pj_id, @@ -127,11 +127,11 @@ def _register_picked_particles_use_diameter(message: dict, _db: Session): }, "recipes": ["em-spa-extract"], } - if _transport_object: + if murfey.server._transport_object: zocalo_message["parameters"]["feedback_queue"] = ( - _transport_object.feedback_queue + murfey.server._transport_object.feedback_queue ) - _transport_object.send( + murfey.server._transport_object.send( "processing_recipe", zocalo_message, new_connection=True ) else: @@ -167,11 +167,11 @@ def _register_picked_particles_use_diameter(message: dict, _db: Session): }, "recipes": ["em-spa-extract"], } - if _transport_object: + if murfey.server._transport_object: zocalo_message["parameters"]["feedback_queue"] = ( - _transport_object.feedback_queue + murfey.server._transport_object.feedback_queue ) - _transport_object.send( + murfey.server._transport_object.send( "processing_recipe", zocalo_message, new_connection=True ) @@ -249,11 +249,13 @@ def _register_picked_particles_use_boxsize(message: dict, _db: Session): }, "recipes": ["em-spa-extract"], } - if _transport_object: + if murfey.server._transport_object: zocalo_message["parameters"]["feedback_queue"] = ( - _transport_object.feedback_queue + murfey.server._transport_object.feedback_queue + ) + murfey.server._transport_object.send( + "processing_recipe", zocalo_message, new_connection=True ) - _transport_object.send("processing_recipe", zocalo_message, new_connection=True) _db.close() @@ -266,8 +268,8 @@ def _request_email( config = get_machine_config(instrument_name=session.instrument_name)[ session.instrument_name ] - if _transport_object: - _transport_object.send( + if murfey.server._transport_object: + murfey.server._transport_object.send( config.notifications_queue, { "groupId": dcg_id, diff --git a/tests/server/test_feedback_job_reservation.py b/tests/server/test_feedback_job_reservation.py new file mode 100644 index 000000000..8c701d939 --- /dev/null +++ b/tests/server/test_feedback_job_reservation.py @@ -0,0 +1,140 @@ +""" +Tests for the CCP-EM Pipeliner job-number reservation helpers in murfey.server.feedback + +These guard the duplicate-job-number fix: SPA feedback used to read the +Pipeliner job counter without advancing it, so two jobs scheduled before the +first had been registered by the node creator reused the same number. The +helpers below now reserve (advance) the counter under the project lock at +schedule time. +""" + +from types import SimpleNamespace +from unittest import mock + +import pytest + + +@pytest.fixture +def feedback(): + """Import murfey.server.feedback with its module-level DB setup stubbed out.""" + with ( + mock.patch( + "murfey.util.config.get_security_config", return_value=mock.MagicMock() + ), + mock.patch("murfey.server.murfey_db.url", return_value=mock.MagicMock()), + mock.patch("sqlmodel.create_engine", return_value=mock.MagicMock()), + ): + import murfey.server.feedback as feedback_module + + yield feedback_module + + +def _make_pipeline(pipeline_dir, job_counter: int) -> None: + """Write a valid default_pipeline.star carrying the given job counter.""" + from pipeliner.project_graph import ProjectGraph + + with ProjectGraph( + create_new=True, + read_only=False, + pipeline_dir=str(pipeline_dir), + name="default", + ) as project: + project.job_counter = job_counter + + +def _read_counter(pipeline_dir) -> int: + from pipeliner.project_graph import ProjectGraph + + with ProjectGraph( + read_only=True, pipeline_dir=str(pipeline_dir), name="default" + ) as project: + return project.job_counter + + +def test_reserve_advances_counter_and_returns_base(feedback, tmp_path): + _make_pipeline(tmp_path, job_counter=12) + + base = feedback._reserve_pipeline_job_numbers(str(tmp_path), 3) + + assert base == 12 + # The counter is consumed now, not when the job later registers. + assert _read_counter(tmp_path) == 15 + + +def test_reserve_blocks_are_contiguous_and_non_overlapping(feedback, tmp_path): + _make_pipeline(tmp_path, job_counter=12) + + first = feedback._reserve_pipeline_job_numbers(str(tmp_path), 2) + second = feedback._reserve_pipeline_job_numbers(str(tmp_path), 2) + + assert first == 12 + assert second == 14 # strictly after the first block — never reused + assert _read_counter(tmp_path) == 16 + + +def test_reserve_floors_at_first_feedback_job(feedback, tmp_path): + # Counter still in the preprocessing range (Extract=5/Select=6 not yet + # registered). A Class2D job must never be handed 5 or 6. + _make_pipeline(tmp_path, job_counter=5) + + base = feedback._reserve_pipeline_job_numbers(str(tmp_path), 2) + + assert base == feedback.FIRST_FEEDBACK_JOB == 7 + assert _read_counter(tmp_path) == 9 + + +def test_reserve_missing_pipeline_falls_back_without_creating(feedback, tmp_path): + base = feedback._reserve_pipeline_job_numbers(str(tmp_path), 2) + + assert base == feedback.FIRST_FEEDBACK_JOB + assert not (tmp_path / "default_pipeline.star").exists() + + +def test_reserve_rejects_non_positive_block(feedback, tmp_path): + _make_pipeline(tmp_path, job_counter=10) + + with pytest.raises(ValueError): + feedback._reserve_pipeline_job_numbers(str(tmp_path), 0) + + +@pytest.mark.parametrize( + "icebreaker, first_block, combine_offset, subsequent_block", + [ + # no icebreaker: Class2D + autoselect + (combine on first batch only) + (False, 3, 2, 2), + # icebreaker: Class2D + IceBreaker + autoselect + (combine on first batch) + (True, 4, 3, 3), + ], +) +def test_reserve_2d_classification_block( + feedback, + tmp_path, + icebreaker, + first_block, + combine_offset, + subsequent_block, +): + _make_pipeline(tmp_path, job_counter=10) + fp = SimpleNamespace(star_combination_job=0, next_job=0) + + with mock.patch.object( + feedback.default_spa_parameters, "do_icebreaker_jobs", icebreaker + ): + # First batch reserves Class2D (+IceBreaker) + autoselect + shared combine. + base1 = feedback._reserve_2d_classification_jobs(str(tmp_path), fp) + assert base1 == 10 + assert fp.next_job == 10 + assert fp.star_combination_job == 10 + combine_offset + assert _read_counter(tmp_path) == 10 + first_block + + # select_classes places autoselect at class2d + (2 if icebreaker else 1), + # which must equal combine - 1 so the layout stays contiguous. + autoselect = base1 + (2 if icebreaker else 1) + assert autoselect == fp.star_combination_job - 1 + + # Second batch: combine already exists (shared) → only Class2D + autoselect. + combine_before = fp.star_combination_job + base2 = feedback._reserve_2d_classification_jobs(str(tmp_path), fp) + assert base2 == 10 + first_block # strictly after the first block + assert fp.star_combination_job == combine_before # combine unchanged + assert _read_counter(tmp_path) == base2 + subsequent_block